SemSegSTF.cpp 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. /**
  2. * @file SemSegSTF.cpp
  3. * @brief Localization system
  4. * @author Erik Rodner
  5. * @date 02/11/2008
  6. */
  7. #include <iostream>
  8. #include "SemSegSTF.h"
  9. #include "vislearning/baselib/Globals.h"
  10. #include "vislearning/baselib/Preprocess.h"
  11. #include "vislearning/baselib/ProgressBar.h"
  12. #include "core/basics/StringTools.h"
  13. #include "vislearning/baselib/Globals.h"
  14. #include "vislearning/cbaselib/CachedExample.h"
  15. #include "vislearning/cbaselib/PascalResults.h"
  16. #include "vislearning/features/fpfeatures/PixelPairFeature.h"
  17. #include "vislearning/features/fpfeatures/SemanticFeature.h"
  18. #include "vislearning/features/fpfeatures/FIGradients.h"
  19. #include "FIShotton.h"
  20. #include "SemSegTools.h"
  21. using namespace OBJREC;
  22. using namespace std;
  23. using namespace NICE;
  24. SemSegSTF::SemSegSTF ( const Config *conf,
  25. const MultiDataset *md )
  26. : SemanticSegmentation ( conf, & ( md->getClassNames ( "train" ) ) )
  27. {
  28. use_semantic_features = conf->gB ( "bost", "use_semantic_features", true );
  29. use_pixelpair_features = conf->gB ( "bost", "use_pixelpair_features", true );
  30. subsamplex = conf->gI ( "bost", "subsamplex", 5 );
  31. subsampley = conf->gI ( "bost", "subsampley", 5 );
  32. numClasses = md->getClassNames ( "train" ).numClasses();
  33. read_pixel_cache = conf->gB ( "FPCPixel", "read_cache", false );
  34. cachepixel = conf->gS ( "FPCPixel", "cache", "fpc.data" );
  35. read_seg_cache = conf->gB ( "FPCSeg", "read_cache", true );
  36. cacheseg = conf->gS ( "FPCSeg", "cache", "segforest.data" );
  37. Examples examples;
  38. vector<CachedExample *> imgexamples;
  39. fpcPixel = new FPCRandomForests ( conf, "FPCPixel" );
  40. fpcPixel->setMaxClassNo ( classNames->getMaxClassno() );
  41. if ( !read_pixel_cache || !read_seg_cache )
  42. {
  43. // Generate Positioned Examples
  44. SemSegTools::collectTrainingExamples ( conf, "FPCPixel", * ( ( *md ) ["train"] ), *classNames,
  45. examples, imgexamples );
  46. }
  47. if ( ! read_pixel_cache )
  48. {
  49. ///////////////////////////////////
  50. // Train Single Pixel Classifier
  51. //////////////////////////////////
  52. FeaturePool fp;
  53. for ( vector<CachedExample *>::const_iterator k = imgexamples.begin();
  54. k != imgexamples.end();
  55. k++ )
  56. fillCachePixel ( *k );
  57. PixelPairFeature hf ( conf );
  58. hf.explode ( fp );
  59. fpcPixel->train ( fp, examples );
  60. fpcPixel->save ( cachepixel );
  61. fp.destroy();
  62. } else {
  63. fprintf ( stderr, "SemSegSTF:: Reading pixel classifier data from %s\n", cachepixel.c_str() );
  64. fpcPixel->read ( cachepixel );
  65. }
  66. fpcSeg = new FPCRandomForests ( conf, "FPCSeg" );
  67. fpcSeg->setMaxClassNo ( classNames->getMaxClassno() );
  68. maxdepthSegmentationForest = conf->gI ( "bost", "maxdepth", 5 );
  69. maxdepthSegmentationForestScores = conf->gI ( "bost", "maxdepth_scores", 9999 );
  70. if ( ! read_seg_cache )
  71. {
  72. ///////////////////////////////////
  73. // Train Segmentation Forest
  74. //////////////////////////////////
  75. fprintf ( stderr, "Calculating Prior Statistics\n" );
  76. ProgressBar pbseg ( "Calculating Prior Statistics" );
  77. pbseg.show();
  78. for ( vector<CachedExample *>::const_iterator k = imgexamples.begin();
  79. k != imgexamples.end();
  80. k++ )
  81. {
  82. pbseg.update ( imgexamples.size() );
  83. fillCacheSegmentation ( *k );
  84. }
  85. pbseg.hide();
  86. FeaturePool fp;
  87. if ( use_semantic_features )
  88. {
  89. set<int> classnos;
  90. classNames->getSelection ( conf->gS ( "FPCSeg", "train_selection" )
  91. , classnos );
  92. SemanticFeature sf ( conf, &classnos );
  93. sf.explode ( fp );
  94. }
  95. fprintf ( stderr, "Training Segmentation Forest\n" );
  96. fpcSeg->train ( fp, examples );
  97. fpcSeg->save ( cacheseg );
  98. // clean up memory !!
  99. for ( vector<CachedExample *>::iterator i = imgexamples.begin();
  100. i != imgexamples.end();
  101. i++ )
  102. delete ( *i );
  103. fp.destroy();
  104. } else {
  105. fprintf ( stderr, "SemSegSTF:: Reading region classifier data from %s\n", cacheseg.c_str() );
  106. fpcSeg->read ( cacheseg );
  107. fprintf ( stderr, "SemSegSTF:: successfully read\n" );
  108. }
  109. }
  110. SemSegSTF::~SemSegSTF()
  111. {
  112. }
  113. void SemSegSTF::fillCacheSegmentation ( CachedExample *ce )
  114. {
  115. FIShotton::buildSemanticMap ( ce,
  116. fpcPixel,
  117. subsamplex,
  118. subsampley,
  119. numClasses );
  120. }
  121. void SemSegSTF::fillCachePixel ( CachedExample *ce )
  122. {
  123. }
  124. void SemSegSTF::semanticseg ( CachedExample *ce,
  125. NICE::Image & segresult,
  126. NICE::MultiChannelImageT<double> & probabilities )
  127. {
  128. int xsize;
  129. int ysize;
  130. ce->getImageSize ( xsize, ysize );
  131. int numClasses = classNames->numClasses();
  132. fillCachePixel ( ce );
  133. fillCacheSegmentation ( ce );
  134. fprintf ( stderr, "BoST classification !\n" );
  135. Example pce ( ce, 0, 0 );
  136. int xsize_s = xsize / subsamplex;
  137. int ysize_s = ysize / subsampley;
  138. ClassificationResult *results = new ClassificationResult [xsize_s*ysize_s];
  139. /** classify each pixel of the image */
  140. FullVector prior ( classNames->getMaxClassno() );
  141. probabilities.reInit ( xsize_s, ysize_s, numClasses, true );
  142. probabilities.setAll ( 0 );
  143. long offset_s = 0;
  144. for ( int ys = 0 ; ys < ysize_s ; ys ++ )
  145. for ( int xs = 0 ; xs < xsize_s ; xs++, offset_s++ )
  146. {
  147. int x = xs * subsamplex;
  148. int y = ys * subsampley;
  149. pce.x = x ;
  150. pce.y = y ;
  151. results[offset_s] = fpcSeg->classify ( pce );
  152. for ( int i = 0 ; i < results[offset_s].scores.size(); i++ )
  153. probabilities.data[i][offset_s] = results[offset_s].scores[i];
  154. /*
  155. if ( imagePriorMethod != IMAGE_PRIOR_NONE )
  156. prior.add ( results[offset_s].scores );
  157. */
  158. }
  159. fprintf ( stderr, "BoST classification ready\n" );
  160. /** save results */
  161. segresult.resize ( xsize_s, ysize_s );
  162. segresult.set ( classNames->classno ( "various" ) );
  163. long int offset = 0;
  164. for ( int y = 0 ; y < ysize_s ; y++ )
  165. for ( int x = 0 ; x < xsize_s ; x++, offset++ )
  166. {
  167. double maxvalue = - numeric_limits<double>::max();
  168. int maxindex = 0;
  169. for ( int i = 0 ; i < ( int ) probabilities.numChannels; i++ )
  170. if ( probabilities.data[i][offset] > maxvalue )
  171. {
  172. maxindex = i;
  173. maxvalue = probabilities.data[i][offset];
  174. }
  175. segresult.setPixel ( x, y, maxindex );
  176. }
  177. }