Ver Fonte

allows writing of tree

Bjoern Froehlich há 13 anos atrás
pai
commit
c41d8e6c86
2 ficheiros alterados com 283 adições e 31 exclusões
  1. 229 23
      semseg/SemSegContextTree.cpp
  2. 54 8
      semseg/SemSegContextTree.h

+ 229 - 23
semseg/SemSegContextTree.cpp

@@ -11,6 +11,7 @@
 #include "core/basics/numerictools.h"
 
 #include "core/basics/Timer.h"
+#include "core/basics/vectorio.h"
 
 #include <omp.h>
 #include <iostream>
@@ -25,8 +26,6 @@ using namespace std;
 
 using namespace NICE;
 
-
-
 class MCImageAccess: public ValueAccess
 {
 
@@ -40,6 +39,11 @@ public:
   {
     return "raw";
   }
+  
+  virtual ValueTypes getType()
+  {
+    return RAWFEAT;
+  }
 };
 
 class ClassificationResultAcess: public ValueAccess
@@ -55,8 +59,49 @@ public:
   {
     return "context";
   }
+  
+  virtual ValueTypes getType()
+  {
+    return CONTEXT;
+  }
 };
 
+void Operation::restore ( std::istream &is )
+{
+  is >> x1;
+  is >> x2;
+  is >> y1;
+  is >> y2;
+  is >> channel1;
+  is >> channel2;
+
+  int tmp;
+  is >> tmp;
+
+  cout << writeInfos() << " " << tmp << endl;
+  
+  if ( tmp >= 0 )
+  {
+    if ( tmp == RAWFEAT )
+    {
+      values = new MCImageAccess();
+    }
+    else if ( tmp == CONTEXT )
+    {
+      values = new ClassificationResultAcess();
+    }
+    else
+    {
+      throw("no valid ValueAccess");
+    }
+  }
+  else
+  {
+    values = NULL;
+  }
+}
+
+
 class Minus: public Operation
 {
 
@@ -134,7 +179,8 @@ public:
     int xsize, ysize;
     getXY ( feats, xsize, ysize );
     double v1 = values->getVal ( feats, BOUND ( x + x1, 0, xsize - 1 ), BOUND ( y + y1, 0, ysize - 1 ), channel1 );
-    double v2 = values->getVal ( feats, BOUND ( x + x2, 0, xsize - 1 ), BOUND ( y + y2, 0, ysize - 1 ), channel2 );
+    double v2 = values->getVal ( feats, BOUND ( x + x2, 0, xsize - 1 ), BOUND ( y + y2, 0, ysize -
+    1 ), channel2 );
     return v1 + v2;
   }
 
@@ -621,6 +667,11 @@ SemSegContextTree::SemSegContextTree ( const Config *conf, const MultiDataset *m
   string segmentationtype = conf->gS ( section, "segmentation_type", "meanshift" );
 
   useGaussian = conf->gB ( section, "use_gaussian", true );
+  
+  randomTests = conf->gI ( section, "random_tests", 10 );
+  
+  bool saveLoadData = conf->gB ("debug", "save_load_data", false);
+  string fileLocation = conf->gS ( "debug", "datafile", "tmp.txt" );
 
   if ( useGaussian )
     throw ( "there something wrong with using gaussian! first fix it!" );
@@ -669,14 +720,27 @@ SemSegContextTree::SemSegContextTree ( const Config *conf, const MultiDataset *m
   // Train Segmentation Context Trees
   ///////////////////////////////////
 
-  train ( md );
+  if(saveLoadData)
+  {
+    if(FileMgt::fileExists(fileLocation))
+      read(fileLocation);
+    else
+    {
+      train ( md );
+      write(fileLocation);
+    }
+  }
+  else
+  {
+    train ( md );
+  }
 }
 
 SemSegContextTree::~SemSegContextTree()
 {
 }
 
-double SemSegContextTree::getBestSplit ( std::vector<NICE::MultiChannelImageT<double> > &feats, std::vector<NICE::MultiChannelImageT<int> > &currentfeats, std::vector<NICE::MultiChannelImageT<double> > &integralImgs, const std::vector<NICE::MatrixT<int> > &labels, int node, Operation *&splitop, double &splitval, const int &tree )
+double SemSegContextTree::getBestSplit ( std::vector<NICE::MultiChannelImageT<double> > &feats, std::vector<NICE::MultiChannelImageT<unsigned short int> > &currentfeats, std::vector<NICE::MultiChannelImageT<double> > &integralImgs, const std::vector<NICE::MatrixT<int> > &labels, int node, Operation *&splitop, double &splitval, const int &tree )
 {
   Timer t;
   t.start();
@@ -854,6 +918,8 @@ double SemSegContextTree::getBestSplit ( std::vector<NICE::MultiChannelImageT<do
     set<vector<int> >::iterator it;
     vector<double> vals;
 
+    double maxval = -numeric_limits<double>::max();
+    double minval = numeric_limits<double>::max();
     for ( it = selFeats.begin() ; it != selFeats.end(); it++ )
     {
       Features feat;
@@ -862,15 +928,27 @@ double SemSegContextTree::getBestSplit ( std::vector<NICE::MultiChannelImageT<do
       feat.cTree = tree;
       feat.tree = &forest[tree];
       feat.integralImg = &integralImgs[ ( *it ) [0]];
-      vals.push_back ( featsel[f]->getVal ( feat, ( *it ) [1], ( *it ) [2] ) );
+      double val = featsel[f]->getVal ( feat, ( *it ) [1], ( *it ) [2] );
+      vals.push_back ( val );
+      maxval = std::max(val,maxval);
+      minval = std::min(val,minval);
+    }
+    
+    if(minval == maxval)
+      continue;
+    
+    double scale = maxval - minval;
+    vector<double> splits;
+    
+    for(int r = 0; r < randomTests; r++)
+    {
+      splits.push_back((( double ) rand() / ( double ) RAND_MAX*scale) + minval);
     }
 
-    int counter = 0;
-
-    for ( it = selFeats.begin() ; it != selFeats.end(); it++ , counter++ )
+    for ( int run = 0 ; run < randomTests; run++ )
     {
       set<vector<int> >::iterator it2;
-      double val = vals[counter];
+      double val = splits[run];
 
       map<int, int> eL, eR;
       int counterL = 0, counterR = 0;
@@ -965,7 +1043,7 @@ double SemSegContextTree::getBestSplit ( std::vector<NICE::MultiChannelImageT<do
   return bestig;
 }
 
-inline double SemSegContextTree::getMeanProb ( const int &x, const int &y, const int &channel, const MultiChannelImageT<int> &currentfeats )
+inline double SemSegContextTree::getMeanProb ( const int &x, const int &y, const int &channel, const MultiChannelImageT<unsigned short int> &currentfeats )
 {
   double val = 0.0;
 
@@ -977,7 +1055,7 @@ inline double SemSegContextTree::getMeanProb ( const int &x, const int &y, const
   return val / ( double ) nbTrees;
 }
 
-void SemSegContextTree::computeIntegralImage ( const NICE::MultiChannelImageT<int> &currentfeats, const NICE::MultiChannelImageT<double> &lfeats, NICE::MultiChannelImageT<double> &integralImage )
+void SemSegContextTree::computeIntegralImage ( const NICE::MultiChannelImageT<unsigned short int> &currentfeats, const NICE::MultiChannelImageT<double> &lfeats, NICE::MultiChannelImageT<double> &integralImage )
 {
   int xsize = currentfeats.width();
   int ysize = currentfeats.height();
@@ -1061,7 +1139,7 @@ void SemSegContextTree::train ( const MultiDataset *md )
 
   //TODO: Speichefresser!, lohnt sich sparse?
   vector<MultiChannelImageT<double> > allfeats;
-  vector<MultiChannelImageT<int> > currentfeats;
+  vector<MultiChannelImageT<unsigned short int> > currentfeats;
   vector<MatrixT<int> > labels;
 
   std::string forbidden_classes_s = conf->gS ( "analysis", "donttrain", "" );
@@ -1132,7 +1210,7 @@ void SemSegContextTree::train ( const MultiDataset *md )
 
     MatrixT<int> tmpMat ( xsize, ysize );
 
-    currentfeats.push_back ( MultiChannelImageT<int> ( xsize, ysize, nbTrees ) );
+    currentfeats.push_back ( MultiChannelImageT<unsigned short int> ( xsize, ysize, nbTrees ) );
     currentfeats[imgcounter].setAll ( 0 );
 
     labels.push_back ( tmpMat );
@@ -1270,7 +1348,7 @@ void SemSegContextTree::train ( const MultiDataset *md )
     cout << "depth: " << depth << endl;
 #endif
     allleaf = true;
-    vector<MultiChannelImageT<int> > lastfeats = currentfeats;
+    vector<MultiChannelImageT<unsigned short int> > lastfeats = currentfeats;
 
 #if 1
     Timer timer;
@@ -1284,11 +1362,6 @@ void SemSegContextTree::train ( const MultiDataset *md )
       startnode[tree] = t;
       //TODO vielleicht parallel wenn nächste schleife trotzdem noch parallelsiert würde, die hat mehr gewicht
       //#pragma omp parallel for
-#if 0
-      timer.stop();
-      cout << "time before tree: " << timer.getLast() << endl;
-      timer.start();
-#endif
       for ( int i = s; i < t; i++ )
       {
         if ( !forest[tree][i].isleaf && forest[tree][i].left < 0 )
@@ -1591,18 +1664,18 @@ void SemSegContextTree::semanticseg ( CachedExample *ce, NICE::Image & segresult
 
   MultiChannelImageT<double> integralImg;
 
-  MultiChannelImageT<int> currentfeats ( xsize, ysize, nbTrees );
+  MultiChannelImageT<unsigned short int> currentfeats ( xsize, ysize, nbTrees );
 
   currentfeats.setAll ( 0 );
 
   depth = 0;
 
-  while ( !allleaf )
+  for(int d = 0; d < maxDepth && !allleaf; d++)
   {
     allleaf = true;
     //TODO vielleicht parallel wenn nächste schleife auch noch parallelsiert würde, die hat mehr gewicht
     //#pragma omp parallel for
-    MultiChannelImageT<int> lastfeats = currentfeats;
+    MultiChannelImageT<unsigned short int> lastfeats = currentfeats;
 
     for ( int tree = 0; tree < nbTrees; tree++ )
     {
@@ -1757,3 +1830,136 @@ void SemSegContextTree::semanticseg ( CachedExample *ce, NICE::Image & segresult
   
   cout << "segmentation finished" << endl;
 }
+
+void SemSegContextTree::store (std::ostream & os, int format) const
+{
+  os << nbTrees << endl;
+  classnames.store(os);
+
+  map<int,int>::const_iterator it;
+  
+  os << labelmap.size() << endl;
+  for ( it=labelmap.begin() ; it != labelmap.end(); it++ )
+    os << (*it).first << " " << (*it).second << endl;
+
+  os << labelmapback.size() << endl;
+  for ( it=labelmapback.begin() ; it != labelmapback.end(); it++ )
+    os << (*it).first << " " << (*it).second << endl;
+
+  int trees = forest.size();
+  os << trees << endl;
+
+  for(int t = 0; t < trees; t++)
+  {
+    int nodes = forest[t].size(); 
+    os << nodes << endl;
+    for(int n = 0; n < nodes; n++)
+    {
+      os << forest[t][n].left << " " << forest[t][n].right << " " << forest[t][n].decision << " " << forest[t][n].isleaf << " " << forest[t][n].depth << " " << forest[t][n].featcounter << endl;
+      os << forest[t][n].dist << endl;
+      
+      if(forest[t][n].feat==NULL)
+        os << -1 << endl;
+      else
+      {
+        os << forest[t][n].feat->getOps() << endl;
+        forest[t][n].feat->store(os);
+      }
+    }
+  }
+}
+
+void SemSegContextTree::restore (std::istream & is, int format)
+{
+  is >> nbTrees;
+  
+  classnames.restore(is);
+  
+  int lsize;
+  is >> lsize;
+  
+  labelmap.clear();
+  for(int l = 0; l < lsize; l++)
+  {
+    int first, second;
+    is >> first;
+    is >> second;
+    labelmap[first] = second;
+  }
+  
+  is >> lsize;
+  labelmapback.clear();
+  for(int l = 0; l < lsize; l++)
+  {
+    int first, second;
+    is >> first;
+    is >> second;
+    labelmapback[first] = second;
+  }
+  
+  int trees;
+  is >> trees;
+  forest.clear();
+  
+  for(int t = 0; t < trees; t++)
+  {
+    vector<TreeNode> tmptree;
+    forest.push_back(tmptree);
+    int nodes;
+    is >> nodes;
+    cout << "nodes: " << nodes << endl;
+    for(int n = 0; n < nodes; n++)
+    {
+      TreeNode tmpnode;
+      forest[t].push_back(tmpnode);
+      is >> forest[t][n].left;
+      is >> forest[t][n].right;
+      is >> forest[t][n].decision;
+      is >> forest[t][n].isleaf;
+      is >> forest[t][n].depth;
+      is >> forest[t][n].featcounter;
+      is >> forest[t][n].dist;
+/*    
+      cout << "forest[t][n].left" << forest[t][n].left << endl;
+      cout << "forest[t][n].right" << forest[t][n].right << endl;
+      cout << "forest[t][n].decision" << forest[t][n].decision << endl;
+      cout << "forest[t][n].isleaf" << forest[t][n].isleaf << endl;
+      cout << "forest[t][n].depth" << forest[t][n].depth << endl;
+      cout << "forest[t][n].featcounter" << forest[t][n].featcounter << endl;
+      cout << "forest[t][n].dist" << forest[t][n].dist << endl;
+*/
+      
+      int feattype;
+      is >> feattype;
+      assert(feattype < NBOPERATIONS);
+      forest[t][n].feat = NULL;
+      if(feattype >= 0)
+      {
+        for(int o = 0; o < ops.size(); o++)
+        {
+          if(ops[o]->getOps() == feattype)
+          {
+            forest[t][n].feat = ops[o]->clone();
+            break;
+          }
+        }
+        
+        if(forest[t][n].feat == NULL)
+        {
+          for(int o = 0; o < cops.size(); o++)
+          {
+            if(cops[o]->getOps() == feattype)
+            {
+              forest[t][n].feat = cops[o]->clone();
+              break;
+            }
+          }
+        }
+        assert(forest[t][n].feat != NULL);
+        forest[t][n].feat->restore(is);
+      }
+    }    
+  }
+}
+
+

+ 54 - 8
semseg/SemSegContextTree.h

@@ -21,9 +21,6 @@ class TreeNode
 {
 
   public:
-    /** probabilities for each class */
-    std::vector<double> probs;
-
     /** left child node */
     int left;
 
@@ -57,20 +54,30 @@ class TreeNode
 
 struct Features {
   NICE::MultiChannelImageT<double> *feats;
-  MultiChannelImageT<int> *cfeats;
+  MultiChannelImageT<unsigned short int> *cfeats;
   int cTree;
   std::vector<TreeNode> *tree;
   NICE::MultiChannelImageT<double> *integralImg;
 };
 
+enum ValueTypes
+{
+  RAWFEAT,
+  CONTEXT,
+  NBVALUETYPES
+};
+
 class ValueAccess
 {
 
   public:
     virtual double getVal ( const Features &feats, const int &x, const int &y, const int &channel ) = 0;
     virtual std::string writeInfos() = 0;
+    virtual ValueTypes getType() = 0;
 };
 
+
+
 enum OperationTypes {
   MINUS,
   MINUSABS,
@@ -135,11 +142,22 @@ class Operation
     }
 
     virtual OperationTypes getOps() = 0;
+    
+    virtual void store(std::ostream & os)
+    {
+      os << x1 << " " << x2 << " " << y1 << " " << y2 << " " << channel1 << " " << channel2 << std::endl;
+      if(values == NULL)
+        os << -1 << std::endl;
+      else
+        os << values->getType() << std::endl;
+    }
+    
+    virtual void restore(std::istream & is);
 };
 
 /** Localization system */
 
-class SemSegContextTree : public SemanticSegmentation
+class SemSegContextTree : public SemanticSegmentation, public NICE::Persistent
 {
     /** Segmentation Method */
     RegionSegmentationMethod *segmentation;
@@ -188,6 +206,9 @@ class SemSegContextTree : public SemanticSegmentation
 
     /** current depth for training */
     int depth;
+    
+    /** how many splittests */
+    int randomTests;
 
     /** operations for pairwise features */
     std::vector<Operation*> ops;
@@ -250,7 +271,7 @@ class SemSegContextTree : public SemanticSegmentation
      * @param integralImage output image (must be initilized)
      * @return void
      **/
-    void computeIntegralImage ( const NICE::MultiChannelImageT<int> &currentfeats, const NICE::MultiChannelImageT<double> &lfeats, NICE::MultiChannelImageT<double> &integralImage );
+    void computeIntegralImage ( const NICE::MultiChannelImageT<unsigned short int> &currentfeats, const NICE::MultiChannelImageT<double> &lfeats, NICE::MultiChannelImageT<double> &integralImage );
 
     /**
      * compute best split for current settings
@@ -262,7 +283,7 @@ class SemSegContextTree : public SemanticSegmentation
      * @param splitval
      * @return best information gain
      */
-    double getBestSplit ( std::vector<NICE::MultiChannelImageT<double> > &feats, std::vector<NICE::MultiChannelImageT<int> > &currentfeats, std::vector<NICE::MultiChannelImageT<double> > &integralImgs, const std::vector<NICE::MatrixT<int> > &labels, int node, Operation *&splitop, double &splitval, const int &tree );
+    double getBestSplit ( std::vector<NICE::MultiChannelImageT<double> > &feats, std::vector<NICE::MultiChannelImageT<unsigned short int> > &currentfeats, std::vector<NICE::MultiChannelImageT<double> > &integralImgs, const std::vector<NICE::MatrixT<int> > &labels, int node, Operation *&splitop, double &splitval, const int &tree );
 
     /**
      * @brief computes the mean probability for a given class over all trees
@@ -272,8 +293,33 @@ class SemSegContextTree : public SemanticSegmentation
      * @param currentfeats information about the nodes
      * @return double mean value
      **/
-    inline double getMeanProb ( const int &x, const int &y, const int &channel, const MultiChannelImageT<int> &currentfeats );
+    inline double getMeanProb ( const int &x, const int &y, const int &channel, const MultiChannelImageT<unsigned short int> &currentfeats );
 
+    /**
+     * @brief load all data to is stream
+     *
+     * @param is input stream
+     * @param format has no influence
+     * @return void
+     **/
+    virtual void restore (std::istream & is, int format = 0);
+    
+    /**
+     * @brief save all data to is stream
+     *
+     * @param os output stream
+     * @param format has no influence
+     * @return void
+     **/
+    virtual void store (std::ostream & os, int format = 0) const;
+   
+    /**
+     * @brief clean up
+     *
+     * @return void
+     **/
+    virtual void clear (){}
+    
 };