浏览代码

minor bugfixes

Bjoern Froehlich 13 年之前
父节点
当前提交
fd9cd73f35
共有 1 个文件被更改,包括 16 次插入9 次删除
  1. 16 9
      semseg/SemSegContextTree.cpp

+ 16 - 9
semseg/SemSegContextTree.cpp

@@ -370,9 +370,16 @@ double SemSegContextTree::getBestSplit (std::vector<NICE::MultiChannelImageT<dou
       feat.rProbs = &regionProbs[(*it) [0]];
       feat.rProbs = &regionProbs[(*it) [0]];
       
       
       double val = featsel[f]->getVal (feat, (*it) [1], (*it) [2]);
       double val = featsel[f]->getVal (feat, (*it) [1], (*it) [2]);
-      vals.push_back (val);
-      maxval = std::max (val, maxval);
-      minval = std::min (val, minval);
+      if(isfinite(val))
+      {
+        vals.push_back (val);
+        maxval = std::max (val, maxval);
+        minval = std::min (val, minval);
+      }
+      else
+      {
+        cerr << "non finite value for " << featsel[f]->writeInfos() <<  endl << (*it) [1] << " " <<  (*it) [2] << endl;
+      }
     }
     }
 
 
     if (minval == maxval)
     if (minval == maxval)
@@ -915,6 +922,7 @@ void SemSegContextTree::train (const MultiDataset *md)
           }
           }
 
 
           forest[tree][i].feat = splitfeat;
           forest[tree][i].feat = splitfeat;
+          
           forest[tree][i].decision = splitval;
           forest[tree][i].decision = splitval;
 
 
           if (splitfeat != NULL)
           if (splitfeat != NULL)
@@ -1225,6 +1233,8 @@ void SemSegContextTree::train (const MultiDataset *md)
   for (int d = 0; d < maxDepth; d++)
   for (int d = 0; d < maxDepth; d++)
   {
   {
     double sum =  contextOverview[d][0] + contextOverview[d][1];
     double sum =  contextOverview[d][0] + contextOverview[d][1];
+    if(sum == 0)
+      sum = 1;
 
 
     contextOverview[d][0] /= sum;
     contextOverview[d][0] /= sum;
     contextOverview[d][1] /= sum;
     contextOverview[d][1] /= sum;
@@ -1842,21 +1852,18 @@ void SemSegContextTree::restore (std::istream & is, int format)
       is >> forest[t][n].left;
       is >> forest[t][n].left;
       is >> forest[t][n].right;
       is >> forest[t][n].right;
       is >> forest[t][n].decision;
       is >> forest[t][n].decision;
-      cout << 1 << endl;
       is >> forest[t][n].isleaf;
       is >> forest[t][n].isleaf;
       is >> forest[t][n].depth;
       is >> forest[t][n].depth;
       is >> forest[t][n].featcounter;
       is >> forest[t][n].featcounter;
-cout << 1 << endl;
       is >> forest[t][n].nodeNumber;
       is >> forest[t][n].nodeNumber;
-      cout << forest[t][n].nodeNumber << endl;
-cout << 2 << endl;
+
       is >> forest[t][n].dist;
       is >> forest[t][n].dist;
-cout << 3 << endl;
+      
       int feattype;
       int feattype;
       is >> feattype;
       is >> feattype;
       assert (feattype < NBOPERATIONS);
       assert (feattype < NBOPERATIONS);
       forest[t][n].feat = NULL;
       forest[t][n].feat = NULL;
-      cout << 4 << endl;
+
       if (feattype >= 0)
       if (feattype >= 0)
       {
       {
         for (uint o = 0; o < ops.size(); o++)
         for (uint o = 0; o < ops.size(); o++)