SemSegNovelty.cpp 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. #include <sstream>
  2. #include <iostream>
  3. #include "SemSegNovelty.h"
  4. #include "core/image/FilterT.h"
  5. #include "gp-hik-exp/GPHIKClassifierNICE.h"
  6. #include "vislearning/baselib/ICETools.h"
  7. #include "vislearning/baselib/Globals.h"
  8. #include "vislearning/features/fpfeatures/SparseVectorFeature.h"
  9. #include "core/basics/StringTools.h"
  10. #include "core/basics/Timer.h"
  11. using namespace std;
  12. using namespace NICE;
  13. using namespace OBJREC;
  14. SemSegNovelty::SemSegNovelty ( const Config *conf,
  15. const MultiDataset *md )
  16. : SemanticSegmentation ( conf, & ( md->getClassNames ( "train" ) ) )
  17. {
  18. this->conf = conf;
  19. string section = "SemSegNovelty";
  20. featExtract = new LFColorWeijer ( conf );
  21. save_cache = conf->gB ( "FPCPixel", "save_cache", true );
  22. read_cache = conf->gB ( "FPCPixel", "read_cache", false );
  23. uncertdir = conf->gS("debug", "uncertainty", "uncertainty");
  24. cache = conf->gS ( "cache", "root", "" );
  25. classifier = new GPHIKClassifierNICE ( conf, "ClassiferGPHIK" );;
  26. whs = conf->gI ( section, "window_size", 10 );
  27. featdist = conf->gI ( section, "grid", 10 );
  28. testWSize = conf->gI (section, "test_window_size", 10);
  29. cn = md->getClassNames ( "train" );
  30. if ( read_cache )
  31. {
  32. string classifierdst = "/classifier.data";
  33. fprintf ( stderr, "SemSegNovelty:: Reading classifier data from %s\n", ( cache + classifierdst ).c_str() );
  34. try
  35. {
  36. if ( classifier != NULL )
  37. {
  38. classifier->read ( cache + classifierdst );
  39. }
  40. fprintf ( stderr, "SemSegNovelty:: successfully read\n" );
  41. }
  42. catch ( char *str )
  43. {
  44. cerr << "error reading data: " << str << endl;
  45. }
  46. }
  47. else
  48. {
  49. train ( md );
  50. }
  51. }
  52. SemSegNovelty::~SemSegNovelty()
  53. {
  54. // clean-up
  55. if ( classifier != NULL )
  56. delete classifier;
  57. if ( featExtract != NULL )
  58. delete featExtract;
  59. }
  60. void SemSegNovelty::train ( const MultiDataset *md )
  61. {
  62. const LabeledSet train = * ( *md ) ["train"];
  63. const LabeledSet *trainp = &train;
  64. ////////////////////////
  65. // feature extraction //
  66. ////////////////////////
  67. std::string forbidden_classes_s = conf->gS ( "analysis", "donttrain", "" );
  68. if ( forbidden_classes_s == "" )
  69. {
  70. forbidden_classes_s = conf->gS ( "analysis", "forbidden_classes", "" );
  71. }
  72. cn.getSelection ( forbidden_classes_s, forbidden_classes );
  73. cerr << "forbidden: " << forbidden_classes_s << endl;
  74. ProgressBar pb ( "Local Feature Extraction" );
  75. pb.show();
  76. int imgnb = 0;
  77. Examples examples;
  78. examples.filename = "training";
  79. int featdim = -1;
  80. LOOP_ALL_S ( *trainp )
  81. {
  82. //EACH_S(classno, currentFile);
  83. EACH_INFO ( classno, info );
  84. std::string currentFile = info.img();
  85. CachedExample *ce = new CachedExample ( currentFile );
  86. const LocalizationResult *locResult = info.localization();
  87. if ( locResult->size() <= 0 )
  88. {
  89. fprintf ( stderr, "WARNING: NO ground truth polygons found for %s !\n",
  90. currentFile.c_str() );
  91. continue;
  92. }
  93. int xsize, ysize;
  94. ce->getImageSize ( xsize, ysize );
  95. Image labels ( xsize, ysize );
  96. labels.set ( 0 );
  97. locResult->calcLabeledImage ( labels, ( *classNames ).getBackgroundClass() );
  98. NICE::ColorImage img;
  99. try {
  100. img = ColorImage ( currentFile );
  101. } catch ( Exception ) {
  102. cerr << "SemSegNovelty: error opening image file <" << currentFile << ">" << endl;
  103. continue;
  104. }
  105. Globals::setCurrentImgFN ( currentFile );
  106. MultiChannelImageT<double> feats;
  107. // extract features
  108. featExtract->getFeats ( img, feats );
  109. featdim = feats.channels();
  110. feats.addChannel(featdim);
  111. for (int c = 0; c < featdim; c++)
  112. {
  113. ImageT<double> tmp = feats[c];
  114. ImageT<double> tmp2 = feats[c+featdim];
  115. NICE::FilterT<double, double, double>::gradientStrength (tmp, tmp2);
  116. }
  117. featdim += featdim;
  118. // compute integral images
  119. for ( int c = 0; c < featdim; c++ )
  120. {
  121. feats.calcIntegral ( c );
  122. }
  123. for ( int y = 0; y < ysize; y += featdist )
  124. {
  125. for ( int x = 0; x < xsize; x += featdist )
  126. {
  127. int classno = labels ( x, y );
  128. if ( forbidden_classes.find ( classno ) != forbidden_classes.end() )
  129. continue;
  130. Example example;
  131. example.vec = NULL;
  132. example.svec = new SparseVector ( featdim );
  133. for ( int f = 0; f < featdim; f++ )
  134. {
  135. double val = feats.getIntegralValue ( x - whs, y - whs, x + whs, y + whs, f );
  136. if ( val > 1e-10 )
  137. ( *example.svec ) [f] = val;
  138. }
  139. example.svec->normalize();
  140. example.position = imgnb;
  141. examples.push_back ( pair<int, Example> ( classno, example ) );
  142. }
  143. }
  144. delete ce;
  145. imgnb++;
  146. pb.update ( trainp->count() );
  147. }
  148. pb.hide();
  149. //////////////////////
  150. // train classifier //
  151. //////////////////////
  152. FeaturePool fp;
  153. Feature *f = new SparseVectorFeature ( featdim );
  154. f->explode ( fp );
  155. delete f;
  156. if ( classifier != NULL )
  157. classifier->train ( fp, examples );
  158. else
  159. {
  160. cerr << "no classifier selected?!" << endl;
  161. exit ( -1 );
  162. }
  163. fp.destroy();
  164. if ( save_cache )
  165. {
  166. if ( classifier != NULL )
  167. classifier->save ( cache + "/classifier.data" );
  168. }
  169. ////////////
  170. //clean up//
  171. ////////////
  172. for ( int i = 0; i < ( int ) examples.size(); i++ )
  173. {
  174. examples[i].second.clean();
  175. }
  176. examples.clear();
  177. cerr << "SemSeg training finished" << endl;
  178. }
  179. void SemSegNovelty::semanticseg ( CachedExample *ce, NICE::Image & segresult, NICE::MultiChannelImageT<double> & probabilities )
  180. {
  181. Timer timer;
  182. timer.start();
  183. Examples examples;
  184. examples.filename = "testing";
  185. segresult.set ( 0 );
  186. int featdim = -1;
  187. std::string currentFile = Globals::getCurrentImgFN();
  188. int xsize, ysize;
  189. ce->getImageSize ( xsize, ysize );
  190. probabilities.reInit( xsize, ysize, cn.getMaxClassno() + 1);
  191. probabilities.set ( 0.0 );
  192. NICE::ColorImage img;
  193. try {
  194. img = ColorImage ( currentFile );
  195. } catch ( Exception ) {
  196. cerr << "SemSegNovelty: error opening image file <" << currentFile << ">" << endl;
  197. return;
  198. }
  199. MultiChannelImageT<double> feats;
  200. // extract features
  201. featExtract->getFeats ( img, feats );
  202. featdim = feats.channels();
  203. feats.addChannel(featdim);
  204. for (int c = 0; c < featdim; c++)
  205. {
  206. ImageT<double> tmp = feats[c];
  207. ImageT<double> tmp2 = feats[c+featdim];
  208. NICE::FilterT<double, double, double>::gradientStrength (tmp, tmp2);
  209. }
  210. featdim += featdim;
  211. // compute integral images
  212. for ( int c = 0; c < featdim; c++ )
  213. {
  214. feats.calcIntegral ( c );
  215. }
  216. FloatImage uncert ( xsize, ysize );
  217. uncert.set ( 0.0 );
  218. double maxunc = -numeric_limits<double>::max();
  219. timer.stop();
  220. cout << "first: " << timer.getLastAbsolute() << endl;
  221. timer.start();
  222. #pragma omp parallel for
  223. for ( int y = 0; y < ysize; y += testWSize )
  224. {
  225. Example example;
  226. example.vec = NULL;
  227. example.svec = new SparseVector ( featdim );
  228. for ( int x = 0; x < xsize; x += testWSize)
  229. {
  230. for ( int f = 0; f < featdim; f++ )
  231. {
  232. double val = feats.getIntegralValue ( x - whs, y - whs, x + whs, y + whs, f );
  233. if ( val > 1e-10 )
  234. ( *example.svec ) [f] = val;
  235. }
  236. example.svec->normalize();
  237. ClassificationResult cr = classifier->classify ( example );
  238. int xs = std::max(0, x - testWSize/2);
  239. int xe = std::min(xsize - 1, x + testWSize/2);
  240. int ys = std::max(0, y - testWSize/2);
  241. int ye = std::min(ysize - 1, y + testWSize/2);
  242. for (int yl = ys; yl <= ye; yl++)
  243. {
  244. for (int xl = xs; xl <= xe; xl++)
  245. {
  246. for ( int j = 0 ; j < cr.scores.size(); j++ )
  247. {
  248. probabilities ( xl, yl, j ) = cr.scores[j];
  249. }
  250. segresult ( xl, yl ) = cr.classno;
  251. uncert ( xl, yl ) = cr.uncertainty;
  252. }
  253. }
  254. if (maxunc < cr.uncertainty)
  255. maxunc = cr.uncertainty;
  256. example.svec->clear();
  257. }
  258. delete example.svec;
  259. example.svec = NULL;
  260. }
  261. cout << "maxunertainty: " << maxunc << endl;
  262. timer.stop();
  263. cout << "second: " << timer.getLastAbsolute() << endl;
  264. timer.start();
  265. ColorImage imgrgb ( xsize, ysize );
  266. std::stringstream out;
  267. std::vector< std::string > list2;
  268. StringTools::split ( Globals::getCurrentImgFN (), '/', list2 );
  269. out << uncertdir << "/" << list2.back();
  270. uncert.writeRaw(out.str() + ".rawfloat");
  271. uncert(0, 0) = 0.0;
  272. uncert(0, 1) = 1.0;
  273. ICETools::convertToRGB ( uncert, imgrgb );
  274. imgrgb.write ( out.str() + "rough.png" );
  275. timer.stop();
  276. cout << "last: " << timer.getLastAbsolute() << endl;
  277. }