瀏覽代碼

improved class. statistics output

Sven Sickert 9 年之前
父節點
當前提交
121896824e
共有 2 個文件被更改,包括 64 次插入56 次删除
  1. 61 55
      semseg/SemSegTools.cpp
  2. 3 1
      semseg/SemSegTools.h

+ 61 - 55
semseg/SemSegTools.cpp

@@ -6,6 +6,7 @@
 
 */
 #include <iostream>
+#include <iomanip>
 
 #include "SemSegTools.h"
 
@@ -42,7 +43,7 @@ void SemSegTools::segmentToOverlay (
             uchar val = orig->getPixelQuick(x,y);
             for (int c = 0; c < 3; c++)
                 channelMat[c](x,y) = alpha*(double)val
-                                   + (1.0-alpha)*(double)segment.getPixel( x, y, c );
+                        + (1.0-alpha)*(double)segment.getPixel( x, y, c );
         }
 
     for (int y = 0; y < ysize; y++)
@@ -92,34 +93,41 @@ void SemSegTools::updateConfusionMatrix(
 void SemSegTools::computeClassificationStatistics(
         Matrix &confMat,
         const ClassNames &classNames,
-        const std::set<int> &forbiddenClasses )
+        const std::set<int> &forbiddenClasses,
+        map<int,int> & classMappingInv )
 {
+    std::cout << "\nPERFORMANCE" << std::endl;
+    std::cout << "###########\n" << std::endl;
+
     double overallTrue = 0.0;
     double sumAll  = 0.0;
 
-    // print confusion matrix & get overall recognition rate
-    std::cout << "Confusion Matrix:" <<  std::endl;
+    // overall recognition rate
     for ( int r = 0; r < (int) confMat.rows(); r++ )
-    {
         for ( int c = 0; c < (int) confMat.cols(); c++ )
         {
             if ( r == c )
                 overallTrue += confMat( r, c );
 
             sumAll += confMat( r, c );
-            std::cout << confMat( r, c ) << " ";
         }
-        std::cout << std::endl;
-    }
+
     overallTrue /= sumAll;
 
+    double truePos = (double)confMat(1,1);
+    //double trueNeg = (double)confMat(0,0);
+    double falsePos = (double)confMat(0,1);
+    double falseNeg = (double)confMat(1,0);
+
     // binary classification metrics
-    double precision, recall, f1score = -1.0;
-    if ( confMat.rows() == 2 )
+    if ( classMappingInv.size() == 2 )
     {
-        precision = (double)confMat(1,1) / (double)(confMat(1,1)+confMat(0,1));
-        recall = (double)confMat(1,1) / (double)(confMat(1,1)+confMat(1,0));
-        f1score = 2.0*(precision*recall)/(precision+recall);
+        double precision = truePos / (truePos+falsePos);
+        double recall = truePos / (truePos+falseNeg);
+        double f1score = 2.0*(precision*recall)/(precision+recall);
+        std::cout << "\nPrecision: " << precision;
+        std::cout << "\nRecall: " << recall;
+        std::cout << "\nF1Score: " << f1score;
     }
 
     // normalizing confMat using rows
@@ -135,43 +143,42 @@ void SemSegTools::computeClassificationStatistics(
                 confMat ( r, c ) /= sum;
     }
 
-    // get average recognition rate
-    double avgTrue = 0.0;
-    int classesTrained = 0;
-    for ( int r = 0 ; r < (int) confMat.rows() ; r++ )
+    // printing confusion matrix
+    short int printWidth = 16;
+    std::cout.precision(6);
+    std::cout << std::setw(printWidth) << "";
+    for (int r = 0; r < (int) confMat.rows(); r++)
     {
-        if ( classNames.existsClassno ( r )
-             && ( forbiddenClasses.find ( r ) == forbiddenClasses.end() ) )
+        int cl = classMappingInv[r];
+        if ( classNames.existsClassno ( cl )
+             && ( forbiddenClasses.find ( cl ) == forbiddenClasses.end() ) )
         {
-            avgTrue += confMat ( r, r );
-            double lsum = 0.0;
-            for ( int r2 = 0; r2 < ( int ) confMat.rows(); r2++ )
-                lsum += confMat ( r,r2 );
-
-            if ( lsum != 0.0 )
-                classesTrained++;
+            std::string cname = classNames.text ( cl );
+            std::cout << std::setw(printWidth) << cname.c_str();
         }
     }
-
-    // print classification statistics
-    std::cout << "\nOverall Recogntion Rate: " << overallTrue;
-    std::cout << "\nAverage Recogntion Rate: " << avgTrue / ( classesTrained );
-    std::cout << "\nLower Bound: " << 1.0 /(double)classesTrained;
-    std::cout << "\nPrecision: " << precision;
-    std::cout << "\nRecall: " << recall;
-    std::cout << "\nF1Score: " << f1score;
-
-    std::cout <<"\n\nClasses:" << std::endl;
-    for ( int r = 0 ; r < (int) confMat.rows() ; r++ )
+    std::cout << std::endl;
+    for (int r = 0; r < (int) confMat.rows(); r++)
     {
-        if ( classNames.existsClassno ( r )
-             && ( forbiddenClasses.find ( r ) == forbiddenClasses.end() ) )
+        int cl = classMappingInv[r];
+        if ( classNames.existsClassno ( cl )
+             && ( forbiddenClasses.find ( cl ) == forbiddenClasses.end() ) )
         {
-            std::string cname = classNames.text ( r );
-            std::cout << cname.c_str() << ": " << confMat ( r, r ) << std::endl;
+            std::string cname = classNames.text ( cl );
+            std::cout << std::setw(printWidth) << cname.c_str();
+
+            for (int c = 0; c < (int) confMat.cols(); c++)
+                std::cout << std::setw(printWidth) << std::fixed << confMat (r, c);
+
+            std::cout << std::endl;
         }
     }
 
+    // print classification statistics
+    std::cout << "\nOverall Recogntion Rate: " << overallTrue;
+    std::cout << "\nAverage Recogntion Rate: " << confMat.trace() / (double)classMappingInv.size();
+    std::cout << "\nLower Bound: " << 1.0 /(double)classMappingInv.size();
+    std::cout << std::endl;
 }
 
 void SemSegTools::saveResultsToImageFile(
@@ -233,7 +240,7 @@ void SemSegTools::collectTrainingExamples (
 
     int backgroundClassNo = 0;
     
-    if ( useExcludedAsBG ) 
+    if ( useExcludedAsBG )
     {
         backgroundClassNo = cn.classno("various");
         assert ( backgroundClassNo >= 0 );
@@ -269,33 +276,33 @@ void SemSegTools::collectTrainingExamples (
         pixelLabels.set(0);
         locResult->calcLabeledImage ( pixelLabels, cn.getBackgroundClass() );
 
-    #ifdef DEBUG_LOCALIZATION
+#ifdef DEBUG_LOCALIZATION
         NICE::Image img (imgfn);
         showImage(img);
         showImage(pixelLabels);
-    #endif
+#endif
 
         Example pce ( ce, 0, 0 );
         for ( int x = 0 ; x < xsize ; x += grid_size_x )
             for ( int y = 0 ; y < ysize ; y += grid_size_y )
             {
                 if ( (x >= grid_border_x) &&
-                    ( y >= grid_border_y ) && ( x < xsize - grid_border_x ) &&
-                    ( y < ysize - grid_border_x ) )
+                     ( y >= grid_border_y ) && ( x < xsize - grid_border_x ) &&
+                     ( y < ysize - grid_border_x ) )
                 {
                     pce.x = x; pce.y = y;
                     int classno = pixelLabels.getPixel(x,y);
 
                     if ( classnoSelection.find(classno) != classnoSelection.end() ) {
-                    examples.push_back ( pair<int, Example> (
-                        classno,
-                        pce // FIXME: offset handling
-                    ) );
+                        examples.push_back ( pair<int, Example> (
+                                                 classno,
+                                                 pce // FIXME: offset handling
+                                                 ) );
                     } else if ( useExcludedAsBG ) {
-                    examples.push_back ( pair<int, Example> (
-                        backgroundClassNo,
-                        pce // FIXME: offset handling
-                    ) );
+                        examples.push_back ( pair<int, Example> (
+                                                 backgroundClassNo,
+                                                 pce // FIXME: offset handling
+                                                 ) );
                     }
                 }
             }
@@ -303,4 +310,3 @@ void SemSegTools::collectTrainingExamples (
 
     std::cerr << "total number of examples: " << (int)examples.size() << std::endl;
 }
-

+ 3 - 1
semseg/SemSegTools.h

@@ -56,11 +56,13 @@ class SemSegTools
      * @param confMat confusion matrix
      * @param classNames class names object
      * @param forbidden_classes set of classes, that should be ignored
+     * @param classMappingInv mapping for a subset of classes
      */
     static void computeClassificationStatistics (
             NICE::Matrix & confMat,
             const OBJREC::ClassNames & classNames,
-            const std::set<int> & forbiddenClasses );
+            const std::set<int> & forbiddenClasses,
+            std::map<int,int> & classMappingInv );
 
     /**
      * @brief save results to image file