SemSegObliqueTree.cpp 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. /**
  2. * @file SemSegObliqueTree.h
  3. * @brief Semantic Segmentation using Oblique Trees
  4. * @author Sven Sickert
  5. * @date 10/17/2014
  6. */
  7. #include <iostream>
  8. #include "SemSegObliqueTree.h"
  9. #include "SemSegTools.h"
  10. #include "vislearning/classifier/fpclassifier/randomforest/FPCRandomForests.h"
  11. #include "vislearning/features/fpfeatures/ConvolutionFeature.h"
  12. #include "vislearning/baselib/cc.h"
  13. using namespace OBJREC;
  14. using namespace std;
  15. using namespace NICE;
  16. //###################### CONSTRUCTORS #########################//
  17. SemSegObliqueTree::SemSegObliqueTree () : SemanticSegmentation ()
  18. {
  19. conf = NULL;
  20. saveLoadData = false;
  21. fileLocation = "classifier.data";
  22. colorMode = 0;
  23. fpc = new FPCRandomForests ();
  24. }
  25. SemSegObliqueTree::SemSegObliqueTree (
  26. const Config *conf,
  27. const ClassNames *classNames )
  28. : SemanticSegmentation( conf, classNames )
  29. {
  30. initFromConfig( conf );
  31. }
  32. //###################### DESTRUCTORS ##########################//
  33. SemSegObliqueTree::~SemSegObliqueTree ()
  34. {
  35. if ( fpc != NULL )
  36. delete fpc;
  37. }
  38. //#################### MEMBER FUNCTIONS #######################//
  39. void SemSegObliqueTree::preprocessChannels (
  40. CachedExample *ce,
  41. bool isColor ) const
  42. {
  43. NICE::MultiChannelImageT<int> * img = NULL;
  44. NICE::MultiChannelImageT<double> * imgD = NULL;
  45. imgD = & ce->getDChannel( CachedExample::D_EOH );
  46. assert( imgD->channels() == 0 );
  47. if ( isColor )
  48. {
  49. img = & ce->getIChannel( CachedExample::I_COLOR );
  50. imgD->reInit ( img->width(), img->height(), 3 );
  51. for ( int y = 0; y < img->height(); y++ )
  52. for ( int x = 0; x < img->width(); x++ )
  53. {
  54. double r = (double)img->get( x, y, 0);
  55. double g = (double)img->get( x, y, 1);
  56. double b = (double)img->get( x, y, 2);
  57. if ( colorMode == 1 )
  58. {
  59. double h,s,v;
  60. ColorConversion::ccRGBtoHSV(r, g, b, &h, &s, &v);
  61. imgD->set( x, y, h, 0);
  62. imgD->set( x, y, s, 1);
  63. imgD->set( x, y, v, 2);
  64. }
  65. else if ( colorMode == 2 )
  66. {
  67. double cX, cY, cZ, cL, ca, cb;
  68. r /= 255.0;
  69. g /= 255.0;
  70. b /= 255.0;
  71. ColorConversion::ccRGBtoXYZ( r, g, b, &cX, &cY, &cZ, 0 );
  72. ColorConversion::ccXYZtoCIE_Lab( cX, cY, cZ, &cL, &ca, &cb, 0 );
  73. imgD->set( x, y, cL, 0);
  74. imgD->set( x, y, ca, 1);
  75. imgD->set( x, y, cb, 2);
  76. }
  77. else
  78. {
  79. imgD->set( x, y, r/255.0, 0 );
  80. imgD->set( x, y, g/255.0, 1 );
  81. imgD->set( x, y, b/255.0, 2 );
  82. }
  83. }
  84. // remove integer channels
  85. img->freeData();
  86. }
  87. else
  88. {
  89. img = & ce->getIChannel( CachedExample::I_GRAYVALUES );
  90. // gray values to range [0,1]
  91. imgD->reInit ( img->width(), img->height(), 1 );
  92. for ( int y = 0; y < img->height(); y++ )
  93. for ( int x = 0; x < img->width(); x++ )
  94. {
  95. double g = (double)img->get( x, y, 0) / 255.0;
  96. imgD->set( x, y, g, 0);
  97. }
  98. // remove integer channel
  99. img->freeData();
  100. }
  101. img = NULL;
  102. imgD = NULL;
  103. }
  104. void SemSegObliqueTree::initFromConfig( const Config *_conf,
  105. const string &s_confSection )
  106. {
  107. conf = _conf;
  108. saveLoadData = conf->gB ( s_confSection, "save_load_data", false );
  109. colorMode = conf->gI ( s_confSection, "color_mode", 0 );
  110. fileLocation = conf->gS ( s_confSection, "datafile", "classifier.data" );
  111. fpc = new FPCRandomForests ( _conf, "FPCRandomForests" );
  112. fpc->setMaxClassNo( classNames->getMaxClassno() );
  113. }
  114. /** training function */
  115. void SemSegObliqueTree::train ( const MultiDataset *md )
  116. {
  117. if ( saveLoadData && FileMgt::fileExists( fileLocation ) )
  118. {
  119. read( fileLocation );
  120. }
  121. else
  122. {
  123. Examples examples;
  124. // image storage
  125. vector<CachedExample *> imgexamples;
  126. // create pixel-wise training examples
  127. SemSegTools::collectTrainingExamples (
  128. conf,
  129. "FPCRandomForests",
  130. * ( ( *md ) ["train"] ),
  131. *classNames,
  132. examples,
  133. imgexamples );
  134. assert ( examples.size() > 0 );
  135. FeaturePool fp;
  136. ConvolutionFeature cf ( conf );
  137. cf.explode( fp );
  138. for ( vector<CachedExample *>::iterator cei = imgexamples.begin();
  139. cei != imgexamples.end(); cei++ )
  140. preprocessChannels ( *cei, cf.isColorMode() );
  141. // start training using random forests
  142. fpc->train( fp, examples);
  143. // save trained classifier to file
  144. if (saveLoadData) save( fileLocation );
  145. // Cleaning up
  146. for ( vector<CachedExample *>::iterator i = imgexamples.begin();
  147. i != imgexamples.end();
  148. i++ )
  149. delete ( *i );
  150. fp.destroy();
  151. }
  152. }
  153. /** classification function */
  154. void SemSegObliqueTree::semanticseg(
  155. CachedExample *ce,
  156. ImageT<int> &segresult,
  157. NICE::MultiChannelImageT<double> &probabilities )
  158. {
  159. // for speed optimization
  160. FPCRandomForests *fpcrf = dynamic_cast<FPCRandomForests *> ( fpc );
  161. int xsize, ysize;
  162. ce->getImageSize ( xsize, ysize );
  163. probabilities.reInit ( xsize, ysize, classNames->getMaxClassno() + 1 );
  164. segresult.resize ( xsize, ysize );
  165. vector<DecisionTree *> forest = fpcrf->getForest ();
  166. DecisionNode *root = forest[0]->getRoot ();
  167. ConvolutionFeature* cf = dynamic_cast<ConvolutionFeature*> (root->f);
  168. preprocessChannels( ce, cf->isColorMode() ); //FIXME!!!
  169. Example pce ( ce, 0, 0 );
  170. for ( int y = 0 ; y < ysize ; y++ )
  171. for ( int x = 0 ; x < xsize ; x++ )
  172. {
  173. pce.x = x ;
  174. pce.y = y;
  175. ClassificationResult r = fpcrf->classify ( pce );
  176. segresult.setPixel ( x, y, r.classno );
  177. for ( int i = 0 ; i < ( int ) probabilities.channels(); i++ )
  178. probabilities[i](x,y) = r.scores[i];
  179. }
  180. }
  181. ///////////////////// INTERFACE PERSISTENT /////////////////////
  182. // interface specific methods for store and restore
  183. ///////////////////// INTERFACE PERSISTENT /////////////////////
  184. void SemSegObliqueTree::restore( istream &is, int format )
  185. {
  186. //dirty solution to circumvent the const-flag
  187. const_cast<ClassNames*>(this->classNames)->restore ( is, format );
  188. fpc->restore( is, format );
  189. }
  190. void SemSegObliqueTree::store ( ostream &os, int format ) const
  191. {
  192. classNames->store( os, format );
  193. fpc->store( os, format );
  194. }
  195. void SemSegObliqueTree::clear ( )
  196. {
  197. fpc->clear();
  198. }