瀏覽代碼

kategorisierung überabeitet

Bjoern Froehlich 13 年之前
父節點
當前提交
06021dfa77
共有 1 個文件被更改,包括 12 次插入8 次删除
  1. 12 8
      semseg/SemSegContextTree.cpp

+ 12 - 8
semseg/SemSegContextTree.cpp

@@ -68,7 +68,8 @@ SemSegContextTree::SemSegContextTree (const Config *conf, const MultiDataset *md
 
   pixelWiseLabeling = false;
 
-  useRegionFeature = conf->gB (section, "use_region_feat", true);
+  useRegionFeature = conf->gB (section, "use_region_feat", false);
+
   if (segmentationtype == "meanshift")
     segmentation = new RSMeanShift (conf);
   else if (segmentationtype == "none")
@@ -554,6 +555,8 @@ inline double computeWeight (const double &d, const double &dim)
 
 void SemSegContextTree::train (const MultiDataset *md)
 {
+  int shortsize = numeric_limits<short>::max();
+  
   Timer timer;
   timer.start();
   const LabeledSet train = * (*md) ["train"];
@@ -968,7 +971,7 @@ void SemSegContextTree::train (const MultiDataset *md)
                       if (labelmap.find (labels[iCounter] (x, y)) != labelmap.end())
                         forest[tree][left].dist[labelmap[labels[iCounter] (x, y) ]]++;
                       forest[tree][left].featcounter++;
-                      if(useCategorization)
+                      if(useCategorization && leftu < shortsize)
                         (*globalCategorFeats[iCounter])[leftu]+=weight;
                     }
                     else
@@ -978,7 +981,7 @@ void SemSegContextTree::train (const MultiDataset *md)
                         forest[tree][right].dist[labelmap[labels[iCounter] (x, y) ]]++;
                       forest[tree][right].featcounter++;
                       
-                      if(useCategorization)
+                      if(useCategorization && rightu < shortsize)
                         (*globalCategorFeats[iCounter])[rightu]+=weight;
                     }
                   }
@@ -1132,8 +1135,11 @@ void SemSegContextTree::train (const MultiDataset *md)
   cerr << "learning finished in: " << timer.getLastAbsolute() << " seconds" << endl;
   timer.start();
   
+  cout << "uniquenumber " << uniquenumber << endl;
+  
   if(useCategorization)
   {
+    uniquenumber = std::min(shortsize, uniquenumber);
     for(uint i = 0; i < globalCategorFeats.size(); i++)
     {
       globalCategorFeats[i]->setDim(uniquenumber);
@@ -1161,9 +1167,7 @@ void SemSegContextTree::train (const MultiDataset *md)
     fasthik->train(globalCategorFeats, ys);
     
   }
-
-  cout << "uniquenumber " << uniquenumber << endl;
-  //getchar();
+  
 #ifdef DEBUG
   for (int tree = 0; tree < nbTrees; tree++)
   {
@@ -1463,7 +1467,7 @@ void SemSegContextTree::semanticseg (CachedExample *ce, NICE::Image & segresult,
               currentfeats.set (x, y, forest[tree][t].left, tree);
 #pragma omp critical
               {
-                if(useCategorization)
+                if(useCategorization && forest[tree][forest[tree][t].left].nodeNumber < uniquenumber)
                   (*globalCategorFeat)[forest[tree][forest[tree][t].left].nodeNumber] += weight;
               }
             }
@@ -1472,7 +1476,7 @@ void SemSegContextTree::semanticseg (CachedExample *ce, NICE::Image & segresult,
               currentfeats.set (x, y, forest[tree][t].right, tree);
 #pragma omp critical
               {
-                if(useCategorization)
+                if(useCategorization && forest[tree][forest[tree][t].right].nodeNumber < uniquenumber)
                   (*globalCategorFeat)[forest[tree][forest[tree][t].right].nodeNumber] += weight;
               }
             }