浏览代码

add ability to run cross-validation

Sven Sickert 11 年之前
父节点
当前提交
17de72c88a
共有 1 个文件被更改,包括 103 次插入53 次删除
  1. 103 53
      progs/testSemanticSegmentation3D.cpp

+ 103 - 53
progs/testSemanticSegmentation3D.cpp

@@ -78,8 +78,10 @@ void segmentToOverlay ( const NICE::Image *orig, const NICE::ColorImage & segmen
   }
 }
 
-void updateMatrix ( const NICE::Image & img, const NICE::Image & gt,
-                    NICE::Matrix & M, const set<int> & forbidden_classes )
+void updateMatrix ( const NICE::Image & img,
+                    const NICE::Image & gt,
+                    NICE::Matrix & M,
+                    const set<int> & forbidden_classes )
 {
   double subsamplex = gt.width() / ( double ) img.width();
   double subsampley = gt.height() / ( double ) img.height();
@@ -109,51 +111,24 @@ void updateMatrix ( const NICE::Image & img, const NICE::Image & gt,
     }
 }
 
-/**
- test semantic segmentation routines
-*/
-int main ( int argc, char **argv )
+void startClassification (SemanticSegmentation *semseg,
+                          std::vector< NICE::Matrix > & M_vec,
+                          const Config & conf,
+                          const LabeledSet* testFiles,
+                          const ClassNames & classNames,
+                          const set<int> & forbidden_classes,
+                          const string & resultdir,
+                          const bool doCrossVal)
 {
-  std::set_terminate ( __gnu_cxx::__verbose_terminate_handler );
-
-  Config conf ( argc, argv );
-
-  ResourceStatistics rs;
+  bool show_results = conf.gB ( "debug", "show_results", false );
+  bool write_results = conf.gB ( "debug", "write_results", false );
+  if (doCrossVal)
+    write_results = false;
 
-  /*-------------I/O CONFIGURATION-------------*/
-  bool postProcessing = conf.gB( "main", "post_process", false);
   bool run_3Dseg = conf.gB( "SSContextTree", "run_3dseg", false);
-  bool show_result = conf.gB ( "debug", "show_results", false );
-  bool write_results = conf.gB ( "debug", "write_results", false );
+  bool postProcessing = conf.gB( "main", "post_process", false);
   string output_type = conf.gS ( "debug", "output_type", "ppm" );
   string output_postfix = conf.gS ( "debug", "output_postfix", "" );
-  string resultdir = conf.gS ( "debug", "resultdir", "." );
-  /*-------------------------------------------*/
-
-#ifdef DEBUG
-  cerr << "Writing Results to " << resultdir << endl;
-#endif
-
-
-  MultiDataset md ( &conf );
-
-  const ClassNames & classNames = md.getClassNames ( "train" );
-
-  // initialize semantic segmentation method
-  SemanticSegmentation *semseg = NULL;
-  semseg = new SemSegContextTree3D ( &conf, &md );
-  
-  // train semantic segmentation method
-  cout << "\nTRAINING" << endl;
-  cout << "########\n" << endl;
-  semseg->train( &md );
-  
-  const LabeledSet *testFiles = md["test"];
-
-  set<int> forbidden_classes;
-  std::string forbidden_classes_s = conf.gS ( "analysis", "forbidden_classes", "" );
-
-  classNames.getSelection ( forbidden_classes_s, forbidden_classes );
 
   vector< int > zsizeVec;
   semseg->getDepthVector ( testFiles, zsizeVec, run_3Dseg );
@@ -162,10 +137,7 @@ int main ( int argc, char **argv )
   vector< string > filelist;
   NICE::MultiChannelImageT<double> segresult;
   NICE::MultiChannelImageT<double> gt;
-  std::vector< NICE::Matrix > M_vec;
 
-  cout << "\nCLASSIFICATION" << endl;
-  cout << "##############\n" << endl;
   for (LabeledSet::const_iterator it = testFiles->begin(); it != testFiles->end(); it++)
   {
     for (std::vector<ImageInfo *>::const_iterator jt = it->second.begin();
@@ -214,7 +186,7 @@ int main ( int argc, char **argv )
       {
         std::string fname = StringTools::baseName ( filelist[z], false );
 
-        if ( show_result || write_results )
+        if ( show_results || write_results )
         {
           NICE::ColorImage orig ( filelist[z] );
           NICE::ColorImage rgb;
@@ -265,12 +237,12 @@ int main ( int argc, char **argv )
             }
           }
 
-          segmentToOverlay ( orig.getChannel(1), rgb, ov_rgb );
-          segmentToOverlay ( orig.getChannel(1), rgb_gt, ov_rgb_gt );
-          
           if ( write_results )
           {
-            std::stringstream out;       
+            segmentToOverlay ( orig.getChannel(1), rgb, ov_rgb );
+            segmentToOverlay ( orig.getChannel(1), rgb_gt, ov_rgb_gt );
+
+            std::stringstream out;
             if ( output_postfix.size() > 0 )
               out << resultdir << "/" << fname << output_postfix;
             else
@@ -296,11 +268,91 @@ int main ( int argc, char **argv )
       idx++;
     }
   }
-  
+
   segresult.freeData();
+}
+
+/**
+ test semantic segmentation routines
+*/
+int main ( int argc, char **argv )
+{
+  std::set_terminate ( __gnu_cxx::__verbose_terminate_handler );
+
+  Config conf ( argc, argv );
+
+  ResourceStatistics rs;
+
+  /*---------------CONFIGURATION---------------*/
+  bool doCrossVal = conf.gB ( "debug", "do_crossval", false );
+  string resultdir = conf.gS ( "debug", "resultdir", "." );
+  /*-------------------------------------------*/
+
+#ifdef DEBUG
+  cerr << "Writing Results to " << resultdir << endl;
+#endif
+
+  std::vector< NICE::Matrix > M_vec;
+
+  MultiDataset md ( &conf );
+
+  const ClassNames & classNames = md.getClassNames ( "train" );
+  set<int> forbidden_classes;
+  classNames.getSelection ( conf.gS ( "analysis", "forbidden_classes", "" ),
+                            forbidden_classes );
+
+  // initialize semantic segmentation method
+  SemanticSegmentation *semseg = NULL;
+  semseg = new SemSegContextTree3D ( &conf, &classNames );
+  
+  // TRAINING AND TESTING
+  if (!doCrossVal)
+  {
+    semseg = new SemSegContextTree3D ( &conf, &classNames );
+
+    // STANDARD EVALUATION
+    cout << "\nTRAINING" << endl;
+    cout << "########\n" << endl;
+    semseg->train( &md );
+
+    cout << "\nCLASSIFICATION" << endl;
+    cout << "##############\n" << endl;
+    const LabeledSet *testFiles = md["test"];
+    startClassification (semseg, M_vec, conf, testFiles, classNames,
+                         forbidden_classes, resultdir, doCrossVal );
+
+    delete semseg;
+  }
+  else
+  {
+    // CROSS-VALIDATION
+    for (int cval = 1; cval <= 10; cval++)
+    {
+      semseg = new SemSegContextTree3D ( &conf, &classNames );
+
+      stringstream ss;
+      ss << cval;
+      string cvaltrain = "train_cv" + ss.str();
+      string cvaltest = "test_cv" + ss.str();
+
+      cout << "\nTRAINING " << cval << endl;
+      cout << "###########\n" << endl;
+      const LabeledSet *trainFiles = md[cvaltrain];
+      semseg->train( trainFiles );
+
+      cout << "\nCLASSIFICATION " << cval << endl;
+      cout << "#################\n" << endl;
+      const LabeledSet *testFiles = md[cvaltest];
+      startClassification (semseg, M_vec, conf, testFiles, classNames,
+                           forbidden_classes, resultdir, doCrossVal );
+
+      delete semseg;
+    }
+  }
 
   cout << "\nSTATISTICS" << endl;
   cout << "##########\n" << endl;
+
   long maxMemory;
   double userCPUTime, sysCPUTime;
   rs.getStatistics ( maxMemory, userCPUTime, sysCPUTime );
@@ -386,7 +438,5 @@ int main ( int argc, char **argv )
   }
   fout.close();
 
-  delete semseg;
-
   return 0;
 }