CRSplineReg.cpp 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. /**
  2. * @file CRSplineReg.cpp
  3. * @brief Implementation of Catmull-Rom-Splines for regression purposes
  4. * @author Frank Prüfer
  5. * @date 09/03/2013
  6. */
  7. #ifdef NICE_USELIB_OPENMP
  8. #include <omp.h>
  9. #endif
  10. #include <iostream>
  11. #include "vislearning/regression/splineregression/CRSplineReg.h"
  12. #include "vislearning/regression/linregression/LinRegression.h"
  13. #include "vislearning/math/mathbase/FullVector.h"
  14. using namespace OBJREC;
  15. using namespace NICE;
  16. CRSplineReg::CRSplineReg ( const NICE::Config *_conf )
  17. {
  18. tau = _conf->gD("CRSplineReg","tau",0.5);
  19. sortDim = _conf->gI("CRSplineReg","sortDim",0);
  20. }
  21. CRSplineReg::CRSplineReg ( uint sDim )
  22. {
  23. sortDim = sDim;
  24. }
  25. CRSplineReg::CRSplineReg ( const CRSplineReg & src ) : RegressionAlgorithm ( src )
  26. {
  27. tau = src.tau;
  28. dataSet = src.dataSet;
  29. labelSet = src.labelSet;
  30. sortDim = src.sortDim;
  31. }
  32. CRSplineReg::~CRSplineReg()
  33. {
  34. }
  35. CRSplineReg* CRSplineReg::clone ( void ) const
  36. {
  37. return new CRSplineReg(*this);
  38. }
  39. void CRSplineReg::teach ( const NICE::VVector & _dataSet, const NICE::Vector & _labelSet)
  40. {
  41. fprintf (stderr, "teach using all !\n");
  42. //NOTE this is crucial if we clear _teachSet afterwards!
  43. //therefore, take care NOT to call _techSet.clear() somewhere out of this method
  44. this->dataSet = _dataSet;
  45. this->labelSet = _labelSet.std_vector();
  46. std::cerr << "number of known training samples: " << this->dataSet.size() << std::endl;
  47. }
  48. void CRSplineReg::teach ( const NICE::Vector & x, const double & y )
  49. {
  50. std::cerr << "CRSplineReg::teach one new example" << std::endl;
  51. for ( size_t i = 0 ; i < x.size() ; i++ )
  52. if ( isnan(x[i]) )
  53. {
  54. fprintf (stderr, "There is a NAN value in within this vector: x[%d] = %f\n", (int)i, x[i]);
  55. std::cerr << x << std::endl;
  56. exit(-1);
  57. }
  58. dataSet.push_back ( x );
  59. labelSet.push_back ( y );
  60. std::cerr << "number of known training samples: " << dataSet.size()<< std::endl;
  61. }
  62. double CRSplineReg::predict ( const NICE::Vector & x )
  63. {
  64. if ( dataSet.size() <= 0 ) {
  65. fprintf (stderr, "CRSplineReg: please use the train method first\n");
  66. exit(-1);
  67. }
  68. int dimension = dataSet[0].size();
  69. FullVector data ( dataSet.size()+1 );
  70. #pragma omp parallel for
  71. for ( uint i = 0; i < dataSet.size(); i++ ){
  72. data[i] = dataSet[i][sortDim];
  73. }
  74. data[dataSet.size()] = x[sortDim];
  75. std::vector<int> sortedInd;
  76. data.getSortedIndices(sortedInd);
  77. int index;
  78. for ( uint i = 0; i < sortedInd.size(); i++ ){
  79. if ( sortedInd[i] == (int)dataSet.size() ){
  80. index = i;
  81. break;
  82. }
  83. }
  84. NICE::Matrix points (4,dimension+1,0.0);
  85. if ( index >= 2 && index < (int)(sortedInd.size() - 2) ){ //everything is okay
  86. points.setRow(0,dataSet[sortedInd[index-2]]);
  87. points(0,dimension) = labelSet[sortedInd[index-2]];
  88. points.setRow(1,dataSet[sortedInd[index-1]]);
  89. points(1,dimension) = labelSet[sortedInd[index-1]];
  90. points.setRow(2,dataSet[sortedInd[index+1]]);
  91. points(2,dimension) = labelSet[sortedInd[index+1]];
  92. points.setRow(3,dataSet[sortedInd[index+2]]);
  93. points(3,dimension) = labelSet[sortedInd[index+2]];
  94. }
  95. else if ( index == 1 ){ //just one point left from x
  96. points.setRow(0,dataSet[sortedInd[index-1]]);
  97. points(0,dimension) = labelSet[sortedInd[index-1]];
  98. points.setRow(1,dataSet[sortedInd[index-1]]);
  99. points(1,dimension) = labelSet[sortedInd[index-1]];
  100. points.setRow(2,dataSet[sortedInd[index+1]]);
  101. points(2,dimension) = labelSet[sortedInd[index+1]];
  102. points.setRow(3,dataSet[sortedInd[index+2]]);
  103. points(3,dimension) = labelSet[sortedInd[index+2]];
  104. }
  105. else if ( index == 0 ){ //x is the farthest left point
  106. points.setRow(0,dataSet[sortedInd[index+1]]);
  107. points(0,dimension) = labelSet[sortedInd[index+1]];
  108. points.setRow(1,dataSet[sortedInd[index+1]]);
  109. points(1,dimension) = labelSet[sortedInd[index+1]];
  110. points.setRow(2,dataSet[sortedInd[index+1]]);
  111. points(2,dimension) = labelSet[sortedInd[index+1]];
  112. points.setRow(3,dataSet[sortedInd[index+2]]);
  113. points(3,dimension) = labelSet[sortedInd[index+2]];
  114. }
  115. else if ( index == (int)(sortedInd.size() - 2) ){ //just one point right from x
  116. points.setRow(0,dataSet[sortedInd[index-2]]);
  117. points(0,dimension) = labelSet[sortedInd[index-2]];
  118. points.setRow(1,dataSet[sortedInd[index-1]]);
  119. points(1,dimension) = labelSet[sortedInd[index-1]];
  120. points.setRow(2,dataSet[sortedInd[index+1]]);
  121. points(2,dimension) = labelSet[sortedInd[index+1]];
  122. points.setRow(3,dataSet[sortedInd[index+1]]);
  123. points(3,dimension) = labelSet[sortedInd[index+1]];
  124. }
  125. else if ( index == (int)(sortedInd.size() - 1) ){ //x is the farthest right point
  126. points.setRow(0,dataSet[sortedInd[index-2]]);
  127. points(0,dimension) = labelSet[sortedInd[index-2]];
  128. points.setRow(1,dataSet[sortedInd[index-1]]);
  129. points(1,dimension) = labelSet[sortedInd[index-1]];
  130. points.setRow(2,dataSet[sortedInd[index-1]]);
  131. points(2,dimension) = labelSet[sortedInd[index-1]];
  132. points.setRow(3,dataSet[sortedInd[index-1]]);
  133. points(3,dimension) = labelSet[sortedInd[index-1]];
  134. }
  135. double t = (x[sortDim]-points(1,sortDim)) / (points(2,sortDim)-points(1,sortDim)); //this is just some kind of heuristic
  136. if ( t != t || t < 0 || t > 1){ //check if t is NAN, -inf or inf (happens in the farthest right or left case from above)
  137. t = 0.5;
  138. }
  139. //P(t) = b0*P0 + b1*P1 + b2*P2 + b3*P3
  140. NICE::Vector P(dimension);
  141. double y;
  142. double b0,b1,b2,b3;
  143. b0 = tau * (-(t*t*t) + 2*t*t - t);
  144. b1 = tau * (3*t*t*t - 5*t*t + 2);
  145. b2 = tau * (-3*t*t*t + 4*t*t + t);
  146. b3 = tau * (t*t*t - t*t);
  147. #pragma omp parallel for
  148. for ( uint i = 0; i < (uint)dimension; i++ ){
  149. P[i] = b0*points(0,i) + b1*points(1,i) + b2*points(2,i) + b3*points(3,i);
  150. }
  151. double diff1 = (P-x).normL2();
  152. uint counter = 1;
  153. while ( diff1 > 1e-5 && counter <= 21){ //adjust t to fit data better
  154. double tmp = t;;
  155. if (tmp > 0.5)
  156. tmp = 1 - tmp;
  157. t += tmp/counter;
  158. b0 = tau * (-(t*t*t) + 2*t*t - t);
  159. b1 = tau * (3*t*t*t - 5*t*t + 2);
  160. b2 = tau * (-3*t*t*t + 4*t*t + t);
  161. b3 = tau * (t*t*t - t*t);
  162. for ( uint i = 0; i < (uint)dimension; i++ ){
  163. P[i] = b0*points(0,i) + b1*points(1,i) + b2*points(2,i) + b3*points(3,i);
  164. }
  165. double diff2 = (P-x).normL2();
  166. if ( diff2 > diff1 && t > 0) {
  167. t -= 2*tmp/counter;
  168. b0 = tau * (-(t*t*t) + 2*t*t - t);
  169. b1 = tau * (3*t*t*t - 5*t*t + 2);
  170. b2 = tau * (-3*t*t*t + 4*t*t + t);
  171. b3 = tau * (t*t*t - t*t);
  172. #pragma omp parallel for
  173. for ( uint i = 0; i < (uint)dimension; i++ ){
  174. P[i] = b0*points(0,i) + b1*points(1,i) + b2*points(2,i) + b3*points(3,i);
  175. }
  176. diff1 = (P-x).normL2();
  177. }
  178. counter++;
  179. }
  180. y = b0*points(0,dimension) + b1*points(1,dimension) + b2*points(2,dimension) + b3*points(3,dimension);
  181. return y;
  182. }