SemSegObliqueTree.cpp 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. /**
  2. * @file SemSegObliqueTree.h
  3. * @brief Semantic Segmentation using Oblique Trees
  4. * @author Sven Sickert
  5. * @date 10/17/2014
  6. */
  7. #include <iostream>
  8. #include "SemSegObliqueTree.h"
  9. #include "SemSegTools.h"
  10. #include "vislearning/classifier/fpclassifier/randomforest/FPCRandomForests.h"
  11. #include "vislearning/features/fpfeatures/ConvolutionFeature.h"
  12. #include "vislearning/baselib/cc.h"
  13. using namespace OBJREC;
  14. using namespace std;
  15. using namespace NICE;
  16. //###################### CONSTRUCTORS #########################//
  17. SemSegObliqueTree::SemSegObliqueTree () : SemanticSegmentation ()
  18. {
  19. conf = NULL;
  20. saveLoadData = false;
  21. fileLocation = "classifier.data";
  22. colorMode = 0;
  23. fpc = new FPCRandomForests ();
  24. }
  25. SemSegObliqueTree::SemSegObliqueTree (
  26. const Config *conf,
  27. const ClassNames *classNames )
  28. : SemanticSegmentation( conf, classNames )
  29. {
  30. initFromConfig( conf );
  31. }
  32. //###################### DESTRUCTORS ##########################//
  33. SemSegObliqueTree::~SemSegObliqueTree ()
  34. {
  35. if ( fpc != NULL )
  36. delete fpc;
  37. }
  38. //#################### MEMBER FUNCTIONS #######################//
  39. void SemSegObliqueTree::preprocessChannels (
  40. CachedExample *ce,
  41. bool isColor ) const
  42. {
  43. NICE::MultiChannelImage3DT<int> * img = NULL;
  44. NICE::MultiChannelImage3DT<double> * imgD = NULL;
  45. imgD = & ce->getDChannel3( CachedExample::D_EOH );
  46. assert( imgD->channels() == 0 );
  47. if ( isColor )
  48. {
  49. img = & ce->getIChannel3( CachedExample::I_COLOR );
  50. imgD->reInit ( img->width(), img->height(), img->depth(), 3 );
  51. for ( int z = 0; z < img->depth(); z++ )
  52. for ( int y = 0; y < img->height(); y++ )
  53. for ( int x = 0; x < img->width(); x++ )
  54. {
  55. double r = (double)img->get( x, y, z, 0);
  56. double g = (double)img->get( x, y, z, 1);
  57. double b = (double)img->get( x, y, z, 2);
  58. if ( colorMode == 1 )
  59. {
  60. double h,s,v;
  61. ColorConversion::ccRGBtoHSV(r, g, b, &h, &s, &v);
  62. imgD->set( x, y, h, 0);
  63. imgD->set( x, y, s, 1);
  64. imgD->set( x, y, v, 2);
  65. }
  66. else if ( colorMode == 2 )
  67. {
  68. double cX, cY, cZ, cL, ca, cb;
  69. r /= 255.0;
  70. g /= 255.0;
  71. b /= 255.0;
  72. ColorConversion::ccRGBtoXYZ( r, g, b, &cX, &cY, &cZ, 0 );
  73. ColorConversion::ccXYZtoCIE_Lab( cX, cY, cZ, &cL, &ca, &cb, 0 );
  74. imgD->set( x, y, z, cL, 0);
  75. imgD->set( x, y, z, ca, 1);
  76. imgD->set( x, y, z, cb, 2);
  77. }
  78. else
  79. {
  80. imgD->set( x, y, z, r/255.0, 0 );
  81. imgD->set( x, y, z, g/255.0, 1 );
  82. imgD->set( x, y, z, b/255.0, 2 );
  83. }
  84. }
  85. // remove integer channels
  86. img->freeData();
  87. }
  88. else
  89. {
  90. img = & ce->getIChannel3( CachedExample::I_GRAYVALUES );
  91. // gray values to range [0,1]
  92. imgD->reInit ( img->width(), img->height(), img->depth(), 1 );
  93. for ( int z = 0; z < img->depth(); z++ )
  94. for ( int y = 0; y < img->height(); y++ )
  95. for ( int x = 0; x < img->width(); x++ )
  96. {
  97. double g = (double)img->get( x, y, z, 0) / 255.0;
  98. imgD->set( x, y, z, g, 0);
  99. }
  100. // remove integer channel
  101. img->freeData();
  102. }
  103. img = NULL;
  104. imgD = NULL;
  105. }
  106. void SemSegObliqueTree::initFromConfig( const Config *_conf,
  107. const string &s_confSection )
  108. {
  109. conf = _conf;
  110. saveLoadData = conf->gB ( s_confSection, "save_load_data", false );
  111. run3Dseg = conf->gB ( s_confSection, "run_3dseg", false );
  112. colorMode = conf->gI ( s_confSection, "color_mode", 0 );
  113. fileLocation = conf->gS ( s_confSection, "datafile", "classifier.data" );
  114. fpc = new FPCRandomForests ( _conf, "FPCRandomForests" );
  115. fpc->setMaxClassNo( classNames->getMaxClassno() );
  116. }
  117. /** training function */
  118. void SemSegObliqueTree::train ( const MultiDataset *md )
  119. {
  120. if ( saveLoadData && FileMgt::fileExists( fileLocation ) )
  121. {
  122. read( fileLocation );
  123. }
  124. else
  125. {
  126. Examples examples;
  127. // image storage
  128. vector<CachedExample *> imgexamples;
  129. // create pixel-wise training examples
  130. SemSegTools::collectTrainingExamples (
  131. conf,
  132. "FPCRandomForests",
  133. * ( ( *md ) ["train"] ),
  134. *classNames,
  135. examples,
  136. imgexamples,
  137. run3Dseg );
  138. assert ( examples.size() > 0 );
  139. FeaturePool fp;
  140. ConvolutionFeature cf ( conf );
  141. cf.explode( fp );
  142. for ( vector<CachedExample *>::iterator cei = imgexamples.begin();
  143. cei != imgexamples.end(); cei++ )
  144. preprocessChannels ( *cei, cf.isColorMode() );
  145. // start training using random forests
  146. fpc->train( fp, examples);
  147. // save trained classifier to file
  148. if (saveLoadData) save( fileLocation );
  149. // Cleaning up
  150. for ( vector<CachedExample *>::iterator i = imgexamples.begin();
  151. i != imgexamples.end();
  152. i++ )
  153. delete ( *i );
  154. fp.destroy();
  155. }
  156. }
  157. /** classification function */
  158. void SemSegObliqueTree::semanticseg(
  159. CachedExample *ce,
  160. ImageT<int> &segresult,
  161. NICE::MultiChannelImageT<double> &probabilities )
  162. {
  163. // for speed optimization
  164. FPCRandomForests *fpcrf = dynamic_cast<FPCRandomForests *> ( fpc );
  165. int xsize, ysize;
  166. ce->getImageSize ( xsize, ysize );
  167. probabilities.reInit ( xsize, ysize, classNames->getMaxClassno() + 1 );
  168. segresult.resize ( xsize, ysize );
  169. vector<DecisionTree *> forest = fpcrf->getForest ();
  170. DecisionNode *root = forest[0]->getRoot ();
  171. ConvolutionFeature* cf = dynamic_cast<ConvolutionFeature*> (root->f);
  172. preprocessChannels( ce, cf->isColorMode() ); //FIXME!!!
  173. Example pce ( ce, 0, 0 );
  174. for ( int y = 0 ; y < ysize ; y++ )
  175. for ( int x = 0 ; x < xsize ; x++ )
  176. {
  177. pce.x = x ;
  178. pce.y = y;
  179. ClassificationResult r = fpcrf->classify ( pce );
  180. segresult.setPixel ( x, y, r.classno );
  181. for ( int i = 0 ; i < ( int ) probabilities.channels(); i++ )
  182. probabilities[i](x,y) = r.scores[i];
  183. }
  184. }
  185. ///////////////////// INTERFACE PERSISTENT /////////////////////
  186. // interface specific methods for store and restore
  187. ///////////////////// INTERFACE PERSISTENT /////////////////////
  188. void SemSegObliqueTree::restore( istream &is, int format )
  189. {
  190. //dirty solution to circumvent the const-flag
  191. const_cast<ClassNames*>(this->classNames)->restore ( is, format );
  192. fpc->restore( is, format );
  193. }
  194. void SemSegObliqueTree::store ( ostream &os, int format ) const
  195. {
  196. classNames->store( os, format );
  197. fpc->store( os, format );
  198. }
  199. void SemSegObliqueTree::clear ( )
  200. {
  201. fpc->clear();
  202. }