VCNearestClassMean.cpp 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. #ifdef NICE_USELIB_ICE
  2. #include <iostream>
  3. #include "vislearning/classifier/vclassifier/VCNearestClassMean.h"
  4. using namespace OBJREC;
  5. using namespace std;
  6. using namespace NICE;
  7. VCNearestClassMean::VCNearestClassMean( const Config *_conf, NICE::VectorDistance<double> *_distancefunc )
  8. : VecClassifier ( _conf ), distancefunc (_distancefunc)
  9. {
  10. if ( _distancefunc == NULL )
  11. distancefunc = new EuclidianDistance<double>();
  12. }
  13. VCNearestClassMean::~VCNearestClassMean()
  14. {
  15. clear();
  16. }
  17. /** classify using simple vector */
  18. ClassificationResult VCNearestClassMean::classify ( const NICE::Vector & x ) const
  19. {
  20. double min_distance= std::numeric_limits<double>::max();
  21. int min_class = -1;
  22. FullVector scores ( classNo.size() );
  23. for(uint i=0;i<this->classNo.size();i++)
  24. {
  25. double distance = distancefunc->calculate ( x, means[i] );
  26. scores[i] = - distance;
  27. if ( distance < min_distance)
  28. {
  29. min_distance = distance;
  30. min_class = classNo[i];
  31. }
  32. }
  33. return ClassificationResult ( min_class, scores );
  34. }
  35. void VCNearestClassMean::teach ( const LabeledSetVector & _teachSet )
  36. {
  37. _teachSet.getClasses ( this->classNo );
  38. //initialize means
  39. NICE::Vector zero( _teachSet.dimension() );
  40. for(uint d=0;d<zero.size();d++) zero[d]=0.0;
  41. for(uint c=0;c<this->classNo.size();c++)
  42. {
  43. means.push_back(zero);
  44. }
  45. //add all class-specific vectors
  46. int index=0;
  47. LOOP_ALL(_teachSet)
  48. {
  49. EACH(classno,x);
  50. //determine index
  51. for(uint c=0;c<this->classNo.size();c++)
  52. {
  53. if(classno==classNo[c]) index=c;
  54. }
  55. for(uint d=0;d<zero.size();d++)
  56. {
  57. means[index][d]+=x[d];
  58. }
  59. }
  60. //normalize vectors
  61. for(uint c=0;c<this->classNo.size();c++)
  62. {
  63. for(uint d=0;d<zero.size();d++)
  64. {
  65. means[c][d]/=_teachSet.count(this->classNo[c]);
  66. }
  67. }
  68. }
  69. void VCNearestClassMean::finishTeaching()
  70. {
  71. //nothing more to do
  72. }
  73. void VCNearestClassMean::clear ()
  74. {
  75. //nothing to do
  76. }
  77. void VCNearestClassMean::store ( std::ostream & os, int format ) const
  78. {
  79. fprintf (stderr, "NOT YET IMPLEMENTED\n");
  80. exit(-1);
  81. }
  82. void VCNearestClassMean::restore ( std::istream & is, int format )
  83. {
  84. fprintf (stderr, "NOT YET IMPLEMENTED\n");
  85. exit(-1);
  86. }
  87. #endif