testSemSegObliqueTrees.cpp 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. /**
  2. * @file testSemSegConvTrees.cpp
  3. * @brief test semantic segmentation routines of the ConvTree method
  4. * @author Sven Sickert
  5. * @date 10/20/2014
  6. */
  7. #include "core/basics/StringTools.h"
  8. #include "core/basics/ResourceStatistics.h"
  9. #include "core/basics/Timer.h"
  10. #include "core/image/Morph.h"
  11. #include "semseg/semseg/SemSegObliqueTree.h"
  12. #include "semseg/semseg/SemSegTools.h"
  13. #include <fstream>
  14. #include <vector>
  15. using namespace OBJREC;
  16. int main ( int argc, char **argv )
  17. {
  18. // variables
  19. NICE::Config conf (argc, argv );
  20. NICE::ResourceStatistics rs;
  21. bool postProcessing = conf.gB( "SemSegObliqueTree", "post_process", false);
  22. MultiDataset md ( &conf );
  23. const ClassNames & classNames = md.getClassNames ( "train" );
  24. const LabeledSet *testFiles = md["test"];
  25. std::set<int> forbiddenClasses;
  26. classNames.getSelection ( conf.gS ( "analysis", "forbidden_classes", "" ),
  27. forbiddenClasses );
  28. std::vector<bool> usedClasses ( classNames.numClasses(), true );
  29. for ( std::set<int>::const_iterator it = forbiddenClasses.begin();
  30. it != forbiddenClasses.end(); ++it)
  31. {
  32. usedClasses [ *it ] = false;
  33. }
  34. std::map<int,int> classMapping, classMappingInv;
  35. int j = 0;
  36. for ( int i = 0; i < usedClasses.size(); i++ )
  37. if (usedClasses[i])
  38. {
  39. classMapping[i] = j;
  40. classMappingInv[j] = i;
  41. j++;
  42. }
  43. NICE::Matrix M ( classMapping.size(), classMapping.size() );
  44. M.set( 0 );
  45. // initialize semantic segmentation method
  46. SemanticSegmentation *semseg = NULL;
  47. // setup actual segmentation method
  48. semseg = new SemSegObliqueTree ( &conf, &classNames );
  49. // training
  50. std::cout << "\nTRAINING" << std::endl;
  51. std::cout << "########\n" << std::endl;
  52. semseg->train( &md );
  53. // testing
  54. NICE::Timer timer;
  55. std::cout << "\nCLASSIFICATION" << std::endl;
  56. std::cout << "##############\n" << std::endl;
  57. for (LabeledSet::const_iterator it = testFiles->begin(); it != testFiles->end(); it++)
  58. {
  59. for (std::vector<ImageInfo *>::const_iterator jt = it->second.begin();
  60. jt != it->second.end(); jt++)
  61. {
  62. ImageInfo & info = *(*jt);
  63. std::string file = info.img();
  64. NICE::ImageT<int> segresult, gtruth;
  65. if ( info.hasLocalizationInfo() )
  66. {
  67. const LocalizationResult *l_gt = info.localization();
  68. segresult.resize ( l_gt->xsize, l_gt->ysize );
  69. segresult.set( 0 );
  70. gtruth.resize( l_gt->xsize, l_gt->ysize );
  71. gtruth.set ( 0 );
  72. l_gt->calcLabeledImage ( gtruth, classNames.getBackgroundClass() );
  73. }
  74. else
  75. {
  76. std::cerr << "testSemSegConvTrees: WARNING: NO localization info found for "
  77. << file << std::endl;
  78. }
  79. // actual testing
  80. NICE::MultiChannelImageT<double> probabilities;
  81. timer.start();
  82. semseg->semanticseg( file, segresult, probabilities );
  83. timer.stop();
  84. std::cout << "Time for Classification: " << timer.getLastAbsolute()
  85. << "\n\n";
  86. // post processing results
  87. if (postProcessing)
  88. {
  89. std::cerr << "testSemSegConvTrees: WARNING: Post processing not yet supported."
  90. << std::endl;
  91. }
  92. // updating confusion matrix
  93. SemSegTools::updateConfusionMatrix ( segresult, gtruth, M,
  94. forbiddenClasses, classMapping );
  95. // saving results to image file
  96. NICE::ColorImage rgb;
  97. NICE::ColorImage rgb_gt;
  98. NICE::ColorImage orig ( file );
  99. classNames.labelToRGB( segresult, rgb);
  100. classNames.labelToRGB( gtruth, rgb_gt);
  101. std::string fname = NICE::StringTools::baseName ( file, false );
  102. SemSegTools::saveResultsToImageFile(
  103. &conf, "analysis", orig, rgb_gt, rgb, fname );
  104. }
  105. }
  106. // evaluation & analysis
  107. SemSegTools::computeClassificationStatistics(
  108. M, classNames, forbiddenClasses, classMappingInv );
  109. // Cleaning up
  110. delete semseg;
  111. }