testSemSegConvTrees.cpp 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. /**
  2. * @file testSemSegConvTrees.cpp
  3. * @brief test semantic segmentation routines of the ConvTree method
  4. * @author Sven Sickert
  5. * @date 10/20/2014
  6. */
  7. #include "core/basics/StringTools.h"
  8. #include "core/basics/ResourceStatistics.h"
  9. #include "core/basics/Timer.h"
  10. #include "semseg/semseg/SemSegConvolutionalTree.h"
  11. #include "semseg/semseg/SemSegTools.h"
  12. #include <fstream>
  13. #include <vector>
  14. using namespace OBJREC;
  15. int main ( int argc, char **argv )
  16. {
  17. // variables
  18. NICE::Config conf (argc, argv );
  19. NICE::ResourceStatistics rs;
  20. MultiDataset md ( &conf );
  21. const ClassNames & classNames = md.getClassNames ( "train" );
  22. const LabeledSet *testFiles = md["test"];
  23. std::set<int> forbiddenClasses;
  24. classNames.getSelection ( conf.gS ( "analysis", "forbidden_classes", "" ),
  25. forbiddenClasses );
  26. NICE::Matrix M ( classNames.getMaxClassno() + 1,
  27. classNames.getMaxClassno() + 1 );
  28. M.set( 0 );
  29. // initialize semantic segmentation method
  30. SemanticSegmentation *semseg = NULL;
  31. // setup actual segmentation method
  32. semseg = new SemSegConvolutionalTree ( &conf, &classNames );
  33. // training
  34. std::cout << "\nTRAINING" << std::endl;
  35. std::cout << "########\n" << std::endl;
  36. semseg->train( &md );
  37. // testing
  38. NICE::Timer timer;
  39. std::cout << "\nCLASSIFICATION" << std::endl;
  40. std::cout << "##############\n" << std::endl;
  41. for (LabeledSet::const_iterator it = testFiles->begin(); it != testFiles->end(); it++)
  42. {
  43. for (std::vector<ImageInfo *>::const_iterator jt = it->second.begin();
  44. jt != it->second.end(); jt++)
  45. {
  46. ImageInfo & info = *(*jt);
  47. std::string file = info.img();
  48. NICE::Image segresult, gtruth;
  49. if ( info.hasLocalizationInfo() )
  50. {
  51. const LocalizationResult *l_gt = info.localization();
  52. segresult.resize ( l_gt->xsize, l_gt->ysize );
  53. segresult.set( 0 );
  54. gtruth.resize( l_gt->xsize, l_gt->ysize );
  55. gtruth.set ( 0 );
  56. l_gt->calcLabeledImage ( gtruth, classNames.getBackgroundClass() );
  57. }
  58. else
  59. {
  60. std::cerr << "testSemSegConvTrees: WARNING: NO localization info found for "
  61. << file << std::endl;
  62. }
  63. // actual testing
  64. NICE::MultiChannelImageT<double> probabilities;
  65. timer.start();
  66. semseg->semanticseg( file, segresult, probabilities );
  67. timer.stop();
  68. std::cout << "Time for Classification: " << timer.getLastAbsolute()
  69. << "\n\n";
  70. // updating confusion matrix
  71. SemSegTools::updateConfusionMatrix (
  72. segresult, gtruth, M, forbiddenClasses );
  73. // saving results to image file
  74. NICE::ColorImage rgb;
  75. NICE::ColorImage rgb_gt;
  76. NICE::ColorImage orig ( file );
  77. classNames.labelToRGB( segresult, rgb);
  78. classNames.labelToRGB( gtruth, rgb_gt);
  79. std::string fname = NICE::StringTools::baseName ( file, false );
  80. SemSegTools::saveResultsToImageFile(
  81. &conf, "analysis", orig, rgb_gt, rgb, fname );
  82. }
  83. }
  84. // evaluation & analysis
  85. SemSegTools::computeClassificationStatistics( M, classNames, forbiddenClasses);
  86. // Cleaning up
  87. delete semseg;
  88. }