/** * @file CRSplineReg.cpp * @brief Implementation of Catmull-Rom-Splines for regression purposes * @author Frank Prüfer * @date 09/03/2013 */ #ifdef NICE_USELIB_OPENMP #include #endif #include #include "vislearning/regression/splineregression/CRSplineReg.h" #include "vislearning/regression/linregression/LinRegression.h" #include "vislearning/math/mathbase/FullVector.h" using namespace OBJREC; using namespace std; using namespace NICE; CRSplineReg::CRSplineReg ( const NICE::Config *_conf ) { tau = _conf->gD("CRSplineReg","tau",0.5); sortDim = _conf->gI("CRSplineReg","sortDim",0); } CRSplineReg::CRSplineReg ( uint sDim ) { sortDim = sDim; } CRSplineReg::CRSplineReg ( const CRSplineReg & src ) : RegressionAlgorithm ( src ) { tau = src.tau; dataSet = src.dataSet; labelSet = src.labelSet; sortDim = src.sortDim; } CRSplineReg::~CRSplineReg() { } CRSplineReg* CRSplineReg::clone ( void ) const { return new CRSplineReg(*this); } void CRSplineReg::teach ( const NICE::VVector & _dataSet, const NICE::Vector & _labelSet) { fprintf (stderr, "teach using all !\n"); //NOTE this is crucial if we clear _teachSet afterwards! //therefore, take care NOT to call _techSet.clear() somewhere out of this method this->dataSet = _dataSet; this->labelSet = _labelSet.std_vector(); std::cerr << "number of known training samples: " << this->dataSet.size() << std::endl; } void CRSplineReg::teach ( const NICE::Vector & x, const double & y ) { std::cerr << "CRSplineReg::teach one new example" << std::endl; for ( size_t i = 0 ; i < x.size() ; i++ ) if ( isnan(x[i]) ) { fprintf (stderr, "There is a NAN value in within this vector: x[%d] = %f\n", (int)i, x[i]); cerr << x << endl; exit(-1); } dataSet.push_back ( x ); labelSet.push_back ( y ); std::cerr << "number of known training samples: " << dataSet.size()<< std::endl; } double CRSplineReg::predict ( const NICE::Vector & x ) { if ( dataSet.size() <= 0 ) { fprintf (stderr, "CRSplineReg: please use the train method first\n"); exit(-1); } int dimension = dataSet[0].size(); FullVector data ( dataSet.size()+1 ); #pragma omp parallel for for ( uint i = 0; i < dataSet.size(); i++ ){ data[i] = dataSet[i][sortDim]; } data[dataSet.size()] = x[sortDim]; std::vector sortedInd; data.getSortedIndices(sortedInd); int index; for ( uint i = 0; i < sortedInd.size(); i++ ){ if ( sortedInd[i] == (int)dataSet.size() ){ index = i; break; } } NICE::Matrix points (4,dimension+1,0.0); if ( index >= 2 && index < (int)(sortedInd.size() - 2) ){ //everything is okay points.setRow(0,dataSet[sortedInd[index-2]]); points(0,dimension) = labelSet[sortedInd[index-2]]; points.setRow(1,dataSet[sortedInd[index-1]]); points(1,dimension) = labelSet[sortedInd[index-1]]; points.setRow(2,dataSet[sortedInd[index+1]]); points(2,dimension) = labelSet[sortedInd[index+1]]; points.setRow(3,dataSet[sortedInd[index+2]]); points(3,dimension) = labelSet[sortedInd[index+2]]; } else if ( index == 1 ){ //just one point left from x points.setRow(0,dataSet[sortedInd[index-1]]); points(0,dimension) = labelSet[sortedInd[index-1]]; points.setRow(1,dataSet[sortedInd[index-1]]); points(1,dimension) = labelSet[sortedInd[index-1]]; points.setRow(2,dataSet[sortedInd[index+1]]); points(2,dimension) = labelSet[sortedInd[index+1]]; points.setRow(3,dataSet[sortedInd[index+2]]); points(3,dimension) = labelSet[sortedInd[index+2]]; } else if ( index == 0 ){ //x is the farthest left point points.setRow(0,dataSet[sortedInd[index+1]]); points(0,dimension) = labelSet[sortedInd[index+1]]; points.setRow(1,dataSet[sortedInd[index+1]]); points(1,dimension) = labelSet[sortedInd[index+1]]; points.setRow(2,dataSet[sortedInd[index+1]]); points(2,dimension) = labelSet[sortedInd[index+1]]; points.setRow(3,dataSet[sortedInd[index+2]]); points(3,dimension) = labelSet[sortedInd[index+2]]; } else if ( index == (int)(sortedInd.size() - 2) ){ //just one point right from x points.setRow(0,dataSet[sortedInd[index-2]]); points(0,dimension) = labelSet[sortedInd[index-2]]; points.setRow(1,dataSet[sortedInd[index-1]]); points(1,dimension) = labelSet[sortedInd[index-1]]; points.setRow(2,dataSet[sortedInd[index+1]]); points(2,dimension) = labelSet[sortedInd[index+1]]; points.setRow(3,dataSet[sortedInd[index+1]]); points(3,dimension) = labelSet[sortedInd[index+1]]; } else if ( index == (int)(sortedInd.size() - 1) ){ //x is the farthest right point points.setRow(0,dataSet[sortedInd[index-2]]); points(0,dimension) = labelSet[sortedInd[index-2]]; points.setRow(1,dataSet[sortedInd[index-1]]); points(1,dimension) = labelSet[sortedInd[index-1]]; points.setRow(2,dataSet[sortedInd[index-1]]); points(2,dimension) = labelSet[sortedInd[index-1]]; points.setRow(3,dataSet[sortedInd[index-1]]); points(3,dimension) = labelSet[sortedInd[index-1]]; } double t = (x[sortDim]-points(1,sortDim)) / (points(2,sortDim)-points(1,sortDim)); //this is just some kind of heuristic if ( t != t || t < 0 || t > 1){ //check if t is NAN, -inf or inf (happens in the farthest right or left case from above) t = 0.5; } //P(t) = b0*P0 + b1*P1 + b2*P2 + b3*P3 NICE::Vector P(dimension); double y; double b0,b1,b2,b3; b0 = tau * (-(t*t*t) + 2*t*t - t); b1 = tau * (3*t*t*t - 5*t*t + 2); b2 = tau * (-3*t*t*t + 4*t*t + t); b3 = tau * (t*t*t - t*t); #pragma omp parallel for for ( uint i = 0; i < (uint)dimension; i++ ){ P[i] = b0*points(0,i) + b1*points(1,i) + b2*points(2,i) + b3*points(3,i); } double diff1 = (P-x).normL2(); uint counter = 1; while ( diff1 > 1e-5 && counter <= 21){ //adjust t to fit data better double tmp = t;; if (tmp > 0.5) tmp = 1 - tmp; t += tmp/counter; b0 = tau * (-(t*t*t) + 2*t*t - t); b1 = tau * (3*t*t*t - 5*t*t + 2); b2 = tau * (-3*t*t*t + 4*t*t + t); b3 = tau * (t*t*t - t*t); for ( uint i = 0; i < (uint)dimension; i++ ){ P[i] = b0*points(0,i) + b1*points(1,i) + b2*points(2,i) + b3*points(3,i); } double diff2 = (P-x).normL2(); if ( diff2 > diff1 && t > 0) { t -= 2*tmp/counter; b0 = tau * (-(t*t*t) + 2*t*t - t); b1 = tau * (3*t*t*t - 5*t*t + 2); b2 = tau * (-3*t*t*t + 4*t*t + t); b3 = tau * (t*t*t - t*t); #pragma omp parallel for for ( uint i = 0; i < (uint)dimension; i++ ){ P[i] = b0*points(0,i) + b1*points(1,i) + b2*points(2,i) + b3*points(3,i); } diff1 = (P-x).normL2(); } counter++; } y = b0*points(0,dimension) + b1*points(1,dimension) + b2*points(2,dimension) + b3*points(3,dimension); return y; }