SemSegConvolutionalTree.cpp 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. /**
  2. * @file SemSegConvolutionalTree.h
  3. * @brief Semantic Segmentation using Covolutional Trees
  4. * @author Sven Sickert
  5. * @date 10/17/2014
  6. */
  7. #include <iostream>
  8. #include "SemSegConvolutionalTree.h"
  9. #include "SemSegTools.h"
  10. #include "vislearning/classifier/fpclassifier/randomforest/FPCRandomForests.h"
  11. #include "vislearning/features/fpfeatures/ConvolutionFeature.h"
  12. #include "vislearning/baselib/cc.h"
  13. using namespace OBJREC;
  14. using namespace std;
  15. using namespace NICE;
  16. //###################### CONSTRUCTORS #########################//
  17. SemSegConvolutionalTree::SemSegConvolutionalTree () : SemanticSegmentation ()
  18. {
  19. conf = NULL;
  20. saveLoadData = false;
  21. fileLocation = "classifier.data";
  22. fpc = new FPCRandomForests ();
  23. }
  24. SemSegConvolutionalTree::SemSegConvolutionalTree (
  25. const Config *conf,
  26. const ClassNames *classNames )
  27. : SemanticSegmentation( conf, classNames )
  28. {
  29. initFromConfig( conf );
  30. }
  31. //###################### DESTRUCTORS ##########################//
  32. SemSegConvolutionalTree::~SemSegConvolutionalTree ()
  33. {
  34. if ( fpc != NULL )
  35. delete fpc;
  36. }
  37. //#################### MEMBER FUNCTIONS #######################//
  38. void SemSegConvolutionalTree::convertRGBToHSV ( CachedExample *ce ) const
  39. {
  40. assert( imgHSV->channels() == 0 );
  41. NICE::MultiChannelImageT<int> * img = NULL;
  42. NICE::MultiChannelImageT<double> * imgHSV = NULL;
  43. img = & ce->getIChannel( CachedExample::I_COLOR );
  44. imgHSV = & ce->getDChannel( CachedExample::D_EOH );
  45. if ( img->channels() == 3 )
  46. {
  47. imgHSV->reInit ( img->width(), img->height(), 3 );
  48. for ( int y = 0; y < img->height(); y++ )
  49. for ( int x = 0; x < img->width(); x++ )
  50. {
  51. double h,s,v;
  52. double r = (double)img->get( x, y, 0);
  53. double g = (double)img->get( x, y, 1);
  54. double b = (double)img->get( x, y, 2);
  55. ColorConversion::ccRGBtoHSV(r, g, b, &h, &s, &v);
  56. imgHSV->set( x, y, h, 0);
  57. imgHSV->set( x, y, h, 1);
  58. imgHSV->set( x, y, h, 2);
  59. }
  60. // remove r,g,b (integer) channels
  61. img->freeData();
  62. img = NULL;
  63. }
  64. else
  65. {
  66. imgHSV = NULL;
  67. }
  68. }
  69. void SemSegConvolutionalTree::initFromConfig( const Config *_conf,
  70. const string &s_confSection )
  71. {
  72. conf = _conf;
  73. saveLoadData = conf->gB ( s_confSection, "save_load_data", false );
  74. fileLocation = conf->gS ( s_confSection, "datafile", "classifier.data" );
  75. fpc = new FPCRandomForests ( _conf, "FPCRandomForests" );
  76. fpc->setMaxClassNo( classNames->getMaxClassno() );
  77. }
  78. /** training function */
  79. void SemSegConvolutionalTree::train ( const MultiDataset *md )
  80. {
  81. if ( saveLoadData && FileMgt::fileExists( fileLocation ) )
  82. {
  83. read( fileLocation );
  84. }
  85. else
  86. {
  87. Examples examples;
  88. // image storage
  89. vector<CachedExample *> imgexamples;
  90. // create pixel-wise training examples
  91. SemSegTools::collectTrainingExamples (
  92. conf,
  93. "FPCRandomForests",
  94. * ( ( *md ) ["train"] ),
  95. *classNames,
  96. examples,
  97. imgexamples );
  98. assert ( examples.size() > 0 );
  99. for ( vector<CachedExample *>::iterator cei = imgexamples.begin();
  100. cei != imgexamples.end(); cei++ )
  101. convertRGBToHSV ( *cei );
  102. FeaturePool fp;
  103. ConvolutionFeature cf ( conf );
  104. cf.explode( fp );
  105. // start training using random forests
  106. fpc->train( fp, examples);
  107. // save trained classifier to file
  108. if (saveLoadData) save( fileLocation );
  109. // Cleaning up
  110. for ( vector<CachedExample *>::iterator i = imgexamples.begin();
  111. i != imgexamples.end();
  112. i++ )
  113. delete ( *i );
  114. fp.destroy();
  115. }
  116. }
  117. /** classification function */
  118. void SemSegConvolutionalTree::semanticseg(
  119. CachedExample *ce,
  120. Image &segresult,
  121. NICE::MultiChannelImageT<double> &probabilities )
  122. {
  123. // for speed optimization
  124. FPCRandomForests *fpcrf = dynamic_cast<FPCRandomForests *> ( fpc );
  125. int xsize, ysize;
  126. ce->getImageSize ( xsize, ysize );
  127. probabilities.reInit ( xsize, ysize, classNames->getMaxClassno() + 1 );
  128. segresult.resize ( xsize, ysize );
  129. convertRGBToHSV(ce);
  130. Example pce ( ce, 0, 0 );
  131. for ( int y = 0 ; y < ysize ; y++ )
  132. for ( int x = 0 ; x < xsize ; x++ )
  133. {
  134. pce.x = x ;
  135. pce.y = y;
  136. ClassificationResult r = fpcrf->classify ( pce );
  137. segresult.setPixel ( x, y, r.classno );
  138. for ( int i = 0 ; i < ( int ) probabilities.channels(); i++ )
  139. probabilities[i](x,y) = r.scores[i];
  140. }
  141. }
  142. ///////////////////// INTERFACE PERSISTENT /////////////////////
  143. // interface specific methods for store and restore
  144. ///////////////////// INTERFACE PERSISTENT /////////////////////
  145. void SemSegConvolutionalTree::restore( istream &is, int format )
  146. {
  147. //dirty solution to circumvent the const-flag
  148. const_cast<ClassNames*>(this->classNames)->restore ( is, format );
  149. fpc->restore( is, format );
  150. }
  151. void SemSegConvolutionalTree::store ( ostream &os, int format ) const
  152. {
  153. classNames->store( os, format );
  154. fpc->store( os, format );
  155. }
  156. void SemSegConvolutionalTree::clear ( )
  157. {
  158. fpc->clear();
  159. }