Browse Source

added class SemSegConvolutionalTree

Sven Sickert 10 years ago
parent
commit
11b3954501
2 changed files with 261 additions and 0 deletions
  1. 157 0
      semseg/SemSegConvolutionalTree.cpp
  2. 104 0
      semseg/SemSegConvolutionalTree.h

+ 157 - 0
semseg/SemSegConvolutionalTree.cpp

@@ -0,0 +1,157 @@
+/**
+ * @file SemSegConvolutionalTree.h
+ * @brief Semantic Segmentation using Covolutional Trees
+ * @author Sven Sickert
+ * @date 10/17/2014
+
+*/
+
+#include <iostream>
+
+#include "SemSegConvolutionalTree.h"
+#include "SemSegTools.h"
+
+#include "vislearning/classifier/fpclassifier/randomforest/FPCRandomForests.h"
+#include "vislearning/features/fpfeatures/ConvolutionFeature.h"
+
+using namespace OBJREC;
+
+using namespace std;
+using namespace NICE;
+
+//###################### CONSTRUCTORS #########################//
+
+SemSegConvolutionalTree::SemSegConvolutionalTree () : SemanticSegmentation ()
+{
+    conf = NULL;
+
+    saveLoadData = false;
+    fileLocation = "classifier.data";
+
+    fpc = new FPCRandomForests ();
+}
+
+SemSegConvolutionalTree::SemSegConvolutionalTree (
+        const Config *conf,
+        const ClassNames *classNames )
+    : SemanticSegmentation( conf, classNames )
+{
+    initFromConfig( conf );
+}
+
+//###################### DESTRUCTORS ##########################//
+
+SemSegConvolutionalTree::~SemSegConvolutionalTree ()
+{
+    if ( fpc != NULL )
+        delete fpc;
+}
+
+//#################### MEMBER FUNCTIONS #######################//
+
+void SemSegConvolutionalTree::initFromConfig( const Config *_conf,
+                                         const string &s_confSection )
+{
+    conf = _conf;
+    saveLoadData = conf->gB ( s_confSection, "save_load_data", false );
+    fileLocation = conf->gS ( s_confSection, "datafile", "classifier.data" );
+
+    fpc = new FPCRandomForests ( _conf, "FPCRandomForests" );
+    fpc->setMaxClassNo( classNames->getMaxClassno() );
+}
+
+/** training function */
+void SemSegConvolutionalTree::train ( const MultiDataset *md )
+{
+    if ( saveLoadData && FileMgt::fileExists( fileLocation ) )
+    {
+        read( fileLocation );
+    }
+    else
+    {
+        Examples examples;
+
+        // image storage
+        vector<CachedExample *> imgexamples;
+
+        // create pixel-wise training examples
+        SemSegTools::collectTrainingExamples (
+          conf,
+          "FPCRandomForests",
+          * ( ( *md ) ["train"] ),
+          *classNames,
+          examples,
+          imgexamples );
+
+        assert ( examples.size() > 0 );
+
+        FeaturePool fp;
+        ConvolutionFeature cf ( conf );
+        cf.explode( fp );
+
+        // start training using random forests
+        fpc->train( fp, examples);
+
+        // save trained classifier to file
+        save( fileLocation );
+
+        // Cleaning up
+        for ( vector<CachedExample *>::iterator i = imgexamples.begin();
+              i != imgexamples.end();
+              i++ )
+            delete ( *i );
+
+        fp.destroy();
+    }
+}
+
+/** classification function */
+void SemSegConvolutionalTree::semanticseg(
+        CachedExample *ce,
+        Image &segresult,
+        NICE::MultiChannelImageT<double> &probabilities )
+{
+    // for speed optimization
+    FPCRandomForests *fpcrf = dynamic_cast<FPCRandomForests *> ( fpc );
+
+    int xsize, ysize;
+    ce->getImageSize ( xsize, ysize );
+    probabilities.reInit ( xsize, ysize, classNames->getMaxClassno() + 1 );
+    segresult.resize ( xsize, ysize );
+
+    Example pce ( ce, 0, 0 );
+    long int offset = 0;
+    for ( int y = 0 ; y < ysize ; y++ )
+      for ( int x = 0 ; x < xsize ; x++, offset++ )
+      {
+        pce.x = x ;
+        pce.y = y;
+        ClassificationResult r = fpcrf->classify ( pce );
+        segresult.setPixel ( x, y, r.classno );
+        for ( int i = 0 ; i < ( int ) probabilities.channels(); i++ )
+          probabilities[i](x,y) = r.scores[i];
+      }
+}
+
+///////////////////// INTERFACE PERSISTENT /////////////////////
+// interface specific methods for store and restore
+///////////////////// INTERFACE PERSISTENT /////////////////////
+
+void SemSegConvolutionalTree::restore( istream &is, int format )
+{
+    //dirty solution to circumvent the const-flag
+    const_cast<ClassNames*>(this->classNames)->restore ( is, format );
+
+    fpc->restore( is, format );
+}
+
+void SemSegConvolutionalTree::store ( ostream &os, int format ) const
+{
+    classNames->store( os, format );
+    fpc->store( os, format );
+}
+
+void SemSegConvolutionalTree::clear ( )
+{
+    fpc->clear();
+}

+ 104 - 0
semseg/SemSegConvolutionalTree.h

@@ -0,0 +1,104 @@
+/**
+ * @file SemSegConvolutionalTree.h
+ * @brief Semantic Segmentation using Covolutional Trees
+ * @author Sven Sickert
+ * @date 10/17/2014
+
+*/
+
+#ifndef SEMSEGCONVOLUTIONALTREEINCLUDE
+#define SEMSEGCONVOLUTIONALTREEINCLUDE
+
+// nice-core includes
+
+// nice-vislearning includes
+#include "vislearning/classifier/classifierbase/FeaturePoolClassifier.h"
+
+// nice-semseg includes
+#include "SemanticSegmentation.h"
+
+namespace OBJREC
+{
+
+class SemSegConvolutionalTree : public SemanticSegmentation
+{
+    private:
+
+        /** pointer to config file */
+        const NICE::Config *conf;
+
+        /** save / load trained classifier */
+        bool saveLoadData;
+
+        /** file location of trained classifier */
+        std::string fileLocation;
+
+        /** classifier for categorization */
+        FeaturePoolClassifier *fpc;
+
+    public:
+
+        /** simple constructor */
+        SemSegConvolutionalTree ();
+
+        /** config constructor */
+        SemSegConvolutionalTree ( const NICE::Config *conf,
+                                  const ClassNames *classNames );
+
+        /** simple destructor */
+        virtual ~SemSegConvolutionalTree();
+
+        /**
+         * @brief Setup internal variables and objects used
+         * @param conf Configuration file to specify variable settings
+         * @param s_confSection Section in configuration file
+         */
+        void initFromConfig (
+                const NICE::Config *_conf,
+                const std::string & s_confSection = "SemSegConvolutionalTree" );
+
+        /**
+         * @brief training function / learn classifier
+         * @param md the data set
+         */
+        void train ( const MultiDataset *md );
+
+        /**
+         * @brief classification function
+           @param ce image data
+           @param segresult result of the semantic segmentation with a label
+                  for each pixel
+           @param probabilities multi-channel image with one channel for
+                  each class and corresponding probabilities for each pixel
+         */
+        void semanticseg ( CachedExample *ce,
+                           NICE::Image &segresult,
+                           NICE::MultiChannelImageT<double> &probabilities );
+
+
+        ///////////////////// INTERFACE PERSISTENT /////////////////////
+        // interface specific methods for store and restore
+        ///////////////////// INTERFACE PERSISTENT /////////////////////
+
+        /**
+         * @brief Load segmentation object from external file (stream)
+         */
+        virtual void restore ( std::istream & is, int format = 0 );
+
+        /**
+         * @brief Save segmentation-object to external file (stream)
+         */
+        virtual void store( std::ostream & os, int format = 0 ) const;
+
+        /**
+         * @brief Clear segmentation-object object
+         */
+        virtual void clear ();
+
+};
+
+
+
+} // namespace
+
+#endif