/** * @file testSemSegConvTrees.cpp * @brief test semantic segmentation routines of the ConvTree method * @author Sven Sickert * @date 10/20/2014 */ #include "core/basics/StringTools.h" #include "core/basics/ResourceStatistics.h" #include "core/basics/Timer.h" #include "core/image/Morph.h" #include "semseg/semseg/SemSegObliqueTree.h" #include "semseg/semseg/SemSegTools.h" #include #include using namespace OBJREC; int main ( int argc, char **argv ) { // variables NICE::Config conf (argc, argv ); NICE::ResourceStatistics rs; bool postProcessing = conf.gB( "SemSegObliqueTree", "post_process", false); MultiDataset md ( &conf ); const ClassNames & classNames = md.getClassNames ( "train" ); const LabeledSet *testFiles = md["test"]; std::set forbiddenClasses; classNames.getSelection ( conf.gS ( "analysis", "forbidden_classes", "" ), forbiddenClasses ); std::vector usedClasses ( classNames.numClasses(), true ); for ( std::set::const_iterator it = forbiddenClasses.begin(); it != forbiddenClasses.end(); ++it) { usedClasses [ *it ] = false; } std::map classMapping; int j = 0; for ( int i = 0; i < usedClasses.size(); i++ ) if (usedClasses[i]) { classMapping[i] = j; j++; } NICE::Matrix M ( classMapping.size(), classMapping.size() ); M.set( 0 ); // initialize semantic segmentation method SemanticSegmentation *semseg = NULL; // setup actual segmentation method semseg = new SemSegObliqueTree ( &conf, &classNames ); // training std::cout << "\nTRAINING" << std::endl; std::cout << "########\n" << std::endl; semseg->train( &md ); // testing NICE::Timer timer; std::cout << "\nCLASSIFICATION" << std::endl; std::cout << "##############\n" << std::endl; for (LabeledSet::const_iterator it = testFiles->begin(); it != testFiles->end(); it++) { for (std::vector::const_iterator jt = it->second.begin(); jt != it->second.end(); jt++) { ImageInfo & info = *(*jt); std::string file = info.img(); NICE::ImageT segresult, gtruth; if ( info.hasLocalizationInfo() ) { const LocalizationResult *l_gt = info.localization(); segresult.resize ( l_gt->xsize, l_gt->ysize ); segresult.set( 0 ); gtruth.resize( l_gt->xsize, l_gt->ysize ); gtruth.set ( 0 ); l_gt->calcLabeledImage ( gtruth, classNames.getBackgroundClass() ); } else { std::cerr << "testSemSegConvTrees: WARNING: NO localization info found for " << file << std::endl; } // actual testing NICE::MultiChannelImageT probabilities; timer.start(); semseg->semanticseg( file, segresult, probabilities ); timer.stop(); std::cout << "Time for Classification: " << timer.getLastAbsolute() << "\n\n"; // post processing results if (postProcessing) { std::cerr << "testSemSegConvTrees: WARNING: Post processing not yet supported." << std::endl; } // updating confusion matrix SemSegTools::updateConfusionMatrix ( segresult, gtruth, M, forbiddenClasses, classMapping ); // saving results to image file NICE::ColorImage rgb; NICE::ColorImage rgb_gt; NICE::ColorImage orig ( file ); classNames.labelToRGB( segresult, rgb); classNames.labelToRGB( gtruth, rgb_gt); std::string fname = NICE::StringTools::baseName ( file, false ); SemSegTools::saveResultsToImageFile( &conf, "analysis", orig, rgb_gt, rgb, fname ); } } // evaluation & analysis SemSegTools::computeClassificationStatistics( M, classNames, forbiddenClasses); // Cleaning up delete semseg; }