123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166 |
- /**
- * @file testSemSegConvTrees.cpp
- * @brief test semantic segmentation routines of the ConvTree method
- * @author Sven Sickert
- * @date 10/20/2014
- */
- #include "core/basics/StringTools.h"
- #include "core/basics/Timer.h"
- #include "core/image/Morph.h"
- #include "semseg/semseg/SemSegObliqueTree.h"
- #include "semseg/semseg/SemSegTools.h"
- #include <fstream>
- #include <vector>
- using namespace OBJREC;
- int main ( int argc, char **argv )
- {
- // variables
- NICE::Config conf (argc, argv );
- NICE::ResourceStatistics rs;
- MultiDataset md ( &conf );
- const ClassNames & classNames = md.getClassNames ( "train" );
- const LabeledSet *testFiles = md["test"];
- std::set<int> forbiddenClasses;
- classNames.getSelection ( conf.gS ( "analysis", "forbidden_classes", "" ),
- forbiddenClasses );
- std::vector<bool> usedClasses ( classNames.numClasses(), true );
- for ( std::set<int>::const_iterator it = forbiddenClasses.begin();
- it != forbiddenClasses.end(); ++it)
- {
- usedClasses [ *it ] = false;
- }
- std::map<int,int> classMapping, classMappingInv;
- int j = 0;
- for ( int i = 0; i < usedClasses.size(); i++ )
- if (usedClasses[i])
- {
- classMapping[i] = j;
- classMappingInv[j] = i;
- j++;
- }
- NICE::Matrix M ( classMapping.size(), classMapping.size() );
- M.set( 0 );
- // initialize semantic segmentation method
- SemanticSegmentation *semseg = NULL;
- // setup actual segmentation method
- semseg = new SemSegObliqueTree ( &conf, &classNames );
- // training
- std::cout << "\nTRAINING" << std::endl;
- std::cout << "########\n" << std::endl;
- semseg->train( &md );
- // testing
- NICE::Timer timer;
- std::cout << "\nCLASSIFICATION" << std::endl;
- std::cout << "##############\n" << std::endl;
- std::vector<int> zsizeVec;
- bool run3Dseg = semseg->isMode3D();
- SemSegTools::getDepthVector ( testFiles, zsizeVec, run3Dseg );
- int depthCount = 0, idx = 0;
- std::vector<std::string> filelist;
- NICE::MultiChannelImageT<int> segresult, gt;
- for (LabeledSet::const_iterator it = testFiles->begin(); it != testFiles->end(); it++)
- {
- for (std::vector<ImageInfo *>::const_iterator jt = it->second.begin();
- jt != it->second.end(); jt++)
- {
- ImageInfo & info = *(*jt);
- std::string file = info.img();
- filelist.push_back(file);
- depthCount++;
- NICE::ImageT<int> gtruth, res;
- if ( info.hasLocalizationInfo() )
- {
- const LocalizationResult *l_gt = info.localization();
- gtruth.resize( l_gt->xsize, l_gt->ysize );
- l_gt->calcLabeledImage ( gtruth, classNames.getBackgroundClass() );
- }
- else
- {
- std::cerr << "testSemSegConvTrees: WARNING: NO localization info found for "
- << file << std::endl;
- }
- segresult.addChannel(gtruth);
- gt.addChannel(gtruth);
- int depthBoundary = 1;
- if ( run3Dseg )
- depthBoundary = zsizeVec[idx];
- std::cout << "Slice " << depthCount << "/"
- << depthBoundary << std::endl;
- if ( depthCount < depthBoundary )
- continue;
- // actual testing
- NICE::MultiChannelImage3DT<double> probabilities;
- timer.start();
- semseg->semanticseg( filelist, segresult, probabilities );
- timer.stop();
- std::cout << "Time for Classification: " << timer.getLastAbsolute()
- << "\n\n";
- // updating confusion matrix
- res = gtruth;
- for ( int z = 0; z < segresult.channels(); z++ )
- {
- for ( int y = 0; y < res.height(); y++ )
- for ( int x = 0; x < res.width(); x++)
- {
- res.setPixel ( x, y, segresult.get(x,y,(unsigned int)z) );
- if ( run3Dseg )
- gtruth.setPixel ( x, y, gt.get(x,y,(unsigned int)z) );
- }
- SemSegTools::updateConfusionMatrix ( res, gtruth, M,
- forbiddenClasses, classMapping );
- // saving results to image file
- NICE::ColorImage rgb;
- NICE::ColorImage rgb_gt;
- NICE::ColorImage orig ( filelist[z] );
- classNames.labelToRGB( res, rgb);
- classNames.labelToRGB( gtruth, rgb_gt);
- std::string fname = NICE::StringTools::baseName ( filelist[z], false );
- std::string outStr;
- SemSegTools::saveResultsToImageFile( &conf, "analysis", orig,
- rgb_gt, rgb, fname, outStr );
- }
- // prepare for new 3d image
- filelist.clear();
- segresult.reInit(0,0,0);
- gt.reInit(0,0,0);
- depthCount = 0;
- idx++;
- }
- }
- // resource statistics
- SemSegTools::computeResourceStatistics ( rs );
- // evaluation & analysis
- SemSegTools::computeClassificationStatistics(
- M, classNames, forbiddenClasses, classMappingInv );
- // Cleaning up
- delete semseg;
- }
|