瀏覽代碼

small fix for arbitrary crossval

Sven Sickert 9 年之前
父節點
當前提交
71f4900f62
共有 1 個文件被更改,包括 7 次插入7 次删除
  1. 7 7
      progs/testSemanticSegmentation3D.cpp

+ 7 - 7
progs/testSemanticSegmentation3D.cpp

@@ -39,11 +39,11 @@ void startClassification (SemanticSegmentation *semseg,
                           const ClassNames & classNames,
                           const set<int> & forbidden_classes,
                           std::map<int,int> & classMapping,
-                          const bool doCrossVal)
+                          const unsigned short cvRuns)
 {
     bool write_results = conf.gB ( "debug", "write_results", false );
     bool writeProbMaps = conf.gB ( "debug", "write_prob_maps", false );
-    if (doCrossVal)
+    if (cvRuns > 1)
         write_results = false;
 
     bool run_3Dseg = conf.gB( "SSContextTree", "run_3dseg", false);
@@ -181,7 +181,7 @@ int main ( int argc, char **argv )
     ResourceStatistics rs;
 
     /*---------------CONFIGURATION---------------*/
-    bool doCrossVal = conf.gB ( "debug", "do_crossval", false );
+    unsigned short crossValRuns = conf.gI ( "debug", "cross_val_runs", 1 );
     /*-------------------------------------------*/
 
 #ifdef DEBUG
@@ -219,7 +219,7 @@ int main ( int argc, char **argv )
     SemanticSegmentation *semseg = NULL;
 
     // TRAINING AND TESTING
-    if (!doCrossVal)
+    if ( crossValRuns == 1 )
     {
         semseg = new SemSegContextTree3D ( &conf, &classNames );
 
@@ -232,14 +232,14 @@ int main ( int argc, char **argv )
         cout << "##############\n" << endl;
         const LabeledSet *testFiles = md["test"];
         startClassification (semseg, M_vec, conf, testFiles, classNames,
-                             forbidden_classes, classMapping, doCrossVal );
+                             forbidden_classes, classMapping, crossValRuns );
 
         delete semseg;
     }
     else
     {
         // CROSS-VALIDATION
-        for (int cval = 1; cval <= 10; cval++)
+        for (int cval = 1; cval <= crossValRuns; cval++)
         {
             semseg = new SemSegContextTree3D ( &conf, &classNames );
 
@@ -257,7 +257,7 @@ int main ( int argc, char **argv )
             cout << "#################\n" << endl;
             const LabeledSet *testFiles = md[cvaltest];
             startClassification (semseg, M_vec, conf, testFiles, classNames,
-                                 forbidden_classes, classMapping, doCrossVal );
+                                 forbidden_classes, classMapping, crossValRuns );
 
             delete semseg;
         }