SemSegSTF.cpp 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. /**
  2. * @file SemSegSTF.cpp
  3. * @brief Localization system
  4. * @author Erik Rodner
  5. * @date 02/11/2008
  6. */
  7. #ifdef NOVISUAL
  8. #include <objrec/nice_nonvis.h>
  9. #else
  10. #include <objrec/nice.h>
  11. #endif
  12. #include <iostream>
  13. #include "SemSegSTF.h"
  14. #include "objrec/baselib/Globals.h"
  15. #include "objrec/baselib/Preprocess.h"
  16. #include "objrec/baselib/ProgressBar.h"
  17. #include "objrec/baselib/StringTools.h"
  18. #include "objrec/baselib/Globals.h"
  19. #include "objrec/cbaselib/CachedExample.h"
  20. #include "objrec/cbaselib/PascalResults.h"
  21. #include "objrec/features/fpfeatures/PixelPairFeature.h"
  22. #include "objrec/features/fpfeatures/SemanticFeature.h"
  23. #include "objrec/features/fpfeatures/FIGradients.h"
  24. #include "FIShotton.h"
  25. #include "SemSegTools.h"
  26. using namespace OBJREC;
  27. using namespace std;
  28. using namespace NICE;
  29. SemSegSTF::SemSegSTF( const Config *conf,
  30. const MultiDataset *md )
  31. : SemanticSegmentation ( conf, &(md->getClassNames("train")) )
  32. {
  33. use_semantic_features = conf->gB("bost", "use_semantic_features", true );
  34. use_pixelpair_features = conf->gB("bost", "use_pixelpair_features", true );
  35. subsamplex = conf->gI("bost", "subsamplex", 5);
  36. subsampley = conf->gI("bost", "subsampley", 5);
  37. numClasses = md->getClassNames("train").numClasses();
  38. read_pixel_cache = conf->gB("FPCPixel", "read_cache", false );
  39. cachepixel = conf->gS("FPCPixel", "cache", "fpc.data" );
  40. read_seg_cache = conf->gB("FPCSeg", "read_cache", true );
  41. cacheseg = conf->gS("FPCSeg", "cache", "segforest.data" );
  42. Examples examples;
  43. vector<CachedExample *> imgexamples;
  44. fpcPixel = new FPCRandomForests ( conf, "FPCPixel" );
  45. fpcPixel->setMaxClassNo ( classNames->getMaxClassno() );
  46. if ( !read_pixel_cache || !read_seg_cache )
  47. {
  48. // Generate Positioned Examples
  49. SemSegTools::collectTrainingExamples ( conf, "FPCPixel", *((*md)["train"]), *classNames,
  50. examples, imgexamples );
  51. }
  52. if ( ! read_pixel_cache )
  53. {
  54. ///////////////////////////////////
  55. // Train Single Pixel Classifier
  56. //////////////////////////////////
  57. FeaturePool fp;
  58. for ( vector<CachedExample *>::const_iterator k = imgexamples.begin();
  59. k != imgexamples.end();
  60. k++ )
  61. fillCachePixel (*k);
  62. PixelPairFeature hf (conf);
  63. hf.explode ( fp );
  64. fpcPixel->train ( fp, examples );
  65. fpcPixel->save ( cachepixel );
  66. fp.destroy();
  67. } else {
  68. fprintf (stderr, "SemSegSTF:: Reading pixel classifier data from %s\n", cachepixel.c_str() );
  69. fpcPixel->read ( cachepixel );
  70. }
  71. fpcSeg = new FPCRandomForests ( conf, "FPCSeg" );
  72. fpcSeg->setMaxClassNo ( classNames->getMaxClassno() );
  73. maxdepthSegmentationForest = conf->gI("bost", "maxdepth", 5);
  74. maxdepthSegmentationForestScores = conf->gI("bost", "maxdepth_scores", 9999);
  75. if ( ! read_seg_cache )
  76. {
  77. ///////////////////////////////////
  78. // Train Segmentation Forest
  79. //////////////////////////////////
  80. fprintf (stderr, "Calculating Prior Statistics\n");
  81. ProgressBar pbseg ("Calculating Prior Statistics");
  82. pbseg.show();
  83. for ( vector<CachedExample *>::const_iterator k = imgexamples.begin();
  84. k != imgexamples.end();
  85. k++ )
  86. {
  87. pbseg.update ( imgexamples.size() );
  88. fillCacheSegmentation ( *k );
  89. }
  90. pbseg.hide();
  91. FeaturePool fp;
  92. if ( use_semantic_features )
  93. {
  94. set<int> classnos;
  95. classNames->getSelection ( conf->gS("FPCSeg", "train_selection")
  96. , classnos );
  97. SemanticFeature sf ( conf, &classnos );
  98. sf.explode ( fp );
  99. }
  100. fprintf (stderr, "Training Segmentation Forest\n");
  101. fpcSeg->train ( fp, examples );
  102. fpcSeg->save ( cacheseg );
  103. // clean up memory !!
  104. for ( vector<CachedExample *>::iterator i = imgexamples.begin();
  105. i != imgexamples.end();
  106. i++ )
  107. delete ( *i );
  108. fp.destroy();
  109. } else {
  110. fprintf (stderr, "SemSegSTF:: Reading region classifier data from %s\n", cacheseg.c_str() );
  111. fpcSeg->read ( cacheseg );
  112. fprintf (stderr, "SemSegSTF:: successfully read\n" );
  113. }
  114. }
  115. SemSegSTF::~SemSegSTF()
  116. {
  117. }
  118. void SemSegSTF::fillCacheSegmentation ( CachedExample *ce )
  119. {
  120. FIShotton::buildSemanticMap ( ce,
  121. fpcPixel,
  122. subsamplex,
  123. subsampley,
  124. numClasses );
  125. }
  126. void SemSegSTF::fillCachePixel ( CachedExample *ce )
  127. {
  128. }
  129. void SemSegSTF::semanticseg ( CachedExample *ce,
  130. NICE::Image & segresult,
  131. GenericImage<double> & probabilities )
  132. {
  133. int xsize;
  134. int ysize;
  135. ce->getImageSize ( xsize, ysize );
  136. int numClasses = classNames->numClasses();
  137. fillCachePixel ( ce );
  138. fillCacheSegmentation ( ce );
  139. fprintf (stderr, "BoST classification !\n");
  140. Example pce ( ce, 0, 0 );
  141. int xsize_s = xsize / subsamplex;
  142. int ysize_s = ysize / subsampley;
  143. ClassificationResult *results = new ClassificationResult [xsize_s*ysize_s];
  144. /** classify each pixel of the image */
  145. FullVector prior ( classNames->getMaxClassno() );
  146. probabilities.reInit ( xsize_s, ysize_s, numClasses, true );
  147. probabilities.setAll ( 0 );
  148. long offset_s = 0;
  149. for ( int ys = 0 ; ys < ysize_s ; ys ++ )
  150. for ( int xs = 0 ; xs < xsize_s ; xs++,offset_s++ )
  151. {
  152. int x = xs * subsamplex;
  153. int y = ys * subsampley;
  154. pce.x = x ; pce.y = y ;
  155. results[offset_s] = fpcSeg->classify ( pce );
  156. for ( int i = 0 ; i < results[offset_s].scores.size(); i++ )
  157. probabilities.data[i][offset_s] = results[offset_s].scores[i];
  158. /*
  159. if ( imagePriorMethod != IMAGE_PRIOR_NONE )
  160. prior.add ( results[offset_s].scores );
  161. */
  162. }
  163. fprintf (stderr, "BoST classification ready\n");
  164. /** save results */
  165. segresult.resize(xsize_s, ysize_s);
  166. segresult.set( classNames->classno("various") );
  167. long int offset = 0;
  168. for ( int y = 0 ; y < ysize_s ; y++ )
  169. for ( int x = 0 ; x < xsize_s ; x++,offset++ )
  170. {
  171. double maxvalue = - numeric_limits<double>::max();
  172. int maxindex = 0;
  173. for ( int i = 0 ; i < (int)probabilities.numChannels; i++ )
  174. if ( probabilities.data[i][offset] > maxvalue )
  175. {
  176. maxindex = i;
  177. maxvalue = probabilities.data[i][offset];
  178. }
  179. segresult.setPixel(x,y,maxindex);
  180. }
  181. }