CRSplineReg.cpp 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. /**
  2. * @file CRSplineReg.cpp
  3. * @brief Implementation of Catmull-Rom-Splines for regression purposes
  4. * @author Frank Prüfer
  5. * @date 09/03/2013
  6. */
  7. #ifdef NICE_USELIB_OPENMP
  8. #include <omp.h>
  9. #endif
  10. #include <iostream>
  11. #include "vislearning/regression/splineregression/CRSplineReg.h"
  12. #include "vislearning/regression/linregression/LinRegression.h"
  13. #include "vislearning/math/mathbase/FullVector.h"
  14. using namespace OBJREC;
  15. using namespace std;
  16. using namespace NICE;
  17. CRSplineReg::CRSplineReg ( const NICE::Config *_conf )
  18. {
  19. tau = _conf->gD("CRSplineReg","tau",0.5);
  20. sortDim = _conf->gI("CRSplineReg","sortDim",0);
  21. }
  22. CRSplineReg::CRSplineReg ( uint sDim )
  23. {
  24. sortDim = sDim;
  25. }
  26. CRSplineReg::CRSplineReg ( const CRSplineReg & src ) : RegressionAlgorithm ( src )
  27. {
  28. tau = src.tau;
  29. dataSet = src.dataSet;
  30. labelSet = src.labelSet;
  31. sortDim = src.sortDim;
  32. }
  33. CRSplineReg::~CRSplineReg()
  34. {
  35. }
  36. void CRSplineReg::teach ( const NICE::VVector & _dataSet, const NICE::Vector & _labelSet)
  37. {
  38. fprintf (stderr, "teach using all !\n");
  39. //NOTE this is crucial if we clear _teachSet afterwards!
  40. //therefore, take care NOT to call _techSet.clear() somewhere out of this method
  41. this->dataSet = _dataSet;
  42. this->labelSet = _labelSet.std_vector();
  43. std::cerr << "number of known training samples: " << this->dataSet.size() << std::endl;
  44. }
  45. void CRSplineReg::teach ( const NICE::Vector & x, const double & y )
  46. {
  47. std::cerr << "CRSplineReg::teach one new example" << std::endl;
  48. for ( size_t i = 0 ; i < x.size() ; i++ )
  49. if ( isnan(x[i]) )
  50. {
  51. fprintf (stderr, "There is a NAN value in within this vector: x[%d] = %f\n", (int)i, x[i]);
  52. cerr << x << endl;
  53. exit(-1);
  54. }
  55. dataSet.push_back ( x );
  56. labelSet.push_back ( y );
  57. std::cerr << "number of known training samples: " << dataSet.size()<< std::endl;
  58. }
  59. double CRSplineReg::predict ( const NICE::Vector & x )
  60. {
  61. if ( dataSet.size() <= 0 ) {
  62. fprintf (stderr, "CRSplineReg: please use the train method first\n");
  63. exit(-1);
  64. }
  65. int dimension = dataSet[0].size();
  66. FullVector data ( dataSet.size()+1 );
  67. #pragma omp parallel for
  68. for ( uint i = 0; i < dataSet.size(); i++ ){
  69. data[i] = dataSet[i][sortDim];
  70. }
  71. data[dataSet.size()] = x[sortDim];
  72. std::vector<int> sortedInd;
  73. data.getSortedIndices(sortedInd);
  74. int index;
  75. for ( uint i = 0; i < sortedInd.size(); i++ ){
  76. if ( sortedInd[i] == dataSet.size() ){
  77. index = i;
  78. break;
  79. }
  80. }
  81. NICE::Matrix points (4,dimension+1,0.0);
  82. if ( index >= 2 && index < (sortedInd.size() - 2) ){ //everything is okay
  83. points.setRow(0,dataSet[sortedInd[index-2]]);
  84. points(0,dimension) = labelSet[sortedInd[index-2]];
  85. points.setRow(1,dataSet[sortedInd[index-1]]);
  86. points(1,dimension) = labelSet[sortedInd[index-1]];
  87. points.setRow(2,dataSet[sortedInd[index+1]]);
  88. points(2,dimension) = labelSet[sortedInd[index+1]];
  89. points.setRow(3,dataSet[sortedInd[index+2]]);
  90. points(3,dimension) = labelSet[sortedInd[index+2]];
  91. }
  92. else if ( index == 1 ){ //just one point left from x
  93. points.setRow(0,dataSet[sortedInd[index-1]]);
  94. points(0,dimension) = labelSet[sortedInd[index-1]];
  95. points.setRow(1,dataSet[sortedInd[index-1]]);
  96. points(1,dimension) = labelSet[sortedInd[index-1]];
  97. points.setRow(2,dataSet[sortedInd[index+1]]);
  98. points(2,dimension) = labelSet[sortedInd[index+1]];
  99. points.setRow(3,dataSet[sortedInd[index+2]]);
  100. points(3,dimension) = labelSet[sortedInd[index+2]];
  101. }
  102. else if ( index == 0 ){ //x is the farthest left point
  103. points.setRow(0,dataSet[sortedInd[index+1]]);
  104. points(0,dimension) = labelSet[sortedInd[index+1]];
  105. points.setRow(1,dataSet[sortedInd[index+1]]);
  106. points(1,dimension) = labelSet[sortedInd[index+1]];
  107. points.setRow(2,dataSet[sortedInd[index+1]]);
  108. points(2,dimension) = labelSet[sortedInd[index+1]];
  109. points.setRow(3,dataSet[sortedInd[index+2]]);
  110. points(3,dimension) = labelSet[sortedInd[index+2]];
  111. }
  112. else if ( index == (sortedInd.size() - 2) ){ //just one point right from x
  113. points.setRow(0,dataSet[sortedInd[index-2]]);
  114. points(0,dimension) = labelSet[sortedInd[index-2]];
  115. points.setRow(1,dataSet[sortedInd[index-1]]);
  116. points(1,dimension) = labelSet[sortedInd[index-1]];
  117. points.setRow(2,dataSet[sortedInd[index+1]]);
  118. points(2,dimension) = labelSet[sortedInd[index+1]];
  119. points.setRow(3,dataSet[sortedInd[index+1]]);
  120. points(3,dimension) = labelSet[sortedInd[index+1]];
  121. }
  122. else if ( index == (sortedInd.size() - 1) ){ //x is the farthest right point
  123. points.setRow(0,dataSet[sortedInd[index-2]]);
  124. points(0,dimension) = labelSet[sortedInd[index-2]];
  125. points.setRow(1,dataSet[sortedInd[index-1]]);
  126. points(1,dimension) = labelSet[sortedInd[index-1]];
  127. points.setRow(2,dataSet[sortedInd[index-1]]);
  128. points(2,dimension) = labelSet[sortedInd[index-1]];
  129. points.setRow(3,dataSet[sortedInd[index-1]]);
  130. points(3,dimension) = labelSet[sortedInd[index-1]];
  131. }
  132. double t = (x[sortDim]-points(1,sortDim)) / (points(2,sortDim)-points(1,sortDim)); //this is just some kind of heuristic
  133. if ( t != t || t < 0 || t > 1){ //check if t is NAN, -inf or inf (happens in the farthest right or left case from above)
  134. t = 0.5;
  135. }
  136. //P(t) = b0*P0 + b1*P1 + b2*P2 + b3*P3
  137. NICE::Vector P(dimension);
  138. double y;
  139. double b0,b1,b2,b3;
  140. b0 = tau * (-(t*t*t) + 2*t*t - t);
  141. b1 = tau * (3*t*t*t - 5*t*t + 2);
  142. b2 = tau * (-3*t*t*t + 4*t*t + t);
  143. b3 = tau * (t*t*t - t*t);
  144. #pragma omp parallel for
  145. for ( uint i = 0; i < dimension; i++ ){
  146. P[i] = b0*points(0,i) + b1*points(1,i) + b2*points(2,i) + b3*points(3,i);
  147. }
  148. double diff1 = (P-x).normL2();
  149. uint counter = 1;
  150. while ( diff1 > 1e-5 && counter <= 21){ //adjust t to fit data better
  151. double tmp = t;;
  152. if (tmp > 0.5)
  153. tmp = 1 - tmp;
  154. t += tmp/counter;
  155. b0 = tau * (-(t*t*t) + 2*t*t - t);
  156. b1 = tau * (3*t*t*t - 5*t*t + 2);
  157. b2 = tau * (-3*t*t*t + 4*t*t + t);
  158. b3 = tau * (t*t*t - t*t);
  159. for ( uint i = 0; i < dimension; i++ ){
  160. P[i] = b0*points(0,i) + b1*points(1,i) + b2*points(2,i) + b3*points(3,i);
  161. }
  162. double diff2 = (P-x).normL2();
  163. if ( diff2 > diff1 && t > 0) {
  164. t -= 2*tmp/counter;
  165. b0 = tau * (-(t*t*t) + 2*t*t - t);
  166. b1 = tau * (3*t*t*t - 5*t*t + 2);
  167. b2 = tau * (-3*t*t*t + 4*t*t + t);
  168. b3 = tau * (t*t*t - t*t);
  169. #pragma omp parallel for
  170. for ( uint i = 0; i < dimension; i++ ){
  171. P[i] = b0*points(0,i) + b1*points(1,i) + b2*points(2,i) + b3*points(3,i);
  172. }
  173. diff1 = (P-x).normL2();
  174. }
  175. counter++;
  176. }
  177. y = b0*points(0,dimension) + b1*points(1,dimension) + b2*points(2,dimension) + b3*points(3,dimension);
  178. return y;
  179. }