RANSACReg.cpp 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. /**
  2. * @file RANSACReg.cpp
  3. * @brief Implementation of RANSAC (RANdom SAmple Consensus) for regression purposes
  4. * @author Frank Prüfer
  5. * @date 09/10/2013
  6. */
  7. #ifdef NICE_USELIB_OPENMP
  8. #include <omp.h>
  9. #endif
  10. #include <iostream>
  11. #include <ctime>
  12. #include "vislearning/regression/linregression/LinRegression.h"
  13. #include "vislearning/regression/linregression/RANSACReg.h"
  14. using namespace OBJREC;
  15. using namespace std;
  16. using namespace NICE;
  17. RANSACReg::RANSACReg ( const Config *_conf )
  18. {
  19. if ( _conf->gB("RANSACReg","start_random_generator" ) )
  20. std::srand ( unsigned ( std::time(0) ) );
  21. threshold = _conf->gD("RANSACReg","threshold",0.5);
  22. iter = _conf->gI("RANSACReg","iterations",10);
  23. }
  24. RANSACReg::RANSACReg ( const RANSACReg & src ) : RegressionAlgorithm ( src )
  25. {
  26. threshold = src.threshold;
  27. n = src.n;
  28. iter = src.iter;
  29. dataSet = src.dataSet;
  30. labelSet = src.labelSet;
  31. modelParams = src.modelParams;
  32. }
  33. RANSACReg::~RANSACReg()
  34. {
  35. }
  36. RANSACReg* RANSACReg::clone ( void ) const
  37. {
  38. return new RANSACReg(*this);
  39. }
  40. void RANSACReg::teach ( const NICE::VVector & dataSet, const NICE::Vector & labelSet )
  41. {
  42. NICE::VVector best_CS(0,0);
  43. std::vector<double> best_labelCS;
  44. vector<int> indices;
  45. for ( uint i = 0; i < dataSet.size(); i++ )
  46. indices.push_back(i);
  47. n = dataSet[0].size()+1;
  48. for ( uint i = 0; i < iter; i++ ){
  49. random_shuffle( indices.begin(), indices.end() );
  50. NICE::VVector randDataSubset;
  51. std::vector<double> randLabelSubset;
  52. for ( uint j = 0; j < n; j++ ){ //choose random subset of n points
  53. randDataSubset.push_back( dataSet[indices[j]] );
  54. randLabelSubset.push_back( labelSet[indices[j]] );
  55. }
  56. LinRegression *linReg = new LinRegression ();
  57. linReg->teach ( randDataSubset, (NICE::Vector)randLabelSubset ); //do LinRegression on subset
  58. std::vector<double> tmp_modelParams = linReg->getModelParams();
  59. NICE::VVector current_CS;
  60. std::vector<double> current_labelCS;
  61. #pragma omp parallel for
  62. for ( uint j = n; j < indices.size(); j++ ){ //compute distance between each datapoint and current model
  63. double lengthNormalVector = 0;
  64. double sum = 0;
  65. for ( uint k = 0; k < tmp_modelParams.size(); k++ ){
  66. sum += tmp_modelParams[k] * dataSet[indices[j]][k];
  67. lengthNormalVector += tmp_modelParams[k] * tmp_modelParams[k];
  68. }
  69. lengthNormalVector = sqrt(lengthNormalVector);
  70. double distance = ( sum - labelSet[indices[j]] ) / lengthNormalVector;
  71. #pragma omp critical
  72. if ( abs(distance) < threshold ){ //if point is close to model, it belongs to consensus set
  73. current_CS.push_back ( dataSet[indices[j]] );
  74. current_labelCS.push_back ( labelSet[indices[j]] );
  75. }
  76. }
  77. if ( current_CS.size() > best_CS.size() ){ //if consensus set contains more points than any previous one, take this model as best_model
  78. best_CS = current_CS;
  79. best_labelCS = current_labelCS;
  80. }
  81. }
  82. LinRegression *best_linReg = new LinRegression (); //compute best_model again with all points of best_consensusSet
  83. best_linReg->teach ( best_CS, (NICE::Vector)best_labelCS );
  84. modelParams = best_linReg->getModelParams();
  85. }
  86. double RANSACReg::predict ( const NICE::Vector & x )
  87. {
  88. NICE::Vector nModel(modelParams);
  89. NICE:: Vector xTmp(1,1.0);
  90. xTmp.append(x);
  91. double y = xTmp.scalarProduct(nModel);
  92. return y;
  93. }