RANSACReg.cpp 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. /**
  2. * @file RANSACReg.cpp
  3. * @brief Implementation of RANSAC (RANdom SAmple Consensus) for regression purposes
  4. * @author Frank Prüfer
  5. * @date 09/10/2013
  6. */
  7. #ifdef NICE_USELIB_OPENMP
  8. #include <omp.h>
  9. #endif
  10. #include <iostream>
  11. #include "vislearning/regression/linregression/LinRegression.h"
  12. #include "vislearning/regression/linregression/RANSACReg.h"
  13. using namespace OBJREC;
  14. using namespace std;
  15. using namespace NICE;
  16. RANSACReg::RANSACReg ( const Config *_conf )
  17. {
  18. threshold = _conf->gD("RANSACReg","threshold",0.5);
  19. iter = _conf->gI("RANSACReg","iterations",10);
  20. }
  21. RANSACReg::RANSACReg ( const RANSACReg & src ) : RegressionAlgorithm ( src )
  22. {
  23. threshold = src.threshold;
  24. n = src.n;
  25. iter = src.iter;
  26. dataSet = src.dataSet;
  27. labelSet = src.labelSet;
  28. modelParams = src.modelParams;
  29. }
  30. RANSACReg::~RANSACReg()
  31. {
  32. }
  33. void RANSACReg::teach ( const NICE::VVector & dataSet, const NICE::Vector & labelSet )
  34. {
  35. //for iter iterations do
  36. //choose random subset of n points (n = dataSet[0].size()+1)
  37. //do LinRegression on subset
  38. //get modelParameters
  39. //test how many points, which are not in subset, are close to model (use threshold and distancefunc here) -> these points are consneus set
  40. //if consensus set contains more points than any previous one, take this model as best_model
  41. //maybe compute best_model again with all points of best_consensusSet
  42. //store best_model and maybe best_consensusSet
  43. NICE::VVector best_CS(0,0);
  44. std::vector<double> best_labelCS;
  45. cerr<<"Size of training data: "<<dataSet.size()<<endl;
  46. vector<int> indices;
  47. for ( uint i = 0; i < dataSet.size(); i++ )
  48. indices.push_back(i);
  49. n = dataSet[0].size()+1;
  50. for ( uint i = 0; i < iter; i++ ){
  51. random_shuffle( indices.begin(), indices.end() );
  52. NICE::VVector randDataSubset;
  53. std::vector<double> randLabelSubset;
  54. for ( uint j = 0; j < n; j++ ){ //choose random subset of n points
  55. randDataSubset.push_back( dataSet[indices[j]] );
  56. randLabelSubset.push_back( labelSet[indices[j]] );
  57. }
  58. LinRegression *linReg = new LinRegression ();
  59. linReg->teach ( randDataSubset, (NICE::Vector)randLabelSubset ); //do LinRegression on subset
  60. std::vector<double> tmp_modelParams = linReg->getModelParams();
  61. NICE::VVector current_CS;
  62. std::vector<double> current_labelCS;
  63. for ( uint j = n; j < indices.size(); j++ ){ //compute distance between each datapoint and current model
  64. double lengthNormalVector = 0;
  65. double sum = 0;
  66. for ( uint k = 0; k < tmp_modelParams.size(); k++ ){
  67. sum += tmp_modelParams[k] * dataSet[indices[j]][k];
  68. lengthNormalVector += tmp_modelParams[k] * tmp_modelParams[k];
  69. }
  70. lengthNormalVector = sqrt(lengthNormalVector);
  71. double distance = ( sum - labelSet[indices[j]] )/ lengthNormalVector;
  72. // cerr<<"distance: "<<distance<<endl;
  73. if ( abs(distance) < threshold ){ //if point is close to model, it belongs to consensus set
  74. current_CS.push_back ( dataSet[indices[j]] );
  75. current_labelCS.push_back ( labelSet[indices[j]] );
  76. }
  77. }
  78. if ( current_CS.size() > best_CS.size() ){ //if consensus set contains more points than any previous one, take this model as best_model
  79. best_CS = current_CS;
  80. best_labelCS = current_labelCS;
  81. }
  82. }
  83. cerr<<"Size of best_CS: "<<best_CS.size()<<endl;
  84. LinRegression *best_linReg = new LinRegression (); //compute best_model again with all points of best_consensusSet
  85. best_linReg->teach ( best_CS, (NICE::Vector)best_labelCS );
  86. modelParams = best_linReg->getModelParams();
  87. }
  88. double RANSACReg::predict ( const NICE::Vector & x )
  89. {
  90. NICE::Vector nModel(modelParams);
  91. NICE:: Vector xTmp(1,1.0);
  92. xTmp.append(x);
  93. double y = xTmp.scalarProduct(nModel);
  94. return y;
  95. }