VCOneVsAll.cpp 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. /**
  2. * @file VCOneVsAll.cpp
  3. * @author Erik Rodner
  4. * @date 10/25/2007
  5. */
  6. #ifdef NOVISUAL
  7. #include <vislearning/nice_nonvis.h>
  8. #else
  9. #include <vislearning/nice.h>
  10. #endif
  11. #include <iostream>
  12. #include "vislearning/classifier/vclassifier/VCOneVsAll.h"
  13. using namespace std;
  14. using namespace NICE;
  15. using namespace OBJREC;
  16. VCOneVsAll::VCOneVsAll ( const Config *_conf, const VecClassifier *_prototype )
  17. : VecClassifier ( _conf ), prototype(_prototype)
  18. {
  19. }
  20. VCOneVsAll::~VCOneVsAll()
  21. {
  22. }
  23. ClassificationResult VCOneVsAll::classify ( const NICE::Vector & x ) const
  24. {
  25. FullVector scores ( maxClassNo+1 );
  26. vector<bool> exists(maxClassNo+1, false);
  27. scores.set(0.0);
  28. double minval = numeric_limits<double>::max();
  29. for ( vector< pair<int, VecClassifier *> >::const_iterator i =
  30. classifiers.begin(); i != classifiers.end(); i++ )
  31. {
  32. int classno = i->first;
  33. exists[classno] = true;
  34. VecClassifier *classifier = i->second;
  35. ClassificationResult r = classifier->classify(x);
  36. scores[classno] += r.scores[1];
  37. minval = std::min(minval, scores[classno]);
  38. }
  39. for(int i = 0; i <= maxClassNo; i++)
  40. {
  41. if(!exists[i])
  42. {
  43. scores[i] = minval-numeric_limits<double>::epsilon();
  44. }
  45. }
  46. return ClassificationResult ( scores.maxElement(), scores );
  47. }
  48. void VCOneVsAll::teach ( const LabeledSetVector & _teachSet )
  49. {
  50. if ( _teachSet.count() <= 0 )
  51. fthrow(Exception, "Number of training examples is zero!\n");
  52. maxClassNo = _teachSet.getMaxClassno();
  53. classifiers.clear();
  54. for ( int i = 0 ; i <= maxClassNo ; i++ )
  55. {
  56. LabeledSetVector binarySubSet (true);
  57. LabeledSetVector::const_iterator exiv = _teachSet.find(i);
  58. if ( exiv == _teachSet.end() )
  59. {
  60. // a test example might be classified as this class
  61. // if we do not use probability scores
  62. cerr << "Class " << i << " does not have any training examples; skipping training." << endl;
  63. continue;
  64. }
  65. int poscount = _teachSet.count(i);
  66. int negcount = _teachSet.count() - poscount;
  67. int mincount = std::min(poscount, negcount);
  68. int c = 0;
  69. for ( vector<Vector *>::const_iterator exi = exiv->second.begin();
  70. exi != exiv->second.end(); exi++, c++ )
  71. {
  72. binarySubSet.add_reference ( 1, *exi );
  73. if( c >= mincount)
  74. break;
  75. }
  76. c = 0;
  77. for ( LabeledSetVector::const_iterator exjv = _teachSet.begin();
  78. exjv != _teachSet.end(); exjv++ , c++)
  79. {
  80. if ( exjv == exiv ) continue;
  81. for ( vector<Vector *>::const_iterator exj = exjv->second.begin();
  82. exj != exjv->second.end(); exj++ )
  83. binarySubSet.add_reference ( 0, *exj );
  84. if( c >= mincount)
  85. break;
  86. }
  87. VecClassifier *classifier;
  88. classifier = prototype->clone();
  89. fprintf (stderr, "Training classifier: class %d <-> remainder\n", i );
  90. classifier->teach ( binarySubSet );
  91. classifier->finishTeaching();
  92. classifiers.push_back ( pair<int, VecClassifier*> (i, classifier) );
  93. }
  94. }
  95. void VCOneVsAll::finishTeaching()
  96. {
  97. }
  98. VecClassifier *VCOneVsAll::clone(void) const
  99. {
  100. VCOneVsAll *classifier = new VCOneVsAll( *this );
  101. return classifier;
  102. }
  103. VCOneVsAll::VCOneVsAll( const VCOneVsAll &vcova ): VecClassifier()
  104. {
  105. prototype = vcova.prototype->clone();
  106. for(int i = 0; i < (int)vcova.classifiers.size(); i++)
  107. {
  108. classifiers.push_back(pair<int, VecClassifier*>(vcova.classifiers[i].first,vcova.classifiers[i].second->clone()));
  109. }
  110. }