/** * @file testSemSegConvTrees.cpp * @brief test semantic segmentation routines of the ConvTree method * @author Sven Sickert * @date 10/20/2014 */ #include "core/basics/StringTools.h" #include "core/basics/Timer.h" #include "core/image/Morph.h" #include "semseg/semseg/SemSegObliqueTree.h" #include "semseg/semseg/SemSegTools.h" #include #include using namespace OBJREC; int main ( int argc, char **argv ) { // variables NICE::Config conf (argc, argv ); NICE::ResourceStatistics rs; MultiDataset md ( &conf ); const ClassNames & classNames = md.getClassNames ( "train" ); const LabeledSet *testFiles = md["test"]; std::set forbiddenClasses; classNames.getSelection ( conf.gS ( "analysis", "forbidden_classes", "" ), forbiddenClasses ); std::vector usedClasses ( classNames.numClasses(), true ); for ( std::set::const_iterator it = forbiddenClasses.begin(); it != forbiddenClasses.end(); ++it) { usedClasses [ *it ] = false; } std::map classMapping, classMappingInv; int j = 0; for ( int i = 0; i < usedClasses.size(); i++ ) if (usedClasses[i]) { classMapping[i] = j; classMappingInv[j] = i; j++; } NICE::Matrix M ( classMapping.size(), classMapping.size() ); M.set( 0 ); // initialize semantic segmentation method SemanticSegmentation *semseg = NULL; // setup actual segmentation method semseg = new SemSegObliqueTree ( &conf, &classNames ); // training std::cout << "\nTRAINING" << std::endl; std::cout << "########\n" << std::endl; semseg->train( &md ); // testing NICE::Timer timer; std::cout << "\nCLASSIFICATION" << std::endl; std::cout << "##############\n" << std::endl; std::vector zsizeVec; bool run3Dseg = semseg->isMode3D(); SemSegTools::getDepthVector ( testFiles, zsizeVec, run3Dseg ); int depthCount = 0, idx = 0; std::vector filelist; NICE::MultiChannelImageT segresult, gt; for (LabeledSet::const_iterator it = testFiles->begin(); it != testFiles->end(); it++) { for (std::vector::const_iterator jt = it->second.begin(); jt != it->second.end(); jt++) { ImageInfo & info = *(*jt); std::string file = info.img(); filelist.push_back(file); depthCount++; NICE::ImageT gtruth, res; if ( info.hasLocalizationInfo() ) { const LocalizationResult *l_gt = info.localization(); gtruth.resize( l_gt->xsize, l_gt->ysize ); l_gt->calcLabeledImage ( gtruth, classNames.getBackgroundClass() ); } else { std::cerr << "testSemSegConvTrees: WARNING: NO localization info found for " << file << std::endl; } segresult.addChannel(gtruth); gt.addChannel(gtruth); int depthBoundary = 1; if ( run3Dseg ) depthBoundary = zsizeVec[idx]; std::cout << "Slice " << depthCount << "/" << depthBoundary << std::endl; if ( depthCount < depthBoundary ) continue; // actual testing NICE::MultiChannelImage3DT probabilities; timer.start(); semseg->semanticseg( filelist, segresult, probabilities ); timer.stop(); std::cout << "Time for Classification: " << timer.getLastAbsolute() << "\n\n"; // updating confusion matrix res = gtruth; for ( int z = 0; z < segresult.channels(); z++ ) { for ( int y = 0; y < res.height(); y++ ) for ( int x = 0; x < res.width(); x++) { res.setPixel ( x, y, segresult.get(x,y,(unsigned int)z) ); if ( run3Dseg ) gtruth.setPixel ( x, y, gt.get(x,y,(unsigned int)z) ); } SemSegTools::updateConfusionMatrix ( res, gtruth, M, forbiddenClasses, classMapping ); // saving results to image file NICE::ColorImage rgb; NICE::ColorImage rgb_gt; NICE::ColorImage orig ( filelist[z] ); classNames.labelToRGB( res, rgb); classNames.labelToRGB( gtruth, rgb_gt); std::string fname = NICE::StringTools::baseName ( filelist[z], false ); std::string outStr; SemSegTools::saveResultsToImageFile( &conf, "analysis", orig, rgb_gt, rgb, fname, outStr ); } // prepare for new 3d image filelist.clear(); segresult.reInit(0,0,0); gt.reInit(0,0,0); depthCount = 0; idx++; } } // resource statistics SemSegTools::computeResourceStatistics ( rs ); // evaluation & analysis SemSegTools::computeClassificationStatistics( M, classNames, forbiddenClasses, classMappingInv ); // Cleaning up delete semseg; }