SemSegConvolutionalTree.cpp 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. /**
  2. * @file SemSegConvolutionalTree.h
  3. * @brief Semantic Segmentation using Covolutional Trees
  4. * @author Sven Sickert
  5. * @date 10/17/2014
  6. */
  7. #include <iostream>
  8. #include "SemSegConvolutionalTree.h"
  9. #include "SemSegTools.h"
  10. #include "vislearning/classifier/fpclassifier/randomforest/FPCRandomForests.h"
  11. #include "vislearning/features/fpfeatures/ConvolutionFeature.h"
  12. #include "vislearning/baselib/cc.h"
  13. using namespace OBJREC;
  14. using namespace std;
  15. using namespace NICE;
  16. //###################### CONSTRUCTORS #########################//
  17. SemSegConvolutionalTree::SemSegConvolutionalTree () : SemanticSegmentation ()
  18. {
  19. conf = NULL;
  20. saveLoadData = false;
  21. fileLocation = "classifier.data";
  22. fpc = new FPCRandomForests ();
  23. }
  24. SemSegConvolutionalTree::SemSegConvolutionalTree (
  25. const Config *conf,
  26. const ClassNames *classNames )
  27. : SemanticSegmentation( conf, classNames )
  28. {
  29. initFromConfig( conf );
  30. }
  31. //###################### DESTRUCTORS ##########################//
  32. SemSegConvolutionalTree::~SemSegConvolutionalTree ()
  33. {
  34. if ( fpc != NULL )
  35. delete fpc;
  36. }
  37. //#################### MEMBER FUNCTIONS #######################//
  38. void SemSegConvolutionalTree::convertRGBToHSV (
  39. CachedExample *ce,
  40. bool isColor ) const
  41. {
  42. NICE::MultiChannelImageT<int> * img = NULL;
  43. NICE::MultiChannelImageT<double> * imgD = NULL;
  44. imgD = & ce->getDChannel( CachedExample::D_EOH );
  45. assert( imgD->channels() == 0 );
  46. if ( isColor )
  47. {
  48. img = & ce->getIChannel( CachedExample::I_COLOR );
  49. imgD->reInit ( img->width(), img->height(), 3 );
  50. for ( int y = 0; y < img->height(); y++ )
  51. for ( int x = 0; x < img->width(); x++ )
  52. {
  53. double h,s,v;
  54. double r = (double)img->get( x, y, 0);
  55. double g = (double)img->get( x, y, 1);
  56. double b = (double)img->get( x, y, 2);
  57. ColorConversion::ccRGBtoHSV(r, g, b, &h, &s, &v);
  58. imgD->set( x, y, h, 0);
  59. imgD->set( x, y, s, 1);
  60. imgD->set( x, y, v, 2);
  61. }
  62. // remove integer channels
  63. img->freeData();
  64. }
  65. // FIXME: never true because of CachedExample implementation of getXChannel
  66. else
  67. {
  68. img = & ce->getIChannel( CachedExample::I_GRAYVALUES );
  69. // gray values to range [0,1]
  70. imgD->reInit ( img->width(), img->height(), 1 );
  71. for ( int y = 0; y < img->height(); y++ )
  72. for ( int x = 0; x < img->width(); x++ )
  73. {
  74. double g = (double)img->get( x, y, 0) / 255.0;
  75. imgD->set( x, y, g, 0);
  76. }
  77. // remove integer channel
  78. img->freeData();
  79. }
  80. img = NULL;
  81. imgD = NULL;
  82. }
  83. void SemSegConvolutionalTree::initFromConfig( const Config *_conf,
  84. const string &s_confSection )
  85. {
  86. conf = _conf;
  87. saveLoadData = conf->gB ( s_confSection, "save_load_data", false );
  88. fileLocation = conf->gS ( s_confSection, "datafile", "classifier.data" );
  89. fpc = new FPCRandomForests ( _conf, "FPCRandomForests" );
  90. fpc->setMaxClassNo( classNames->getMaxClassno() );
  91. }
  92. /** training function */
  93. void SemSegConvolutionalTree::train ( const MultiDataset *md )
  94. {
  95. if ( saveLoadData && FileMgt::fileExists( fileLocation ) )
  96. {
  97. read( fileLocation );
  98. }
  99. else
  100. {
  101. Examples examples;
  102. // image storage
  103. vector<CachedExample *> imgexamples;
  104. // create pixel-wise training examples
  105. SemSegTools::collectTrainingExamples (
  106. conf,
  107. "FPCRandomForests",
  108. * ( ( *md ) ["train"] ),
  109. *classNames,
  110. examples,
  111. imgexamples );
  112. assert ( examples.size() > 0 );
  113. FeaturePool fp;
  114. ConvolutionFeature cf ( conf );
  115. cf.explode( fp );
  116. for ( vector<CachedExample *>::iterator cei = imgexamples.begin();
  117. cei != imgexamples.end(); cei++ )
  118. convertRGBToHSV ( *cei, cf.isColorMode() );
  119. // start training using random forests
  120. fpc->train( fp, examples);
  121. // save trained classifier to file
  122. if (saveLoadData) save( fileLocation );
  123. // Cleaning up
  124. for ( vector<CachedExample *>::iterator i = imgexamples.begin();
  125. i != imgexamples.end();
  126. i++ )
  127. delete ( *i );
  128. fp.destroy();
  129. }
  130. }
  131. /** classification function */
  132. void SemSegConvolutionalTree::semanticseg(
  133. CachedExample *ce,
  134. Image &segresult,
  135. NICE::MultiChannelImageT<double> &probabilities )
  136. {
  137. // for speed optimization
  138. FPCRandomForests *fpcrf = dynamic_cast<FPCRandomForests *> ( fpc );
  139. int xsize, ysize;
  140. ce->getImageSize ( xsize, ysize );
  141. probabilities.reInit ( xsize, ysize, classNames->getMaxClassno() + 1 );
  142. segresult.resize ( xsize, ysize );
  143. vector<DecisionTree *> forest = fpcrf->getForest ();
  144. DecisionNode *root = forest[0]->getRoot ();
  145. ConvolutionFeature* cf = dynamic_cast<ConvolutionFeature*> (root->f);
  146. convertRGBToHSV( ce, cf->isColorMode() ); //FIXME!!!
  147. Example pce ( ce, 0, 0 );
  148. for ( int y = 0 ; y < ysize ; y++ )
  149. for ( int x = 0 ; x < xsize ; x++ )
  150. {
  151. pce.x = x ;
  152. pce.y = y;
  153. ClassificationResult r = fpcrf->classify ( pce );
  154. segresult.setPixel ( x, y, r.classno );
  155. for ( int i = 0 ; i < ( int ) probabilities.channels(); i++ )
  156. probabilities[i](x,y) = r.scores[i];
  157. }
  158. }
  159. ///////////////////// INTERFACE PERSISTENT /////////////////////
  160. // interface specific methods for store and restore
  161. ///////////////////// INTERFACE PERSISTENT /////////////////////
  162. void SemSegConvolutionalTree::restore( istream &is, int format )
  163. {
  164. //dirty solution to circumvent the const-flag
  165. const_cast<ClassNames*>(this->classNames)->restore ( is, format );
  166. fpc->restore( is, format );
  167. }
  168. void SemSegConvolutionalTree::store ( ostream &os, int format ) const
  169. {
  170. classNames->store( os, format );
  171. fpc->store( os, format );
  172. }
  173. void SemSegConvolutionalTree::clear ( )
  174. {
  175. fpc->clear();
  176. }