testSemSegConvTrees.cpp 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. /**
  2. * @file testSemSegConvTrees.cpp
  3. * @brief test semantic segmentation routines of the ConvTree method
  4. * @author Sven Sickert
  5. * @date 10/20/2014
  6. */
  7. #include "core/basics/StringTools.h"
  8. #include "core/basics/ResourceStatistics.h"
  9. #include "core/basics/Timer.h"
  10. #include "core/image/Morph.h"
  11. #include "semseg/semseg/SemSegConvolutionalTree.h"
  12. #include "semseg/semseg/SemSegTools.h"
  13. #include <fstream>
  14. #include <vector>
  15. using namespace OBJREC;
  16. int main ( int argc, char **argv )
  17. {
  18. // variables
  19. NICE::Config conf (argc, argv );
  20. NICE::ResourceStatistics rs;
  21. MultiDataset md ( &conf );
  22. const ClassNames & classNames = md.getClassNames ( "train" );
  23. const LabeledSet *testFiles = md["test"];
  24. std::set<int> forbiddenClasses;
  25. classNames.getSelection ( conf.gS ( "analysis", "forbidden_classes", "" ),
  26. forbiddenClasses );
  27. NICE::Matrix M ( classNames.getMaxClassno() + 1,
  28. classNames.getMaxClassno() + 1 );
  29. M.set( 0 );
  30. // initialize semantic segmentation method
  31. SemanticSegmentation *semseg = NULL;
  32. // setup actual segmentation method
  33. semseg = new SemSegConvolutionalTree ( &conf, &classNames );
  34. // training
  35. std::cout << "\nTRAINING" << std::endl;
  36. std::cout << "########\n" << std::endl;
  37. semseg->train( &md );
  38. // testing
  39. NICE::Timer timer;
  40. std::cout << "\nCLASSIFICATION" << std::endl;
  41. std::cout << "##############\n" << std::endl;
  42. for (LabeledSet::const_iterator it = testFiles->begin(); it != testFiles->end(); it++)
  43. {
  44. for (std::vector<ImageInfo *>::const_iterator jt = it->second.begin();
  45. jt != it->second.end(); jt++)
  46. {
  47. ImageInfo & info = *(*jt);
  48. std::string file = info.img();
  49. NICE::Image segresult, gtruth;
  50. if ( info.hasLocalizationInfo() )
  51. {
  52. const LocalizationResult *l_gt = info.localization();
  53. segresult.resize ( l_gt->xsize, l_gt->ysize );
  54. segresult.set( 0 );
  55. gtruth.resize( l_gt->xsize, l_gt->ysize );
  56. gtruth.set ( 0 );
  57. l_gt->calcLabeledImage ( gtruth, classNames.getBackgroundClass() );
  58. }
  59. else
  60. {
  61. std::cerr << "testSemSegConvTrees: WARNING: NO localization info found for "
  62. << file << std::endl;
  63. }
  64. // actual testing
  65. NICE::MultiChannelImageT<double> probabilities;
  66. timer.start();
  67. semseg->semanticseg( file, segresult, probabilities );
  68. timer.stop();
  69. std::cout << "Time for Classification: " << timer.getLastAbsolute()
  70. << "\n\n";
  71. // post processing results
  72. NICE::Image postIm(segresult.width(), segresult.height());
  73. NICE::median(segresult, &postIm, 1);
  74. segresult = postIm;
  75. // updating confusion matrix
  76. SemSegTools::updateConfusionMatrix (
  77. segresult, gtruth, M, forbiddenClasses );
  78. // saving results to image file
  79. NICE::ColorImage rgb;
  80. NICE::ColorImage rgb_gt;
  81. NICE::ColorImage orig ( file );
  82. classNames.labelToRGB( segresult, rgb);
  83. classNames.labelToRGB( gtruth, rgb_gt);
  84. std::string fname = NICE::StringTools::baseName ( file, false );
  85. SemSegTools::saveResultsToImageFile(
  86. &conf, "analysis", orig, rgb_gt, rgb, fname );
  87. }
  88. }
  89. // evaluation & analysis
  90. SemSegTools::computeClassificationStatistics( M, classNames, forbiddenClasses);
  91. // Cleaning up
  92. delete semseg;
  93. }