Bladeren bron

kategorisierung überabeitet

Bjoern Froehlich 13 jaren geleden
bovenliggende
commit
06021dfa77
1 gewijzigde bestanden met toevoegingen van 12 en 8 verwijderingen
  1. 12 8
      semseg/SemSegContextTree.cpp

+ 12 - 8
semseg/SemSegContextTree.cpp

@@ -68,7 +68,8 @@ SemSegContextTree::SemSegContextTree (const Config *conf, const MultiDataset *md
 
 
   pixelWiseLabeling = false;
   pixelWiseLabeling = false;
 
 
-  useRegionFeature = conf->gB (section, "use_region_feat", true);
+  useRegionFeature = conf->gB (section, "use_region_feat", false);
+
   if (segmentationtype == "meanshift")
   if (segmentationtype == "meanshift")
     segmentation = new RSMeanShift (conf);
     segmentation = new RSMeanShift (conf);
   else if (segmentationtype == "none")
   else if (segmentationtype == "none")
@@ -554,6 +555,8 @@ inline double computeWeight (const double &d, const double &dim)
 
 
 void SemSegContextTree::train (const MultiDataset *md)
 void SemSegContextTree::train (const MultiDataset *md)
 {
 {
+  int shortsize = numeric_limits<short>::max();
+  
   Timer timer;
   Timer timer;
   timer.start();
   timer.start();
   const LabeledSet train = * (*md) ["train"];
   const LabeledSet train = * (*md) ["train"];
@@ -968,7 +971,7 @@ void SemSegContextTree::train (const MultiDataset *md)
                       if (labelmap.find (labels[iCounter] (x, y)) != labelmap.end())
                       if (labelmap.find (labels[iCounter] (x, y)) != labelmap.end())
                         forest[tree][left].dist[labelmap[labels[iCounter] (x, y) ]]++;
                         forest[tree][left].dist[labelmap[labels[iCounter] (x, y) ]]++;
                       forest[tree][left].featcounter++;
                       forest[tree][left].featcounter++;
-                      if(useCategorization)
+                      if(useCategorization && leftu < shortsize)
                         (*globalCategorFeats[iCounter])[leftu]+=weight;
                         (*globalCategorFeats[iCounter])[leftu]+=weight;
                     }
                     }
                     else
                     else
@@ -978,7 +981,7 @@ void SemSegContextTree::train (const MultiDataset *md)
                         forest[tree][right].dist[labelmap[labels[iCounter] (x, y) ]]++;
                         forest[tree][right].dist[labelmap[labels[iCounter] (x, y) ]]++;
                       forest[tree][right].featcounter++;
                       forest[tree][right].featcounter++;
                       
                       
-                      if(useCategorization)
+                      if(useCategorization && rightu < shortsize)
                         (*globalCategorFeats[iCounter])[rightu]+=weight;
                         (*globalCategorFeats[iCounter])[rightu]+=weight;
                     }
                     }
                   }
                   }
@@ -1132,8 +1135,11 @@ void SemSegContextTree::train (const MultiDataset *md)
   cerr << "learning finished in: " << timer.getLastAbsolute() << " seconds" << endl;
   cerr << "learning finished in: " << timer.getLastAbsolute() << " seconds" << endl;
   timer.start();
   timer.start();
   
   
+  cout << "uniquenumber " << uniquenumber << endl;
+  
   if(useCategorization)
   if(useCategorization)
   {
   {
+    uniquenumber = std::min(shortsize, uniquenumber);
     for(uint i = 0; i < globalCategorFeats.size(); i++)
     for(uint i = 0; i < globalCategorFeats.size(); i++)
     {
     {
       globalCategorFeats[i]->setDim(uniquenumber);
       globalCategorFeats[i]->setDim(uniquenumber);
@@ -1161,9 +1167,7 @@ void SemSegContextTree::train (const MultiDataset *md)
     fasthik->train(globalCategorFeats, ys);
     fasthik->train(globalCategorFeats, ys);
     
     
   }
   }
-
-  cout << "uniquenumber " << uniquenumber << endl;
-  //getchar();
+  
 #ifdef DEBUG
 #ifdef DEBUG
   for (int tree = 0; tree < nbTrees; tree++)
   for (int tree = 0; tree < nbTrees; tree++)
   {
   {
@@ -1463,7 +1467,7 @@ void SemSegContextTree::semanticseg (CachedExample *ce, NICE::Image & segresult,
               currentfeats.set (x, y, forest[tree][t].left, tree);
               currentfeats.set (x, y, forest[tree][t].left, tree);
 #pragma omp critical
 #pragma omp critical
               {
               {
-                if(useCategorization)
+                if(useCategorization && forest[tree][forest[tree][t].left].nodeNumber < uniquenumber)
                   (*globalCategorFeat)[forest[tree][forest[tree][t].left].nodeNumber] += weight;
                   (*globalCategorFeat)[forest[tree][forest[tree][t].left].nodeNumber] += weight;
               }
               }
             }
             }
@@ -1472,7 +1476,7 @@ void SemSegContextTree::semanticseg (CachedExample *ce, NICE::Image & segresult,
               currentfeats.set (x, y, forest[tree][t].right, tree);
               currentfeats.set (x, y, forest[tree][t].right, tree);
 #pragma omp critical
 #pragma omp critical
               {
               {
-                if(useCategorization)
+                if(useCategorization && forest[tree][forest[tree][t].right].nodeNumber < uniquenumber)
                   (*globalCategorFeat)[forest[tree][forest[tree][t].right].nodeNumber] += weight;
                   (*globalCategorFeat)[forest[tree][forest[tree][t].right].nodeNumber] += weight;
               }
               }
             }
             }