testClassifierGMM.cpp 9.0 KB


  1. /**
  2. * @file testClassifier.cpp
  3. * @brief main program for classifier evaluation
  4. * @author Erik Rodner
  5. * @date 2007-10-12
  6. */
  7. #include <fstream>
  8. #include <iostream>
  9. #include <vislearning/cbaselib/MultiDataset.h>
  10. #include <objrec/iclassifier/icgeneric/CSGeneric.h>
  11. #include <vislearning/cbaselib/ClassificationResults.h>
  12. #include <objrec/iclassifier/codebook/MutualInformation.h>
  13. #include "vislearning/classifier/classifierbase/FeaturePoolClassifier.h"
  14. #include <vislearning/classifier/fpclassifier/randomforest/FPCRandomForestTransfer.h>
  15. #include <vislearning/classifier/classifierinterfaces/VCFeaturePool.h>
  16. #include "core/basics/Config.h"
  17. #include <vislearning/baselib/Preprocess.h>
  18. #include <core/basics/StringTools.h>
  19. #include "vislearning/math/cluster/GMM.h"
  20. #undef DEBUG
  21. using namespace OBJREC;
  22. using namespace NICE;
  23. using namespace std;
  24. void binarizeVector ( NICE::Vector & xout, const NICE::Vector & x, const NICE::Vector & thresholds )
  25. {
  26. xout.resize(x.size());
  27. for ( size_t i = 0 ; i < x.size() ; i++ )
  28. if ( fabs(x[i]) > thresholds[i] )
  29. xout[i] = 1.0;
  30. else
  31. xout[i] = 0.0;
  32. }
  33. void binarizeSet ( LabeledSetVector & dst, const LabeledSetVector & src, const NICE::Vector & thresholds )
  34. {
  35. LOOP_ALL(src)
  36. {
  37. EACH(classno,x);
  38. NICE::Vector dstv;
  39. binarizeVector ( dstv, x, thresholds );
  40. dst.add ( classno, dstv );
  41. }
  42. }
  43. int main (int argc, char **argv)
  44. {
  45. fprintf (stderr, "testClassifier: init\n");
  46. std::set_terminate(__gnu_cxx::__verbose_terminate_handler);
  47. Config conf ( argc, argv );
  48. string wekafile = conf.gS("main", "weka", "");
  49. string trainfn = conf.gS("main", "train", "train.vec");
  50. string testfn = conf.gS("main", "test", "test.vec");
  51. int format = conf.gI("main", "format", 0 );
  52. bool binarize = conf.gB("main", "binarize", false );
  53. int wekaclass = conf.gI("main", "wekaclass", 1 );
  54. string classifier_cache = conf.gS("main", "classifiercache", "");
  55. string classifier_cache_in = conf.gS("main", "classifierin", "");
  56. int numRuns = conf.gI("main", "runs", 1);
  57. // classno:text,classno:text,...
  58. string classes = conf.gS("main", "classes", "");
  59. int classesnb = conf.gI("main", "classes", 0);
  60. string classesconf = conf.gS("main", "classesconf", "");
  61. fprintf (stderr, "testClassifier: reading config\n");
  62. Preprocess::Init ( &conf );
  63. fprintf (stderr, "testClassifier: reading multi dataset\n");
  64. int testMaxClassNo;
  65. int trainMaxClassNo;
  66. ClassNames *classNames;
  67. if(classes.size() == 0 && classesnb != 0)
  68. {
  69. classNames = new ClassNames ();
  70. for ( int classno = 0 ; classno < classesnb ; classno++ )
  71. {
  72. classNames->addClass ( classno, StringTools::convertToString<int> ( classno ), StringTools::convertToString<int> (classno) );
  73. }
  74. trainMaxClassNo = classNames->getMaxClassno();
  75. testMaxClassNo = trainMaxClassNo;
  76. }
  77. else
  78. if ( classes.size() > 0 )
  79. {
  80. classNames = new ClassNames ();
  81. vector<string> classes_sub;
  82. StringTools::split ( string(classes), ',', classes_sub );
  83. for ( vector<string>::const_iterator i = classes_sub.begin();
  84. i != classes_sub.end(); i++ )
  85. {
  86. vector<string> desc;
  87. StringTools::split ( *i, ':', desc);
  88. if ( desc.size() != 2 )
  89. break;
  90. int classno = StringTools::convert<int> ( desc[0] );
  91. classNames->addClass ( classno, desc[1], desc[1] );
  92. }
  93. trainMaxClassNo = classNames->getMaxClassno();
  94. testMaxClassNo = trainMaxClassNo;
  95. classNames->store(cout);
  96. }
  97. else if ( classesconf.size() > 0 ) {
  98. classNames = new ClassNames ();
  99. Config cConf ( classesconf );
  100. classNames->readFromConfig ( cConf, "*" );
  101. trainMaxClassNo = classNames->getMaxClassno();
  102. testMaxClassNo = trainMaxClassNo;
  103. }
  104. else
  105. {
  106. MultiDataset md ( &conf );
  107. classNames = new ClassNames ( md.getClassNames("train"), "*" );
  108. testMaxClassNo = md.getClassNames("test").getMaxClassno();
  109. trainMaxClassNo = md.getClassNames("train").getMaxClassno();
  110. }
  111. LabeledSetVector train;
  112. if ( classifier_cache_in.size() <= 0 )
  113. {
  114. fprintf (stderr, "testClassifier: Reading training dataset from %s\n", trainfn.c_str() );
  115. train.read ( trainfn, format );
  116. train.printInformation();
  117. } else {
  118. fprintf (stderr, "testClassifier: skipping training set %s\n", trainfn.c_str() );
  119. }
  120. LabeledSetVector test;
  121. fprintf (stderr, "testClassifier: Reading test dataset from %s\n", testfn.c_str() );
  122. test.read ( testfn, format );
  123. GMM *gmm = NULL;
  124. int nbgmm = conf.gI("main", "gmm", 0);
  125. if(nbgmm > 0)
  126. {
  127. gmm = new GMM(&conf, nbgmm);
  128. VVector vset;
  129. Vector l;
  130. train.getFlatRepresentation(vset,l);
  131. gmm->computeMixture(vset);
  132. map<int, vector<NICE::Vector *> >::iterator iter;
  133. for( iter = train.begin(); iter != train.end(); ++iter )
  134. {
  135. for(uint i = 0; i < iter->second.size(); ++i)
  136. {
  137. gmm->getProbs(*(iter->second[i]),*(iter->second[i]));
  138. }
  139. }
  140. for( iter = test.begin(); iter != test.end(); ++iter )
  141. {
  142. for(uint i = 0; i < iter->second.size(); ++i)
  143. {
  144. gmm->getProbs(*(iter->second[i]),*(iter->second[i]));
  145. }
  146. }
  147. }
  148. ClassificationResults cresults;
  149. for (int runs = 0 ; runs < numRuns ; runs++ ) {
  150. VecClassifier *vec_classifier = NULL;
  151. if ( conf.gS("main", "classifier") == "random_forest_transfer" )
  152. {
  153. FeaturePoolClassifier *fpc = new FPCRandomForestTransfer ( &conf, classNames );
  154. vec_classifier = new VCFeaturePool ( &conf, fpc );
  155. } else {
  156. vec_classifier = CSGeneric::selectVecClassifier ( &conf, "main" );
  157. }
  158. NICE::Vector thresholds;
  159. if ( classifier_cache_in.size() <= 0 )
  160. {
  161. if ( binarize ) {
  162. LabeledSetVector trainbin;
  163. NICE::Vector mis;
  164. MutualInformation mi;
  165. fprintf (stderr, "testClassifier: computing mutual information\n");
  166. mi.computeThresholdsOverall ( train, thresholds, mis );
  167. fprintf (stderr, "testClassifier: done!\n");
  168. binarizeSet ( trainbin, train, thresholds );
  169. vec_classifier->teach ( trainbin );
  170. } else {
  171. vec_classifier->teach ( train );
  172. }
  173. vec_classifier->finishTeaching();
  174. if ( classifier_cache.size() > 0 )
  175. vec_classifier->save ( classifier_cache );
  176. } else {
  177. vec_classifier->setMaxClassNo ( classNames->getMaxClassno() );
  178. vec_classifier->read ( classifier_cache_in );
  179. }
  180. ProgressBar pb ("Classification");
  181. pb.show();
  182. std::vector<int> count ( testMaxClassNo+1, 0 );
  183. std::vector<int> correct ( testMaxClassNo+1, 0 );
  184. MatrixT<int> confusionMatrix ( testMaxClassNo+1, trainMaxClassNo+1, 0 );
  185. int n = test.count();
  186. LOOP_ALL(test)
  187. {
  188. EACH(classno,v);
  189. pb.update ( n );
  190. fprintf (stderr, "\tclassification\n" );
  191. ClassificationResult r;
  192. if ( binarize )
  193. {
  194. NICE::Vector vout;
  195. binarizeVector ( vout, v, thresholds );
  196. r = vec_classifier->classify ( vout );
  197. } else {
  198. r = vec_classifier->classify ( v );
  199. }
  200. r.classno_groundtruth = classno;
  201. r.classname = classNames->text( r.classno );
  202. #ifdef DEBUG
  203. if ( r.classno == classno )
  204. fprintf (stderr, "+ classification %d (\"%s\") <-> %d (\"%s\") score=%f\n", classno,
  205. classNames->text(classno).c_str(), r.classno, r.classname.c_str(), r.scores[r.classno]);
  206. else
  207. fprintf (stderr, "- classification %d (\"%s\") <-> %d (\"%s\") score=%f\n", classno,
  208. classNames->text(classno).c_str(), r.classno, r.classname.c_str(), r.scores[r.classno] );
  209. #endif
  210. r.scores.store ( cerr );
  211. if ( r.classno >= 0 )
  212. {
  213. if ( classno == r.classno ) correct[classno]++;
  214. count[classno]++;
  215. if ( r.ok() ) {
  216. confusionMatrix(classno, r.classno)++;
  217. }
  218. cresults.push_back ( r );
  219. }
  220. }
  221. pb.hide();
  222. if ( wekafile.size() > 0 )
  223. {
  224. string wekafile_s = wekafile;
  225. if ( numRuns > 1 )
  226. wekafile_s = wekafile_s + "." + StringTools::convertToString<int>(runs) + ".txt";
  227. cresults.writeWEKA ( wekafile_s, wekaclass );
  228. }
  229. int count_total = 0;
  230. int correct_total = 0;
  231. int classes_tested = 0;
  232. double avg_recognition = 0.0;
  233. for ( size_t classno = 0; classno < correct.size(); classno++ )
  234. {
  235. if ( count[classno] == 0 ) {
  236. fprintf (stdout, "class %d not tested !!\n", (int)classno);
  237. } else {
  238. fprintf (stdout, "classification result class %d (\"%s\") : %5.2f %%\n",
  239. (int)classno, classNames->text(classno).c_str(), correct[classno]*100.0/count[classno] );
  240. avg_recognition += correct[classno]/(double)count[classno];
  241. classes_tested++;
  242. }
  243. count_total += count[classno];
  244. correct_total += correct[classno];
  245. }
  246. avg_recognition /= classes_tested;
  247. fprintf (stdout, "overall recognition rate : %-5.3f %%\n", correct_total*100.0/count_total );
  248. fprintf (stdout, "average recognition rate : %-5.3f %%\n", avg_recognition*100 );
  249. fprintf (stdout, "total:%d misclassified:%d\n", count_total, count_total - correct_total );
  250. int max_count = *(max_element( count.begin(), count.end() ));
  251. fprintf (stdout, "no of classes : %d\n", classNames->numClasses() );
  252. fprintf (stdout, "lower bound 1 : %f\n", 100.0/(classNames->numClasses()));
  253. fprintf (stdout, "lower bound 2 : %f\n", max_count * 100.0 / (double) count_total);
  254. cout << confusionMatrix << endl;
  255. delete vec_classifier;
  256. }
  257. delete classNames;
  258. return 0;
  259. }