testSemSegObliqueTrees.cpp 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. /**
  2. * @file testSemSegConvTrees.cpp
  3. * @brief test semantic segmentation routines of the ConvTree method
  4. * @author Sven Sickert
  5. * @date 10/20/2014
  6. */
  7. #include "core/basics/StringTools.h"
  8. #include "core/basics/Timer.h"
  9. #include "core/image/Morph.h"
  10. #include "semseg/semseg/SemSegObliqueTree.h"
  11. #include "semseg/semseg/SemSegTools.h"
  12. #include <fstream>
  13. #include <vector>
  14. using namespace OBJREC;
  15. int main ( int argc, char **argv )
  16. {
  17. // variables
  18. NICE::Config conf (argc, argv );
  19. NICE::ResourceStatistics rs;
  20. MultiDataset md ( &conf );
  21. const ClassNames & classNames = md.getClassNames ( "train" );
  22. const LabeledSet *testFiles = md["test"];
  23. std::set<int> forbiddenClasses;
  24. classNames.getSelection ( conf.gS ( "analysis", "forbidden_classes", "" ),
  25. forbiddenClasses );
  26. std::vector<bool> usedClasses ( classNames.numClasses(), true );
  27. for ( std::set<int>::const_iterator it = forbiddenClasses.begin();
  28. it != forbiddenClasses.end(); ++it)
  29. {
  30. usedClasses [ *it ] = false;
  31. }
  32. std::map<int,int> classMapping, classMappingInv;
  33. int j = 0;
  34. for ( int i = 0; i < usedClasses.size(); i++ )
  35. if (usedClasses[i])
  36. {
  37. classMapping[i] = j;
  38. classMappingInv[j] = i;
  39. j++;
  40. }
  41. NICE::Matrix M ( classMapping.size(), classMapping.size() );
  42. M.set( 0 );
  43. // initialize semantic segmentation method
  44. SemanticSegmentation *semseg = NULL;
  45. // setup actual segmentation method
  46. semseg = new SemSegObliqueTree ( &conf, &classNames );
  47. // training
  48. std::cout << "\nTRAINING" << std::endl;
  49. std::cout << "########\n" << std::endl;
  50. semseg->train( &md );
  51. // testing
  52. NICE::Timer timer;
  53. std::cout << "\nCLASSIFICATION" << std::endl;
  54. std::cout << "##############\n" << std::endl;
  55. std::vector<int> zsizeVec;
  56. bool run3Dseg = semseg->isMode3D();
  57. SemSegTools::getDepthVector ( testFiles, zsizeVec, run3Dseg );
  58. int depthCount = 0, idx = 0;
  59. std::vector<std::string> filelist;
  60. NICE::MultiChannelImageT<int> segresult, gt;
  61. for (LabeledSet::const_iterator it = testFiles->begin(); it != testFiles->end(); it++)
  62. {
  63. for (std::vector<ImageInfo *>::const_iterator jt = it->second.begin();
  64. jt != it->second.end(); jt++)
  65. {
  66. ImageInfo & info = *(*jt);
  67. std::string file = info.img();
  68. filelist.push_back(file);
  69. depthCount++;
  70. NICE::ImageT<int> gtruth, res;
  71. if ( info.hasLocalizationInfo() )
  72. {
  73. const LocalizationResult *l_gt = info.localization();
  74. gtruth.resize( l_gt->xsize, l_gt->ysize );
  75. l_gt->calcLabeledImage ( gtruth, classNames.getBackgroundClass() );
  76. }
  77. else
  78. {
  79. std::cerr << "testSemSegConvTrees: WARNING: NO localization info found for "
  80. << file << std::endl;
  81. }
  82. segresult.addChannel(gtruth);
  83. gt.addChannel(gtruth);
  84. int depthBoundary = 1;
  85. if ( run3Dseg )
  86. depthBoundary = zsizeVec[idx];
  87. std::cout << "Slice " << depthCount << "/"
  88. << depthBoundary << std::endl;
  89. if ( depthCount < depthBoundary )
  90. continue;
  91. // actual testing
  92. NICE::MultiChannelImage3DT<double> probabilities;
  93. timer.start();
  94. semseg->semanticseg( filelist, segresult, probabilities );
  95. timer.stop();
  96. std::cout << "Time for Classification: " << timer.getLastAbsolute()
  97. << "\n\n";
  98. // updating confusion matrix
  99. res = gtruth;
  100. for ( int z = 0; z < segresult.channels(); z++ )
  101. {
  102. for ( int y = 0; y < res.height(); y++ )
  103. for ( int x = 0; x < res.width(); x++)
  104. {
  105. res.setPixel ( x, y, segresult.get(x,y,(unsigned int)z) );
  106. if ( run3Dseg )
  107. gtruth.setPixel ( x, y, gt.get(x,y,(unsigned int)z) );
  108. }
  109. SemSegTools::updateConfusionMatrix ( res, gtruth, M,
  110. forbiddenClasses, classMapping );
  111. // saving results to image file
  112. NICE::ColorImage rgb;
  113. NICE::ColorImage rgb_gt;
  114. NICE::ColorImage orig ( filelist[z] );
  115. classNames.labelToRGB( res, rgb);
  116. classNames.labelToRGB( gtruth, rgb_gt);
  117. std::string fname = NICE::StringTools::baseName ( filelist[z], false );
  118. std::string outStr;
  119. SemSegTools::saveResultsToImageFile( &conf, "analysis", orig,
  120. rgb_gt, rgb, fname, outStr );
  121. }
  122. // prepare for new 3d image
  123. filelist.clear();
  124. segresult.reInit(0,0,0);
  125. gt.reInit(0,0,0);
  126. depthCount = 0;
  127. idx++;
  128. }
  129. }
  130. // resource statistics
  131. SemSegTools::computeResourceStatistics ( rs );
  132. // evaluation & analysis
  133. SemSegTools::computeClassificationStatistics(
  134. M, classNames, forbiddenClasses, classMappingInv );
  135. // Cleaning up
  136. delete semseg;
  137. }