SemSegConvolutionalTree.cpp 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. /**
  2. * @file SemSegConvolutionalTree.h
  3. * @brief Semantic Segmentation using Covolutional Trees
  4. * @author Sven Sickert
  5. * @date 10/17/2014
  6. */
  7. #include <iostream>
  8. #include "SemSegConvolutionalTree.h"
  9. #include "SemSegTools.h"
  10. #include "vislearning/classifier/fpclassifier/randomforest/FPCRandomForests.h"
  11. #include "vislearning/features/fpfeatures/ConvolutionFeature.h"
  12. using namespace OBJREC;
  13. using namespace std;
  14. using namespace NICE;
  15. //###################### CONSTRUCTORS #########################//
  16. SemSegConvolutionalTree::SemSegConvolutionalTree () : SemanticSegmentation ()
  17. {
  18. conf = NULL;
  19. saveLoadData = false;
  20. fileLocation = "classifier.data";
  21. fpc = new FPCRandomForests ();
  22. }
  23. SemSegConvolutionalTree::SemSegConvolutionalTree (
  24. const Config *conf,
  25. const ClassNames *classNames )
  26. : SemanticSegmentation( conf, classNames )
  27. {
  28. initFromConfig( conf );
  29. }
  30. //###################### DESTRUCTORS ##########################//
  31. SemSegConvolutionalTree::~SemSegConvolutionalTree ()
  32. {
  33. if ( fpc != NULL )
  34. delete fpc;
  35. }
  36. //#################### MEMBER FUNCTIONS #######################//
  37. void SemSegConvolutionalTree::initFromConfig( const Config *_conf,
  38. const string &s_confSection )
  39. {
  40. conf = _conf;
  41. saveLoadData = conf->gB ( s_confSection, "save_load_data", false );
  42. fileLocation = conf->gS ( s_confSection, "datafile", "classifier.data" );
  43. fpc = new FPCRandomForests ( _conf, "FPCRandomForests" );
  44. fpc->setMaxClassNo( classNames->getMaxClassno() );
  45. }
  46. /** training function */
  47. void SemSegConvolutionalTree::train ( const MultiDataset *md )
  48. {
  49. if ( saveLoadData && FileMgt::fileExists( fileLocation ) )
  50. {
  51. read( fileLocation );
  52. }
  53. else
  54. {
  55. Examples examples;
  56. // image storage
  57. vector<CachedExample *> imgexamples;
  58. // create pixel-wise training examples
  59. SemSegTools::collectTrainingExamples (
  60. conf,
  61. "FPCRandomForests",
  62. * ( ( *md ) ["train"] ),
  63. *classNames,
  64. examples,
  65. imgexamples );
  66. assert ( examples.size() > 0 );
  67. FeaturePool fp;
  68. ConvolutionFeature cf ( conf );
  69. cf.explode( fp );
  70. // start training using random forests
  71. fpc->train( fp, examples);
  72. // save trained classifier to file
  73. if (saveLoadData) save( fileLocation );
  74. // Cleaning up
  75. for ( vector<CachedExample *>::iterator i = imgexamples.begin();
  76. i != imgexamples.end();
  77. i++ )
  78. delete ( *i );
  79. fp.destroy();
  80. }
  81. }
  82. /** classification function */
  83. void SemSegConvolutionalTree::semanticseg(
  84. CachedExample *ce,
  85. Image &segresult,
  86. NICE::MultiChannelImageT<double> &probabilities )
  87. {
  88. // for speed optimization
  89. FPCRandomForests *fpcrf = dynamic_cast<FPCRandomForests *> ( fpc );
  90. int xsize, ysize;
  91. ce->getImageSize ( xsize, ysize );
  92. probabilities.reInit ( xsize, ysize, classNames->getMaxClassno() + 1 );
  93. segresult.resize ( xsize, ysize );
  94. Example pce ( ce, 0, 0 );
  95. for ( int y = 0 ; y < ysize ; y++ )
  96. for ( int x = 0 ; x < xsize ; x++ )
  97. {
  98. pce.x = x ;
  99. pce.y = y;
  100. ClassificationResult r = fpcrf->classify ( pce );
  101. segresult.setPixel ( x, y, r.classno );
  102. for ( int i = 0 ; i < ( int ) probabilities.channels(); i++ )
  103. probabilities[i](x,y) = r.scores[i];
  104. }
  105. }
  106. ///////////////////// INTERFACE PERSISTENT /////////////////////
  107. // interface specific methods for store and restore
  108. ///////////////////// INTERFACE PERSISTENT /////////////////////
  109. void SemSegConvolutionalTree::restore( istream &is, int format )
  110. {
  111. //dirty solution to circumvent the const-flag
  112. const_cast<ClassNames*>(this->classNames)->restore ( is, format );
  113. fpc->restore( is, format );
  114. }
  115. void SemSegConvolutionalTree::store ( ostream &os, int format ) const
  116. {
  117. classNames->store( os, format );
  118. fpc->store( os, format );
  119. }
  120. void SemSegConvolutionalTree::clear ( )
  121. {
  122. fpc->clear();
  123. }