Bjoern Froehlich 13 жил өмнө
parent
commit
25c52555f6

+ 14 - 10
semseg/SemSegContextTree.cpp

@@ -796,7 +796,7 @@ double SemSegContextTree::getBestSplit ( std::vector<NICE::MultiChannelImageT<do
 
   for ( uint i = 0; i < fraction.size(); i++ )
   {
-    if ( forbidden_classes.find ( labelmapback[i]-1 ) != forbidden_classes.end() )
+    if ( forbidden_classes.find ( labelmapback[i] ) != forbidden_classes.end() )
       fraction[i] = 0;
     else
       fraction[i] = ( ( double ) maxSamples ) / ( ( double ) featcounter * a[i] * a.size() );
@@ -819,8 +819,11 @@ double SemSegContextTree::getBestSplit ( std::vector<NICE::MultiChannelImageT<do
         {
           int cn = labels[iCounter] ( x, y );
           double randD = ( double ) rand() / ( double ) RAND_MAX;
-
-          if ( randD < fraction[labelmap[cn]-1] )
+          
+          if(labelmap.find(cn) == labelmap.end())
+            continue;
+          
+          if ( randD < fraction[labelmap[cn]] )
           {
             vector<int> tmp ( 3, 0 );
             tmp[0] = iCounter;
@@ -1319,8 +1322,7 @@ void SemSegContextTree::train ( const MultiDataset *md )
 
   for ( mapit = labelcounter.begin(); mapit != labelcounter.end(); mapit++ )
   {
-    labelmap[mapit->first] = classes+1;
-
+    labelmap[mapit->first] = classes;
     labelmapback[classes] = mapit->first;
     classes++;
   }
@@ -1341,7 +1343,9 @@ void SemSegContextTree::train ( const MultiDataset *md )
       {
         featcounter++;
         int cn = labels[iCounter] ( x, y );
-        a[labelmap[cn]-1] ++;
+        if(labelmap.find(cn) == labelmap.end())
+            continue;
+        a[labelmap[cn]] ++;
       }
     }
   }
@@ -1491,8 +1495,8 @@ void SemSegContextTree::train ( const MultiDataset *md )
                     if ( val < splitval )
                     {
                       currentfeats[iCounter].set ( x, y, left, tree );
-                      if(labelmap[labels[iCounter] ( x, y ) ] > 0)
-                        forest[tree][left].dist[labelmap[labels[iCounter] ( x, y )]-1]++;
+                      if(labelmap.find(labels[iCounter] ( x, y )) != labelmap.end())
+                        forest[tree][left].dist[labelmap[labels[iCounter] ( x, y ) ]]++;
                       forest[tree][left].featcounter++;
                       SparseVector v;
                       v.insert ( pair<int, double> ( leftu, 1.0 ) );
@@ -1501,8 +1505,8 @@ void SemSegContextTree::train ( const MultiDataset *md )
                     else
                     {
                       currentfeats[iCounter].set ( x, y, right, tree );
-                      if(labelmap[labels[iCounter] ( x, y ) ] > 0)
-                        forest[tree][right].dist[labelmap[labels[iCounter] ( x, y )]-1]++;
+                      if(labelmap.find(labels[iCounter] ( x, y )) != labelmap.end())
+                        forest[tree][right].dist[labelmap[labels[iCounter] ( x, y ) ]]++;
                       forest[tree][right].featcounter++;
                       //feld im subsampled finden und in diesem rechts hochzählen
                       SparseVector v;