SemSegObliqueTree.cpp 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. /**
  2. * @file SemSegObliqueTree.h
  3. * @brief Semantic Segmentation using Oblique Trees
  4. * @author Sven Sickert
  5. * @date 10/17/2014
  6. */
  7. #include <iostream>
  8. #include "SemSegObliqueTree.h"
  9. #include "SemSegTools.h"
  10. #include "vislearning/classifier/fpclassifier/randomforest/FPCRandomForests.h"
  11. #include "vislearning/features/fpfeatures/ConvolutionFeature.h"
  12. #include "vislearning/baselib/cc.h"
  13. using namespace OBJREC;
  14. using namespace std;
  15. using namespace NICE;
  16. //###################### CONSTRUCTORS #########################//
  17. SemSegObliqueTree::SemSegObliqueTree () : SemanticSegmentation ()
  18. {
  19. conf = NULL;
  20. saveLoadData = false;
  21. fileLocation = "classifier.data";
  22. colorMode = 0;
  23. fpc = new FPCRandomForests ();
  24. }
  25. SemSegObliqueTree::SemSegObliqueTree (
  26. const Config *conf,
  27. const ClassNames *classNames )
  28. : SemanticSegmentation( conf, classNames )
  29. {
  30. initFromConfig( conf );
  31. }
  32. //###################### DESTRUCTORS ##########################//
  33. SemSegObliqueTree::~SemSegObliqueTree ()
  34. {
  35. if ( fpc != NULL )
  36. delete fpc;
  37. }
  38. //#################### MEMBER FUNCTIONS #######################//
  39. void SemSegObliqueTree::preprocessChannels (
  40. CachedExample *ce,
  41. bool isColor ) const
  42. {
  43. NICE::MultiChannelImage3DT<int> * img = NULL;
  44. NICE::MultiChannelImage3DT<double> * imgD = NULL;
  45. imgD = & ce->getDChannel3( CachedExample::D_EOH );
  46. assert( imgD->channels() == 0 );
  47. if ( isColor )
  48. {
  49. img = & ce->getIChannel3( CachedExample::I_COLOR );
  50. imgD->reInit ( img->width(), img->height(), img->depth(), 3 );
  51. for ( int z = 0; z < img->depth(); z++ )
  52. for ( int y = 0; y < img->height(); y++ )
  53. for ( int x = 0; x < img->width(); x++ )
  54. {
  55. double r = (double)img->get( x, y, z, 0);
  56. double g = (double)img->get( x, y, z, 1);
  57. double b = (double)img->get( x, y, z, 2);
  58. if ( colorMode == 1 )
  59. {
  60. double h,s,v;
  61. ColorConversion::ccRGBtoHSV(r, g, b, &h, &s, &v);
  62. imgD->set( x, y, h, 0);
  63. imgD->set( x, y, s, 1);
  64. imgD->set( x, y, v, 2);
  65. }
  66. else if ( colorMode == 2 )
  67. {
  68. double cX, cY, cZ, cL, ca, cb;
  69. r /= 255.0;
  70. g /= 255.0;
  71. b /= 255.0;
  72. ColorConversion::ccRGBtoXYZ( r, g, b, &cX, &cY, &cZ, 0 );
  73. ColorConversion::ccXYZtoCIE_Lab( cX, cY, cZ, &cL, &ca, &cb, 0 );
  74. imgD->set( x, y, z, cL, 0);
  75. imgD->set( x, y, z, ca, 1);
  76. imgD->set( x, y, z, cb, 2);
  77. }
  78. else
  79. {
  80. imgD->set( x, y, z, r/255.0, 0 );
  81. imgD->set( x, y, z, g/255.0, 1 );
  82. imgD->set( x, y, z, b/255.0, 2 );
  83. }
  84. }
  85. // remove integer channels
  86. img->freeData();
  87. }
  88. else
  89. {
  90. img = & ce->getIChannel3( CachedExample::I_GRAYVALUES );
  91. // gray values to range [0,1]
  92. imgD->reInit ( img->width(), img->height(), img->depth(), 1 );
  93. for ( int z = 0; z < img->depth(); z++ )
  94. for ( int y = 0; y < img->height(); y++ )
  95. for ( int x = 0; x < img->width(); x++ )
  96. {
  97. double g = (double)img->get( x, y, z, 0) / 255.0;
  98. imgD->set( x, y, z, g, 0);
  99. }
  100. // remove integer channel
  101. img->freeData();
  102. }
  103. img = NULL;
  104. imgD = NULL;
  105. }
  106. void SemSegObliqueTree::initFromConfig( const Config *_conf,
  107. const string &s_confSection )
  108. {
  109. conf = _conf;
  110. saveLoadData = conf->gB ( s_confSection, "save_load_data", false );
  111. run3Dseg = conf->gB ( s_confSection, "run_3dseg", false );
  112. colorMode = conf->gI ( s_confSection, "color_mode", 0 );
  113. fileLocation = conf->gS ( s_confSection, "datafile", "classifier.data" );
  114. fpc = new FPCRandomForests ( _conf, "FPCRandomForests" );
  115. fpc->setMaxClassNo( classNames->getMaxClassno() );
  116. }
  117. /** training function */
  118. void SemSegObliqueTree::train ( const MultiDataset *md )
  119. {
  120. if ( saveLoadData && FileMgt::fileExists( fileLocation ) )
  121. {
  122. read( fileLocation );
  123. }
  124. else
  125. {
  126. Examples examples;
  127. // image storage
  128. vector<CachedExample *> imgexamples;
  129. // create pixel-wise training examples
  130. SemSegTools::collectTrainingExamples (
  131. conf,
  132. "FPCRandomForests",
  133. * ( ( *md ) ["train"] ),
  134. *classNames,
  135. examples,
  136. imgexamples,
  137. run3Dseg );
  138. assert ( examples.size() > 0 );
  139. FeaturePool fp;
  140. ConvolutionFeature cf ( conf );
  141. cf.explode( fp );
  142. for ( vector<CachedExample *>::iterator cei = imgexamples.begin();
  143. cei != imgexamples.end(); cei++ )
  144. preprocessChannels ( *cei, cf.isColorMode() );
  145. // start training using random forests
  146. fpc->train( fp, examples);
  147. // save trained classifier to file
  148. if (saveLoadData) save( fileLocation );
  149. // Cleaning up
  150. for ( vector<CachedExample *>::iterator i = imgexamples.begin();
  151. i != imgexamples.end();
  152. i++ )
  153. delete ( *i );
  154. fp.destroy();
  155. }
  156. }
  157. /** classification function */
  158. void SemSegObliqueTree::semanticseg(
  159. CachedExample *ce,
  160. ImageT<int> &segresult,
  161. NICE::MultiChannelImageT<double> &probabilities )
  162. {
  163. // for speed optimization
  164. FPCRandomForests *fpcrf = dynamic_cast<FPCRandomForests *> ( fpc );
  165. int xsize, ysize;
  166. ce->getImageSize ( xsize, ysize );
  167. probabilities.reInit ( xsize, ysize, classNames->getMaxClassno() + 1 );
  168. segresult.resize ( xsize, ysize );
  169. vector<DecisionTree *> forest = fpcrf->getForest ();
  170. DecisionNode *root = forest[0]->getRoot ();
  171. ConvolutionFeature* cf = dynamic_cast<ConvolutionFeature*> (root->f);
  172. preprocessChannels( ce, cf->isColorMode() ); //FIXME!!!
  173. Example pce ( ce, 0, 0 );
  174. for ( int y = 0 ; y < ysize ; y++ )
  175. for ( int x = 0 ; x < xsize ; x++ )
  176. {
  177. pce.x = x;
  178. pce.y = y;
  179. ClassificationResult r = fpcrf->classify ( pce );
  180. segresult.setPixel ( x, y, r.classno );
  181. for ( int i = 0 ; i < ( int ) probabilities.channels(); i++ )
  182. probabilities[i](x,y) = r.scores[i];
  183. }
  184. }
  185. /** classification function 3d */
  186. void SemSegObliqueTree::semanticseg(
  187. OBJREC::CachedExample *ce,
  188. NICE::MultiChannelImageT<int> &segresult,
  189. NICE::MultiChannelImage3DT<double> &probabilities )
  190. {
  191. // for speed optimization
  192. FPCRandomForests *fpcrf = dynamic_cast<FPCRandomForests *> ( fpc );
  193. int xsize, ysize, zsize;
  194. ce->getImageSize3 ( xsize, ysize, zsize );
  195. probabilities.reInit ( xsize, ysize, zsize, classNames->getMaxClassno() + 1 );
  196. segresult.reInit ( xsize, ysize, (unsigned int)zsize );
  197. vector<DecisionTree *> forest = fpcrf->getForest ();
  198. DecisionNode *root = forest[0]->getRoot ();
  199. ConvolutionFeature* cf = dynamic_cast<ConvolutionFeature*> (root->f);
  200. preprocessChannels( ce, cf->isColorMode() ); //FIXME!!!
  201. Example pce ( ce, 0, 0, 0 );
  202. for ( int z = 0 ; z < zsize; z++ )
  203. for ( int y = 0 ; y < ysize ; y++ )
  204. for ( int x = 0 ; x < xsize ; x++ )
  205. {
  206. pce.x = x;
  207. pce.y = y;
  208. pce.z = z;
  209. ClassificationResult r = fpcrf->classify ( pce );
  210. segresult.set ( x, y, z, r.classno );
  211. for ( int i = 0 ; i < ( int ) probabilities.channels(); i++ )
  212. probabilities[i](x,y,z) = r.scores[i];
  213. }
  214. }
  215. ///////////////////// INTERFACE PERSISTENT /////////////////////
  216. // interface specific methods for store and restore
  217. ///////////////////// INTERFACE PERSISTENT /////////////////////
  218. void SemSegObliqueTree::restore( istream &is, int format )
  219. {
  220. //dirty solution to circumvent the const-flag
  221. const_cast<ClassNames*>(this->classNames)->restore ( is, format );
  222. fpc->restore( is, format );
  223. }
  224. void SemSegObliqueTree::store ( ostream &os, int format ) const
  225. {
  226. classNames->store( os, format );
  227. fpc->store( os, format );
  228. }
  229. void SemSegObliqueTree::clear ( )
  230. {
  231. fpc->clear();
  232. }