FeatureLearningPrototypes.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. #include "FeatureLearningPrototypes.h"
  2. #include <iostream>
  3. #include <core/image/FilterT.h>
  4. #include <core/image/CircleT.h>
  5. #include <core/image/Convert.h>
  6. #include <core/vector/VectorT.h>
  7. #include <vislearning/features/localfeatures/LFonHSG.h>
  8. #include <vislearning/features/localfeatures/LFColorSande.h>
  9. #include <vislearning/features/localfeatures/LFColorWeijer.h>
  10. #include <vislearning/features/localfeatures/LFReadCache.h>
  11. #include <vislearning/features/localfeatures/LFWriteCache.h>
  12. //
  13. #include <vislearning/math/cluster/KMeans.h>
  14. #include <vislearning/math/cluster/GMM.h>
  15. using namespace std;
  16. using namespace NICE;
  17. using namespace OBJREC;
  18. //**********************************************
  19. //
  20. // PROTECTED METHODS
  21. //
  22. //**********************************************
  23. void FeatureLearningPrototypes::setClusterAlgo( const std::string & _clusterAlgoString)
  24. {
  25. //be careful with previously allocated memory
  26. if (this->clusterAlgo != NULL)
  27. delete clusterAlgo;
  28. if (_clusterAlgoString.compare("kmeans") == 0)
  29. {
  30. this->clusterAlgo = new OBJREC::KMeans(this->initialNumberOfClusters);
  31. }
  32. else if (_clusterAlgoString.compare("GMM") == 0)
  33. {
  34. this->clusterAlgo = new OBJREC::GMM(this->conf, this->initialNumberOfClusters);
  35. }
  36. else
  37. {
  38. std::cerr << "Unknown cluster algorithm selected, use k-means instead" << std::endl;
  39. this->clusterAlgo = new OBJREC::KMeans(this->initialNumberOfClusters);
  40. }
  41. }
  42. void FeatureLearningPrototypes::extractFeaturesFromTrainingImages( const OBJREC::MultiDataset *_md, NICE::VVector & examplesTraining )
  43. {
  44. examplesTraining.clear();
  45. int numberOfTrainImage ( 0 );
  46. const LabeledSet *trainFiles = (*_md)["train"];
  47. //run over all training images
  48. LOOP_ALL_S( *trainFiles )
  49. {
  50. EACH_INFO( classno, info );
  51. std::string filename = info.img();
  52. NICE::ColorImage img( filename );
  53. if ( b_showTrainingImages )
  54. {
  55. showImage( img, "Input" );
  56. }
  57. //variables to store feature informatio
  58. NICE::VVector features;
  59. NICE::VVector cfeatures;
  60. NICE::VVector positions;
  61. //compute features
  62. Globals::setCurrentImgFN ( filename );
  63. if (featureExtractor == NULL)
  64. std::cerr << "feature Extractor is NULL" << std::endl;
  65. else
  66. featureExtractor->extractFeatures ( img, features, positions );
  67. //store feature information in larger data structure
  68. for ( NICE::VVector::iterator i = features.begin();
  69. i != features.end();
  70. i++)
  71. {
  72. //normalization :)
  73. i->normalizeL1();
  74. examplesTraining.push_back(*i);
  75. }
  76. //don't waste memory
  77. features.clear();
  78. positions.clear();
  79. numberOfTrainImage++;
  80. }//Loop over all training images
  81. }
  82. void FeatureLearningPrototypes::train ( const OBJREC::MultiDataset *_md )
  83. {
  84. bool loadSuccess = this->loadInitialCodebook();
  85. if ( !loadSuccess )
  86. {
  87. //**********************************************
  88. //
  89. // EXTRACT FEATURES FROM TRAINING IMAGES
  90. //
  91. //**********************************************
  92. std::cerr << " EXTRACT FEATURES FROM TRAINING IMAGES" << std::endl;
  93. NICE::VVector examplesTraining;
  94. this->extractFeaturesFromTrainingImages( _md, examplesTraining );
  95. //**********************************************
  96. //
  97. // CLUSTER FEATURES FROM TRAINING IMAGES
  98. //
  99. // THIS GIVES US AN INITIAL CODEBOOK
  100. //
  101. //**********************************************
  102. std::cerr << " CLUSTER FEATURES FROM TRAINING IMAGES" << std::endl;
  103. //go, go, go...
  104. prototypes.clear();
  105. std::vector< double > weights;
  106. std::vector< int > assignment;
  107. clusterAlgo->cluster ( examplesTraining, prototypes, weights, assignment);
  108. weights.clear();
  109. assignment.clear();
  110. }
  111. this->writeInitialCodebook();
  112. }
  113. bool FeatureLearningPrototypes::loadInitialCodebook ( )
  114. {
  115. if ( b_loadInitialCodebook )
  116. {
  117. std::cerr << " INITIAL CODEBOOK ALREADY COMPUTED - RE-USE IT" << std::endl;
  118. std::cerr << " // WARNING - WE DO NOT VERIFY WHETHER THIS IS THE CORRECT CODEBOOK FOR THIS TRAINING SET!!!!" << std::endl;
  119. prototypes.clear();
  120. try
  121. {
  122. prototypes.read(cacheInitialCodebook);
  123. }
  124. catch (...)
  125. {
  126. std::cerr << "Error while loading initial codebook" << std::endl;
  127. return false;
  128. }
  129. return true;
  130. }
  131. else
  132. return false;
  133. }
  134. bool FeatureLearningPrototypes::writeInitialCodebook ( )
  135. {
  136. if ( b_saveInitialCodebook )
  137. {
  138. std::cerr << " SAVE INITIAL CODEBOOK " << std::endl;
  139. try
  140. {
  141. prototypes.write( cacheInitialCodebook );
  142. }
  143. catch (...)
  144. {
  145. std::cerr << "Error while saving initial codebook" << std::endl;
  146. return false;
  147. }
  148. return true;
  149. }
  150. else
  151. return false;
  152. }
  153. //**********************************************
  154. //
  155. // PUBLIC METHODS
  156. //
  157. //**********************************************
  158. FeatureLearningPrototypes::FeatureLearningPrototypes ( const Config *_conf,
  159. const MultiDataset *_md, const std::string & _section )
  160. : FeatureLearningGeneric ( _conf, _section )
  161. {
  162. //feature stuff
  163. // which OpponentSIFT implementation to use {NICE, VANDESANDE}
  164. std::string opSiftImpl;
  165. opSiftImpl = conf->gS ( "Descriptor", "implementation", "VANDESANDE" );
  166. // read features?
  167. bool readfeat;
  168. readfeat = conf->gB ( "Descriptor", "read", true );
  169. // write features?
  170. bool writefeat;
  171. writefeat = conf->gB ( "Descriptor", "write", true );
  172. // define the initial number of clusters our codebook shall contain
  173. initialNumberOfClusters = conf->gI(section, "initialNumberOfClusters", 10);
  174. // define the clustering algorithm to be used
  175. std::string clusterAlgoString = conf->gS(section, "clusterAlgo", "kmeans");
  176. // define the distance function to be used
  177. std::string distFunctionString = conf->gS(section, "distFunction", "euclidian");
  178. //**********************************************
  179. //
  180. // SET UP VARIABLES AND METHODS
  181. // - FEATURE TYPE
  182. // - CLUSTERING ALGO
  183. // - DISTANCE FUNCTION
  184. // - ...
  185. //
  186. //**********************************************
  187. std::cerr << " SET UP VARIABLES AND METHODS " << std::endl;
  188. // Welche Opponentsift Implementierung soll genutzt werden ?
  189. LocalFeatureRepresentation *cSIFT = NULL;
  190. LocalFeatureRepresentation *writeFeats = NULL;
  191. LocalFeatureRepresentation *readFeats = NULL;
  192. this->featureExtractor = NULL;
  193. if ( opSiftImpl == "NICE" )
  194. {
  195. cSIFT = new OBJREC::LFonHSG ( conf, "HSGtrain" );
  196. }
  197. else if ( opSiftImpl == "VANDESANDE" )
  198. {
  199. cSIFT = new OBJREC::LFColorSande ( conf, "LFColorSandeTrain" );
  200. }
  201. else
  202. {
  203. fthrow ( Exception, "feattype: %s not yet supported" << opSiftImpl );
  204. }
  205. this->featureExtractor = cSIFT;
  206. if ( writefeat )
  207. {
  208. // write the features to a file, if there isn't any to read
  209. writeFeats = new LFWriteCache ( conf, cSIFT );
  210. this->featureExtractor = writeFeats;
  211. }
  212. if ( readfeat )
  213. {
  214. // read the features from a file
  215. if ( writefeat )
  216. {
  217. readFeats = new LFReadCache ( conf, writeFeats, -1 );
  218. }
  219. else
  220. {
  221. readFeats = new LFReadCache ( conf, cSIFT, -1 );
  222. }
  223. this->featureExtractor = readFeats;
  224. }
  225. this->clusterAlgo = NULL;
  226. this->setClusterAlgo( clusterAlgoString );
  227. if (distFunctionString.compare("euclidian") == 0)
  228. {
  229. distFunction = new NICE::EuclidianDistance<double>();
  230. }
  231. else
  232. {
  233. std::cerr << "Unknown vector distance selected, use euclidian instead" << std::endl;
  234. distFunction = new NICE::EuclidianDistance<double>();
  235. }
  236. //run the training to initially compute a codebook and stuff like that
  237. this->train( _md );
  238. //only set feature stuff to NULL, deletion of the underlying object is done in the destructor
  239. if ( cSIFT != NULL )
  240. cSIFT = NULL;
  241. if ( writeFeats != NULL )
  242. writeFeats = NULL;
  243. if ( readFeats != NULL )
  244. readFeats = NULL ;
  245. //so far, we have not seen any new image
  246. this->newImageCounter = 0;
  247. //TODO stupid
  248. this->maxValForVisualization = 0.005;
  249. }
  250. FeatureLearningPrototypes::~FeatureLearningPrototypes()
  251. {
  252. // clean-up
  253. if ( clusterAlgo != NULL )
  254. delete clusterAlgo;
  255. if ( distFunction != NULL )
  256. delete distFunction;
  257. if ( featureExtractor != NULL )
  258. delete featureExtractor;
  259. }
  260. NICE::FloatImage FeatureLearningPrototypes::evaluateCurrentCodebook ( const std::string & _filename , const bool & beforeComputingNewFeatures )
  261. {
  262. NICE::ColorImage img( _filename );
  263. if ( b_showTrainingImages )
  264. {
  265. showImage( img, "Input" );
  266. }
  267. int xsize ( img.width() );
  268. int ysize ( img.height() );
  269. //variables to store feature information
  270. NICE::VVector features;
  271. NICE::VVector cfeatures;
  272. NICE::VVector positions;
  273. //compute features
  274. Globals::setCurrentImgFN ( _filename );
  275. featureExtractor->extractFeatures ( img, features, positions );
  276. FloatImage noveltyImage ( xsize, ysize );
  277. noveltyImage.set ( 0.0 );
  278. double maxDist ( 0.0 );
  279. NICE::VVector::const_iterator posIt = positions.begin();
  280. //store feature information in larger data structure
  281. for ( NICE::VVector::iterator i = features.begin();
  282. i != features.end();
  283. i++, posIt++)
  284. {
  285. //normalization :)
  286. i->normalizeL1();
  287. //loop over codebook representatives
  288. double minDist ( std::numeric_limits<double>::max() );
  289. for (NICE::VVector::const_iterator it = prototypes.begin(); it != prototypes.end(); it++)
  290. {
  291. //compute distance
  292. double tmpDist ( this->distFunction->calculate(*i,*it) );
  293. if (tmpDist < minDist)
  294. minDist = tmpDist;
  295. }
  296. if (minDist > maxDist)
  297. maxDist = minDist;
  298. //take minimum distance and store in in a float image
  299. noveltyImage ( (*posIt)[0], (*posIt)[1] ) = minDist;
  300. }
  301. //gauss-filtering for nicer visualization
  302. FloatImage noveltyImageGaussFiltered ( xsize, ysize );
  303. float sigma ( 3.0 );
  304. FilterT<float, float, float> filter;
  305. filter.filterGaussSigmaApproximate ( noveltyImage, sigma, &noveltyImageGaussFiltered );
  306. double maxFiltered ( noveltyImageGaussFiltered.max() );
  307. std::cerr << "maximum distance of Training images: " << maxDist << std::endl;
  308. std::cerr << "maximum distance of Training images after filtering: " << maxFiltered << std::endl;
  309. if ( beforeComputingNewFeatures )
  310. this->oldMaxDist = maxFiltered;
  311. //for suitable visualization of scores between zero (known) and one (unknown)
  312. // noveltyImageGaussFiltered( 0 , 0 ) = std::max<double>(maxDist, 1.0);
  313. //convert float to RGB
  314. NICE::ColorImage noveltyImageRGB ( xsize, ysize );
  315. // ICETools::convertToRGB ( noveltyImageGaussFiltered, noveltyImageRGB );
  316. if ( beforeComputingNewFeatures )
  317. {
  318. imageToPseudoColorWithRangeSpecification( noveltyImageGaussFiltered, noveltyImageRGB, 0 /* min */, maxValForVisualization /* maxFiltered*/ /* max */ );
  319. std::cerr << "set max value to: " << noveltyImageGaussFiltered.max() << std::endl;
  320. }
  321. else
  322. {
  323. imageToPseudoColorWithRangeSpecification( noveltyImageGaussFiltered, noveltyImageRGB, 0 /* min */, maxValForVisualization /*this->oldMaxDist*/ /* max */ );
  324. std::cerr << "set max value to: " << this->oldMaxDist << std::endl;
  325. }
  326. if ( b_showResults )
  327. showImage(noveltyImageRGB, "Novelty Image");
  328. else
  329. {
  330. std::vector< std::string > list2;
  331. StringTools::split ( _filename, '/', list2 );
  332. std::string destination ( s_resultdir + NICE::intToString(this->newImageCounter -1 ) + "_" + list2.back() + "_3_updatedNoveltyMap.ppm");
  333. if ( beforeComputingNewFeatures )
  334. destination = s_resultdir + NICE::intToString(this->newImageCounter) + "_" + list2.back() + "_0_initialNoveltyMap.ppm";
  335. noveltyImageRGB.writePPM( destination );
  336. }
  337. // now look where the closest features for the current cluster indices are
  338. int tmpProtCnt ( 0 );
  339. for (NICE::VVector::const_iterator protIt = prototypes.begin(); protIt != prototypes.end(); protIt++, tmpProtCnt++)
  340. {
  341. double distToNewCluster ( std::numeric_limits<double>::max() );
  342. int indexOfMostSimFeat( 0 );
  343. double tmpDist;
  344. int tmpCnt ( 0 );
  345. for ( NICE::VVector::iterator i = features.begin();
  346. i != features.end();
  347. i++, tmpCnt++)
  348. {
  349. tmpDist = this->distFunction->calculate( *i, *protIt );
  350. if ( tmpDist < distToNewCluster )
  351. {
  352. distToNewCluster = tmpDist;
  353. indexOfMostSimFeat = tmpCnt;
  354. }
  355. }
  356. int posX ( ( positions[indexOfMostSimFeat] ) [0] );
  357. int posY ( ( positions[indexOfMostSimFeat] ) [1] );
  358. NICE::Circle circ ( Coord( posX, posY), 2*tmpProtCnt /* radius*/, Color(200,0,255 ) );
  359. img.draw(circ);
  360. }
  361. if ( b_showResults )
  362. showImage(img, "Current image and most similar features for current cluster");
  363. else
  364. {
  365. std::vector< std::string > list2;
  366. StringTools::split ( _filename, '/', list2 );
  367. std::string destination ( s_resultdir + NICE::intToString(this->newImageCounter-1) + "_" + list2.back() + "_3_updatedCurrentCluster.ppm");
  368. if ( beforeComputingNewFeatures )
  369. destination = s_resultdir + NICE::intToString(this->newImageCounter) + "_" + list2.back() + "_0_initialCurrentCluster.ppm";
  370. img.writePPM( destination );
  371. }
  372. return noveltyImageGaussFiltered;
  373. }