SemSegTools.cpp 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. /**
  2. * @file SemSegTools.cpp
  3. * @brief tools for semantic segmentation
  4. * @author Erik Rodner, Sven Sickert
  5. * @date 03/19/2009
  6. */
  7. #include <iostream>
  8. #include "SemSegTools.h"
  9. using namespace OBJREC;
  10. using namespace std;
  11. using namespace NICE;
  12. #undef DEBUG_LOCALIZATION
  13. #undef DEBUG
  14. void SemSegTools::segmentToOverlay (
  15. const NICE::Image *orig,
  16. const NICE::ColorImage & segment,
  17. NICE::ColorImage & result )
  18. {
  19. int xsize = orig->width();
  20. int ysize = orig->height();
  21. result.resize( xsize, ysize );
  22. std::vector< NICE::MatrixT<double> > channelMat;
  23. double alpha = .3;
  24. for (int c = 0; c < 3; c++)
  25. {
  26. NICE::MatrixT<double> chan ( xsize, ysize );
  27. channelMat.push_back( chan );
  28. }
  29. for (int y = 0; y < ysize; y++)
  30. for (int x = 0; x < xsize; x++)
  31. {
  32. uchar val = orig->getPixelQuick(x,y);
  33. for (int c = 0; c < 3; c++)
  34. channelMat[c](x,y) = alpha*(double)val
  35. + (1.0-alpha)*(double)segment.getPixel( x, y, c );
  36. }
  37. for (int y = 0; y < ysize; y++)
  38. for (int x = 0; x < xsize; x++)
  39. for (int c = 0; c < 3; c++)
  40. {
  41. int val = channelMat[c](x,y);
  42. result.setPixel( x, y, c, (uchar)val);
  43. }
  44. }
  45. void SemSegTools::updateConfusionMatrix(
  46. const ImageT<int> &img,
  47. const ImageT<int> &gt,
  48. Matrix &M,
  49. const std::set<int> &forbiddenClasses,
  50. map<int,int> & classMapping )
  51. {
  52. double subsamplex = gt.width() / ( double ) img.width();
  53. double subsampley = gt.height() / ( double ) img.height();
  54. for ( int y = 0 ; y < gt.height() ; y++ )
  55. for ( int x = 0 ; x < gt.width() ; x++ )
  56. {
  57. int xx = ( int ) ( x / subsamplex );
  58. int yy = ( int ) ( y / subsampley );
  59. if ( xx < 0 ) xx = 0;
  60. if ( yy < 0 ) yy = 0;
  61. if ( xx > img.width() - 1 ) xx = img.width() - 1;
  62. if ( yy > img.height() - 1 ) yy = img.height() - 1;
  63. int cimg = img.getPixel ( xx, yy );
  64. int gimg = gt.getPixel ( x, y );
  65. if ( forbiddenClasses.find ( gimg ) == forbiddenClasses.end() )
  66. {
  67. M ( classMapping[gimg], classMapping[cimg] ) ++;
  68. }
  69. }
  70. }
  71. void SemSegTools::computeClassificationStatistics(
  72. Matrix &confMat,
  73. const ClassNames &classNames,
  74. const std::set<int> &forbiddenClasses )
  75. {
  76. double overallTrue = 0.0;
  77. double sumAll = 0.0;
  78. // print confusion matrix & get overall recognition rate
  79. std::cout << "Confusion Matrix:" << std::endl;
  80. for ( int r = 0; r < (int) confMat.rows(); r++ )
  81. {
  82. for ( int c = 0; c < (int) confMat.cols(); c++ )
  83. {
  84. if ( r == c )
  85. overallTrue += confMat( r, c );
  86. sumAll += confMat( r, c );
  87. std::cout << confMat( r, c ) << " ";
  88. }
  89. std::cout << std::endl;
  90. }
  91. overallTrue /= sumAll;
  92. // binary classification metrics
  93. double precision, recall, f1score = -1.0;
  94. if ( confMat.rows() == 2 )
  95. {
  96. precision = (double)confMat(1,1) / (double)(confMat(1,1)+confMat(0,1));
  97. recall = (double)confMat(1,1) / (double)(confMat(1,1)+confMat(1,0));
  98. f1score = 2.0*(precision*recall)/(precision+recall);
  99. }
  100. // normalizing confMat using rows
  101. for ( int r = 0 ; r < (int) confMat.rows() ; r++ )
  102. {
  103. double sum = 0.0;
  104. for ( int c = 0 ; c < (int) confMat.cols() ; c++ )
  105. sum += confMat ( r, c );
  106. if ( std::fabs ( sum ) > 1e-4 )
  107. for ( int c = 0 ; c < (int) confMat.cols() ; c++ )
  108. confMat ( r, c ) /= sum;
  109. }
  110. // get average recognition rate
  111. double avgTrue = 0.0;
  112. int classesTrained = 0;
  113. for ( int r = 0 ; r < (int) confMat.rows() ; r++ )
  114. {
  115. if ( classNames.existsClassno ( r )
  116. && ( forbiddenClasses.find ( r ) == forbiddenClasses.end() ) )
  117. {
  118. avgTrue += confMat ( r, r );
  119. double lsum = 0.0;
  120. for ( int r2 = 0; r2 < ( int ) confMat.rows(); r2++ )
  121. lsum += confMat ( r,r2 );
  122. if ( lsum != 0.0 )
  123. classesTrained++;
  124. }
  125. }
  126. // print classification statistics
  127. std::cout << "\nOverall Recogntion Rate: " << overallTrue;
  128. std::cout << "\nAverage Recogntion Rate: " << avgTrue / ( classesTrained );
  129. std::cout << "\nLower Bound: " << 1.0 /(double)classesTrained;
  130. std::cout << "\nPrecision: " << precision;
  131. std::cout << "\nRecall: " << recall;
  132. std::cout << "\nF1Score: " << f1score;
  133. std::cout <<"\n\nClasses:" << std::endl;
  134. for ( int r = 0 ; r < (int) confMat.rows() ; r++ )
  135. {
  136. if ( classNames.existsClassno ( r )
  137. && ( forbiddenClasses.find ( r ) == forbiddenClasses.end() ) )
  138. {
  139. std::string cname = classNames.text ( r );
  140. std::cout << cname.c_str() << ": " << confMat ( r, r ) << std::endl;
  141. }
  142. }
  143. }
  144. void SemSegTools::saveResultsToImageFile(
  145. const Config *conf,
  146. const string &section,
  147. const ColorImage &orig,
  148. const ColorImage &gtruth,
  149. const ColorImage &segment,
  150. const string &file )
  151. {
  152. std::string resultDir = conf->gS ( section, "resultdir", "." );
  153. std::string outputType = conf->gS ( section, "output_type", "ppm" );
  154. std::string outputPostfix = conf->gS ( section, "output_postfix", "" );
  155. NICE::ColorImage overlaySegment, overlayGTruth;
  156. NICE::Image* origGrey = orig.getChannel(1);
  157. segmentToOverlay( origGrey, segment, overlaySegment );
  158. segmentToOverlay( origGrey, gtruth, overlayGTruth );
  159. std::stringstream out;
  160. out << resultDir << "/" << file << outputPostfix;
  161. #ifdef DEBUG
  162. std::cout << "Writing to file " << out.str() << "_*." << outputType << std::endl;
  163. #endif
  164. orig.write ( out.str() + "_orig." + outputType );
  165. segment.write ( out.str() + "_result." + outputType );
  166. gtruth.write ( out.str() + "_groundtruth." + outputType );
  167. overlaySegment.write ( out.str() + "_overlay_res." + outputType );
  168. overlayGTruth.write ( out.str() + "_overlay_gt." + outputType );
  169. }
  170. void SemSegTools::collectTrainingExamples (
  171. const Config * conf,
  172. const std::string & section,
  173. const LabeledSet & train,
  174. const ClassNames & cn,
  175. Examples & examples,
  176. vector<CachedExample *> & imgexamples )
  177. {
  178. assert ( train.count() > 0 );
  179. examples.clear();
  180. imgexamples.clear();
  181. int grid_size_x = conf->gI(section, "grid_size_x", 5 );
  182. int grid_size_y = conf->gI(section, "grid_size_y", 5 );
  183. int grid_border_x = conf->gI(section, "grid_border_x", 20 );
  184. int grid_border_y = conf->gI(section, "grid_border_y", 20 );
  185. std::string selection = conf->gS(section, "train_selection" );
  186. set<int> classnoSelection;
  187. cn.getSelection ( selection, classnoSelection );
  188. bool useExcludedAsBG = conf->gB(section, "use_excluded_as_background", false );
  189. int backgroundClassNo = 0;
  190. if ( useExcludedAsBG )
  191. {
  192. backgroundClassNo = cn.classno("various");
  193. assert ( backgroundClassNo >= 0 );
  194. }
  195. LOOP_ALL_S (train)
  196. {
  197. EACH_INFO(image_classno,imgInfo);
  198. std::string imgfn = imgInfo.img();
  199. if ( ! imgInfo.hasLocalizationInfo() ) {
  200. std::cerr << "WARNING: NO localization info found for "
  201. << imgfn << " !" << std::endl;
  202. continue;
  203. }
  204. int xsize, ysize;
  205. CachedExample *ce = new CachedExample ( imgfn );
  206. ce->getImageSize ( xsize, ysize );
  207. imgexamples.push_back ( ce );
  208. const LocalizationResult *locResult = imgInfo.localization();
  209. if ( locResult->size() <= 0 ) {
  210. std::cerr << "WARNING: NO ground truth polygons found for "
  211. << imgfn << " !" << std::endl;
  212. continue;
  213. }
  214. std::cerr << "SemSegTools: Collecting pixel examples from localization info: "
  215. << imgfn << std::endl;
  216. NICE::Image pixelLabels (xsize, ysize);
  217. pixelLabels.set(0);
  218. locResult->calcLabeledImage ( pixelLabels, cn.getBackgroundClass() );
  219. #ifdef DEBUG_LOCALIZATION
  220. NICE::Image img (imgfn);
  221. showImage(img);
  222. showImage(pixelLabels);
  223. #endif
  224. Example pce ( ce, 0, 0 );
  225. for ( int x = 0 ; x < xsize ; x += grid_size_x )
  226. for ( int y = 0 ; y < ysize ; y += grid_size_y )
  227. {
  228. if ( (x >= grid_border_x) &&
  229. ( y >= grid_border_y ) && ( x < xsize - grid_border_x ) &&
  230. ( y < ysize - grid_border_x ) )
  231. {
  232. pce.x = x; pce.y = y;
  233. int classno = pixelLabels.getPixel(x,y);
  234. if ( classnoSelection.find(classno) != classnoSelection.end() ) {
  235. examples.push_back ( pair<int, Example> (
  236. classno,
  237. pce // FIXME: offset handling
  238. ) );
  239. } else if ( useExcludedAsBG ) {
  240. examples.push_back ( pair<int, Example> (
  241. backgroundClassNo,
  242. pce // FIXME: offset handling
  243. ) );
  244. }
  245. }
  246. }
  247. }
  248. std::cerr << "total number of examples: " << (int)examples.size() << std::endl;
  249. }