RegKNN.cpp 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. /**
  2. * @file RegKNN.cpp
  3. * @brief Implementation of k-Nearest-Neighbor algorithm for regression purposes
  4. * @author Frank Prüfer
  5. * @date 08/29/2013
  6. */
  7. #include <iostream>
  8. #include "vislearning/regression/npregression/RegKNN.h"
  9. #include "vislearning/math/mathbase/FullVector.h"
  10. using namespace OBJREC;
  11. using namespace std;
  12. using namespace NICE;
  13. RegKNN::RegKNN ( const Config *_conf, NICE::VectorDistance<double> *_distancefunc ) : distancefunc (_distancefunc)
  14. {
  15. K = _conf->gI("RegKNN", "K", 1 );
  16. if ( _distancefunc == NULL )
  17. distancefunc = new EuclidianDistance<double>();
  18. }
  19. RegKNN::~RegKNN()
  20. {
  21. }
  22. void RegKNN::teach ( const NICE::VVector & _dataSet, const NICE::Vector & _labelSet)
  23. {
  24. fprintf (stderr, "teach using all !\n");
  25. //NOTE this is crucial if we clear _teachSet afterwards!
  26. //therefore, take care NOT to call _techSet.clear() somewhere out of this method
  27. this->dataSet = _dataSet;
  28. this->labelSet = _labelSet;
  29. std::cerr << "number of known training samples: " << this->dataSet.size() << std::endl;
  30. }
  31. // void RegKNN::teach ( const NICE::Vector & x, const double & y )
  32. // {
  33. // std::cerr << "RegKNN::teach one new example" << std::endl;
  34. //
  35. // for ( size_t i = 0 ; i < x.size() ; i++ )
  36. // if ( isnan(x[i]) )
  37. // {
  38. // fprintf (stderr, "There is a NAN value in within this vector: x[%d] = %f\n", (int)i, x[i]);
  39. // cerr << x << endl;
  40. // exit(-1);
  41. // }
  42. //
  43. // dataSet.push_back ( x );
  44. // labelSet.push_back ( y );
  45. //
  46. // std::cerr << "number of known training samples: " << dataSet.size()<< std::endl;
  47. // }
  48. double RegKNN::predict ( const NICE::Vector & x )
  49. {
  50. FullVector distances(dataSet.size());
  51. if ( dataSet.size() <= 0 ) {
  52. fprintf (stderr, "RegKNN: please train this classifier before classifying\n");
  53. exit(-1);
  54. }
  55. for(uint i = 0; i < dataSet.size(); i++){
  56. double distance = distancefunc->calculate (x,dataSet[i]);
  57. if ( isnan(distance) )
  58. {
  59. fprintf (stderr, "RegKNN::classify: NAN value found !!\n");
  60. cerr << x << endl;
  61. cerr << dataSet[i] << endl;
  62. }
  63. distances[i] = distance;
  64. }
  65. std::vector<int> ind;
  66. distances.getSortedIndices(ind);
  67. double response = 0.0;
  68. for(uint i = 0; i < K; i++){
  69. response += labelSet[ind[i]];
  70. }
  71. return (response / (double) K);
  72. }