RegKNN.cpp 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. /**
  2. * @file RegKNN.cpp
  3. * @brief Implementation of k-Nearest-Neighbor algorithm for regression purposes
  4. * @author Frank Prüfer
  5. * @date 08/29/2013
  6. */
  7. #ifdef NICE_USELIB_OPENMP
  8. #include <omp.h>
  9. #endif
  10. #include <iostream>
  11. #include "vislearning/regression/npregression/RegKNN.h"
  12. #include "vislearning/math/mathbase/FullVector.h"
  13. using namespace OBJREC;
  14. using namespace NICE;
  15. RegKNN::RegKNN ( const Config *_conf, NICE::VectorDistance<double> *_distancefunc ) : distancefunc (_distancefunc)
  16. {
  17. K = _conf->gI("RegKNN", "K", 1 );
  18. if ( _distancefunc == NULL )
  19. distancefunc = new EuclidianDistance<double>();
  20. }
  21. RegKNN::RegKNN ( const RegKNN & src ) : RegressionAlgorithm ( src )
  22. {
  23. dataSet = src.dataSet;
  24. labelSet = src.labelSet;
  25. distancefunc = src.distancefunc;
  26. K = src.K;
  27. }
  28. RegKNN::~RegKNN ()
  29. {
  30. }
  31. RegKNN* RegKNN::clone ( void ) const
  32. {
  33. return new RegKNN(*this);
  34. }
  35. void RegKNN::teach ( const NICE::VVector & _dataSet, const NICE::Vector & _labelSet)
  36. {
  37. fprintf (stderr, "teach using all !\n");
  38. //NOTE this is crucial if we clear _teachSet afterwards!
  39. //therefore, take care NOT to call _techSet.clear() somewhere out of this method
  40. this->dataSet = _dataSet;
  41. this->labelSet = _labelSet.std_vector();
  42. std::cerr << "number of known training samples: " << this->dataSet.size() << std::endl;
  43. }
  44. void RegKNN::teach ( const NICE::Vector & x, const double & y )
  45. {
  46. std::cerr << "RegKNN::teach one new example" << std::endl;
  47. for ( size_t i = 0 ; i < x.size() ; i++ )
  48. if ( isnan(x[i]) )
  49. {
  50. fprintf (stderr, "There is a NAN value within this vector: x[%d] = %f\n", (int)i, x[i]);
  51. std::cerr << x << std::endl;
  52. exit(-1);
  53. }
  54. dataSet.push_back ( x );
  55. labelSet.push_back ( y );
  56. std::cerr << "number of known training samples: " << dataSet.size()<< std::endl;
  57. }
  58. double RegKNN::predict ( const NICE::Vector & x )
  59. {
  60. FullVector distances(dataSet.size());
  61. if ( dataSet.size() <= 0 )
  62. {
  63. fprintf (stderr, "RegKNN: please use the teach method first\n");
  64. exit(-1);
  65. }
  66. #pragma omp parallel for
  67. for(uint i = 0; i < dataSet.size(); i++)
  68. {
  69. double distance = distancefunc->calculate (x,dataSet[i]);
  70. if ( isnan(distance) )
  71. {
  72. fprintf (stderr, "RegKNN::predict: NAN value found !!\n");
  73. std::cerr << x << std::endl;
  74. std::cerr << dataSet[i] << std::endl;
  75. }
  76. // #pragma omp critical
  77. distances[i] = distance;
  78. }
  79. std::vector<int> ind;
  80. distances.getSortedIndices(ind);
  81. double response = 0.0;
  82. if ( dataSet.size() < K )
  83. {
  84. std::cerr << K << std::endl;
  85. K = dataSet.size();
  86. std::cerr<<"RegKNN: Not enough datapoints! Setting K to: "<< K << std::endl;
  87. }
  88. if ( distances[ind[0]] == 0.0 ) {
  89. std::cerr<<"RegKNN: Warning: datapoint was already seen during training... using its label as prediction."<< std::endl;
  90. return labelSet[ind[0]];
  91. }
  92. double maxElement = distances.max(); //normalize distances
  93. distances.multiply(1.0/maxElement);
  94. double weightSum = 0.0;
  95. for(uint i = 0; i < K; i++)
  96. {
  97. response += 1.0/distances[ind[i]] * labelSet[ind[i]];
  98. weightSum += 1.0/distances[ind[i]];
  99. }
  100. return ( response / weightSum );
  101. }