testClassifierGMM.cpp 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. /**
  2. * @file testClassifier.cpp
  3. * @brief main program for classifier evaluation
  4. * @author Erik Rodner
  5. * @date 2007-10-12
  6. */
  7. #include <fstream>
  8. #include <iostream>
  9. #include <vislearning/cbaselib/MultiDataset.h>
  10. #include "vislearning/classifier/genericClassifierSelection.h"
  11. #include <vislearning/cbaselib/ClassificationResults.h>
  12. #include <vislearning/cbaselib/MutualInformation.h>
  13. #include "vislearning/classifier/classifierbase/FeaturePoolClassifier.h"
  14. #include <vislearning/classifier/fpclassifier/randomforest/FPCRandomForestTransfer.h>
  15. #include <vislearning/classifier/classifierinterfaces/VCFeaturePool.h>
  16. #include "core/basics/Config.h"
  17. #include <vislearning/baselib/Preprocess.h>
  18. #include <core/basics/StringTools.h>
  19. #include "vislearning/math/cluster/GMM.h"
  20. #undef DEBUG
  21. using namespace OBJREC;
  22. using namespace NICE;
  23. using namespace std;
  24. void binarizeVector( NICE::Vector & xout, const NICE::Vector & x, const NICE::Vector & thresholds )
  25. {
  26. xout.resize( x.size() );
  27. for ( size_t i = 0 ; i < x.size() ; i++ )
  28. if ( fabs( x[i] ) > thresholds[i] )
  29. xout[i] = 1.0;
  30. else
  31. xout[i] = 0.0;
  32. }
  33. void binarizeSet( LabeledSetVector & dst, const LabeledSetVector & src, const NICE::Vector & thresholds )
  34. {
  35. LOOP_ALL( src )
  36. {
  37. EACH( classno, x );
  38. NICE::Vector dstv;
  39. binarizeVector( dstv, x, thresholds );
  40. dst.add( classno, dstv );
  41. }
  42. }
  43. int main( int argc, char **argv )
  44. {
  45. fprintf( stderr, "testClassifier: init\n" );
  46. std::set_terminate( __gnu_cxx::__verbose_terminate_handler );
  47. Config conf( argc, argv );
  48. string wekafile = conf.gS( "main", "weka", "" );
  49. string trainfn = conf.gS( "main", "train", "train.vec" );
  50. string testfn = conf.gS( "main", "test", "test.vec" );
  51. int format = conf.gI( "main", "format", 0 );
  52. bool binarize = conf.gB( "main", "binarize", false );
  53. int wekaclass = conf.gI( "main", "wekaclass", 1 );
  54. string classifier_cache = conf.gS( "main", "classifiercache", "" );
  55. string classifier_cache_in = conf.gS( "main", "classifierin", "" );
  56. int numRuns = conf.gI( "main", "runs", 1 );
  57. // classno:text,classno:text,...
  58. string classes = conf.gS( "main", "classes", "" );
  59. int classesnb = conf.gI( "main", "classes", 0 );
  60. string classesconf = conf.gS( "main", "classesconf", "" );
  61. fprintf( stderr, "testClassifier: reading config\n" );
  62. Preprocess::Init( &conf );
  63. fprintf( stderr, "testClassifier: reading multi dataset\n" );
  64. int testMaxClassNo;
  65. int trainMaxClassNo;
  66. ClassNames *classNames;
  67. if ( classes.size() == 0 && classesnb != 0 )
  68. {
  69. classNames = new ClassNames();
  70. for ( int classno = 0 ; classno < classesnb ; classno++ )
  71. {
  72. classNames->addClass( classno, StringTools::convertToString<int> ( classno ), StringTools::convertToString<int> ( classno ) );
  73. }
  74. trainMaxClassNo = classNames->getMaxClassno();
  75. testMaxClassNo = trainMaxClassNo;
  76. }
  77. else
  78. if ( classes.size() > 0 )
  79. {
  80. classNames = new ClassNames();
  81. vector<string> classes_sub;
  82. StringTools::split( string( classes ), ',', classes_sub );
  83. for ( vector<string>::const_iterator i = classes_sub.begin();
  84. i != classes_sub.end(); i++ )
  85. {
  86. vector<string> desc;
  87. StringTools::split( *i, ':', desc );
  88. if ( desc.size() != 2 )
  89. break;
  90. int classno = StringTools::convert<int> ( desc[0] );
  91. classNames->addClass( classno, desc[1], desc[1] );
  92. }
  93. trainMaxClassNo = classNames->getMaxClassno();
  94. testMaxClassNo = trainMaxClassNo;
  95. classNames->store( cout );
  96. }
  97. else if ( classesconf.size() > 0 ) {
  98. classNames = new ClassNames();
  99. Config cConf( classesconf );
  100. classNames->readFromConfig( cConf, "*" );
  101. trainMaxClassNo = classNames->getMaxClassno();
  102. testMaxClassNo = trainMaxClassNo;
  103. }
  104. else
  105. {
  106. MultiDataset md( &conf );
  107. classNames = new ClassNames( md.getClassNames( "train" ), "*" );
  108. testMaxClassNo = md.getClassNames( "test" ).getMaxClassno();
  109. trainMaxClassNo = md.getClassNames( "train" ).getMaxClassno();
  110. }
  111. LabeledSetVector train;
  112. if ( classifier_cache_in.size() <= 0 )
  113. {
  114. fprintf( stderr, "testClassifier: Reading training dataset from %s\n", trainfn.c_str() );
  115. train.read( trainfn, format );
  116. train.printInformation();
  117. } else {
  118. fprintf( stderr, "testClassifier: skipping training set %s\n", trainfn.c_str() );
  119. }
  120. LabeledSetVector test;
  121. fprintf( stderr, "testClassifier: Reading test dataset from %s\n", testfn.c_str() );
  122. test.read( testfn, format );
  123. GMM *gmm = NULL;
  124. int nbgmm = conf.gI( "main", "gmm", 0 );
  125. if ( nbgmm > 0 )
  126. {
  127. gmm = new GMM( &conf, nbgmm );
  128. VVector vset;
  129. Vector l;
  130. train.getFlatRepresentation( vset, l );
  131. gmm->computeMixture( vset );
  132. map<int, vector<NICE::Vector *> >::iterator iter;
  133. for ( iter = train.begin(); iter != train.end(); ++iter )
  134. {
  135. for ( uint i = 0; i < iter->second.size(); ++i )
  136. {
  137. gmm->getProbs( *( iter->second[i] ), *( iter->second[i] ) );
  138. }
  139. }
  140. for ( iter = test.begin(); iter != test.end(); ++iter )
  141. {
  142. for ( uint i = 0; i < iter->second.size(); ++i )
  143. {
  144. gmm->getProbs( *( iter->second[i] ), *( iter->second[i] ) );
  145. }
  146. }
  147. }
  148. ClassificationResults cresults;
  149. for ( int runs = 0 ; runs < numRuns ; runs++ ) {
  150. VecClassifier *vec_classifier = NULL;
  151. if ( conf.gS( "main", "classifier" ) == "random_forest_transfer" )
  152. {
  153. FeaturePoolClassifier *fpc = new FPCRandomForestTransfer( &conf, classNames );
  154. vec_classifier = new VCFeaturePool( &conf, fpc );
  155. } else {
  156. vec_classifier = GenericClassifierSelection::selectVecClassifier( &conf, "main" );
  157. }
  158. NICE::Vector thresholds;
  159. if ( classifier_cache_in.size() <= 0 )
  160. {
  161. if ( binarize ) {
  162. LabeledSetVector trainbin;
  163. NICE::Vector mis;
  164. MutualInformation mi;
  165. fprintf( stderr, "testClassifier: computing mutual information\n" );
  166. mi.computeThresholdsOverall( train, thresholds, mis );
  167. fprintf( stderr, "testClassifier: done!\n" );
  168. binarizeSet( trainbin, train, thresholds );
  169. vec_classifier->teach( trainbin );
  170. } else {
  171. vec_classifier->teach( train );
  172. }
  173. vec_classifier->finishTeaching();
  174. if ( classifier_cache.size() > 0 )
  175. vec_classifier->save( classifier_cache );
  176. } else {
  177. vec_classifier->setMaxClassNo( classNames->getMaxClassno() );
  178. vec_classifier->read( classifier_cache_in );
  179. }
  180. ProgressBar pb( "Classification" );
  181. pb.show();
  182. std::vector<int> count( testMaxClassNo + 1, 0 );
  183. std::vector<int> correct( testMaxClassNo + 1, 0 );
  184. MatrixT<int> confusionMatrix( testMaxClassNo + 1, trainMaxClassNo + 1, 0 );
  185. int n = test.count();
  186. LOOP_ALL( test )
  187. {
  188. EACH( classno, v );
  189. pb.update( n );
  190. fprintf( stderr, "\tclassification\n" );
  191. ClassificationResult r;
  192. if ( binarize )
  193. {
  194. NICE::Vector vout;
  195. binarizeVector( vout, v, thresholds );
  196. r = vec_classifier->classify( vout );
  197. } else {
  198. r = vec_classifier->classify( v );
  199. }
  200. r.classno_groundtruth = classno;
  201. r.classname = classNames->text( r.classno );
  202. #ifdef DEBUG
  203. if ( r.classno == classno )
  204. fprintf( stderr, "+ classification %d (\"%s\") <-> %d (\"%s\") score=%f\n", classno,
  205. classNames->text( classno ).c_str(), r.classno, r.classname.c_str(), r.scores[r.classno] );
  206. else
  207. fprintf( stderr, "- classification %d (\"%s\") <-> %d (\"%s\") score=%f\n", classno,
  208. classNames->text( classno ).c_str(), r.classno, r.classname.c_str(), r.scores[r.classno] );
  209. #endif
  210. r.scores.store( cerr );
  211. if ( r.classno >= 0 )
  212. {
  213. if ( classno == r.classno ) correct[classno]++;
  214. count[classno]++;
  215. if ( r.ok() ) {
  216. confusionMatrix( classno, r.classno )++;
  217. }
  218. cresults.push_back( r );
  219. }
  220. }
  221. pb.hide();
  222. if ( wekafile.size() > 0 )
  223. {
  224. string wekafile_s = wekafile;
  225. if ( numRuns > 1 )
  226. wekafile_s = wekafile_s + "." + StringTools::convertToString<int>( runs ) + ".txt";
  227. cresults.writeWEKA( wekafile_s, wekaclass );
  228. }
  229. int count_total = 0;
  230. int correct_total = 0;
  231. int classes_tested = 0;
  232. double avg_recognition = 0.0;
  233. for ( size_t classno = 0; classno < correct.size(); classno++ )
  234. {
  235. if ( count[classno] == 0 ) {
  236. fprintf( stdout, "class %d not tested !!\n", ( int )classno );
  237. } else {
  238. fprintf( stdout, "classification result class %d (\"%s\") : %5.2f %%\n",
  239. ( int )classno, classNames->text( classno ).c_str(), correct[classno]*100.0 / count[classno] );
  240. avg_recognition += correct[classno] / ( double )count[classno];
  241. classes_tested++;
  242. }
  243. count_total += count[classno];
  244. correct_total += correct[classno];
  245. }
  246. avg_recognition /= classes_tested;
  247. fprintf( stdout, "overall recognition rate : %-5.3f %%\n", correct_total*100.0 / count_total );
  248. fprintf( stdout, "average recognition rate : %-5.3f %%\n", avg_recognition*100 );
  249. fprintf( stdout, "total:%d misclassified:%d\n", count_total, count_total - correct_total );
  250. int max_count = *( max_element( count.begin(), count.end() ) );
  251. fprintf( stdout, "no of classes : %d\n", classNames->numClasses() );
  252. fprintf( stdout, "lower bound 1 : %f\n", 100.0 / ( classNames->numClasses() ) );
  253. fprintf( stdout, "lower bound 2 : %f\n", max_count * 100.0 / ( double ) count_total );
  254. cout << confusionMatrix << endl;
  255. delete vec_classifier;
  256. }
  257. delete classNames;
  258. return 0;
  259. }