Browse Source

output of prob maps for testSemanticSegmentation3D

Sven Sickert 10 years ago
parent
commit
3a590d0a7c
1 changed files with 54 additions and 28 deletions
  1. 54 28
      progs/testSemanticSegmentation3D.cpp

+ 54 - 28
progs/testSemanticSegmentation3D.cpp

@@ -13,13 +13,13 @@
 
 #include "core/basics/Config.h"
 #include "core/basics/StringTools.h"
-#include <vislearning/baselib/ICETools.h>
+#include "vislearning/baselib/ICETools.h"
 
-#include <core/image/MultiChannelImage3DT.h>
-#include <semseg/semseg/SemSegContextTree3D.h>
+#include "core/image/MultiChannelImage3DT.h"
+#include "semseg/semseg/SemSegContextTree3D.h"
 
-#include <core/basics/ResourceStatistics.h>
-#include <core/image/Morph.h>
+#include "core/basics/ResourceStatistics.h"
+#include "core/image/Morph.h"
 
 #include <fstream>
 #include <vector>
@@ -132,7 +132,6 @@ void startClassification (SemanticSegmentation *semseg,
 
   vector< int > zsizeVec;
   semseg->getDepthVector ( testFiles, zsizeVec, run_3Dseg );
-
   int depthCount = 0, idx = 0;
   vector< string > filelist;
   NICE::MultiChannelImageT<double> segresult;
@@ -193,8 +192,6 @@ void startClassification (SemanticSegmentation *semseg,
           NICE::ColorImage rgb_gt;
           NICE::ColorImage ov_rgb;
           NICE::ColorImage ov_rgb_gt;
-          NICE::ColorImage prob_map( probabilities.width(), probabilities.height() );
-          prob_map.set(0,0,0);
 
           for ( int y = 0 ; y < orig.height(); y++ )
           {
@@ -206,15 +203,6 @@ void startClassification (SemanticSegmentation *semseg,
             }
           }
 
-          for ( int y = 0 ; y < probabilities.height(); y++ )
-            for ( int x = 0 ; x < probabilities.width(); x++ )
-            {
-              double probVal = probabilities.get( x, y, z, 1 ) * 255.0;
-              int tmp = round(probVal);
-              for ( int c = 0 ; c < 3 ; c++ )
-                prob_map.setPixel( x, y, c, tmp );
-            }
-
           // confusion matrix
           NICE::Matrix M ( classNames.getMaxClassno() + 1, classNames.getMaxClassno() + 1 );
           M.set ( 0 );
@@ -255,7 +243,29 @@ void startClassification (SemanticSegmentation *semseg,
             rgb_gt.write ( out.str() + "_groundtruth." + output_type );
             ov_rgb.write ( out.str() + "_overlay_res." + output_type );
             ov_rgb_gt.write ( out.str() + "_overlay_gt." + output_type );
-            prob_map.write ( out.str() + "_probs." + output_type );
+
+            // write Probability maps
+            NICE::ColorImage prob_map( probabilities.width(), probabilities.height() );
+            prob_map.set(0,0,0);
+            int iNumChannels = probabilities.channels();
+            for ( int idxProbMap = 0; idxProbMap < iNumChannels; idxProbMap++)
+            {
+                for ( int y = 0 ; y < probabilities.height(); y++ )
+                {
+                    for ( int x = 0 ; x < probabilities.width(); x++ )
+                    {
+                        double probVal = probabilities.get( x, y, z, idxProbMap ) * 255.0;
+                        int tmp = round(probVal);
+                        for ( int c = 0 ; c < 3 ; c++ )
+                            prob_map.setPixel( x, y, c, tmp );
+                    }
+                }
+                std::stringstream ssFileProbMap;
+                //ssFileProbMap << out.str() << "_probs." << "c" << idxProbMap << "." << output_type;
+                ssFileProbMap << out.str() << "_probs." << "c-" << classNames.code( idxProbMap ) << "." << output_type;
+                //classNames
+                prob_map.write ( ssFileProbMap.str() );
+            }
           }
         }
       }
@@ -303,7 +313,6 @@ int main ( int argc, char **argv )
 
   // initialize semantic segmentation method
   SemanticSegmentation *semseg = NULL;
-  //semseg = new SemSegContextTree3D ( &conf, &classNames );
   
   // TRAINING AND TESTING
   if (!doCrossVal)
@@ -360,6 +369,9 @@ int main ( int argc, char **argv )
   cout << "CPU Time (user): " << userCPUTime << " seconds" << endl;
   cout << "CPU Time (sys):  " << sysCPUTime << " seconds" << endl;
 
+  cout << "\nPERFORMANCE" << endl;
+  cout << "###########\n" << endl;
+
   double overall = 0.0;
   double sumall = 0.0;
 
@@ -382,6 +394,25 @@ int main ( int argc, char **argv )
   }
   overall /= sumall;
 
+  cout << "Confusion Matrix:" << endl;
+  for (int r = 0; r < (int) M.rows(); r++)
+  {
+    for (int c = 0; c < (int) M.cols(); c++)
+    {
+      cout << M(r,c) << "  ";
+    }
+    cout << endl;
+  }
+
+  // metrics for binary classification
+  double precision, recall, f1score = -1.0;
+  if (classNames.getMaxClassno()+1 == 2)
+  {
+    precision = (double)M(1,1) / (double)(M(1,1)+M(0,1));
+    recall = (double)M(1,1) / (double)(M(1,1)+M(1,0));
+    f1score = 2.0*(precision*recall)/(precision+recall);
+  }
+
   // normalizing M using rows
   for ( int r = 0 ; r < ( int ) M.rows() ; r++ )
   {
@@ -415,16 +446,13 @@ int main ( int argc, char **argv )
     }
   }
 
-  // print/save results of evaluation
-  cout << "\nPERFORMANCE" << endl;
-  cout << "###########\n" << endl;
-  ofstream fout ( ( resultdir + "/res.txt" ).c_str(), ios::out );
-  fout << "Overall Recognition Rate: " << overall << endl;
-  fout << "Average Recognition Rate: " << avg_perf / ( classes_trained ) << endl;
-  fout << "Lower Bound: " << 1.0  / classes_trained << endl;
+  // print results of evaluation
   cout << "Overall Recogntion Rate: " << overall << endl;
   cout << "Average Recogntion Rate: " << avg_perf / ( classes_trained ) << endl;
   cout << "Lower Bound: " << 1.0  / classes_trained << endl;
+  cout << "Precision: " << precision << endl;
+  cout << "Recall: " << recall << endl;
+  cout << "F1Score: " << f1score << endl;
 
   cout <<"\nClasses:" << endl;
   for ( int r = 0 ; r < ( int ) M.rows() ; r++ )
@@ -432,11 +460,9 @@ int main ( int argc, char **argv )
     if ( ( classNames.existsClassno ( r ) ) && ( forbidden_classes.find ( r ) == forbidden_classes.end() ) )
     {
       std::string classname = classNames.text ( r );
-      fout << classname.c_str() << ": " << M ( r, r ) << endl;
       cout << classname.c_str() << ": " << M ( r, r ) << endl;
     }
   }
-  fout.close();
 
   return 0;
 }