testSemSegConvTrees.cpp 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. /**
  2. * @file testSemSegConvTrees.cpp
  3. * @brief test semantic segmentation routines of the ConvTree method
  4. * @author Sven Sickert
  5. * @date 10/20/2014
  6. */
  7. #include "core/basics/StringTools.h"
  8. #include "core/basics/ResourceStatistics.h"
  9. #include "semseg/semseg/SemSegConvolutionalTree.h"
  10. #include "semseg/semseg/SemSegTools.h"
  11. #include <fstream>
  12. #include <vector>
  13. using namespace OBJREC;
  14. int main ( int argc, char **argv )
  15. {
  16. // variables
  17. NICE::Config conf (argc, argv );
  18. NICE::ResourceStatistics rs;
  19. MultiDataset md ( &conf );
  20. const ClassNames & classNames = md.getClassNames ( "train" );
  21. const LabeledSet *testFiles = md["test"];
  22. std::set<int> forbiddenClasses;
  23. classNames.getSelection ( conf.gS ( "analysis", "forbidden_classes", "" ),
  24. forbiddenClasses );
  25. NICE::Matrix M ( classNames.getMaxClassno() + 1,
  26. classNames.getMaxClassno() + 1 );
  27. M.set( 0 );
  28. // initialize semantic segmentation method
  29. SemanticSegmentation *semseg = NULL;
  30. // setup actual segmentation method
  31. semseg = new SemSegConvolutionalTree ( &conf, &classNames );
  32. // training
  33. std::cout << "\nTRAINING" << std::endl;
  34. std::cout << "########\n" << std::endl;
  35. semseg->train( &md );
  36. // testing
  37. std::cout << "\nCLASSIFICATION" << std::endl;
  38. std::cout << "##############\n" << std::endl;
  39. for (LabeledSet::const_iterator it = testFiles->begin(); it != testFiles->end(); it++)
  40. {
  41. for (std::vector<ImageInfo *>::const_iterator jt = it->second.begin();
  42. jt != it->second.end(); jt++)
  43. {
  44. ImageInfo & info = *(*jt);
  45. std::string file = info.img();
  46. NICE::Image segresult, gtruth;
  47. if ( info.hasLocalizationInfo() )
  48. {
  49. const LocalizationResult *l_gt = info.localization();
  50. segresult.resize ( l_gt->xsize, l_gt->ysize );
  51. segresult.set( 0 );
  52. gtruth.resize( l_gt->xsize, l_gt->ysize );
  53. gtruth.set ( 0 );
  54. l_gt->calcLabeledImage ( gtruth, classNames.getBackgroundClass() );
  55. }
  56. else
  57. {
  58. std::cerr << "testSemSegConvTrees: WARNING: NO localization info found for "
  59. << file << std::endl;
  60. }
  61. // actual testing
  62. NICE::MultiChannelImageT<double> probabilities;
  63. semseg->semanticseg( file, segresult, probabilities );
  64. // updating confusion matrix
  65. SemSegTools::updateConfusionMatrix (
  66. segresult, gtruth, M, forbiddenClasses );
  67. // saving results to image file
  68. NICE::ColorImage rgb;
  69. NICE::ColorImage rgb_gt;
  70. NICE::ColorImage orig ( file );
  71. classNames.labelToRGB( segresult, rgb);
  72. classNames.labelToRGB( gtruth, rgb_gt);
  73. std::string fname = NICE::StringTools::baseName ( file, false );
  74. SemSegTools::saveResultsToImageFile(
  75. &conf, "analysis", orig, rgb_gt, rgb, fname );
  76. }
  77. }
  78. // evaluation & analysis
  79. SemSegTools::computeClassificationStatistics( M, classNames, forbiddenClasses);
  80. // Cleaning up
  81. delete semseg;
  82. }