|
@@ -0,0 +1,103 @@
|
|
|
+/**
|
|
|
+ * @file testSemSegConvTrees.cpp
|
|
|
+ * @brief test semantic segmentation routines of the ConvTree method
|
|
|
+ * @author Sven Sickert
|
|
|
+ * @date 10/20/2014
|
|
|
+
|
|
|
+*/
|
|
|
+
|
|
|
+#include "core/basics/StringTools.h"
|
|
|
+#include "core/basics/ResourceStatistics.h"
|
|
|
+
|
|
|
+#include "semseg/semseg/SemSegConvolutionalTree.h"
|
|
|
+#include "semseg/semseg/SemSegTools.h"
|
|
|
+
|
|
|
+#include <fstream>
|
|
|
+#include <vector>
|
|
|
+
|
|
|
+using namespace OBJREC;
|
|
|
+
|
|
|
+int main ( int argc, char **argv )
|
|
|
+{
|
|
|
+ // variables
|
|
|
+ NICE::Config conf (argc, argv );
|
|
|
+ NICE::ResourceStatistics rs;
|
|
|
+
|
|
|
+ MultiDataset md ( &conf );
|
|
|
+ const ClassNames & classNames = md.getClassNames ( "train" );
|
|
|
+ const LabeledSet *testFiles = md["test"];
|
|
|
+ std::set<int> forbiddenClasses;
|
|
|
+ classNames.getSelection ( conf.gS ( "analysis", "forbidden_classes", "" ),
|
|
|
+ forbiddenClasses );
|
|
|
+
|
|
|
+ NICE::Matrix M ( classNames.getMaxClassno() + 1,
|
|
|
+ classNames.getMaxClassno() + 1 );
|
|
|
+ M.set( 0 );
|
|
|
+
|
|
|
+ // initialize semantic segmentation method
|
|
|
+ SemanticSegmentation *semseg = NULL;
|
|
|
+
|
|
|
+ // setup actual segmentation method
|
|
|
+ semseg = new SemSegConvolutionalTree ( &conf, &classNames );
|
|
|
+
|
|
|
+ // training
|
|
|
+ std::cout << "\nTRAINING" << std::endl;
|
|
|
+ std::cout << "########\n" << std::endl;
|
|
|
+ semseg->train( &md );
|
|
|
+
|
|
|
+ // testing
|
|
|
+ std::cout << "\nCLASSIFICATION" << std::endl;
|
|
|
+ std::cout << "##############\n" << std::endl;
|
|
|
+ for (LabeledSet::const_iterator it = testFiles->begin(); it != testFiles->end(); it++)
|
|
|
+ {
|
|
|
+ for (std::vector<ImageInfo *>::const_iterator jt = it->second.begin();
|
|
|
+ jt != it->second.end(); jt++)
|
|
|
+ {
|
|
|
+ ImageInfo & info = *(*jt);
|
|
|
+ std::string file = info.img();
|
|
|
+
|
|
|
+ NICE::Image segresult, gtruth;
|
|
|
+ if ( info.hasLocalizationInfo() )
|
|
|
+ {
|
|
|
+ const LocalizationResult *l_gt = info.localization();
|
|
|
+
|
|
|
+ segresult.resize ( l_gt->xsize, l_gt->ysize );
|
|
|
+ segresult.set( 0 );
|
|
|
+ gtruth.resize( l_gt->xsize, l_gt->ysize );
|
|
|
+ gtruth.set ( 0 );
|
|
|
+ l_gt->calcLabeledImage ( gtruth, classNames.getBackgroundClass() );
|
|
|
+ }
|
|
|
+ else
|
|
|
+ {
|
|
|
+ std::cerr << "testSemSegConvTrees: WARNING: NO localization info found for "
|
|
|
+ << file << std::endl;
|
|
|
+ }
|
|
|
+
|
|
|
+ // actual testing
|
|
|
+ NICE::MultiChannelImageT<double> probabilities;
|
|
|
+ semseg->semanticseg( file, segresult, probabilities );
|
|
|
+
|
|
|
+ // updating confusion matrix
|
|
|
+ SemSegTools::updateConfusionMatrix (
|
|
|
+ segresult, gtruth, M, forbiddenClasses );
|
|
|
+
|
|
|
+ // saving results to image file
|
|
|
+ NICE::ColorImage rgb;
|
|
|
+ NICE::ColorImage rgb_gt;
|
|
|
+ NICE::ColorImage orig ( file );
|
|
|
+ classNames.labelToRGB( segresult, rgb);
|
|
|
+ classNames.labelToRGB( gtruth, rgb_gt);
|
|
|
+ std::string fname = NICE::StringTools::baseName ( file, false );
|
|
|
+
|
|
|
+ SemSegTools::saveResultsToImageFile(
|
|
|
+ &conf, "analysis", orig, rgb_gt, rgb, fname );
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // evaluation & analysis
|
|
|
+ SemSegTools::computeClassificationStatistics( M, classNames, forbiddenClasses);
|
|
|
+
|
|
|
+ // Cleaning up
|
|
|
+ delete semseg;
|
|
|
+
|
|
|
+}
|