Преглед изворни кода

added program for SemSegConvolutionalTree

Sven Sickert пре 10 година
родитељ
комит
77a6393da6
1 измењених фајлова са 103 додато и 0 уклоњено
  1. 103 0
      progs/testSemSegConvTrees.cpp

+ 103 - 0
progs/testSemSegConvTrees.cpp

@@ -0,0 +1,103 @@
+/**
+ * @file testSemSegConvTrees.cpp
+ * @brief test semantic segmentation routines of the ConvTree method
+ * @author Sven Sickert
+ * @date 10/20/2014
+
+*/
+
+#include "core/basics/StringTools.h"
+#include "core/basics/ResourceStatistics.h"
+
+#include "semseg/semseg/SemSegConvolutionalTree.h"
+#include "semseg/semseg/SemSegTools.h"
+
+#include <fstream>
+#include <vector>
+
+using namespace OBJREC;
+
+int main ( int argc, char **argv )
+{
+    // variables
+    NICE::Config conf (argc, argv );
+    NICE::ResourceStatistics rs;
+
+    MultiDataset md ( &conf );
+    const ClassNames & classNames = md.getClassNames ( "train" );
+    const LabeledSet *testFiles = md["test"];
+    std::set<int> forbiddenClasses;
+    classNames.getSelection ( conf.gS ( "analysis", "forbidden_classes", "" ),
+                              forbiddenClasses );
+
+    NICE::Matrix M ( classNames.getMaxClassno() + 1,
+                     classNames.getMaxClassno() + 1 );
+    M.set( 0 );
+
+    // initialize semantic segmentation method
+    SemanticSegmentation *semseg = NULL;
+
+    // setup actual segmentation method
+    semseg = new SemSegConvolutionalTree ( &conf, &classNames );
+
+    // training
+    std::cout << "\nTRAINING" << std::endl;
+    std::cout << "########\n" << std::endl;
+    semseg->train( &md );
+
+    // testing
+    std::cout << "\nCLASSIFICATION" << std::endl;
+    std::cout << "##############\n" << std::endl;
+    for (LabeledSet::const_iterator it = testFiles->begin(); it != testFiles->end(); it++)
+    {
+        for (std::vector<ImageInfo *>::const_iterator jt = it->second.begin();
+            jt != it->second.end(); jt++)
+        {
+            ImageInfo & info = *(*jt);
+            std::string file = info.img();
+
+            NICE::Image segresult, gtruth;
+            if ( info.hasLocalizationInfo() )
+            {
+                const LocalizationResult *l_gt = info.localization();
+
+                segresult.resize ( l_gt->xsize, l_gt->ysize );
+                segresult.set( 0 );
+                gtruth.resize( l_gt->xsize, l_gt->ysize );
+                gtruth.set ( 0 );
+                l_gt->calcLabeledImage ( gtruth, classNames.getBackgroundClass() );
+            }
+            else
+            {
+                std::cerr << "testSemSegConvTrees: WARNING: NO localization info found for "
+                          << file << std::endl;
+            }
+
+            // actual testing
+            NICE::MultiChannelImageT<double> probabilities;
+            semseg->semanticseg( file, segresult, probabilities );
+
+            // updating confusion matrix
+            SemSegTools::updateConfusionMatrix (
+                        segresult, gtruth, M, forbiddenClasses );
+
+            // saving results to image file
+            NICE::ColorImage rgb;
+            NICE::ColorImage rgb_gt;
+            NICE::ColorImage orig ( file );
+            classNames.labelToRGB( segresult, rgb);
+            classNames.labelToRGB( gtruth, rgb_gt);
+            std::string fname = NICE::StringTools::baseName ( file, false );
+
+            SemSegTools::saveResultsToImageFile(
+                        &conf, "analysis", orig, rgb_gt, rgb, fname );
+        }
+    }
+
+    // evaluation & analysis
+    SemSegTools::computeClassificationStatistics( M, classNames, forbiddenClasses);
+
+    // Cleaning up
+    delete semseg;
+
+}