/** * @file SemSegConvolutionalTree.h * @brief Semantic Segmentation using Covolutional Trees * @author Sven Sickert * @date 10/17/2014 */ #include <iostream> #include "SemSegConvolutionalTree.h" #include "SemSegTools.h" #include "vislearning/classifier/fpclassifier/randomforest/FPCRandomForests.h" #include "vislearning/features/fpfeatures/ConvolutionFeature.h" #include "vislearning/baselib/cc.h" using namespace OBJREC; using namespace std; using namespace NICE; //###################### CONSTRUCTORS #########################// SemSegConvolutionalTree::SemSegConvolutionalTree () : SemanticSegmentation () { conf = NULL; saveLoadData = false; fileLocation = "classifier.data"; fpc = new FPCRandomForests (); } SemSegConvolutionalTree::SemSegConvolutionalTree ( const Config *conf, const ClassNames *classNames ) : SemanticSegmentation( conf, classNames ) { initFromConfig( conf ); } //###################### DESTRUCTORS ##########################// SemSegConvolutionalTree::~SemSegConvolutionalTree () { if ( fpc != NULL ) delete fpc; } //#################### MEMBER FUNCTIONS #######################// void SemSegConvolutionalTree::convertRGBToHSV ( CachedExample *ce ) const { assert( imgHSV->channels() == 0 ); NICE::MultiChannelImageT<int> * img = NULL; NICE::MultiChannelImageT<double> * imgHSV = NULL; img = & ce->getIChannel( CachedExample::I_COLOR ); imgHSV = & ce->getDChannel( CachedExample::D_EOH ); if ( img->channels() == 3 ) { imgHSV->reInit ( img->width(), img->height(), 3 ); for ( int y = 0; y < img->height(); y++ ) for ( int x = 0; x < img->width(); x++ ) { double h,s,v; double r = (double)img->get( x, y, 0); double g = (double)img->get( x, y, 1); double b = (double)img->get( x, y, 2); ColorConversion::ccRGBtoHSV(r, g, b, &h, &s, &v); imgHSV->set( x, y, h, 0); imgHSV->set( x, y, h, 1); imgHSV->set( x, y, h, 2); } // remove r,g,b (integer) channels img->freeData(); img = NULL; } else { imgHSV = NULL; } } void SemSegConvolutionalTree::initFromConfig( const Config *_conf, const string &s_confSection ) { conf = _conf; saveLoadData = conf->gB ( s_confSection, "save_load_data", false ); fileLocation = conf->gS ( s_confSection, "datafile", "classifier.data" ); fpc = new FPCRandomForests ( _conf, "FPCRandomForests" ); fpc->setMaxClassNo( classNames->getMaxClassno() ); } /** training function */ void SemSegConvolutionalTree::train ( const MultiDataset *md ) { if ( saveLoadData && FileMgt::fileExists( fileLocation ) ) { read( fileLocation ); } else { Examples examples; // image storage vector<CachedExample *> imgexamples; // create pixel-wise training examples SemSegTools::collectTrainingExamples ( conf, "FPCRandomForests", * ( ( *md ) ["train"] ), *classNames, examples, imgexamples ); assert ( examples.size() > 0 ); for ( vector<CachedExample *>::iterator cei = imgexamples.begin(); cei != imgexamples.end(); cei++ ) convertRGBToHSV ( *cei ); FeaturePool fp; ConvolutionFeature cf ( conf ); cf.explode( fp ); // start training using random forests fpc->train( fp, examples); // save trained classifier to file if (saveLoadData) save( fileLocation ); // Cleaning up for ( vector<CachedExample *>::iterator i = imgexamples.begin(); i != imgexamples.end(); i++ ) delete ( *i ); fp.destroy(); } } /** classification function */ void SemSegConvolutionalTree::semanticseg( CachedExample *ce, Image &segresult, NICE::MultiChannelImageT<double> &probabilities ) { // for speed optimization FPCRandomForests *fpcrf = dynamic_cast<FPCRandomForests *> ( fpc ); int xsize, ysize; ce->getImageSize ( xsize, ysize ); probabilities.reInit ( xsize, ysize, classNames->getMaxClassno() + 1 ); segresult.resize ( xsize, ysize ); convertRGBToHSV(ce); Example pce ( ce, 0, 0 ); for ( int y = 0 ; y < ysize ; y++ ) for ( int x = 0 ; x < xsize ; x++ ) { pce.x = x ; pce.y = y; ClassificationResult r = fpcrf->classify ( pce ); segresult.setPixel ( x, y, r.classno ); for ( int i = 0 ; i < ( int ) probabilities.channels(); i++ ) probabilities[i](x,y) = r.scores[i]; } } ///////////////////// INTERFACE PERSISTENT ///////////////////// // interface specific methods for store and restore ///////////////////// INTERFACE PERSISTENT ///////////////////// void SemSegConvolutionalTree::restore( istream &is, int format ) { //dirty solution to circumvent the const-flag const_cast<ClassNames*>(this->classNames)->restore ( is, format ); fpc->restore( is, format ); } void SemSegConvolutionalTree::store ( ostream &os, int format ) const { classNames->store( os, format ); fpc->store( os, format ); } void SemSegConvolutionalTree::clear ( ) { fpc->clear(); }