浏览代码

changes in ordner to use more class labels

Sven Sickert 10 年之前
父节点
当前提交
702042b81e

+ 36 - 19
progs/testSemanticSegmentation3D.cpp

@@ -78,10 +78,11 @@ void segmentToOverlay ( const NICE::Image *orig, const NICE::ColorImage & segmen
   }
   }
 }
 }
 
 
-void updateMatrix ( const NICE::Image & img,
-                    const NICE::Image & gt,
+void updateMatrix ( const NICE::ImageT<int> & img,
+                    const NICE::ImageT<int> & gt,
                     NICE::Matrix & M,
                     NICE::Matrix & M,
-                    const set<int> & forbidden_classes )
+                    const set<int> & forbidden_classes,
+                    map<int,int> & classMapping )
 {
 {
   double subsamplex = gt.width() / ( double ) img.width();
   double subsamplex = gt.width() / ( double ) img.width();
   double subsampley = gt.height() / ( double ) img.height();
   double subsampley = gt.height() / ( double ) img.height();
@@ -106,7 +107,7 @@ void updateMatrix ( const NICE::Image & img,
 
 
       if ( forbidden_classes.find ( gimg ) == forbidden_classes.end() )
       if ( forbidden_classes.find ( gimg ) == forbidden_classes.end() )
       {
       {
-        M ( gimg, cimg ) ++;
+        M ( classMapping[gimg], classMapping[cimg] ) ++;
       }
       }
     }
     }
 }
 }
@@ -117,6 +118,7 @@ void startClassification (SemanticSegmentation *semseg,
                           const LabeledSet* testFiles,
                           const LabeledSet* testFiles,
                           const ClassNames & classNames,
                           const ClassNames & classNames,
                           const set<int> & forbidden_classes,
                           const set<int> & forbidden_classes,
+                          map<int,int> & classMapping,
                           const string & resultdir,
                           const string & resultdir,
                           const bool doCrossVal)
                           const bool doCrossVal)
 {
 {
@@ -135,8 +137,8 @@ void startClassification (SemanticSegmentation *semseg,
   semseg->getDepthVector ( testFiles, zsizeVec, run_3Dseg );
   semseg->getDepthVector ( testFiles, zsizeVec, run_3Dseg );
   int depthCount = 0, idx = 0;
   int depthCount = 0, idx = 0;
   vector< string > filelist;
   vector< string > filelist;
-  NICE::MultiChannelImageT<double> segresult;
-  NICE::MultiChannelImageT<double> gt;
+  NICE::MultiChannelImageT<int> segresult;
+  NICE::MultiChannelImageT<int> gt;
 
 
   for (LabeledSet::const_iterator it = testFiles->begin(); it != testFiles->end(); it++)
   for (LabeledSet::const_iterator it = testFiles->begin(); it != testFiles->end(); it++)
   {
   {
@@ -148,8 +150,8 @@ void startClassification (SemanticSegmentation *semseg,
       filelist.push_back ( file );
       filelist.push_back ( file );
       depthCount++;
       depthCount++;
 
 
-      NICE::Image lm;
-      NICE::Image lm_gt;
+      NICE::ImageT<int> lm;
+      NICE::ImageT<int> lm_gt;
       if ( info.hasLocalizationInfo() )
       if ( info.hasLocalizationInfo() )
       {
       {
         const LocalizationResult *l_gt = info.localization();
         const LocalizationResult *l_gt = info.localization();
@@ -205,9 +207,9 @@ void startClassification (SemanticSegmentation *semseg,
           }
           }
 
 
           // confusion matrix
           // confusion matrix
-          NICE::Matrix M ( classNames.getMaxClassno() + 1, classNames.getMaxClassno() + 1 );
+          NICE::Matrix M ( classMapping.size(), classMapping.size() );
           M.set ( 0 );
           M.set ( 0 );
-          updateMatrix ( lm, lm_gt, M, forbidden_classes );
+          updateMatrix ( lm, lm_gt, M, forbidden_classes, classMapping );
           M_vec.push_back ( M );
           M_vec.push_back ( M );
 
 
           classNames.labelToRGB ( lm, rgb );
           classNames.labelToRGB ( lm, rgb );
@@ -315,6 +317,23 @@ int main ( int argc, char **argv )
   classNames.getSelection ( conf.gS ( "analysis", "forbidden_classes", "" ),
   classNames.getSelection ( conf.gS ( "analysis", "forbidden_classes", "" ),
                             forbidden_classes );
                             forbidden_classes );
 
 
+  vector<bool> usedClasses ( classNames.numClasses(), true );
+  for ( set<int>::const_iterator it = forbidden_classes.begin();
+        it != forbidden_classes.end(); ++it)
+  {
+      usedClasses [ *it ] = false;
+  }
+
+  map<int,int> classMapping;
+  int j = 0;
+  for ( int i = 0; i < usedClasses.size(); i++ )
+      if (usedClasses[i])
+      {
+          classMapping[i] = j;
+          j++;
+      }
+
+
   // initialize semantic segmentation method
   // initialize semantic segmentation method
   SemanticSegmentation *semseg = NULL;
   SemanticSegmentation *semseg = NULL;
   
   
@@ -332,7 +351,7 @@ int main ( int argc, char **argv )
     cout << "##############\n" << endl;
     cout << "##############\n" << endl;
     const LabeledSet *testFiles = md["test"];
     const LabeledSet *testFiles = md["test"];
     startClassification (semseg, M_vec, conf, testFiles, classNames,
     startClassification (semseg, M_vec, conf, testFiles, classNames,
-                         forbidden_classes, resultdir, doCrossVal );
+                         forbidden_classes, classMapping, resultdir, doCrossVal );
 
 
     delete semseg;
     delete semseg;
   }
   }
@@ -357,7 +376,7 @@ int main ( int argc, char **argv )
       cout << "#################\n" << endl;
       cout << "#################\n" << endl;
       const LabeledSet *testFiles = md[cvaltest];
       const LabeledSet *testFiles = md[cvaltest];
       startClassification (semseg, M_vec, conf, testFiles, classNames,
       startClassification (semseg, M_vec, conf, testFiles, classNames,
-                           forbidden_classes, resultdir, doCrossVal );
+                           forbidden_classes, classMapping, resultdir, doCrossVal );
 
 
       delete semseg;
       delete semseg;
     }
     }
@@ -379,13 +398,12 @@ int main ( int argc, char **argv )
   double overall = 0.0;
   double overall = 0.0;
   double sumall = 0.0;
   double sumall = 0.0;
 
 
-  NICE::Matrix M ( classNames.getMaxClassno() + 1, classNames.getMaxClassno() + 1 );
+  NICE::Matrix M ( classMapping.size(), classMapping.size() );
   M.set ( 0 );
   M.set ( 0 );
   for ( int s = 0; s < ( int ) M_vec.size(); s++ )
   for ( int s = 0; s < ( int ) M_vec.size(); s++ )
   {
   {
     NICE::Matrix M_tmp = M_vec[s];
     NICE::Matrix M_tmp = M_vec[s];
     for ( int r = 0; r < ( int ) M_tmp.rows(); r++ )
     for ( int r = 0; r < ( int ) M_tmp.rows(); r++ )
-    {
       for ( int c = 0; c < ( int ) M_tmp.cols(); c++ )
       for ( int c = 0; c < ( int ) M_tmp.cols(); c++ )
       {
       {
         if ( r == c )
         if ( r == c )
@@ -394,23 +412,22 @@ int main ( int argc, char **argv )
         sumall += M_tmp ( r, c );
         sumall += M_tmp ( r, c );
         M ( r, c ) += M_tmp ( r, c );
         M ( r, c ) += M_tmp ( r, c );
       }
       }
-    }
   }
   }
   overall /= sumall;
   overall /= sumall;
 
 
   cout << "Confusion Matrix:" << endl;
   cout << "Confusion Matrix:" << endl;
+  cout.precision(4);
   for (int r = 0; r < (int) M.rows(); r++)
   for (int r = 0; r < (int) M.rows(); r++)
   {
   {
     for (int c = 0; c < (int) M.cols(); c++)
     for (int c = 0; c < (int) M.cols(); c++)
-    {
-      cout << M(r,c) << "  ";
-    }
+      cout << M(r,c)/sumall << "  ";
+
     cout << endl;
     cout << endl;
   }
   }
 
 
   // metrics for binary classification
   // metrics for binary classification
   double precision, recall, f1score = -1.0;
   double precision, recall, f1score = -1.0;
-  if (classNames.getMaxClassno()+1 == 2)
+  if (classNames.numClasses() == 2)
   {
   {
     precision = (double)M(1,1) / (double)(M(1,1)+M(0,1));
     precision = (double)M(1,1) / (double)(M(1,1)+M(0,1));
     recall = (double)M(1,1) / (double)(M(1,1)+M(1,0));
     recall = (double)M(1,1) / (double)(M(1,1)+M(1,0));

+ 12 - 61
semseg/SemSegContextTree3D.cpp

@@ -848,7 +848,8 @@ void SemSegContextTree3D::train ( const LabeledSet * trainp )
     ///////////////////////////////////////////////////////////////////////////////
     ///////////////////////////////////////////////////////////////////////////////
     int depthCount = 0;
     int depthCount = 0;
     vector< string > filelist;
     vector< string > filelist;
-    NICE::MultiChannelImageT<uchar> pixelLabels;
+    NICE::MultiChannelImageT<int> pixelLabels;
+    std::map<int, bool> labelExist;
 
 
     for (LabeledSet::const_iterator it = trainp->begin(); it != trainp->end(); it++)
     for (LabeledSet::const_iterator it = trainp->begin(); it != trainp->end(); it++)
     {
     {
@@ -864,7 +865,7 @@ void SemSegContextTree3D::train ( const LabeledSet * trainp )
             const LocalizationResult *locResult = info.localization();
             const LocalizationResult *locResult = info.localization();
 
 
             // getting groundtruth
             // getting groundtruth
-            NICE::Image pL;
+            NICE::ImageT<int> pL;
             pL.resize ( locResult->xsize, locResult->ysize );
             pL.resize ( locResult->xsize, locResult->ysize );
             pL.set ( 0 );
             pL.set ( 0 );
             locResult->calcLabeledImage ( pL, ( *classNames ).getBackgroundClass() );
             locResult->calcLabeledImage ( pL, ( *classNames ).getBackgroundClass() );
@@ -920,9 +921,7 @@ void SemSegContextTree3D::train ( const LabeledSet * trainp )
             }
             }
 
 
             for ( int x = 0; x < xsize; x++ )
             for ( int x = 0; x < xsize; x++ )
-            {
                 for ( int y = 0; y < ysize; y++ )
                 for ( int y = 0; y < ysize; y++ )
-                {
                     for ( int z = 0; z < zsize; z++ )
                     for ( int z = 0; z < zsize; z++ )
                     {
                     {
                         if ( useFeat1 )
                         if ( useFeat1 )
@@ -938,13 +937,11 @@ void SemSegContextTree3D::train ( const LabeledSet * trainp )
                         if ( forbidden_classes.find ( classno ) != forbidden_classes.end() )
                         if ( forbidden_classes.find ( classno ) != forbidden_classes.end() )
                             continue;
                             continue;
 
 
-                        labelcounter[classno]++;
+                        labelExist[classno] = 1;
 
 
                         if ( useCategorization )
                         if ( useCategorization )
                             classesPerImage[imgCounter][classno] = 1;
                             classesPerImage[imgCounter][classno] = 1;
                     }
                     }
-                }
-            }
 
 
             filelist.clear();
             filelist.clear();
             pixelLabels.reInit ( 0,0,0 );
             pixelLabels.reInit ( 0,0,0 );
@@ -954,8 +951,8 @@ void SemSegContextTree3D::train ( const LabeledSet * trainp )
     }
     }
 
 
     int classes = 0;
     int classes = 0;
-    for ( map<int, int>::const_iterator mapit = labelcounter.begin();
-          mapit != labelcounter.end(); mapit++ )
+    for ( map<int, bool>::const_iterator mapit = labelExist.begin();
+          mapit != labelExist.end(); mapit++ )
     {
     {
         labelmap[mapit->first] = classes;
         labelmap[mapit->first] = classes;
         labelmapback[classes] = mapit->first;
         labelmapback[classes] = mapit->first;
@@ -1020,7 +1017,7 @@ void SemSegContextTree3D::train ( const LabeledSet * trainp )
 
 
     //balancing
     //balancing
     a = vector<double> ( classes, 0.0 );
     a = vector<double> ( classes, 0.0 );
-    int featcounter = 0;
+    int selectionCounter = 0;
     for ( int iCounter = 0; iCounter < imgCounter; iCounter++ )
     for ( int iCounter = 0; iCounter < imgCounter; iCounter++ )
     {
     {
         int xsize = ( int ) nodeIndices[iCounter].width();
         int xsize = ( int ) nodeIndices[iCounter].width();
@@ -1028,30 +1025,25 @@ void SemSegContextTree3D::train ( const LabeledSet * trainp )
         int zsize = ( int ) nodeIndices[iCounter].depth();
         int zsize = ( int ) nodeIndices[iCounter].depth();
 
 
         for ( int x = 0; x < xsize; x++ )
         for ( int x = 0; x < xsize; x++ )
-        {
             for ( int y = 0; y < ysize; y++ )
             for ( int y = 0; y < ysize; y++ )
-            {
                 for ( int z = 0; z < zsize; z++ )
                 for ( int z = 0; z < zsize; z++ )
                 {
                 {
-                    featcounter++;
                     int cn = labels[iCounter] ( x, y, ( uint ) z );
                     int cn = labels[iCounter] ( x, y, ( uint ) z );
                     if ( labelmap.find ( cn ) == labelmap.end() )
                     if ( labelmap.find ( cn ) == labelmap.end() )
                         continue;
                         continue;
                     a[labelmap[cn]] ++;
                     a[labelmap[cn]] ++;
+                    selectionCounter++;
                 }
                 }
-            }
-        }
     }
     }
 
 
     for ( int i = 0; i < ( int ) a.size(); i++ )
     for ( int i = 0; i < ( int ) a.size(); i++ )
-    {
-        a[i] /= ( double ) featcounter;
-    }
+        a[i] /= ( double ) selectionCounter;
 
 
 #ifdef VERBOSE
 #ifdef VERBOSE
     cout << "\nDistribution:" << endl;
     cout << "\nDistribution:" << endl;
     for ( int i = 0; i < ( int ) a.size(); i++ )
     for ( int i = 0; i < ( int ) a.size(); i++ )
-        cout << "class " << i << ": " << a[i] << endl;
+        cout << "class '" << classNames->code(labelmapback[i]) << "': "
+             << a[i] << endl;
 #endif
 #endif
 
 
     depth = 0;
     depth = 0;
@@ -1230,47 +1222,6 @@ void SemSegContextTree3D::train ( const LabeledSet * trainp )
 
 
                         assert ( lcounter > 0 && rcounter > 0 );
                         assert ( lcounter > 0 && rcounter > 0 );
 
 
-                        //            if ( lcounter <= 0 || rcounter <= 0 )
-                        //            {
-                        //              cout << "lcounter : " << lcounter << " rcounter: " << rcounter << endl;
-                        //              cout << "splitval: " << splitval << " splittype: " << splitfeat->writeInfos() << endl;
-                        //              cout << "bestig: " << bestig << endl;
-
-                        //              for ( int i = 0; i < imgCounter; i++ )
-                        //              {
-                        //                int xsize = nodeIndices[i].width();
-                        //                int ysize = nodeIndices[i].height();
-                        //                int zsize = nodeIndices[i].depth();
-                        //                int counter = 0;
-
-                        //                for ( int x = 0; x < xsize; x++ )
-                        //                {
-                        //                  for ( int y = 0; y < ysize; y++ )
-                        //                  {
-                        //                    for ( int z = 0; z < zsize; z++ )
-                        //                    {
-                        //                      if ( lastNodeIndices[i].get ( x, y, tree ) == node )
-                        //                      {
-                        //                        if ( ++counter > 30 )
-                        //                          break;
-
-                        //                        Features feat;
-                        //                        feat.feats = &allfeats[i];
-                        //                        feat.rProbs = &lastRegionProbs[i];
-
-                        //                        double val = splitfeat->getVal ( feat, x, y, z );
-                        //                        if ( !isfinite ( val ) ) val = 0.0;
-
-                        //                        cout << "splitval: " << splitval << " val: " << val << endl;
-                        //                      }
-                        //                    }
-                        //                  }
-                        //                }
-                        //              }
-
-                        //              assert ( lcounter > 0 && rcounter > 0 );
-                        //            }
-
                         for ( int c = 0; c < classes; c++ )
                         for ( int c = 0; c < classes; c++ )
                         {
                         {
                             forest[tree][left].dist[c] /= lcounter;
                             forest[tree][left].dist[c] /= lcounter;
@@ -1661,7 +1612,7 @@ void SemSegContextTree3D::addFeatureMaps (
 
 
 void SemSegContextTree3D::classify (
 void SemSegContextTree3D::classify (
         const std::vector<std::string> & filelist,
         const std::vector<std::string> & filelist,
-        NICE::MultiChannelImageT<double> & segresult,
+        NICE::MultiChannelImageT<int> & segresult,
         NICE::MultiChannelImage3DT<double> & probabilities )
         NICE::MultiChannelImage3DT<double> & probabilities )
 {
 {
     ///////////////////////// build MCI3DT from files ///////////////////////////
     ///////////////////////// build MCI3DT from files ///////////////////////////

+ 2 - 2
semseg/SemSegContextTree3D.h

@@ -66,7 +66,7 @@ private:
   int featsPerSplit;
   int featsPerSplit;
 
 
   /** count samples per label */
   /** count samples per label */
-  std::map<int, int> labelcounter;
+  //std::map<int, int> labelcounter;
 
 
   /** map of labels */
   /** map of labels */
   std::map<int, int> labelmap;
   std::map<int, int> labelmap;
@@ -229,7 +229,7 @@ public:
    * @param probabilities probabilities for each pixel (output)
    * @param probabilities probabilities for each pixel (output)
    */
    */
   void classify ( const std::vector<std::string> & filelist,
   void classify ( const std::vector<std::string> & filelist,
-                  NICE::MultiChannelImageT<double> & segresult,
+                  NICE::MultiChannelImageT<int> & segresult,
                   NICE::MultiChannelImage3DT<double> & probabilities );
                   NICE::MultiChannelImage3DT<double> & probabilities );
 
 
   /**
   /**

+ 1 - 1
semseg/SemanticSegmentation.cpp

@@ -97,7 +97,7 @@ void SemanticSegmentation::semanticseg ( const std::string & filename,
 }
 }
 
 
 void SemanticSegmentation::classify ( const std::vector<std::string> & filelist,
 void SemanticSegmentation::classify ( const std::vector<std::string> & filelist,
-                                      NICE::MultiChannelImageT<double> & segresult,
+                                      NICE::MultiChannelImageT<int> & segresult,
                                       NICE::MultiChannelImage3DT<double> & probabilities )
                                       NICE::MultiChannelImage3DT<double> & probabilities )
 {
 {
   for ( int it = 0; it < ( int ) filelist.size(); it++ )
   for ( int it = 0; it < ( int ) filelist.size(); it++ )

+ 1 - 1
semseg/SemanticSegmentation.h

@@ -113,7 +113,7 @@ class SemanticSegmentation : public NICE::Persistent
      * @param probabilities probabilities for each pixel (output)
      * @param probabilities probabilities for each pixel (output)
      */
      */
     virtual void classify ( const std::vector<std::string> & filelist,
     virtual void classify ( const std::vector<std::string> & filelist,
-                    NICE::MultiChannelImageT<double> & segresult,
+                    NICE::MultiChannelImageT<int> & segresult,
                     NICE::MultiChannelImage3DT<double> & probabilities );
                     NICE::MultiChannelImage3DT<double> & probabilities );
 
 
     /** this function has to be overloaded by all subclasses
     /** this function has to be overloaded by all subclasses