VCOneVsOne.cpp 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. /**
  2. * @file VCOneVsOne.cpp
  3. * @author Erik Rodner
  4. * @date 10/25/2007
  5. */
  6. #include <iostream>
  7. #include "core/basics/StringTools.h"
  8. #include "vislearning/classifier/vclassifier/VCOneVsOne.h"
  9. using namespace OBJREC;
  10. using namespace std;
  11. using namespace NICE;
  12. VCOneVsOne::VCOneVsOne ( const Config *conf, VecClassifier *_prototype )
  13. : VecClassifier ( conf ), prototype ( _prototype )
  14. {
  15. use_weighted_voting = conf->gB("VCOneVsOne", "use_weighted_voting", false);
  16. }
  17. VCOneVsOne::~VCOneVsOne()
  18. {
  19. }
  20. ClassificationResult VCOneVsOne::classify ( const NICE::Vector & x ) const
  21. {
  22. FullVector scores ( maxClassNo+1 );
  23. scores.set(0);
  24. for ( vector< triplet<int, int, VecClassifier *> >::const_iterator i =
  25. classifiers.begin(); i != classifiers.end(); i++ )
  26. {
  27. VecClassifier *classifier = i->third;
  28. ClassificationResult r = classifier->classify(x);
  29. int classi = i->first;
  30. int classj = i->second;
  31. if ( use_weighted_voting )
  32. {
  33. if ( r.classno == 0 )
  34. scores[classi]-=r.scores[1];
  35. else
  36. scores[classj]+=r.scores[1];
  37. } else {
  38. if ( r.classno == 0 )
  39. scores[classi]++;
  40. else
  41. scores[classj]++;
  42. }
  43. }
  44. scores.normalize();
  45. return ClassificationResult ( scores.maxElement(), scores );
  46. }
  47. void VCOneVsOne::teach ( const LabeledSetVector & _teachSet )
  48. {
  49. maxClassNo = _teachSet.getMaxClassno();
  50. classifiers.clear();
  51. assert ( maxClassNo+1 == _teachSet.numClasses() );
  52. for ( int i = 0 ; i <= maxClassNo ; i++ )
  53. {
  54. for ( int j = i+1 ; j <= maxClassNo ; j++ )
  55. {
  56. LabeledSetVector binarySubSet (true);
  57. LabeledSetVector::const_iterator exiv = _teachSet.find(i);
  58. for ( vector<Vector *>::const_iterator exi = exiv->second.begin();
  59. exi != exiv->second.end(); exi++ )
  60. binarySubSet.add_reference ( 0, *exi );
  61. LabeledSetVector::const_iterator exjv = _teachSet.find(j);
  62. for ( vector<Vector *>::const_iterator exj = exjv->second.begin();
  63. exj != exjv->second.end(); exj++ )
  64. binarySubSet.add_reference ( 1, *exj );
  65. VecClassifier *classifier;
  66. classifier = prototype->clone();
  67. fprintf (stderr, "Training classifier: class %d <-> class %d\n", i, j );
  68. classifier->teach ( binarySubSet );
  69. classifier->finishTeaching();
  70. classifiers.push_back ( triplet<int, int, VecClassifier*> (i,j,classifier) );
  71. }
  72. }
  73. }
  74. void VCOneVsOne::finishTeaching()
  75. {
  76. }
  77. void VCOneVsOne::read (const string& s, int format)
  78. {
  79. ifstream ifs ( s.c_str(), ios::in );
  80. ifs >> maxClassNo;
  81. ifs.close();
  82. for ( int i = 0 ; i <= maxClassNo ; i++ )
  83. {
  84. for ( int j = i+1 ; j <= maxClassNo ; j++ )
  85. {
  86. VecClassifier *classifier;
  87. classifier = prototype->clone();
  88. string classifiercache = s + ".onevsone." + StringTools::convertToString<int> ( i ) + "." + StringTools::convertToString<int> ( j );
  89. fprintf (stderr, "Loading classifier: class %d <-> class %d\n", i, j );
  90. classifier->read ( classifiercache, format );
  91. classifiers.push_back ( triplet<int, int, VecClassifier*> (i,j,classifier) );
  92. }
  93. }
  94. }
  95. void VCOneVsOne::save (const string& s, int format) const
  96. {
  97. ofstream ofs ( s.c_str(), ios::out );
  98. ofs << maxClassNo << endl;
  99. ofs.close();
  100. for ( vector< triplet<int, int, VecClassifier *> >::const_iterator i =
  101. classifiers.begin(); i != classifiers.end(); i++ )
  102. {
  103. int classi = i->first;
  104. int classj = i->second;
  105. VecClassifier *classifier = i->third;
  106. string classifiercache = s + ".onevsone." + StringTools::convertToString<int> ( classi ) + "." + StringTools::convertToString<int> ( classj );
  107. classifier->save ( classifiercache, format );
  108. }
  109. }
  110. void VCOneVsOne::store ( std::ostream & os, int format ) const
  111. {
  112. fprintf (stderr, "VCOneVsOne: unable to write to stream! please use read()\n");
  113. }
  114. void VCOneVsOne::restore ( std::istream & is, int format )
  115. {
  116. fprintf (stderr, "VCOneVsOne: unable to read from stream! please use save()\n");
  117. exit (-1);
  118. }