testClassifierGMM.cpp 9.0 KB


  1. /**
  2. * @file testClassifier.cpp
  3. * @brief main program for classifier evaluation
  4. * @author Erik Rodner
  5. * @date 2007-10-12
  6. */
  7. #include <objrec/nice_nonvis.h>
  8. #include <fstream>
  9. #include <iostream>
  10. #include <objrec/cbaselib/MultiDataset.h>
  11. #include <objrec/iclassifier/icgeneric/CSGeneric.h>
  12. #include <objrec/cbaselib/ClassificationResults.h>
  13. #include <objrec/iclassifier/codebook/MutualInformation.h>
  14. #include "objrec/classifier/classifierbase/FeaturePoolClassifier.h"
  15. #include <objrec/classifier/fpclassifier/randomforest/FPCRandomForestTransfer.h>
  16. #include <objrec/classifier/classifierinterfaces/VCFeaturePool.h>
  17. #include <objrec/baselib/Config.h>
  18. #include <objrec/baselib/Preprocess.h>
  19. #include <objrec/baselib/StringTools.h>
  20. #include "objrec/math/cluster/GMM.h"
  21. #undef DEBUG
  22. using namespace OBJREC;
  23. using namespace NICE;
  24. using namespace std;
  25. void binarizeVector ( NICE::Vector & xout, const NICE::Vector & x, const NICE::Vector & thresholds )
  26. {
  27. xout.resize(x.size());
  28. for ( size_t i = 0 ; i < x.size() ; i++ )
  29. if ( fabs(x[i]) > thresholds[i] )
  30. xout[i] = 1.0;
  31. else
  32. xout[i] = 0.0;
  33. }
  34. void binarizeSet ( LabeledSetVector & dst, const LabeledSetVector & src, const NICE::Vector & thresholds )
  35. {
  36. LOOP_ALL(src)
  37. {
  38. EACH(classno,x);
  39. NICE::Vector dstv;
  40. binarizeVector ( dstv, x, thresholds );
  41. dst.add ( classno, dstv );
  42. }
  43. }
  44. int main (int argc, char **argv)
  45. {
  46. fprintf (stderr, "testClassifier: init\n");
  47. std::set_terminate(__gnu_cxx::__verbose_terminate_handler);
  48. Config conf ( argc, argv );
  49. string wekafile = conf.gS("main", "weka", "");
  50. string trainfn = conf.gS("main", "train", "train.vec");
  51. string testfn = conf.gS("main", "test", "test.vec");
  52. int format = conf.gI("main", "format", 0 );
  53. bool binarize = conf.gB("main", "binarize", false );
  54. int wekaclass = conf.gI("main", "wekaclass", 1 );
  55. string classifier_cache = conf.gS("main", "classifiercache", "");
  56. string classifier_cache_in = conf.gS("main", "classifierin", "");
  57. int numRuns = conf.gI("main", "runs", 1);
  58. // classno:text,classno:text,...
  59. string classes = conf.gS("main", "classes", "");
  60. int classesnb = conf.gI("main", "classes", 0);
  61. string classesconf = conf.gS("main", "classesconf", "");
  62. fprintf (stderr, "testClassifier: reading config\n");
  63. Preprocess::Init ( &conf );
  64. fprintf (stderr, "testClassifier: reading multi dataset\n");
  65. int testMaxClassNo;
  66. int trainMaxClassNo;
  67. ClassNames *classNames;
  68. if(classes.size() == 0 && classesnb != 0)
  69. {
  70. classNames = new ClassNames ();
  71. for ( int classno = 0 ; classno < classesnb ; classno++ )
  72. {
  73. classNames->addClass ( classno, StringTools::convertToString<int> ( classno ), StringTools::convertToString<int> (classno) );
  74. }
  75. trainMaxClassNo = classNames->getMaxClassno();
  76. testMaxClassNo = trainMaxClassNo;
  77. }
  78. else
  79. if ( classes.size() > 0 )
  80. {
  81. classNames = new ClassNames ();
  82. vector<string> classes_sub;
  83. StringTools::split ( string(classes), ',', classes_sub );
  84. for ( vector<string>::const_iterator i = classes_sub.begin();
  85. i != classes_sub.end(); i++ )
  86. {
  87. vector<string> desc;
  88. StringTools::split ( *i, ':', desc);
  89. if ( desc.size() != 2 )
  90. break;
  91. int classno = StringTools::convert<int> ( desc[0] );
  92. classNames->addClass ( classno, desc[1], desc[1] );
  93. }
  94. trainMaxClassNo = classNames->getMaxClassno();
  95. testMaxClassNo = trainMaxClassNo;
  96. classNames->store(cout);
  97. }
  98. else if ( classesconf.size() > 0 ) {
  99. classNames = new ClassNames ();
  100. Config cConf ( classesconf );
  101. classNames->readFromConfig ( cConf, "*" );
  102. trainMaxClassNo = classNames->getMaxClassno();
  103. testMaxClassNo = trainMaxClassNo;
  104. }
  105. else
  106. {
  107. MultiDataset md ( &conf );
  108. classNames = new ClassNames ( md.getClassNames("train"), "*" );
  109. testMaxClassNo = md.getClassNames("test").getMaxClassno();
  110. trainMaxClassNo = md.getClassNames("train").getMaxClassno();
  111. }
  112. LabeledSetVector train;
  113. if ( classifier_cache_in.size() <= 0 )
  114. {
  115. fprintf (stderr, "testClassifier: Reading training dataset from %s\n", trainfn.c_str() );
  116. train.read ( trainfn, format );
  117. train.printInformation();
  118. } else {
  119. fprintf (stderr, "testClassifier: skipping training set %s\n", trainfn.c_str() );
  120. }
  121. LabeledSetVector test;
  122. fprintf (stderr, "testClassifier: Reading test dataset from %s\n", testfn.c_str() );
  123. test.read ( testfn, format );
  124. GMM *gmm = NULL;
  125. int nbgmm = conf.gI("main", "gmm", 0);
  126. if(nbgmm > 0)
  127. {
  128. gmm = new GMM(&conf, nbgmm);
  129. VVector vset;
  130. Vector l;
  131. train.getFlatRepresentation(vset,l);
  132. gmm->computeMixture(vset);
  133. map<int, vector<NICE::Vector *> >::iterator iter;
  134. for( iter = train.begin(); iter != train.end(); ++iter )
  135. {
  136. for(uint i = 0; i < iter->second.size(); ++i)
  137. {
  138. gmm->getProbs(*(iter->second[i]),*(iter->second[i]));
  139. }
  140. }
  141. for( iter = test.begin(); iter != test.end(); ++iter )
  142. {
  143. for(uint i = 0; i < iter->second.size(); ++i)
  144. {
  145. gmm->getProbs(*(iter->second[i]),*(iter->second[i]));
  146. }
  147. }
  148. }
  149. ClassificationResults cresults;
  150. for (int runs = 0 ; runs < numRuns ; runs++ ) {
  151. VecClassifier *vec_classifier = NULL;
  152. if ( conf.gS("main", "classifier") == "random_forest_transfer" )
  153. {
  154. FeaturePoolClassifier *fpc = new FPCRandomForestTransfer ( &conf, classNames );
  155. vec_classifier = new VCFeaturePool ( &conf, fpc );
  156. } else {
  157. vec_classifier = CSGeneric::selectVecClassifier ( &conf, "main" );
  158. }
  159. NICE::Vector thresholds;
  160. if ( classifier_cache_in.size() <= 0 )
  161. {
  162. if ( binarize ) {
  163. LabeledSetVector trainbin;
  164. NICE::Vector mis;
  165. MutualInformation mi;
  166. fprintf (stderr, "testClassifier: computing mutual information\n");
  167. mi.computeThresholdsOverall ( train, thresholds, mis );
  168. fprintf (stderr, "testClassifier: done!\n");
  169. binarizeSet ( trainbin, train, thresholds );
  170. vec_classifier->teach ( trainbin );
  171. } else {
  172. vec_classifier->teach ( train );
  173. }
  174. vec_classifier->finishTeaching();
  175. if ( classifier_cache.size() > 0 )
  176. vec_classifier->save ( classifier_cache );
  177. } else {
  178. vec_classifier->setMaxClassNo ( classNames->getMaxClassno() );
  179. vec_classifier->read ( classifier_cache_in );
  180. }
  181. ProgressBar pb ("Classification");
  182. pb.show();
  183. std::vector<int> count ( testMaxClassNo+1, 0 );
  184. std::vector<int> correct ( testMaxClassNo+1, 0 );
  185. MatrixT<int> confusionMatrix ( testMaxClassNo+1, trainMaxClassNo+1, 0 );
  186. int n = test.count();
  187. LOOP_ALL(test)
  188. {
  189. EACH(classno,v);
  190. pb.update ( n );
  191. fprintf (stderr, "\tclassification\n" );
  192. ClassificationResult r;
  193. if ( binarize )
  194. {
  195. NICE::Vector vout;
  196. binarizeVector ( vout, v, thresholds );
  197. r = vec_classifier->classify ( vout );
  198. } else {
  199. r = vec_classifier->classify ( v );
  200. }
  201. r.classno_groundtruth = classno;
  202. r.classname = classNames->text( r.classno );
  203. #ifdef DEBUG
  204. if ( r.classno == classno )
  205. fprintf (stderr, "+ classification %d (\"%s\") <-> %d (\"%s\") score=%f\n", classno,
  206. classNames->text(classno).c_str(), r.classno, r.classname.c_str(), r.scores[r.classno]);
  207. else
  208. fprintf (stderr, "- classification %d (\"%s\") <-> %d (\"%s\") score=%f\n", classno,
  209. classNames->text(classno).c_str(), r.classno, r.classname.c_str(), r.scores[r.classno] );
  210. #endif
  211. r.scores.store ( cerr );
  212. if ( r.classno >= 0 )
  213. {
  214. if ( classno == r.classno ) correct[classno]++;
  215. count[classno]++;
  216. if ( r.ok() ) {
  217. confusionMatrix(classno, r.classno)++;
  218. }
  219. cresults.push_back ( r );
  220. }
  221. }
  222. pb.hide();
  223. if ( wekafile.size() > 0 )
  224. {
  225. string wekafile_s = wekafile;
  226. if ( numRuns > 1 )
  227. wekafile_s = wekafile_s + "." + StringTools::convertToString<int>(runs) + ".txt";
  228. cresults.writeWEKA ( wekafile_s, wekaclass );
  229. }
  230. int count_total = 0;
  231. int correct_total = 0;
  232. int classes_tested = 0;
  233. double avg_recognition = 0.0;
  234. for ( size_t classno = 0; classno < correct.size(); classno++ )
  235. {
  236. if ( count[classno] == 0 ) {
  237. fprintf (stdout, "class %d not tested !!\n", (int)classno);
  238. } else {
  239. fprintf (stdout, "classification result class %d (\"%s\") : %5.2f %%\n",
  240. (int)classno, classNames->text(classno).c_str(), correct[classno]*100.0/count[classno] );
  241. avg_recognition += correct[classno]/(double)count[classno];
  242. classes_tested++;
  243. }
  244. count_total += count[classno];
  245. correct_total += correct[classno];
  246. }
  247. avg_recognition /= classes_tested;
  248. fprintf (stdout, "overall recognition rate : %-5.3f %%\n", correct_total*100.0/count_total );
  249. fprintf (stdout, "average recognition rate : %-5.3f %%\n", avg_recognition*100 );
  250. fprintf (stdout, "total:%d misclassified:%d\n", count_total, count_total - correct_total );
  251. int max_count = *(max_element( count.begin(), count.end() ));
  252. fprintf (stdout, "no of classes : %d\n", classNames->numClasses() );
  253. fprintf (stdout, "lower bound 1 : %f\n", 100.0/(classNames->numClasses()));
  254. fprintf (stdout, "lower bound 2 : %f\n", max_count * 100.0 / (double) count_total);
  255. cout << confusionMatrix << endl;
  256. delete vec_classifier;
  257. }
  258. delete classNames;
  259. return 0;
  260. }