VCNearestNeighbour.cpp 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. /**
  2. * @file VCNearestNeighbour.cpp
  3. * @brief Simple Nearest Neighbour Implementation
  4. * @author Erik Rodner
  5. * @date 10/25/2007
  6. */
  7. #include <iostream>
  8. #include <queue>
  9. #include "vislearning/classifier/vclassifier/VCNearestNeighbour.h"
  10. using namespace OBJREC;
  11. using namespace std;
  12. using namespace NICE;
  13. #undef DEBUG_VCN
  14. VCNearestNeighbour::VCNearestNeighbour ( const Config *_conf, NICE::VectorDistance<double> *_distancefunc )
  15. : VecClassifier ( _conf ), distancefunc (_distancefunc)
  16. {
  17. K = _conf->gI("VCNearestNeighbour", "K", 1 );
  18. if ( _distancefunc == NULL )
  19. distancefunc = new EuclidianDistance<double>();
  20. }
  21. VCNearestNeighbour::VCNearestNeighbour ( const VCNearestNeighbour & src ) : VecClassifier()
  22. {
  23. if ( src.teachSet.size() )
  24. fthrow(Exception, "It is not yet possible to clone an already trained nearest neighbour classifier.");
  25. distancefunc = src.distancefunc;
  26. K = src.K;
  27. maxClassNo = src.maxClassNo;
  28. }
  29. VCNearestNeighbour::~VCNearestNeighbour()
  30. {
  31. }
  32. /** classify using simple vector */
  33. ClassificationResult VCNearestNeighbour::classify ( const NICE::Vector & x ) const
  34. {
  35. double mindist = std::numeric_limits<double>::max();
  36. int minclass = 0;
  37. FullVector mindists ( maxClassNo + 1 );
  38. mindists.set ( mindist );
  39. if ( teachSet.count() <= 0 ) {
  40. fprintf (stderr, "VCNearestNeighbour: please train this classifier before classifying\n");
  41. exit(-1);
  42. }
  43. priority_queue< pair<double, int> > examples;
  44. LOOP_ALL(teachSet)
  45. {
  46. EACH(classno,y)
  47. double distance = distancefunc->calculate ( x, y );
  48. if ( NICE::isNaN(distance) )
  49. {
  50. fprintf (stderr, "VCNearestNeighbour::classify: NAN value found !!\n");
  51. cerr << x << endl;
  52. cerr << y << endl;
  53. }
  54. if ( mindists[classno] > distance )
  55. mindists[classno] = distance;
  56. if ( mindist > distance )
  57. {
  58. minclass = classno;
  59. mindist = distance;
  60. }
  61. if ( K > 1 )
  62. examples.push ( pair<double, int> ( -distance, classno ) );
  63. }
  64. if ( mindist == 0.0 )
  65. fprintf (stderr, "VCNearestNeighbour::classify WARNING distance is zero, reclassification?\n");
  66. #ifdef DEBUG_VCN
  67. for ( int i = 0 ; i < mindists.size() ; i++ )
  68. fprintf (stderr, "class %d : %f\n", i, mindists.get(i) );
  69. #endif
  70. if ( K > 1 )
  71. {
  72. FullVector votes ( maxClassNo + 1 );
  73. votes.set(0.0);
  74. for ( int k = 0 ; k < K ; k++ )
  75. {
  76. const pair<double, int> & t = examples.top();
  77. votes[ t.second ]++;
  78. examples.pop();
  79. }
  80. votes.normalize();
  81. return ClassificationResult ( votes.maxElement(), votes );
  82. }
  83. else
  84. {
  85. //do we really want to do this? Only useful, if we want to obtain probability like scores
  86. // for ( int i = 0 ; i < mindists.size() ; i++ )
  87. // {
  88. // mindists[i] = 1.0 / (mindists[i] + 1.0);
  89. // }
  90. //mindists.normalize();
  91. return ClassificationResult ( minclass, mindists );
  92. }
  93. }
  94. /** classify using a simple vector */
  95. void VCNearestNeighbour::teach ( const LabeledSetVector & _teachSet )
  96. {
  97. fprintf (stderr, "teach using all !\n");
  98. maxClassNo = _teachSet.getMaxClassno();
  99. //NOTE this is crucial if we clear _teachSet afterwards!
  100. //therefore, take care NOT to call _techSet.clear() somewhere out of this method
  101. this->teachSet = _teachSet;
  102. std::cerr << "number of known training samples: " << this->teachSet.begin()->second.size() << std::endl;
  103. // //just for testing - remove everything but the first element
  104. // map< int, vector<NICE::Vector *> >::iterator it = this->teachSet.begin();
  105. // it++;
  106. // this->teachSet.erase(it, this->teachSet.end());
  107. // std::cerr << "keep " << this->teachSet.size() << " elements" << std::endl;
  108. }
  109. void VCNearestNeighbour::teach ( int classno, const NICE::Vector & x )
  110. {
  111. std::cerr << "VCNearestNeighbour::teach one new example" << std::endl;
  112. for ( size_t i = 0 ; i < x.size() ; i++ )
  113. if ( NICE::isNaN(x[i]) )
  114. {
  115. fprintf (stderr, "There is a NAN value in within this vector: x[%d] = %f\n", (int)i, x[i]);
  116. cerr << x << endl;
  117. exit(-1);
  118. }
  119. if ( classno > maxClassNo ) maxClassNo = classno;
  120. teachSet.add ( classno, x );
  121. std::cerr << "adden class " << classno << " with feature " << x << std::endl;
  122. int tmpCnt(0);
  123. for (LabeledSetVector::const_iterator it = this->teachSet.begin(); it != this->teachSet.end(); it++)
  124. {
  125. tmpCnt += it->second.size();
  126. }
  127. std::cerr << "number of known training samples: " << tmpCnt << std::endl;
  128. }
  129. void VCNearestNeighbour::finishTeaching()
  130. {
  131. }
  132. VCNearestNeighbour *VCNearestNeighbour::clone() const
  133. {
  134. VCNearestNeighbour *myclone = new VCNearestNeighbour ( *this );
  135. return myclone;
  136. }
  137. void VCNearestNeighbour::clear ()
  138. {
  139. teachSet.clear();
  140. }
  141. void VCNearestNeighbour::store ( std::ostream & os, int format ) const
  142. {
  143. teachSet.store ( os, format );
  144. }
  145. void VCNearestNeighbour::restore ( std::istream & is, int format )
  146. {
  147. teachSet.restore ( is, format );
  148. maxClassNo = teachSet.getMaxClassno();
  149. }