VCOneVsAll.cpp 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. /**
  2. * @file VCOneVsAll.cpp
  3. * @author Erik Rodner
  4. * @date 10/25/2007
  5. */
  6. #include <iostream>
  7. #include "vislearning/classifier/vclassifier/VCOneVsAll.h"
  8. using namespace std;
  9. using namespace NICE;
  10. using namespace OBJREC;
  11. VCOneVsAll::VCOneVsAll ( const Config *_conf, const VecClassifier *_prototype )
  12. : VecClassifier ( _conf ), prototype(_prototype)
  13. {
  14. }
  15. VCOneVsAll::~VCOneVsAll()
  16. {
  17. }
  18. ClassificationResult VCOneVsAll::classify ( const NICE::Vector & x ) const
  19. {
  20. FullVector scores ( maxClassNo+1 );
  21. vector<bool> exists(maxClassNo+1, false);
  22. scores.set(0.0);
  23. double minval = numeric_limits<double>::max();
  24. for ( vector< pair<int, VecClassifier *> >::const_iterator i =
  25. classifiers.begin(); i != classifiers.end(); i++ )
  26. {
  27. int classno = i->first;
  28. exists[classno] = true;
  29. VecClassifier *classifier = i->second;
  30. ClassificationResult r = classifier->classify(x);
  31. scores[classno] += r.scores[1];
  32. minval = std::min(minval, scores[classno]);
  33. }
  34. for(int i = 0; i <= maxClassNo; i++)
  35. {
  36. if(!exists[i])
  37. {
  38. scores[i] = minval-numeric_limits<double>::epsilon();
  39. }
  40. }
  41. return ClassificationResult ( scores.maxElement(), scores );
  42. }
  43. void VCOneVsAll::teach ( const LabeledSetVector & _teachSet )
  44. {
  45. if ( _teachSet.count() <= 0 )
  46. fthrow(Exception, "Number of training examples is zero!\n");
  47. maxClassNo = _teachSet.getMaxClassno();
  48. classifiers.clear();
  49. for ( int i = 0 ; i <= maxClassNo ; i++ )
  50. {
  51. LabeledSetVector binarySubSet (true);
  52. LabeledSetVector::const_iterator exiv = _teachSet.find(i);
  53. if ( exiv == _teachSet.end() )
  54. {
  55. // a test example might be classified as this class
  56. // if we do not use probability scores
  57. cerr << "Class " << i << " does not have any training examples; skipping training." << endl;
  58. continue;
  59. }
  60. int poscount = _teachSet.count(i);
  61. int negcount = _teachSet.count() - poscount;
  62. int mincount = std::min(poscount, negcount);
  63. int c = 0;
  64. for ( vector<Vector *>::const_iterator exi = exiv->second.begin();
  65. exi != exiv->second.end(); exi++, c++ )
  66. {
  67. binarySubSet.add_reference ( 1, *exi );
  68. if( c >= mincount)
  69. break;
  70. }
  71. c = 0;
  72. for ( LabeledSetVector::const_iterator exjv = _teachSet.begin();
  73. exjv != _teachSet.end(); exjv++ , c++)
  74. {
  75. if ( exjv == exiv ) continue;
  76. for ( vector<Vector *>::const_iterator exj = exjv->second.begin();
  77. exj != exjv->second.end(); exj++ )
  78. binarySubSet.add_reference ( 0, *exj );
  79. if( c >= mincount)
  80. break;
  81. }
  82. VecClassifier *classifier;
  83. classifier = prototype->clone();
  84. fprintf (stderr, "Training classifier: class %d <-> remainder\n", i );
  85. classifier->teach ( binarySubSet );
  86. classifier->finishTeaching();
  87. classifiers.push_back ( pair<int, VecClassifier*> (i, classifier) );
  88. }
  89. }
  90. void VCOneVsAll::finishTeaching()
  91. {
  92. }
  93. VecClassifier *VCOneVsAll::clone(void) const
  94. {
  95. VCOneVsAll *classifier = new VCOneVsAll( *this );
  96. return classifier;
  97. }
  98. VCOneVsAll::VCOneVsAll( const VCOneVsAll &vcova ): VecClassifier()
  99. {
  100. prototype = vcova.prototype->clone();
  101. for(int i = 0; i < (int)vcova.classifiers.size(); i++)
  102. {
  103. classifiers.push_back(pair<int, VecClassifier*>(vcova.classifiers[i].first,vcova.classifiers[i].second->clone()));
  104. }
  105. }