/** * @file CRSplineReg.cpp * @brief Implementation of Catmull-Rom-Splines for regression purposes * @author Frank Prüfer * @date 09/03/2013 */ #include #include "vislearning/regression/splineregression/CRSplineReg.h" #include "vislearning/math/mathbase/FullVector.h" using namespace OBJREC; using namespace std; using namespace NICE; CRSplineReg::CRSplineReg ( ) { tau = 0.5; } CRSplineReg::CRSplineReg ( const CRSplineReg & src ) : RegressionAlgorithm ( src ) { tau = src.tau; dataSet = src.dataSet; labelSet = src.labelSet; } CRSplineReg::~CRSplineReg() { } void CRSplineReg::teach ( const NICE::VVector & _dataSet, const NICE::Vector & _labelSet) { fprintf (stderr, "teach using all !\n"); //NOTE this is crucial if we clear _teachSet afterwards! //therefore, take care NOT to call _techSet.clear() somewhere out of this method this->dataSet = _dataSet; this->labelSet = _labelSet.std_vector(); std::cerr << "number of known training samples: " << this->dataSet.size() << std::endl; } void CRSplineReg::teach ( const NICE::Vector & x, const double & y ) { std::cerr << "CRSplineReg::teach one new example" << std::endl; for ( size_t i = 0 ; i < x.size() ; i++ ) if ( isnan(x[i]) ) { fprintf (stderr, "There is a NAN value in within this vector: x[%d] = %f\n", (int)i, x[i]); cerr << x << endl; exit(-1); } dataSet.push_back ( x ); labelSet.push_back ( y ); std::cerr << "number of known training samples: " << dataSet.size()<< std::endl; } double CRSplineReg::predict ( const NICE::Vector & x ) { if ( dataSet.size() <= 0 ) { fprintf (stderr, "CRSplineReg: please use the train method first\n"); exit(-1); } if ( dataSet[0].size() == 1 ){ //one-dimensional case FullVector data ( dataSet.size()+1 ); for ( uint i = 0; i < dataSet.size(); i++ ){ data[i] = dataSet[i][0]; } cerr<<"data x: "< ind; data.getSortedIndices(ind); int index; for ( uint i = 0; i < ind.size(); i++ ){ if ( ind[i] == dataSet.size() ){ index = i; break; } } NICE::Matrix points (4,2,0.0); if ( index >= 2 && index < (ind.size() - 2) ){ //everything is okay points(0,0) = data[ind[index-2]]; points(0,1) = labelSet[ind[index-2]]; points(1,0) = data[ind[index-1]]; points(1,1) = labelSet[ind[index-1]]; points(2,0) = data[ind[index+1]]; points(2,1) = labelSet[ind[index+1]]; points(3,0) = data[ind[index+2]]; points(3,1) = labelSet[ind[index+2]]; } else if ( index == 1 ){ //just one point left from x points(0,0) = data[ind[index-1]]; points(0,1) = labelSet[ind[index-1]]; points(1,0) = data[ind[index-1]]; points(1,1) = labelSet[ind[index-1]]; points(2,0) = data[ind[index+1]]; points(2,1) = labelSet[ind[index+1]]; points(3,0) = data[ind[index+2]]; points(3,1) = labelSet[ind[index+2]]; } else if ( index == 0 ){ //x is the farthest left point //do linear approximation } else if ( index == (ind.size() - 2) ){ //just one point right from x points(0,0) = data[ind[index-2]]; points(0,1) = labelSet[ind[index-2]]; points(1,0) = data[ind[index-1]]; points(1,1) = labelSet[ind[index-1]]; points(2,0) = data[ind[index+1]]; points(2,1) = labelSet[ind[index+1]]; points(3,0) = data[ind[index+1]]; points(3,1) = labelSet[ind[index+1]]; } else if ( index == (ind.size() - 1) ){ //x is the farthest right point //do linear approximation } double t = (x[0] - points(1,0)) / (points(2,0) - points(1,0)); cerr<<"t: "<