/** * @file SemSegConvolutionalTree.h * @brief Semantic Segmentation using Covolutional Trees * @author Sven Sickert * @date 10/17/2014 */ #include #include "SemSegConvolutionalTree.h" #include "SemSegTools.h" #include "vislearning/classifier/fpclassifier/randomforest/FPCRandomForests.h" #include "vislearning/features/fpfeatures/ConvolutionFeature.h" using namespace OBJREC; using namespace std; using namespace NICE; //###################### CONSTRUCTORS #########################// SemSegConvolutionalTree::SemSegConvolutionalTree () : SemanticSegmentation () { conf = NULL; saveLoadData = false; fileLocation = "classifier.data"; fpc = new FPCRandomForests (); } SemSegConvolutionalTree::SemSegConvolutionalTree ( const Config *conf, const ClassNames *classNames ) : SemanticSegmentation( conf, classNames ) { initFromConfig( conf ); } //###################### DESTRUCTORS ##########################// SemSegConvolutionalTree::~SemSegConvolutionalTree () { if ( fpc != NULL ) delete fpc; } //#################### MEMBER FUNCTIONS #######################// void SemSegConvolutionalTree::initFromConfig( const Config *_conf, const string &s_confSection ) { conf = _conf; saveLoadData = conf->gB ( s_confSection, "save_load_data", false ); fileLocation = conf->gS ( s_confSection, "datafile", "classifier.data" ); fpc = new FPCRandomForests ( _conf, "FPCRandomForests" ); fpc->setMaxClassNo( classNames->getMaxClassno() ); } /** training function */ void SemSegConvolutionalTree::train ( const MultiDataset *md ) { if ( saveLoadData && FileMgt::fileExists( fileLocation ) ) { read( fileLocation ); } else { Examples examples; // image storage vector imgexamples; // create pixel-wise training examples SemSegTools::collectTrainingExamples ( conf, "FPCRandomForests", * ( ( *md ) ["train"] ), *classNames, examples, imgexamples ); assert ( examples.size() > 0 ); FeaturePool fp; ConvolutionFeature cf ( conf ); cf.explode( fp ); // start training using random forests fpc->train( fp, examples); // save trained classifier to file if (saveLoadData) save( fileLocation ); // Cleaning up for ( vector::iterator i = imgexamples.begin(); i != imgexamples.end(); i++ ) delete ( *i ); fp.destroy(); } } /** classification function */ void SemSegConvolutionalTree::semanticseg( CachedExample *ce, Image &segresult, NICE::MultiChannelImageT &probabilities ) { // for speed optimization FPCRandomForests *fpcrf = dynamic_cast ( fpc ); int xsize, ysize; ce->getImageSize ( xsize, ysize ); probabilities.reInit ( xsize, ysize, classNames->getMaxClassno() + 1 ); segresult.resize ( xsize, ysize ); Example pce ( ce, 0, 0 ); for ( int y = 0 ; y < ysize ; y++ ) for ( int x = 0 ; x < xsize ; x++ ) { pce.x = x ; pce.y = y; ClassificationResult r = fpcrf->classify ( pce ); segresult.setPixel ( x, y, r.classno ); for ( int i = 0 ; i < ( int ) probabilities.channels(); i++ ) probabilities[i](x,y) = r.scores[i]; } } ///////////////////// INTERFACE PERSISTENT ///////////////////// // interface specific methods for store and restore ///////////////////// INTERFACE PERSISTENT ///////////////////// void SemSegConvolutionalTree::restore( istream &is, int format ) { //dirty solution to circumvent the const-flag const_cast(this->classNames)->restore ( is, format ); fpc->restore( is, format ); } void SemSegConvolutionalTree::store ( ostream &os, int format ) const { classNames->store( os, format ); fpc->store( os, format ); } void SemSegConvolutionalTree::clear ( ) { fpc->clear(); }