FeatureLearningClusterBased.cpp 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. #include "FeatureLearningClusterBased.h"
  2. #include <iostream>
  3. #include <core/image/FilterT.h>
  4. #include <core/vector/VectorT.h>
  5. #include <vislearning/baselib/ICETools.h>
  6. //
  7. #include <vislearning/features/localfeatures/LFonHSG.h>
  8. #include <vislearning/features/localfeatures/LFColorSande.h>
  9. #include <vislearning/features/localfeatures/LFColorWeijer.h>
  10. #include <vislearning/features/localfeatures/LFReadCache.h>
  11. #include <vislearning/features/localfeatures/LFWriteCache.h>
  12. //
  13. #include <vislearning/math/cluster/KMeans.h>
  14. #include <vislearning/math/cluster/GMM.h>
  15. using namespace std;
  16. using namespace NICE;
  17. using namespace OBJREC;
  18. //**********************************************
  19. //
  20. // PROTECTED METHODS
  21. //
  22. //**********************************************
  23. void FeatureLearningClusterBased::extractFeaturesFromTrainingImages( const OBJREC::MultiDataset *_md, NICE::VVector & examplesTraining )
  24. {
  25. examplesTraining.clear();
  26. int numberOfTrainImage ( 0 );
  27. const LabeledSet *trainFiles = (*_md)["train"];
  28. //run over all training images
  29. LOOP_ALL_S( *trainFiles )
  30. {
  31. EACH_INFO( classno, info );
  32. std::string filename = info.img();
  33. NICE::ColorImage img( filename );
  34. if ( showTrainingImages )
  35. {
  36. showImage( img, "Input" );
  37. }
  38. //variables to store feature informatio
  39. NICE::VVector features;
  40. NICE::VVector cfeatures;
  41. NICE::VVector positions;
  42. //compute features
  43. Globals::setCurrentImgFN ( filename );
  44. if (featureExtractor == NULL)
  45. std::cerr << "feature Extractor is NULL" << std::endl;
  46. else
  47. featureExtractor->extractFeatures ( img, features, positions );
  48. //store feature information in larger data structure
  49. for ( NICE::VVector::iterator i = features.begin();
  50. i != features.end();
  51. i++)
  52. {
  53. //normalization :)
  54. i->normalizeL1();
  55. examplesTraining.push_back(*i);
  56. }
  57. //don't waste memory
  58. features.clear();
  59. positions.clear();
  60. numberOfTrainImage++;
  61. }//Loop over all training images
  62. }
  63. void FeatureLearningClusterBased::train ( const OBJREC::MultiDataset *_md )
  64. {
  65. //**********************************************
  66. //
  67. // EXTRACT FEATURES FROM TRAINING IMAGES
  68. //
  69. //**********************************************
  70. std::cerr << " EXTRACT FEATURES FROM TRAINING IMAGES" << std::endl;
  71. NICE::VVector examplesTraining;
  72. this->extractFeaturesFromTrainingImages( _md, examplesTraining );
  73. //**********************************************
  74. //
  75. // CLUSTER FEATURES FROM TRAINING IMAGES
  76. //
  77. // THIS GIVES US AN INITIAL CODEBOOK
  78. //
  79. //**********************************************
  80. std::cerr << " CLUSTER FEATURES FROM TRAINING IMAGES" << std::endl;
  81. //go, go, go...
  82. prototypes.clear();
  83. std::vector< double > weights;
  84. std::vector< int > assignment;
  85. clusterAlgo->cluster ( examplesTraining, prototypes, weights, assignment);
  86. weights.clear();
  87. assignment.clear();
  88. }
  89. //**********************************************
  90. //
  91. // PUBLIC METHODS
  92. //
  93. //**********************************************
  94. FeatureLearningClusterBased::FeatureLearningClusterBased ( const Config *_conf,
  95. const MultiDataset *_md, const std::string & _section )
  96. : FeatureLearningGeneric ( _conf )
  97. {
  98. this->section = _section;
  99. //feature stuff
  100. //! which OpponentSIFT implementation to use {NICE, VANDESANDE}
  101. std::string opSiftImpl;
  102. opSiftImpl = conf->gS ( "Descriptor", "implementation", "VANDESANDE" );
  103. //! read features?
  104. bool readfeat;
  105. readfeat = conf->gB ( "Descriptor", "read", true );
  106. //! write features?
  107. bool writefeat;
  108. writefeat = conf->gB ( "Descriptor", "write", true );
  109. showTrainingImages = conf->gB( section, "showTrainingImages", false );
  110. //! define the initial number of clusters our codebook shall contain
  111. initialNumberOfClusters = conf->gI(section, "initialNumberOfClusters", 10);
  112. //! define the clustering algorithm to be used
  113. std::string clusterAlgoString = conf->gS(section, "clusterAlgo", "kmeans");
  114. //! define the distance function to be used
  115. std::string distFunctionString = conf->gS(section, "distFunction", "euclidian");
  116. //**********************************************
  117. //
  118. // SET UP VARIABLES AND METHODS
  119. // - FEATURE TYPE
  120. // - CLUSTERING ALGO
  121. // - DISTANCE FUNCTION
  122. // - ...
  123. //
  124. //**********************************************
  125. std::cerr << " SET UP VARIABLES AND METHODS " << std::endl;
  126. // Welche Opponentsift Implementierung soll genutzt werden ?
  127. LocalFeatureRepresentation *cSIFT = NULL;
  128. LocalFeatureRepresentation *writeFeats = NULL;
  129. LocalFeatureRepresentation *readFeats = NULL;
  130. this->featureExtractor = NULL;
  131. if ( opSiftImpl == "NICE" )
  132. {
  133. cSIFT = new OBJREC::LFonHSG ( conf, "HSGtrain" );
  134. }
  135. else if ( opSiftImpl == "VANDESANDE" )
  136. {
  137. cSIFT = new OBJREC::LFColorSande ( conf, "LFColorSandeTrain" );
  138. }
  139. else
  140. {
  141. fthrow ( Exception, "feattype: %s not yet supported" << opSiftImpl );
  142. }
  143. this->featureExtractor = cSIFT;
  144. if ( writefeat )
  145. {
  146. // write the features to a file, if there isn't any to read
  147. writeFeats = new LFWriteCache ( conf, cSIFT );
  148. this->featureExtractor = writeFeats;
  149. }
  150. if ( readfeat )
  151. {
  152. // read the features from a file
  153. if ( writefeat )
  154. {
  155. readFeats = new LFReadCache ( conf, writeFeats, -1 );
  156. }
  157. else
  158. {
  159. readFeats = new LFReadCache ( conf, cSIFT, -1 );
  160. }
  161. this->featureExtractor = readFeats;
  162. }
  163. if (clusterAlgoString.compare("kmeans") == 0)
  164. {
  165. clusterAlgo = new OBJREC::KMeans(initialNumberOfClusters);
  166. }
  167. else if (clusterAlgoString.compare("GMM") == 0)
  168. {
  169. clusterAlgo = new OBJREC::GMM(conf, initialNumberOfClusters);
  170. }
  171. else
  172. {
  173. std::cerr << "Unknown cluster algorithm selected, use k-means instead" << std::endl;
  174. clusterAlgo = new OBJREC::KMeans(initialNumberOfClusters);
  175. }
  176. if (distFunctionString.compare("euclidian") == 0)
  177. {
  178. distFunction = new NICE::EuclidianDistance<double>();
  179. }
  180. else
  181. {
  182. std::cerr << "Unknown vector distance selected, use euclidian instead" << std::endl;
  183. distFunction = new NICE::EuclidianDistance<double>();
  184. }
  185. //run the training to initially compute a codebook and stuff like that
  186. this->train( _md );
  187. //only set feature stuff to NULL, deletion of the underlying object is done in the destructor
  188. if ( cSIFT != NULL )
  189. cSIFT = NULL;
  190. if ( writeFeats != NULL )
  191. writeFeats = NULL;
  192. if ( readFeats != NULL )
  193. readFeats = NULL ;
  194. }
  195. FeatureLearningClusterBased::~FeatureLearningClusterBased()
  196. {
  197. // clean-up
  198. if ( clusterAlgo != NULL )
  199. delete clusterAlgo;
  200. if ( distFunction != NULL )
  201. delete distFunction;
  202. if ( featureExtractor != NULL )
  203. delete featureExtractor;
  204. }
  205. void FeatureLearningClusterBased::learnNewFeatures ( OBJREC::CachedExample *_ce )
  206. {
  207. }
  208. void FeatureLearningClusterBased::evaluateCurrentCodebook ( const std::string & filename )
  209. {
  210. NICE::ColorImage img( filename );
  211. if ( showTrainingImages )
  212. {
  213. showImage( img, "Input" );
  214. }
  215. int xsize ( img.width() );
  216. int ysize ( img.height() );
  217. //variables to store feature information
  218. NICE::VVector features;
  219. NICE::VVector cfeatures;
  220. NICE::VVector positions;
  221. //compute features
  222. Globals::setCurrentImgFN ( filename );
  223. featureExtractor->extractFeatures ( img, features, positions );
  224. FloatImage noveltyImage ( xsize, ysize );
  225. noveltyImage.set ( 0.0 );
  226. double maxDist ( 0.0 );
  227. NICE::VVector::const_iterator posIt = positions.begin();
  228. //store feature information in larger data structure
  229. for ( NICE::VVector::iterator i = features.begin();
  230. i != features.end();
  231. i++, posIt++)
  232. {
  233. //normalization :)
  234. i->normalizeL1();
  235. //loop over codebook representatives
  236. double minDist ( std::numeric_limits<double>::max() );
  237. for (NICE::VVector::const_iterator it = prototypes.begin(); it != prototypes.end(); it++)
  238. {
  239. //compute distance
  240. double tmpDist ( distFunction->calculate(*i,*it) );
  241. if (tmpDist < minDist)
  242. minDist = tmpDist;
  243. }
  244. if (minDist > maxDist)
  245. maxDist = minDist;
  246. //take minimum distance and store in in a float image
  247. noveltyImage ( (*posIt)[0], (*posIt)[1] ) = minDist;
  248. }
  249. //gauss-filtering for nicer visualization
  250. FloatImage noveltyImageGaussFiltered ( xsize, ysize );
  251. float sigma ( 3.0 );
  252. FilterT<float, float, float> filter;
  253. filter.filterGaussSigmaApproximate ( noveltyImage, sigma, &noveltyImageGaussFiltered );
  254. std::cerr << "maximum distance of Training images: " << maxDist;
  255. //for suitable visualization of scores between zero (known) and one (unknown)
  256. // noveltyImageGaussFiltered( 0 , 0 ) = std::max<double>(maxDist, 1.0);
  257. //convert float to RGB
  258. NICE::ColorImage noveltyImageRGB ( xsize, ysize );
  259. ICETools::convertToRGB ( noveltyImageGaussFiltered, noveltyImageRGB );
  260. showImage(noveltyImageRGB, "Novelty Image");
  261. }