SemSegTools.cpp 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. /**
  2. * @file SemSegTools.cpp
  3. * @brief tools for semantic segmentation
  4. * @author Erik Rodner, Sven Sickert
  5. * @date 03/19/2009
  6. */
  7. #include <iostream>
  8. #include "SemSegTools.h"
  9. using namespace OBJREC;
  10. using namespace std;
  11. using namespace NICE;
  12. #undef DEBUG_LOCALIZATION
  13. #undef DEBUG
  14. void SemSegTools::segmentToOverlay (
  15. const NICE::Image *orig,
  16. const NICE::ColorImage & segment,
  17. NICE::ColorImage & result )
  18. {
  19. int xsize = orig->width();
  20. int ysize = orig->height();
  21. result.resize( xsize, ysize );
  22. std::vector< NICE::MatrixT<double> > channelMat;
  23. double alpha = .3;
  24. for (int c = 0; c < 3; c++)
  25. {
  26. NICE::MatrixT<double> chan ( xsize, ysize );
  27. channelMat.push_back( chan );
  28. }
  29. for (int y = 0; y < ysize; y++)
  30. for (int x = 0; x < xsize; x++)
  31. {
  32. uchar val = orig->getPixelQuick(x,y);
  33. for (int c = 0; c < 3; c++)
  34. channelMat[c](x,y) = alpha*(double)val
  35. + (1.0-alpha)*(double)segment.getPixel( x, y, c );
  36. }
  37. for (int y = 0; y < ysize; y++)
  38. for (int x = 0; x < xsize; x++)
  39. for (int c = 0; c < 3; c++)
  40. {
  41. int val = channelMat[c](x,y);
  42. result.setPixel( x, y, c, (uchar)val);
  43. }
  44. }
  45. void SemSegTools::updateConfusionMatrix(
  46. const Image &img,
  47. const Image &gt,
  48. Matrix &M,
  49. const std::set<int> &forbiddenClasses )
  50. {
  51. double subsamplex = gt.width() / ( double ) img.width();
  52. double subsampley = gt.height() / ( double ) img.height();
  53. for ( int y = 0 ; y < gt.height() ; y++ )
  54. for ( int x = 0 ; x < gt.width() ; x++ )
  55. {
  56. int xx = ( int ) ( x / subsamplex );
  57. int yy = ( int ) ( y / subsampley );
  58. if ( xx < 0 ) xx = 0;
  59. if ( yy < 0 ) yy = 0;
  60. if ( xx > img.width() - 1 ) xx = img.width() - 1;
  61. if ( yy > img.height() - 1 ) yy = img.height() - 1;
  62. int cimg = img.getPixel ( xx, yy );
  63. int gimg = gt.getPixel ( x, y );
  64. if ( forbiddenClasses.find ( gimg ) == forbiddenClasses.end() )
  65. {
  66. M ( gimg, cimg ) ++;
  67. }
  68. }
  69. }
  70. void SemSegTools::computeClassificationStatistics(
  71. Matrix &confMat,
  72. const ClassNames &classNames,
  73. const std::set<int> &forbiddenClasses )
  74. {
  75. double overallTrue = 0.0;
  76. double sumAll = 0.0;
  77. // print confusion matrix & get overall recognition rate
  78. std::cout << "Confusion Matrix:" << std::endl;
  79. for ( int r = 0; r < (int) confMat.rows(); r++ )
  80. {
  81. for ( int c = 0; c < (int) confMat.cols(); c++ )
  82. {
  83. if ( r == c )
  84. overallTrue += confMat( r, c );
  85. sumAll += confMat( r, c );
  86. std::cout << confMat( r, c ) << " ";
  87. }
  88. std::cout << std::endl;
  89. }
  90. overallTrue /= sumAll;
  91. // binary classification metrics
  92. double precision, recall, f1score = -1.0;
  93. if ( confMat.rows() == 2 )
  94. {
  95. precision = (double)confMat(1,1) / (double)(confMat(1,1)+confMat(0,1));
  96. recall = (double)confMat(1,1) / (double)(confMat(1,1)+confMat(1,0));
  97. f1score = 2.0*(precision*recall)/(precision+recall);
  98. }
  99. // normalizing confMat using rows
  100. for ( int r = 0 ; r < (int) confMat.rows() ; r++ )
  101. {
  102. double sum = 0.0;
  103. for ( int c = 0 ; c < (int) confMat.cols() ; c++ )
  104. sum += confMat ( r, c );
  105. if ( std::fabs ( sum ) > 1e-4 )
  106. for ( int c = 0 ; c < (int) confMat.cols() ; c++ )
  107. confMat ( r, c ) /= sum;
  108. }
  109. // get average recognition rate
  110. double avgTrue = 0.0;
  111. int classesTrained = 0;
  112. for ( int r = 0 ; r < (int) confMat.rows() ; r++ )
  113. {
  114. if ( classNames.existsClassno ( r )
  115. && ( forbiddenClasses.find ( r ) == forbiddenClasses.end() ) )
  116. {
  117. avgTrue += confMat ( r, r );
  118. double lsum = 0.0;
  119. for ( int r2 = 0; r2 < ( int ) confMat.rows(); r2++ )
  120. lsum += confMat ( r,r2 );
  121. if ( lsum != 0.0 )
  122. classesTrained++;
  123. }
  124. }
  125. // print classification statistics
  126. std::cout << "\nOverall Recogntion Rate: " << overallTrue;
  127. std::cout << "\nAverage Recogntion Rate: " << avgTrue / ( classesTrained );
  128. std::cout << "\nLower Bound: " << 1.0 /(double)classesTrained;
  129. std::cout << "\nPrecision: " << precision;
  130. std::cout << "\nRecall: " << recall;
  131. std::cout << "\nF1Score: " << f1score;
  132. std::cout <<"\n\nClasses:" << std::endl;
  133. for ( int r = 0 ; r < (int) confMat.rows() ; r++ )
  134. {
  135. if ( classNames.existsClassno ( r )
  136. && ( forbiddenClasses.find ( r ) == forbiddenClasses.end() ) )
  137. {
  138. std::string cname = classNames.text ( r );
  139. std::cout << cname.c_str() << ": " << confMat ( r, r ) << std::endl;
  140. }
  141. }
  142. }
  143. void SemSegTools::saveResultsToImageFile(
  144. const Config *conf,
  145. const string &section,
  146. const ColorImage &orig,
  147. const ColorImage &gtruth,
  148. const ColorImage &segment,
  149. const string &file )
  150. {
  151. std::string resultDir = conf->gS ( section, "resultdir", "." );
  152. std::string outputType = conf->gS ( section, "output_type", "ppm" );
  153. std::string outputPostfix = conf->gS ( section, "output_postfix", "" );
  154. NICE::ColorImage overlaySegment, overlayGTruth;
  155. NICE::Image* origGrey = orig.getChannel(1);
  156. segmentToOverlay( origGrey, segment, overlaySegment );
  157. segmentToOverlay( origGrey, gtruth, overlayGTruth );
  158. std::stringstream out;
  159. out << resultDir << "/" << file << outputPostfix;
  160. #ifdef DEBUG
  161. std::cout << "Writing to file " << out.str() << "_*." << outputType << std::endl;
  162. #endif
  163. orig.write ( out.str() + "_orig." + outputType );
  164. segment.write ( out.str() + "_result." + outputType );
  165. gtruth.write ( out.str() + "_groundtruth." + outputType );
  166. overlaySegment.write ( out.str() + "_overlay_res." + outputType );
  167. overlayGTruth.write ( out.str() + "_overlay_gt." + outputType );
  168. }
  169. void SemSegTools::collectTrainingExamples (
  170. const Config * conf,
  171. const std::string & section,
  172. const LabeledSet & train,
  173. const ClassNames & cn,
  174. Examples & examples,
  175. vector<CachedExample *> & imgexamples )
  176. {
  177. assert ( train.count() > 0 );
  178. examples.clear();
  179. imgexamples.clear();
  180. int grid_size_x = conf->gI(section, "grid_size_x", 5 );
  181. int grid_size_y = conf->gI(section, "grid_size_y", 5 );
  182. int grid_border_x = conf->gI(section, "grid_border_x", 20 );
  183. int grid_border_y = conf->gI(section, "grid_border_y", 20 );
  184. std::string selection = conf->gS(section, "train_selection" );
  185. set<int> classnoSelection;
  186. cn.getSelection ( selection, classnoSelection );
  187. bool useExcludedAsBG = conf->gB(section, "use_excluded_as_background", false );
  188. int backgroundClassNo = 0;
  189. if ( useExcludedAsBG )
  190. {
  191. backgroundClassNo = cn.classno("various");
  192. assert ( backgroundClassNo >= 0 );
  193. }
  194. LOOP_ALL_S (train)
  195. {
  196. EACH_INFO(image_classno,imgInfo);
  197. std::string imgfn = imgInfo.img();
  198. if ( ! imgInfo.hasLocalizationInfo() ) {
  199. std::cerr << "WARNING: NO localization info found for "
  200. << imgfn << " !" << std::endl;
  201. continue;
  202. }
  203. int xsize, ysize;
  204. CachedExample *ce = new CachedExample ( imgfn );
  205. ce->getImageSize ( xsize, ysize );
  206. imgexamples.push_back ( ce );
  207. const LocalizationResult *locResult = imgInfo.localization();
  208. if ( locResult->size() <= 0 ) {
  209. std::cerr << "WARNING: NO ground truth polygons found for "
  210. << imgfn << " !" << std::endl;
  211. continue;
  212. }
  213. std::cerr << "SemSegTools: Collecting pixel examples from localization info: "
  214. << imgfn << std::endl;
  215. NICE::Image pixelLabels (xsize, ysize);
  216. pixelLabels.set(0);
  217. locResult->calcLabeledImage ( pixelLabels, cn.getBackgroundClass() );
  218. #ifdef DEBUG_LOCALIZATION
  219. NICE::Image img (imgfn);
  220. showImage(img);
  221. showImage(pixelLabels);
  222. #endif
  223. Example pce ( ce, 0, 0 );
  224. for ( int x = 0 ; x < xsize ; x += grid_size_x )
  225. for ( int y = 0 ; y < ysize ; y += grid_size_y )
  226. {
  227. if ( (x >= grid_border_x) &&
  228. ( y >= grid_border_y ) && ( x < xsize - grid_border_x ) &&
  229. ( y < ysize - grid_border_x ) )
  230. {
  231. pce.x = x; pce.y = y;
  232. int classno = pixelLabels.getPixel(x,y);
  233. if ( classnoSelection.find(classno) != classnoSelection.end() ) {
  234. examples.push_back ( pair<int, Example> (
  235. classno,
  236. pce // FIXME: offset handling
  237. ) );
  238. } else if ( useExcludedAsBG ) {
  239. examples.push_back ( pair<int, Example> (
  240. backgroundClassNo,
  241. pce // FIXME: offset handling
  242. ) );
  243. }
  244. }
  245. }
  246. }
  247. std::cerr << "total number of examples: " << (int)examples.size() << std::endl;
  248. }