evaluateCompleteBoWPipeline.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. /**
  2. * @file evaluateCompleteBoWPipeline.cpp
  3. * @brief A complete BoW pipeline: feature extraction, codebook creation, vector quantization, classifier training, evaluation on separate test set
  4. * @author Alexander Freytag
  5. * @date 10-05-2013
  6. */
  7. //STL
  8. #include <iostream>
  9. #include <limits>
  10. //core -- basic stuff
  11. #include <core/basics/Config.h>
  12. #include <core/basics/ResourceStatistics.h>
  13. #include <core/basics/Timer.h>
  14. #include <core/image/Convert.h>
  15. #include <core/vector/VectorT.h>
  16. //vislearning -- basic stuff
  17. #include <vislearning/baselib/Globals.h>
  18. #include <vislearning/baselib/ICETools.h>
  19. #include <vislearning/cbaselib/MultiDataset.h>
  20. #include <vislearning/cbaselib/Example.h>
  21. #include <vislearning/cbaselib/ClassificationResult.h>
  22. #include <vislearning/cbaselib/ClassificationResults.h>
  23. //
  24. // vislearning -- classifier
  25. #include <vislearning/classifier/classifierbase/VecClassifier.h>
  26. #include <vislearning/classifier/genericClassifierSelection.h>
  27. //
  28. // vislearning -- BoW codebooks
  29. #include "vislearning/features/simplefeatures/CodebookPrototypes.h"
  30. #include "vislearning/features/simplefeatures/BoWFeatureConverter.h"
  31. //
  32. // vislearning -- local features
  33. // #include <vislearning/features/localfeatures/LFonHSG.h>
  34. // #include <vislearning/features/localfeatures/LFColorSande.h>
  35. // #include <vislearning/features/localfeatures/LFColorWeijer.h>
  36. // #include <vislearning/features/localfeatures/LFReadCache.h>
  37. // #include <vislearning/features/localfeatures/LFWriteCache.h>
  38. #include <vislearning/features/localfeatures/GenericLFSelection.h>
  39. //
  40. // vislearning -- clustering methods
  41. #include <vislearning/math/cluster/GenericClusterAlgorithmSelection.h>
  42. // #include <vislearning/math/cluster/ClusterAlgorithm.h>
  43. // #include <vislearning/math/cluster/RandomClustering.h>
  44. // #include <vislearning/math/cluster/KMeans.h>
  45. // #include <vislearning/math/cluster/KMedian.h>
  46. // #include <vislearning/math/cluster/GMM.h>
  47. //
  48. using namespace std;
  49. using namespace NICE;
  50. using namespace OBJREC;
  51. // LocalFeatureRepresentation * setFeatureExtractor( const Config * _conf )
  52. // {
  53. // LocalFeatureRepresentation * featureExtractor;
  54. //
  55. // //feature stuff
  56. // // which OpponentSIFT implementation to use {NICE, VANDESANDE}
  57. // std::string opSiftImpl;
  58. // opSiftImpl = _conf->gS ( "Descriptor", "implementation", "VANDESANDE" );
  59. // // read features?
  60. // bool readfeat;
  61. // readfeat = _conf->gB ( "Descriptor", "read", true );
  62. // // write features?
  63. // bool writefeat;
  64. // writefeat = _conf->gB ( "Descriptor", "write", true );
  65. //
  66. // // Welche Opponentsift Implementierung soll genutzt werden ?
  67. // LocalFeatureRepresentation *cSIFT = NULL;
  68. // LocalFeatureRepresentation *writeFeats = NULL;
  69. // LocalFeatureRepresentation *readFeats = NULL;
  70. // featureExtractor = NULL;
  71. // if ( opSiftImpl == "NICE" )
  72. // {
  73. // cSIFT = new OBJREC::LFonHSG ( _conf, "HSG" );
  74. // }
  75. // else if ( opSiftImpl == "VANDESANDE" )
  76. // {
  77. // cSIFT = new OBJREC::LFColorSande ( _conf, "LFColorSande" );
  78. // }
  79. // else
  80. // {
  81. // fthrow ( Exception, "feattype: %s not yet supported" << opSiftImpl );
  82. // }
  83. //
  84. // featureExtractor = cSIFT;
  85. //
  86. // if ( writefeat )
  87. // {
  88. // // write the features to a file, if there isn't any to read
  89. // writeFeats = new LFWriteCache ( _conf, cSIFT );
  90. // featureExtractor = writeFeats;
  91. // }
  92. //
  93. // if ( readfeat )
  94. // {
  95. // // read the features from a file
  96. // if ( writefeat )
  97. // {
  98. // readFeats = new LFReadCache ( _conf, writeFeats, -1 );
  99. // }
  100. // else
  101. // {
  102. // readFeats = new LFReadCache ( _conf, cSIFT, -1 );
  103. // }
  104. // featureExtractor = readFeats;
  105. // }
  106. //
  107. // //only set feature stuff to NULL, deletion of the underlying object is done in the destructor
  108. // if ( cSIFT != NULL )
  109. // cSIFT = NULL;
  110. // if ( writeFeats != NULL )
  111. // writeFeats = NULL;
  112. // if ( readFeats != NULL )
  113. // readFeats = NULL ;
  114. //
  115. // return featureExtractor;
  116. // }
  117. //
  118. // OBJREC::ClusterAlgorithm * setClusterAlgo( const Config * _conf )
  119. // {
  120. // std::string section ( "clusteringStuff" );
  121. // // define the initial number of clusters our codebook shall contain
  122. // int noClusters = _conf->gI(section, "noClusters", 10);
  123. //
  124. // // define the clustering algorithm to be used
  125. // std::string clusterAlgoString = _conf->gS(section, "clusterAlgo", "kmeans");
  126. //
  127. // OBJREC::ClusterAlgorithm * clusterAlgo;
  128. //
  129. // if (clusterAlgoString.compare("kmeans") == 0)
  130. // {
  131. // clusterAlgo = new OBJREC::KMeans(noClusters);
  132. // }
  133. // else if (clusterAlgoString.compare("kmedian") == 0)
  134. // {
  135. // clusterAlgo = new OBJREC::KMedian(noClusters);
  136. // }
  137. // else if (clusterAlgoString.compare("GMM") == 0)
  138. // {
  139. // clusterAlgo = new OBJREC::GMM( _conf, noClusters );
  140. // }
  141. // else if ( clusterAlgoString.compare("RandomClustering") == 0 )
  142. // {
  143. // clusterAlgo = new OBJREC::RandomClustering( _conf, section );
  144. // }
  145. // else
  146. // {
  147. // std::cerr << "Unknown cluster algorithm selected, use random clustering instead" << std::endl;
  148. // clusterAlgo = new OBJREC::RandomClustering( _conf, section );
  149. // }
  150. //
  151. // return clusterAlgo;
  152. // }
  153. /**
  154. a complete BoW pipeline
  155. possibly, we can make use of objrec/progs/testClassifier.cpp
  156. */
  157. int main( int argc, char **argv )
  158. {
  159. std::set_terminate( __gnu_cxx::__verbose_terminate_handler );
  160. Config * conf = new Config ( argc, argv );
  161. const bool writeClassificationResults = conf->gB( "main", "writeClassificationResults", true );
  162. const std::string resultsfile = conf->gS( "main", "resultsfile", "/tmp/results.txt" );
  163. ResourceStatistics rs;
  164. // ========================================================================
  165. // TRAINING STEP
  166. // ========================================================================
  167. MultiDataset md( conf );
  168. const LabeledSet *trainFiles = md["train"];
  169. //**********************************************
  170. //
  171. // FEATURE EXTRACTION FOR TRAINING IMAGES
  172. //
  173. //**********************************************
  174. std::cerr << "FEATURE EXTRACTION FOR TRAINING IMAGES" << std::endl;
  175. OBJREC::LocalFeatureRepresentation * featureExtractor = OBJREC::GenericLFSelection::selectLocalFeatureRep ( conf, "features", OBJREC::GenericLFSelection::TRAINING );
  176. // LocalFeatureRepresentation * featureExtractor = setFeatureExtractor( conf );
  177. //collect features in a single data structure
  178. NICE::VVector featuresFromAllTrainingImages;
  179. featuresFromAllTrainingImages.clear();
  180. //determine how many training images we actually use to easily allocate the correct amount of memory afterwards
  181. int numberOfTrainingImages ( 0 );
  182. for(LabeledSet::const_iterator classIt = trainFiles->begin() ; classIt != trainFiles->end() ; classIt++)
  183. {
  184. numberOfTrainingImages += classIt->second.size();
  185. std::cerr << "number of examples for this class: " << classIt->second.size() << std::endl;
  186. }
  187. //okay, this is redundant - but I see no way to do it more easy right now...
  188. std::vector<NICE::VVector> featuresOfImages ( numberOfTrainingImages );
  189. //this again is somehow redundant, but we need the labels lateron for easy access - change this to a better solution :)
  190. NICE::VectorT<int> labelsTrain ( numberOfTrainingImages, 0 );
  191. //TODO replace the nasty makro by a suitable for-loop to make it omp-ready (parallelization)
  192. int imgCnt ( 0 );
  193. // the corresponding nasty makro: LOOP_ALL_S( *trainFiles )
  194. for(LabeledSet::const_iterator classIt = trainFiles->begin() ; classIt != trainFiles->end() ; classIt++)
  195. {
  196. for ( std::vector<ImageInfo *>::const_iterator imgIt = classIt->second.begin();
  197. imgIt != classIt->second.end();
  198. imgIt++, imgCnt++
  199. )
  200. {
  201. // the corresponding nasty makro: EACH_INFO( classno, info );
  202. int classno ( classIt->first );
  203. const ImageInfo imgInfo = *(*imgIt);
  204. std::string filename = imgInfo.img();
  205. NICE::ColorImage img( filename );
  206. //compute features
  207. //variables to store feature information
  208. NICE::VVector features;
  209. NICE::VVector positions;
  210. Globals::setCurrentImgFN ( filename );
  211. featureExtractor->extractFeatures ( img, features, positions );
  212. //normalization :)
  213. for ( NICE::VVector::iterator i = features.begin();
  214. i != features.end();
  215. i++)
  216. {
  217. i->normalizeL1();
  218. }
  219. //collect them all in a larger data structure
  220. featuresFromAllTrainingImages.append( features );
  221. //and store it as well in the data struct that additionally keeps the information which features belong to which image
  222. //TODO this can be made more clever!
  223. // featuresOfImages.push_back( features );
  224. featuresOfImages[imgCnt] = features;
  225. labelsTrain[imgCnt] = classno;
  226. }
  227. }
  228. //**********************************************
  229. //
  230. // CODEBOOK CREATION
  231. //
  232. //**********************************************
  233. std::cerr << "CODEBOOK CREATION" << std::endl;
  234. // OBJREC::ClusterAlgorithm * clusterAlgo = setClusterAlgo( conf );
  235. OBJREC::ClusterAlgorithm * clusterAlgo = OBJREC::GenericClusterAlgorithmSelection::selectClusterAlgo ( conf );
  236. NICE::VVector prototypes;
  237. std::vector<double> weights;
  238. std::vector<int> assignments;
  239. std::cerr << "call cluster of cluster algo " << std::endl;
  240. clusterAlgo->cluster( featuresFromAllTrainingImages, prototypes, weights, assignments );
  241. std::cerr << "create new codebook with the computed prototypes" << std::endl;
  242. OBJREC::CodebookPrototypes * codebook = new OBJREC::CodebookPrototypes ( prototypes );
  243. //**********************************************
  244. //
  245. // VECTOR QUANTIZATION OF
  246. // FEATURES OF TRAINING IMAGES
  247. //
  248. //**********************************************
  249. OBJREC::BoWFeatureConverter * bowConverter = new OBJREC::BoWFeatureConverter ( conf, codebook );
  250. OBJREC::LabeledSetVector trainSet;
  251. NICE::VVector histograms ( featuresOfImages.size() /* number of vectors*/, 0 /* dimension of vectors*/ ); //the internal vectors will be resized within calcHistogram
  252. NICE::VVector::iterator histogramIt = histograms.begin();
  253. NICE::VectorT<int>::const_iterator labelsIt = labelsTrain.begin();
  254. for (std::vector<NICE::VVector>::const_iterator imgIt = featuresOfImages.begin(); imgIt != featuresOfImages.end(); imgIt++, histogramIt++, labelsIt++)
  255. {
  256. bowConverter->calcHistogram ( *imgIt, *histogramIt );
  257. bowConverter->normalizeHistogram ( *histogramIt );
  258. //NOTE perhaps we should use add_reference here
  259. trainSet.add( *labelsIt, *histogramIt );
  260. }
  261. //**********************************************
  262. //
  263. // CLASSIFIER TRAINING
  264. //
  265. //**********************************************
  266. std::string classifierType = conf->gS( "main", "classifierType", "GPHIK" );
  267. OBJREC::VecClassifier * classifier = OBJREC::GenericClassifierSelection::selectVecClassifier( conf, classifierType );
  268. //TODO integrate GP-HIK-NICE into vislearning and add it into genericClassifierSelection
  269. //this method adds the training data to the temporary knowledge of our classifier
  270. classifier->teach( trainSet );
  271. //now the actual training step starts (e.g., parameter estimation, ... )
  272. classifier->finishTeaching();
  273. // ========================================================================
  274. // TEST STEP
  275. // ========================================================================
  276. const LabeledSet *testFiles = md["test"];
  277. delete featureExtractor;
  278. featureExtractor = OBJREC::GenericLFSelection::selectLocalFeatureRep ( conf, "features", OBJREC::GenericLFSelection::TESTING );
  279. NICE::Matrix confusionMat ( trainFiles->size() /* number of classes in training */, testFiles->size() /* number of classes for testing*/, 0.0 );
  280. NICE::Timer t;
  281. ClassificationResults results;
  282. LOOP_ALL_S( *testFiles )
  283. {
  284. EACH_INFO( classno, info );
  285. std::string filename = info.img();
  286. //**********************************************
  287. //
  288. // FEATURE EXTRACTION FOR TEST IMAGES
  289. //
  290. //**********************************************
  291. NICE::ColorImage img( filename );
  292. //compute features
  293. //variables to store feature information
  294. NICE::VVector features;
  295. NICE::VVector positions;
  296. Globals::setCurrentImgFN ( filename );
  297. featureExtractor->extractFeatures ( img, features, positions );
  298. //normalization :)
  299. for ( NICE::VVector::iterator i = features.begin();
  300. i != features.end();
  301. i++)
  302. {
  303. i->normalizeL1();
  304. }
  305. //**********************************************
  306. //
  307. // VECTOR QUANTIZATION OF
  308. // FEATURES OF TEST IMAGES
  309. //
  310. //**********************************************
  311. NICE::Vector histogramOfCurrentImg;
  312. bowConverter->calcHistogram ( features, histogramOfCurrentImg );
  313. bowConverter->normalizeHistogram ( histogramOfCurrentImg );
  314. //**********************************************
  315. //
  316. // CLASSIFIER EVALUATION
  317. //
  318. //**********************************************
  319. uint classno_groundtruth = classno;
  320. t.start();
  321. ClassificationResult r = classifier->classify ( histogramOfCurrentImg );
  322. t.stop();
  323. uint classno_estimated = r.classno;
  324. r.classno_groundtruth = classno_groundtruth;
  325. //if we like to store the classification results for external post processing, uncomment this
  326. if ( writeClassificationResults )
  327. {
  328. results.push_back( r );
  329. }
  330. confusionMat( classno_estimated, classno_groundtruth ) += 1;
  331. }
  332. confusionMat.normalizeColumnsL1();
  333. std::cerr << confusionMat << std::endl;
  334. std::cerr << "average recognition rate: " << confusionMat.trace()/confusionMat.rows() << std::endl;
  335. if ( writeClassificationResults )
  336. {
  337. double avgRecogResults ( results.getAverageRecognitionRate () );
  338. std::cerr << "average recognition rate according to classificationResults: " << avgRecogResults << std::endl;
  339. results.writeWEKA ( resultsfile, 0 );
  340. }
  341. return 0;
  342. }