CRSplineReg.cpp 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. /**
  2. * @file CRSplineReg.cpp
  3. * @brief Implementation of Catmull-Rom-Splines for regression purposes
  4. * @author Frank Prüfer
  5. * @date 09/03/2013
  6. */
  7. #include <iostream>
  8. #include "vislearning/regression/splineregression/CRSplineReg.h"
  9. #include "vislearning/math/mathbase/FullVector.h"
  10. using namespace OBJREC;
  11. using namespace std;
  12. using namespace NICE;
  13. CRSplineReg::CRSplineReg ( )
  14. {
  15. tau = 0.5;
  16. }
  17. CRSplineReg::CRSplineReg ( const CRSplineReg & src ) : RegressionAlgorithm ( src )
  18. {
  19. tau = src.tau;
  20. dataSet = src.dataSet;
  21. labelSet = src.labelSet;
  22. }
  23. CRSplineReg::~CRSplineReg()
  24. {
  25. }
  26. void CRSplineReg::teach ( const NICE::VVector & _dataSet, const NICE::Vector & _labelSet)
  27. {
  28. fprintf (stderr, "teach using all !\n");
  29. //NOTE this is crucial if we clear _teachSet afterwards!
  30. //therefore, take care NOT to call _techSet.clear() somewhere out of this method
  31. this->dataSet = _dataSet;
  32. this->labelSet = _labelSet.std_vector();
  33. std::cerr << "number of known training samples: " << this->dataSet.size() << std::endl;
  34. }
  35. void CRSplineReg::teach ( const NICE::Vector & x, const double & y )
  36. {
  37. std::cerr << "CRSplineReg::teach one new example" << std::endl;
  38. for ( size_t i = 0 ; i < x.size() ; i++ )
  39. if ( isnan(x[i]) )
  40. {
  41. fprintf (stderr, "There is a NAN value in within this vector: x[%d] = %f\n", (int)i, x[i]);
  42. cerr << x << endl;
  43. exit(-1);
  44. }
  45. dataSet.push_back ( x );
  46. labelSet.push_back ( y );
  47. std::cerr << "number of known training samples: " << dataSet.size()<< std::endl;
  48. }
  49. double CRSplineReg::predict ( const NICE::Vector & x )
  50. {
  51. if ( dataSet.size() <= 0 ) {
  52. fprintf (stderr, "CRSplineReg: please use the train method first\n");
  53. exit(-1);
  54. }
  55. if ( dataSet[0].size() == 1 ){ //one-dimensional case
  56. FullVector data ( dataSet.size()+1 );
  57. for ( uint i = 0; i < dataSet.size(); i++ ){
  58. data[i] = dataSet[i][0];
  59. }
  60. cerr<<"data x: "<<x[0]<<endl;
  61. data[dataSet.size()] = x[0];
  62. std::vector<int> ind;
  63. data.getSortedIndices(ind);
  64. int index;
  65. for ( uint i = 0; i < ind.size(); i++ ){
  66. if ( ind[i] == dataSet.size() ){
  67. index = i;
  68. break;
  69. }
  70. }
  71. NICE::Matrix points (4,2,0.0);
  72. if ( index >= 2 && index < (ind.size() - 2) ){ //everything is okay
  73. points(0,0) = data[ind[index-2]];
  74. points(0,1) = labelSet[ind[index-2]];
  75. points(1,0) = data[ind[index-1]];
  76. points(1,1) = labelSet[ind[index-1]];
  77. points(2,0) = data[ind[index+1]];
  78. points(2,1) = labelSet[ind[index+1]];
  79. points(3,0) = data[ind[index+2]];
  80. points(3,1) = labelSet[ind[index+2]];
  81. }
  82. else if ( index == 1 ){ //just one point left from x
  83. points(0,0) = data[ind[index-1]];
  84. points(0,1) = labelSet[ind[index-1]];
  85. points(1,0) = data[ind[index-1]];
  86. points(1,1) = labelSet[ind[index-1]];
  87. points(2,0) = data[ind[index+1]];
  88. points(2,1) = labelSet[ind[index+1]];
  89. points(3,0) = data[ind[index+2]];
  90. points(3,1) = labelSet[ind[index+2]];
  91. }
  92. else if ( index == 0 ){ //x is the farthest left point
  93. //do linear approximation
  94. }
  95. else if ( index == (ind.size() - 2) ){ //just one point right from x
  96. points(0,0) = data[ind[index-2]];
  97. points(0,1) = labelSet[ind[index-2]];
  98. points(1,0) = data[ind[index-1]];
  99. points(1,1) = labelSet[ind[index-1]];
  100. points(2,0) = data[ind[index+1]];
  101. points(2,1) = labelSet[ind[index+1]];
  102. points(3,0) = data[ind[index+1]];
  103. points(3,1) = labelSet[ind[index+1]];
  104. }
  105. else if ( index == (ind.size() - 1) ){ //x is the farthest right point
  106. //do linear approximation
  107. }
  108. double t = (x[0] - points(1,0)) / (points(2,0) - points(1,0));
  109. cerr<<"t: "<<t<<endl;
  110. // NICE::Vector vecT(4,1.0);
  111. //
  112. // vecT[1] = t;
  113. // vecT[2] = t*t;
  114. // vecT[3] = t*t*t;
  115. //
  116. // Matrix coeffMatrix (4,4,0.0); // M = (0 2 0 0
  117. // coeffMatrix(0,1) = 2.0; // -1 0 1 0
  118. // coeffMatrix(1,0) = -1.0; // 2 -5 4 -1
  119. // coeffMatrix(1,2) = 1.0; // -1 3 -3 1)
  120. // coeffMatrix(2,0) = 2.0;
  121. // coeffMatrix(2,1) = -5.0;
  122. // coeffMatrix(2,2) = 4.0;
  123. // coeffMatrix(2,3) = -1.0;
  124. // coeffMatrix(3,0) = -1.0;
  125. // coeffMatrix(3,1) = 3.0;
  126. // coeffMatrix(3,2) = -3.0;
  127. // coeffMatrix(3,3) = 1.0;
  128. //
  129. // // P(t) = tau * vecT * coeffMatrix * points;
  130. // NICE::Vector P;
  131. // NICE::Matrix tmp;
  132. // tmp.multiply(coeffMatrix,points);
  133. // P.multiply(vecT,tmp);
  134. // P *= tau;
  135. //P(t) = b0*P0 + b1*P1 + b2*P2 + b3*P3
  136. NICE::Vector P(2);
  137. double b0,b1,b2,b3;
  138. b0 = tau * (-(t*t*t) + 2*t*t - t);
  139. b1 = tau * (3*t*t*t - 5*t*t + 2);
  140. b2 = tau * (-3*t*t*t + 4*t*t + t);
  141. b3 = tau * (t*t*t - t*t);
  142. P[0] = b0*points(0,0) + b1*points(1,0) + b2*points(2,0) + b3*points(3,0);
  143. P[1] = b0*points(0,1) + b1*points(1,1) + b2*points(2,1) + b3*points(3,1);
  144. cerr<<"Response x: "<<P[0]<<endl;
  145. cerr<<"Response y: "<<P[1]<<endl;
  146. return P[1];
  147. }
  148. }