CRSplineReg.cpp 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. /**
  2. * @file CRSplineReg.cpp
  3. * @brief Implementation of Catmull-Rom-Splines for regression purposes
  4. * @author Frank Prüfer
  5. * @date 09/03/2013
  6. */
  7. #ifdef NICE_USELIB_OPENMP
  8. #include <omp.h>
  9. #endif
  10. #include <iostream>
  11. #include "vislearning/regression/splineregression/CRSplineReg.h"
  12. #include "vislearning/regression/linregression/LinRegression.h"
  13. #include "vislearning/math/mathbase/FullVector.h"
  14. using namespace OBJREC;
  15. using namespace std;
  16. using namespace NICE;
  17. CRSplineReg::CRSplineReg ( const NICE::Config *_conf )
  18. {
  19. tau = _conf->gD("CRSplineReg","tau",0.5);
  20. sortDim = _conf->gI("CRSplineReg","sortDim",0);
  21. }
  22. CRSplineReg::CRSplineReg ( uint sDim )
  23. {
  24. sortDim = sDim;
  25. }
  26. CRSplineReg::CRSplineReg ( const CRSplineReg & src ) : RegressionAlgorithm ( src )
  27. {
  28. tau = src.tau;
  29. dataSet = src.dataSet;
  30. labelSet = src.labelSet;
  31. sortDim = src.sortDim;
  32. }
  33. CRSplineReg::~CRSplineReg()
  34. {
  35. }
  36. CRSplineReg* CRSplineReg::clone ( void ) const
  37. {
  38. return new CRSplineReg(*this);
  39. }
  40. void CRSplineReg::teach ( const NICE::VVector & _dataSet, const NICE::Vector & _labelSet)
  41. {
  42. fprintf (stderr, "teach using all !\n");
  43. //NOTE this is crucial if we clear _teachSet afterwards!
  44. //therefore, take care NOT to call _techSet.clear() somewhere out of this method
  45. this->dataSet = _dataSet;
  46. this->labelSet = _labelSet.std_vector();
  47. std::cerr << "number of known training samples: " << this->dataSet.size() << std::endl;
  48. }
  49. void CRSplineReg::teach ( const NICE::Vector & x, const double & y )
  50. {
  51. std::cerr << "CRSplineReg::teach one new example" << std::endl;
  52. for ( size_t i = 0 ; i < x.size() ; i++ )
  53. if ( isnan(x[i]) )
  54. {
  55. fprintf (stderr, "There is a NAN value in within this vector: x[%d] = %f\n", (int)i, x[i]);
  56. cerr << x << endl;
  57. exit(-1);
  58. }
  59. dataSet.push_back ( x );
  60. labelSet.push_back ( y );
  61. std::cerr << "number of known training samples: " << dataSet.size()<< std::endl;
  62. }
  63. double CRSplineReg::predict ( const NICE::Vector & x )
  64. {
  65. if ( dataSet.size() <= 0 ) {
  66. fprintf (stderr, "CRSplineReg: please use the train method first\n");
  67. exit(-1);
  68. }
  69. int dimension = dataSet[0].size();
  70. FullVector data ( dataSet.size()+1 );
  71. #pragma omp parallel for
  72. for ( uint i = 0; i < dataSet.size(); i++ ){
  73. data[i] = dataSet[i][sortDim];
  74. }
  75. data[dataSet.size()] = x[sortDim];
  76. std::vector<int> sortedInd;
  77. data.getSortedIndices(sortedInd);
  78. int index;
  79. for ( uint i = 0; i < sortedInd.size(); i++ ){
  80. if ( sortedInd[i] == (int)dataSet.size() ){
  81. index = i;
  82. break;
  83. }
  84. }
  85. NICE::Matrix points (4,dimension+1,0.0);
  86. if ( index >= 2 && index < (int)(sortedInd.size() - 2) ){ //everything is okay
  87. points.setRow(0,dataSet[sortedInd[index-2]]);
  88. points(0,dimension) = labelSet[sortedInd[index-2]];
  89. points.setRow(1,dataSet[sortedInd[index-1]]);
  90. points(1,dimension) = labelSet[sortedInd[index-1]];
  91. points.setRow(2,dataSet[sortedInd[index+1]]);
  92. points(2,dimension) = labelSet[sortedInd[index+1]];
  93. points.setRow(3,dataSet[sortedInd[index+2]]);
  94. points(3,dimension) = labelSet[sortedInd[index+2]];
  95. }
  96. else if ( index == 1 ){ //just one point left from x
  97. points.setRow(0,dataSet[sortedInd[index-1]]);
  98. points(0,dimension) = labelSet[sortedInd[index-1]];
  99. points.setRow(1,dataSet[sortedInd[index-1]]);
  100. points(1,dimension) = labelSet[sortedInd[index-1]];
  101. points.setRow(2,dataSet[sortedInd[index+1]]);
  102. points(2,dimension) = labelSet[sortedInd[index+1]];
  103. points.setRow(3,dataSet[sortedInd[index+2]]);
  104. points(3,dimension) = labelSet[sortedInd[index+2]];
  105. }
  106. else if ( index == 0 ){ //x is the farthest left point
  107. points.setRow(0,dataSet[sortedInd[index+1]]);
  108. points(0,dimension) = labelSet[sortedInd[index+1]];
  109. points.setRow(1,dataSet[sortedInd[index+1]]);
  110. points(1,dimension) = labelSet[sortedInd[index+1]];
  111. points.setRow(2,dataSet[sortedInd[index+1]]);
  112. points(2,dimension) = labelSet[sortedInd[index+1]];
  113. points.setRow(3,dataSet[sortedInd[index+2]]);
  114. points(3,dimension) = labelSet[sortedInd[index+2]];
  115. }
  116. else if ( index == (int)(sortedInd.size() - 2) ){ //just one point right from x
  117. points.setRow(0,dataSet[sortedInd[index-2]]);
  118. points(0,dimension) = labelSet[sortedInd[index-2]];
  119. points.setRow(1,dataSet[sortedInd[index-1]]);
  120. points(1,dimension) = labelSet[sortedInd[index-1]];
  121. points.setRow(2,dataSet[sortedInd[index+1]]);
  122. points(2,dimension) = labelSet[sortedInd[index+1]];
  123. points.setRow(3,dataSet[sortedInd[index+1]]);
  124. points(3,dimension) = labelSet[sortedInd[index+1]];
  125. }
  126. else if ( index == (int)(sortedInd.size() - 1) ){ //x is the farthest right point
  127. points.setRow(0,dataSet[sortedInd[index-2]]);
  128. points(0,dimension) = labelSet[sortedInd[index-2]];
  129. points.setRow(1,dataSet[sortedInd[index-1]]);
  130. points(1,dimension) = labelSet[sortedInd[index-1]];
  131. points.setRow(2,dataSet[sortedInd[index-1]]);
  132. points(2,dimension) = labelSet[sortedInd[index-1]];
  133. points.setRow(3,dataSet[sortedInd[index-1]]);
  134. points(3,dimension) = labelSet[sortedInd[index-1]];
  135. }
  136. double t = (x[sortDim]-points(1,sortDim)) / (points(2,sortDim)-points(1,sortDim)); //this is just some kind of heuristic
  137. if ( t != t || t < 0 || t > 1){ //check if t is NAN, -inf or inf (happens in the farthest right or left case from above)
  138. t = 0.5;
  139. }
  140. //P(t) = b0*P0 + b1*P1 + b2*P2 + b3*P3
  141. NICE::Vector P(dimension);
  142. double y;
  143. double b0,b1,b2,b3;
  144. b0 = tau * (-(t*t*t) + 2*t*t - t);
  145. b1 = tau * (3*t*t*t - 5*t*t + 2);
  146. b2 = tau * (-3*t*t*t + 4*t*t + t);
  147. b3 = tau * (t*t*t - t*t);
  148. #pragma omp parallel for
  149. for ( uint i = 0; i < (uint)dimension; i++ ){
  150. P[i] = b0*points(0,i) + b1*points(1,i) + b2*points(2,i) + b3*points(3,i);
  151. }
  152. double diff1 = (P-x).normL2();
  153. uint counter = 1;
  154. while ( diff1 > 1e-5 && counter <= 21){ //adjust t to fit data better
  155. double tmp = t;;
  156. if (tmp > 0.5)
  157. tmp = 1 - tmp;
  158. t += tmp/counter;
  159. b0 = tau * (-(t*t*t) + 2*t*t - t);
  160. b1 = tau * (3*t*t*t - 5*t*t + 2);
  161. b2 = tau * (-3*t*t*t + 4*t*t + t);
  162. b3 = tau * (t*t*t - t*t);
  163. for ( uint i = 0; i < (uint)dimension; i++ ){
  164. P[i] = b0*points(0,i) + b1*points(1,i) + b2*points(2,i) + b3*points(3,i);
  165. }
  166. double diff2 = (P-x).normL2();
  167. if ( diff2 > diff1 && t > 0) {
  168. t -= 2*tmp/counter;
  169. b0 = tau * (-(t*t*t) + 2*t*t - t);
  170. b1 = tau * (3*t*t*t - 5*t*t + 2);
  171. b2 = tau * (-3*t*t*t + 4*t*t + t);
  172. b3 = tau * (t*t*t - t*t);
  173. #pragma omp parallel for
  174. for ( uint i = 0; i < (uint)dimension; i++ ){
  175. P[i] = b0*points(0,i) + b1*points(1,i) + b2*points(2,i) + b3*points(3,i);
  176. }
  177. diff1 = (P-x).normL2();
  178. }
  179. counter++;
  180. }
  181. y = b0*points(0,dimension) + b1*points(1,dimension) + b2*points(2,dimension) + b3*points(3,dimension);
  182. return y;
  183. }