CRSplineReg.cpp 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. /**
  2. * @file CRSplineReg.cpp
  3. * @brief Implementation of Catmull-Rom-Splines for regression purposes
  4. * @author Frank Prüfer
  5. * @date 09/03/2013
  6. */
  7. #ifdef NICE_USELIB_OPENMP
  8. #include <omp.h>
  9. #endif
  10. #include <iostream>
  11. #include "vislearning/regression/splineregression/CRSplineReg.h"
  12. #include "vislearning/regression/linregression/LinRegression.h"
  13. #include "vislearning/math/mathbase/FullVector.h"
  14. using namespace OBJREC;
  15. using namespace std;
  16. using namespace NICE;
  17. CRSplineReg::CRSplineReg ( )
  18. {
  19. tau = 0.5;
  20. }
  21. CRSplineReg::CRSplineReg ( const CRSplineReg & src ) : RegressionAlgorithm ( src )
  22. {
  23. tau = src.tau;
  24. dataSet = src.dataSet;
  25. labelSet = src.labelSet;
  26. }
  27. CRSplineReg::~CRSplineReg()
  28. {
  29. }
  30. void CRSplineReg::teach ( const NICE::VVector & _dataSet, const NICE::Vector & _labelSet)
  31. {
  32. fprintf (stderr, "teach using all !\n");
  33. //NOTE this is crucial if we clear _teachSet afterwards!
  34. //therefore, take care NOT to call _techSet.clear() somewhere out of this method
  35. this->dataSet = _dataSet;
  36. this->labelSet = _labelSet.std_vector();
  37. std::cerr << "number of known training samples: " << this->dataSet.size() << std::endl;
  38. }
  39. void CRSplineReg::teach ( const NICE::Vector & x, const double & y )
  40. {
  41. std::cerr << "CRSplineReg::teach one new example" << std::endl;
  42. for ( size_t i = 0 ; i < x.size() ; i++ )
  43. if ( isnan(x[i]) )
  44. {
  45. fprintf (stderr, "There is a NAN value in within this vector: x[%d] = %f\n", (int)i, x[i]);
  46. cerr << x << endl;
  47. exit(-1);
  48. }
  49. dataSet.push_back ( x );
  50. labelSet.push_back ( y );
  51. std::cerr << "number of known training samples: " << dataSet.size()<< std::endl;
  52. }
  53. double CRSplineReg::predict ( const NICE::Vector & x )
  54. {
  55. if ( dataSet.size() <= 0 ) {
  56. fprintf (stderr, "CRSplineReg: please use the train method first\n");
  57. exit(-1);
  58. }
  59. int dimension = dataSet[0].size();
  60. int sortDim = 0;
  61. FullVector data ( dataSet.size()+1 );
  62. #pragma omp parallel for
  63. for ( uint i = 0; i < dataSet.size(); i++ ){
  64. data[i] = dataSet[i][sortDim];
  65. }
  66. data[dataSet.size()] = x[sortDim];
  67. std::vector<int> sortedInd;
  68. data.getSortedIndices(sortedInd);
  69. int index;
  70. for ( uint i = 0; i < sortedInd.size(); i++ ){
  71. if ( sortedInd[i] == dataSet.size() ){
  72. index = i;
  73. break;
  74. }
  75. }
  76. NICE::Matrix points (4,dimension+1,0.0);
  77. if ( index >= 2 && index < (sortedInd.size() - 2) ){ //everything is okay
  78. points.setRow(0,dataSet[sortedInd[index-2]]);
  79. points(0,dimension) = labelSet[sortedInd[index-2]];
  80. points.setRow(1,dataSet[sortedInd[index-1]]);
  81. points(1,dimension) = labelSet[sortedInd[index-1]];
  82. points.setRow(2,dataSet[sortedInd[index+1]]);
  83. points(2,dimension) = labelSet[sortedInd[index+1]];
  84. points.setRow(3,dataSet[sortedInd[index+2]]);
  85. points(3,dimension) = labelSet[sortedInd[index+2]];
  86. }
  87. else if ( index == 1 ){ //just one point left from x
  88. points.setRow(0,dataSet[sortedInd[index-1]]);
  89. points(0,dimension) = labelSet[sortedInd[index-1]];
  90. points.setRow(1,dataSet[sortedInd[index-1]]);
  91. points(1,dimension) = labelSet[sortedInd[index-1]];
  92. points.setRow(2,dataSet[sortedInd[index+1]]);
  93. points(2,dimension) = labelSet[sortedInd[index+1]];
  94. points.setRow(3,dataSet[sortedInd[index+2]]);
  95. points(3,dimension) = labelSet[sortedInd[index+2]];
  96. }
  97. else if ( index == 0 ){ //x is the farthest left point
  98. points.setRow(0,dataSet[sortedInd[index+1]]);
  99. points(0,dimension) = labelSet[sortedInd[index+1]];
  100. points.setRow(1,dataSet[sortedInd[index+1]]);
  101. points(1,dimension) = labelSet[sortedInd[index+1]];
  102. points.setRow(2,dataSet[sortedInd[index+1]]);
  103. points(2,dimension) = labelSet[sortedInd[index+1]];
  104. points.setRow(3,dataSet[sortedInd[index+2]]);
  105. points(3,dimension) = labelSet[sortedInd[index+2]];
  106. }
  107. else if ( index == (sortedInd.size() - 2) ){ //just one point right from x
  108. points.setRow(0,dataSet[sortedInd[index-2]]);
  109. points(0,dimension) = labelSet[sortedInd[index-2]];
  110. points.setRow(1,dataSet[sortedInd[index-1]]);
  111. points(1,dimension) = labelSet[sortedInd[index-1]];
  112. points.setRow(2,dataSet[sortedInd[index+1]]);
  113. points(2,dimension) = labelSet[sortedInd[index+1]];
  114. points.setRow(3,dataSet[sortedInd[index+1]]);
  115. points(3,dimension) = labelSet[sortedInd[index+1]];
  116. }
  117. else if ( index == (sortedInd.size() - 1) ){ //x is the farthest right point
  118. points.setRow(0,dataSet[sortedInd[index-2]]);
  119. points(0,dimension) = labelSet[sortedInd[index-2]];
  120. points.setRow(1,dataSet[sortedInd[index-1]]);
  121. points(1,dimension) = labelSet[sortedInd[index-1]];
  122. points.setRow(2,dataSet[sortedInd[index-1]]);
  123. points(2,dimension) = labelSet[sortedInd[index-1]];
  124. points.setRow(3,dataSet[sortedInd[index-1]]);
  125. points(3,dimension) = labelSet[sortedInd[index-1]];
  126. }
  127. // NICE::Vector P(dimension);
  128. // double y;
  129. // double diff,bestT;
  130. // double bestDiff = std::numeric_limits<double>::max();
  131. // double b0,b1,b2,b3;
  132. //
  133. // for ( double t = 0.0; t <= 1.0; t += 0.02 ){
  134. // b0 = tau * (-(t*t*t) + 2*t*t - t);
  135. // b1 = tau * (3*t*t*t - 5*t*t + 2);
  136. // b2 = tau * (-3*t*t*t + 4*t*t + t);
  137. // b3 = tau * (t*t*t - t*t);
  138. //
  139. // #pragma omp parallel for
  140. // for ( uint i = 0; i < dimension; i++ ){
  141. // P[i] = b0*points(0,i) + b1*points(1,i) + b2*points(2,i) + b3*points(3,i);
  142. // }
  143. //
  144. // diff = (P-x).normL2();
  145. // if ( diff < bestDiff ){
  146. // bestDiff = diff;
  147. // bestT = t;
  148. // }
  149. // }
  150. // b0 = tau * (-(bestT*bestT*bestT) + 2*bestT*bestT - bestT);
  151. // b1 = tau * (3*bestT*bestT*bestT - 5*bestT*bestT + 2);
  152. // b2 = tau * (-3*bestT*bestT*bestT + 4*bestT*bestT + bestT);
  153. // b3 = tau * (bestT*bestT*bestT - bestT*bestT);
  154. double t = (x[sortDim]-points(1,sortDim)) / (points(2,sortDim)-points(1,sortDim));
  155. // cerr<<"t: "<<t<<endl;
  156. //P(t) = b0*P0 + b1*P1 + b2*P2 + b3*P3
  157. NICE::Vector P(dimension);
  158. double y;
  159. double b0,b1,b2,b3;
  160. b0 = tau * (-(t*t*t) + 2*t*t - t);
  161. b1 = tau * (3*t*t*t - 5*t*t + 2);
  162. b2 = tau * (-3*t*t*t + 4*t*t + t);
  163. b3 = tau * (t*t*t - t*t);
  164. #pragma omp parallel for
  165. for ( uint i = 0; i < dimension; i++ ){
  166. P[i] = b0*points(0,i) + b1*points(1,i) + b2*points(2,i) + b3*points(3,i);
  167. }
  168. double diff1 = (P-x).normL2();
  169. uint counter = 1;
  170. while ( diff1 > 1e-5 && counter <= 21){
  171. double tmp = t;;
  172. if (tmp > 0.5)
  173. tmp = 1 - tmp;
  174. t += tmp/counter;
  175. b0 = tau * (-(t*t*t) + 2*t*t - t);
  176. b1 = tau * (3*t*t*t - 5*t*t + 2);
  177. b2 = tau * (-3*t*t*t + 4*t*t + t);
  178. b3 = tau * (t*t*t - t*t);
  179. for ( uint i = 0; i < dimension; i++ ){
  180. P[i] = b0*points(0,i) + b1*points(1,i) + b2*points(2,i) + b3*points(3,i);
  181. }
  182. double diff2 = (P-x).normL2();
  183. if ( diff2 > diff1 && t > 0) {
  184. t -= 2*tmp/counter;
  185. b0 = tau * (-(t*t*t) + 2*t*t - t);
  186. b1 = tau * (3*t*t*t - 5*t*t + 2);
  187. b2 = tau * (-3*t*t*t + 4*t*t + t);
  188. b3 = tau * (t*t*t - t*t);
  189. #pragma omp parallel for
  190. for ( uint i = 0; i < dimension; i++ ){
  191. P[i] = b0*points(0,i) + b1*points(1,i) + b2*points(2,i) + b3*points(3,i);
  192. }
  193. diff1 = (P-x).normL2();
  194. }
  195. counter++;
  196. }
  197. y = b0*points(0,dimension) + b1*points(1,dimension) + b2*points(2,dimension) + b3*points(3,dimension);
  198. return y;
  199. }