SemSegConvolutionalTree.cpp 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. /**
  2. * @file SemSegConvolutionalTree.h
  3. * @brief Semantic Segmentation using Covolutional Trees
  4. * @author Sven Sickert
  5. * @date 10/17/2014
  6. */
  7. #include <iostream>
  8. #include "SemSegConvolutionalTree.h"
  9. #include "SemSegTools.h"
  10. #include "vislearning/classifier/fpclassifier/randomforest/FPCRandomForests.h"
  11. #include "vislearning/features/fpfeatures/ConvolutionFeature.h"
  12. #include "vislearning/baselib/cc.h"
  13. using namespace OBJREC;
  14. using namespace std;
  15. using namespace NICE;
  16. //###################### CONSTRUCTORS #########################//
  17. SemSegConvolutionalTree::SemSegConvolutionalTree () : SemanticSegmentation ()
  18. {
  19. conf = NULL;
  20. saveLoadData = false;
  21. fileLocation = "classifier.data";
  22. fpc = new FPCRandomForests ();
  23. }
  24. SemSegConvolutionalTree::SemSegConvolutionalTree (
  25. const Config *conf,
  26. const ClassNames *classNames )
  27. : SemanticSegmentation( conf, classNames )
  28. {
  29. initFromConfig( conf );
  30. }
  31. //###################### DESTRUCTORS ##########################//
  32. SemSegConvolutionalTree::~SemSegConvolutionalTree ()
  33. {
  34. if ( fpc != NULL )
  35. delete fpc;
  36. }
  37. //#################### MEMBER FUNCTIONS #######################//
  38. void SemSegConvolutionalTree::convertRGBToHSV ( CachedExample *ce ) const
  39. {
  40. NICE::MultiChannelImageT<int> * img = NULL;
  41. NICE::MultiChannelImageT<double> * imgD = NULL;
  42. img = & ce->getIChannel( CachedExample::I_COLOR );
  43. imgD = & ce->getDChannel( CachedExample::D_EOH );
  44. assert( imgD->channels() == 0 );
  45. if ( img->channels() == 3 )
  46. {
  47. imgD->reInit ( img->width(), img->height(), 3 );
  48. for ( int y = 0; y < img->height(); y++ )
  49. for ( int x = 0; x < img->width(); x++ )
  50. {
  51. double h,s,v;
  52. double r = (double)img->get( x, y, 0);
  53. double g = (double)img->get( x, y, 1);
  54. double b = (double)img->get( x, y, 2);
  55. ColorConversion::ccRGBtoHSV(r, g, b, &h, &s, &v);
  56. imgD->set( x, y, h, 0);
  57. imgD->set( x, y, s, 1);
  58. imgD->set( x, y, v, 2);
  59. }
  60. // remove integer channels
  61. img->freeData();
  62. }
  63. else if ( img->channels() == 1 )
  64. {
  65. // gray values to range [0,1]
  66. imgD->reInit ( img->width(), img->height(), 1 );
  67. for ( int y = 0; y < img->height(); y++ )
  68. for ( int x = 0; x < img->width(); x++ )
  69. {
  70. double g = (double)img->get( x, y, 0) / 255.0;
  71. imgD->set( x, y, g, 0);
  72. }
  73. // remove integer channel
  74. img->freeData();
  75. }
  76. img = NULL;
  77. imgD = NULL;
  78. }
  79. void SemSegConvolutionalTree::initFromConfig( const Config *_conf,
  80. const string &s_confSection )
  81. {
  82. conf = _conf;
  83. saveLoadData = conf->gB ( s_confSection, "save_load_data", false );
  84. fileLocation = conf->gS ( s_confSection, "datafile", "classifier.data" );
  85. fpc = new FPCRandomForests ( _conf, "FPCRandomForests" );
  86. fpc->setMaxClassNo( classNames->getMaxClassno() );
  87. }
  88. /** training function */
  89. void SemSegConvolutionalTree::train ( const MultiDataset *md )
  90. {
  91. if ( saveLoadData && FileMgt::fileExists( fileLocation ) )
  92. {
  93. read( fileLocation );
  94. }
  95. else
  96. {
  97. Examples examples;
  98. // image storage
  99. vector<CachedExample *> imgexamples;
  100. // create pixel-wise training examples
  101. SemSegTools::collectTrainingExamples (
  102. conf,
  103. "FPCRandomForests",
  104. * ( ( *md ) ["train"] ),
  105. *classNames,
  106. examples,
  107. imgexamples );
  108. assert ( examples.size() > 0 );
  109. for ( vector<CachedExample *>::iterator cei = imgexamples.begin();
  110. cei != imgexamples.end(); cei++ )
  111. convertRGBToHSV ( *cei );
  112. FeaturePool fp;
  113. ConvolutionFeature cf ( conf );
  114. cf.explode( fp );
  115. // start training using random forests
  116. fpc->train( fp, examples);
  117. // save trained classifier to file
  118. if (saveLoadData) save( fileLocation );
  119. // Cleaning up
  120. for ( vector<CachedExample *>::iterator i = imgexamples.begin();
  121. i != imgexamples.end();
  122. i++ )
  123. delete ( *i );
  124. fp.destroy();
  125. }
  126. }
  127. /** classification function */
  128. void SemSegConvolutionalTree::semanticseg(
  129. CachedExample *ce,
  130. Image &segresult,
  131. NICE::MultiChannelImageT<double> &probabilities )
  132. {
  133. // for speed optimization
  134. FPCRandomForests *fpcrf = dynamic_cast<FPCRandomForests *> ( fpc );
  135. int xsize, ysize;
  136. ce->getImageSize ( xsize, ysize );
  137. probabilities.reInit ( xsize, ysize, classNames->getMaxClassno() + 1 );
  138. segresult.resize ( xsize, ysize );
  139. convertRGBToHSV(ce);
  140. Example pce ( ce, 0, 0 );
  141. for ( int y = 0 ; y < ysize ; y++ )
  142. for ( int x = 0 ; x < xsize ; x++ )
  143. {
  144. pce.x = x ;
  145. pce.y = y;
  146. ClassificationResult r = fpcrf->classify ( pce );
  147. segresult.setPixel ( x, y, r.classno );
  148. for ( int i = 0 ; i < ( int ) probabilities.channels(); i++ )
  149. probabilities[i](x,y) = r.scores[i];
  150. }
  151. }
  152. ///////////////////// INTERFACE PERSISTENT /////////////////////
  153. // interface specific methods for store and restore
  154. ///////////////////// INTERFACE PERSISTENT /////////////////////
  155. void SemSegConvolutionalTree::restore( istream &is, int format )
  156. {
  157. //dirty solution to circumvent the const-flag
  158. const_cast<ClassNames*>(this->classNames)->restore ( is, format );
  159. fpc->restore( is, format );
  160. }
  161. void SemSegConvolutionalTree::store ( ostream &os, int format ) const
  162. {
  163. classNames->store( os, format );
  164. fpc->store( os, format );
  165. }
  166. void SemSegConvolutionalTree::clear ( )
  167. {
  168. fpc->clear();
  169. }