SemSegRegionBased.cpp 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440
  1. #ifdef NICE_USELIB_OPENMP
  2. #include <omp.h>
  3. #endif
  4. #include "SemSegRegionBased.h"
  5. #include <iostream>
  6. #include "vislearning/cbaselib/CachedExample.h"
  7. #include "vislearning/classifier/fpclassifier/randomforest/FPCRandomForests.h"
  8. #include "vislearning/classifier/fpclassifier/logisticregression/FPCSMLR.h"
  9. #include <objrec/iclassifier/icgeneric/CSGeneric.h>
  10. #include "vislearning/features/fpfeatures/PixelPairFeature.h"
  11. #include "vislearning/classifier/genericClassifierSelection.h"
  12. #include "SemSegTools.h"
  13. #include "objrec/segmentation/RSMeanShift.h"
  14. #include "objrec/segmentation/RSCache.h"
  15. #include "objrec/segmentation/RSGraphBased.h"
  16. #include "vislearning/baselib/Globals.h"
  17. #include <vislearning/cbaselib/VectorFeature.h>
  18. #include "vislearning/features/fpfeatures/SparseVectorFeature.h"
  19. #include "vislearning/features/localfeatures/LFColorWeijer.h"
  20. #include "vislearning/features/localfeatures/LFColorSande.h"
  21. #include "vislearning/features/localfeatures/LocalFeatureSift.h"
  22. #include "vislearning/features/localfeatures/LocalFeatureOpponnentSift.h"
  23. #include "vislearning/features/localfeatures/LocalFeatureLFInterface.h"
  24. #include "vislearning/features/localfeatures/LocalFeatureRGBSift.h"
  25. #include "vislearning/features/localfeatures/LFCache.h"
  26. #include "objrec/features/regionfeatures/RFColor.h"
  27. #include "objrec/features/regionfeatures/RFHoG.h"
  28. #include "objrec/features/regionfeatures/RFBoV.h"
  29. #include "objrec/features/regionfeatures/RFBoVCodebook.h"
  30. #include "objrec/features/regionfeatures/RFCsurka.h"
  31. #include "objrec/iclassifier/codebook/CodebookRandomForest.h"
  32. #include "vislearning/math/cluster/GMM.h"
  33. #undef DEMO
  34. #undef WRITEFEATS
  35. using namespace OBJREC;
  36. using namespace std;
  37. using namespace NICE;
  38. #define DEBUG_PRINTS
  39. SemSegRegionBased::SemSegRegionBased ( const Config *c, const MultiDataset *md )
  40. : SemanticSegmentation ( c, & ( md->getClassNames ( "train" ) ) )
  41. {
  42. #ifdef DEBUG_PRINTS
  43. cout << "SemSegRegionBased Constructor starts" << endl;
  44. #endif
  45. conf = c;
  46. save_cache = conf->gB ( "SemSegRegion", "save_cache", true );
  47. read_cache = conf->gB ( "SemSegRegion", "read_cache", false );
  48. classifiercache = conf->gS ( "SemSegRegion", "cache", "classifier.data" );
  49. cache = conf->gS ( "cache", "root", "tmp/" );
  50. bool colorw = conf->gB ( "SemSegRegion", "colorw", false );
  51. bool bov = conf->gB ( "SemSegRegion", "bov", false );
  52. bool hog = conf->gB ( "SemSegRegion", "hog", false );
  53. bool structf = conf->gB ( "SemSegRegion", "struct", false );
  54. string classifiertype = conf->gS ( "SemSegRegion", "classifier", "RF" );
  55. bool usegcopt = conf->gB ( "SemSegRegion", "gcopt", false );
  56. bool bovmoosmann = conf->gB ( "SemSegRegion", "bovmoosmann", false );
  57. bool csurka = conf->gB ( "SemSegRegion", "csurka", false );
  58. if ( colorw )
  59. {
  60. LocalFeature *lfcw = new LFColorWeijer ( conf );
  61. rfc = new RFColor ( conf, lfcw );
  62. }
  63. else
  64. {
  65. rfc = NULL;
  66. }
  67. if ( hog )
  68. {
  69. rfhog = new RFHoG ( conf );
  70. }
  71. else
  72. {
  73. rfhog = NULL;
  74. }
  75. if ( structf )
  76. {
  77. rfstruct = new RFStruct ( conf );
  78. }
  79. else
  80. {
  81. rfstruct = NULL;
  82. }
  83. LocalFeature *lfcache = NULL;
  84. if ( bov || bovmoosmann || csurka )
  85. {
  86. string ftype = conf->gS ( "BOV", "feature", "sift" );
  87. siftFeats = NULL;
  88. if ( ftype == "sift" )
  89. {
  90. siftFeats = new LocalFeatureSift ( conf );
  91. lfcache = new LFCache ( conf, siftFeats );
  92. }
  93. if ( ftype == "osift" )
  94. {
  95. siftFeats = new LocalFeatureOpponnentSift ( conf );
  96. lfcache = new LFCache ( conf, siftFeats );
  97. }
  98. if ( ftype == "rsift" )
  99. {
  100. siftFeats = new LocalFeatureRGBSift ( conf );
  101. lfcache = new LFCache ( conf, siftFeats );
  102. }
  103. if ( ftype == "sande" )
  104. {
  105. LocalFeatureRepresentation *sande = new LFColorSande ( conf, "LFColorSandeTrain" );
  106. siftFeats = new LocalFeatureLFInterface ( conf, sande );
  107. LocalFeatureRepresentation *sande2 = new LFColorSande ( conf, "LFColorSandeTest" );
  108. LocalFeature *siftFeats2 = new LocalFeatureLFInterface ( conf, sande2 );
  109. lfcache = new LFCache ( conf, siftFeats2 );
  110. }
  111. if ( siftFeats == NULL )
  112. {
  113. throw "please choose one of the following features für BOV: osift, rsift, sift, sande";
  114. }
  115. }
  116. if ( csurka )
  117. {
  118. rfCsurka = new RFCsurka ( conf, lfcache );
  119. }
  120. else
  121. {
  122. rfCsurka = NULL;
  123. }
  124. if ( bov )
  125. {
  126. rfbov = new RFBoV ( conf, lfcache );
  127. }
  128. else
  129. {
  130. rfbov = NULL;
  131. }
  132. if ( bovmoosmann )
  133. {
  134. rfbovcrdf = new RFBoVCodebook ( conf, lfcache );
  135. }
  136. else
  137. {
  138. rfbovcrdf = NULL;
  139. }
  140. // setting classifier
  141. fpc = NULL;
  142. vclassifier = NULL;
  143. if ( classifiertype == "RF" )
  144. {
  145. fpc = new FPCRandomForests ( conf, "ClassifierForest" );
  146. }
  147. else if ( classifiertype == "SMLR" )
  148. {
  149. fpc = new FPCSMLR ( conf, "ClassifierSMLR" );
  150. }
  151. else if ( classifiertype == "VECC" )
  152. {
  153. vclassifier = CSGeneric::selectVecClassifier ( conf, "vecClassifier" );
  154. }
  155. else
  156. {
  157. throw "classifiertype not (yet) supported";
  158. }
  159. if ( fpc != NULL )
  160. fpc->setMaxClassNo ( classNames->getMaxClassno() );
  161. else if ( vclassifier != NULL )
  162. vclassifier->setMaxClassNo ( classNames->getMaxClassno() );
  163. cn = md->getClassNames ( "train" );
  164. // setting segmentation method
  165. RegionSegmentationMethod *tmprsm = new RSMeanShift ( conf );
  166. rsm = new RSCache ( conf, tmprsm );
  167. //rsm = new RSGraphBased(conf);
  168. // use global optimization (MRF)
  169. if ( usegcopt )
  170. gcopt = new PPGraphCut ( conf );
  171. else
  172. gcopt = NULL;
  173. classifiercache = cache + classifiercache;
  174. // read training data or start training
  175. if ( read_cache )
  176. {
  177. fprintf ( stderr, "SemSegRegion:: Reading classifier data from %s\n", cache.c_str() );
  178. if ( fpc != NULL )
  179. fpc->read ( classifiercache );
  180. else if ( vclassifier != NULL )
  181. vclassifier->read ( classifiercache );
  182. if ( rfCsurka != NULL )
  183. {
  184. bool usegmm = conf->gB ( "Csurka", "usegmm", false );
  185. bool usepca = conf->gB ( "Csurka", "usepca", false );
  186. if ( usepca || usegmm )
  187. {
  188. RFCsurka *_rfcsurka = dynamic_cast< RFCsurka * > ( rfCsurka );
  189. if ( usepca )
  190. {
  191. int pcadim = conf->gI ( "Csurka", "pcadim", 100 );
  192. PCA *pca = new PCA ( pcadim );
  193. string pcadst = cache + "/csurka.pca";
  194. if ( !FileMgt::fileExists ( pcadst ) )
  195. {
  196. throw ( pcadst + " not found" );
  197. }
  198. else
  199. {
  200. pca->read ( pcadst );
  201. }
  202. _rfcsurka->setPCA ( pca );
  203. }
  204. if ( usegmm )
  205. {
  206. int gaussians = conf->gI ( "Csurka", "gaussians", 1024 );
  207. GMM *g = new GMM ( conf, gaussians );
  208. string gmmdst = cache + "/csurka.gmm";
  209. if ( !g->loadData ( cache + "/gmmSIFT" ) )
  210. {
  211. throw ( gmmdst + " not found" );
  212. }
  213. _rfcsurka->setGMM ( g );
  214. }
  215. }
  216. }
  217. if ( rfbov != NULL )
  218. {
  219. RFBoV *rfbovdyn = dynamic_cast< RFBoV * > ( rfbov );
  220. int gaussians = conf->gI ( "SIFTTrain", "gaussians", 512 );
  221. GMM *g = new GMM ( conf, gaussians );
  222. PCA *pca = new PCA ( 100 );
  223. string pcadst = cache + "/bov.pca";
  224. if ( !g->loadData ( cache + "/gmmSIFT" ) || !FileMgt::fileExists ( pcadst ) )
  225. {
  226. throw ( "pca or gmm not found" );
  227. }
  228. else
  229. {
  230. pca->read ( pcadst );
  231. }
  232. rfbovdyn->setPCA ( pca );
  233. rfbovdyn->setGMM ( g );
  234. }
  235. fprintf ( stderr, "SemSegRegion:: successfully read\n" );
  236. }
  237. else
  238. {
  239. train ( md );
  240. }
  241. #ifdef DEBUG_PRINTS
  242. cout << "SemSegRegionBased Constructor finished" << endl;
  243. #endif
  244. }
  245. SemSegRegionBased::~SemSegRegionBased()
  246. {
  247. #ifdef DEBUG_PRINTS
  248. cout << "SemSegRegionBased Destructor starts" << endl;
  249. #endif
  250. if ( fpc != NULL )
  251. {
  252. delete fpc;
  253. }
  254. if ( vclassifier != NULL )
  255. {
  256. delete vclassifier;
  257. }
  258. #ifdef DEBUG_PRINTS
  259. cout << "SemSegRegionBased Destructor finished" << endl;
  260. #endif
  261. }
  262. void SemSegRegionBased::train ( const MultiDataset *md )
  263. {
  264. #ifdef DEBUG_PRINTS
  265. cout << "SemSegRegionBased::train starts" << endl;
  266. #endif
  267. Examples examples;
  268. examples.filename = "training";
  269. const LabeledSet train = * ( *md ) ["train"];
  270. set<int> forbidden_classes;
  271. std::string forbidden_classes_s = conf->gS ( "analysis", "donttrain", "" );
  272. if ( forbidden_classes_s == "" )
  273. {
  274. forbidden_classes_s = conf->gS ( "analysis", "forbidden_classes", "" );
  275. }
  276. cn.getSelection ( forbidden_classes_s, forbidden_classes );
  277. if ( gcopt != NULL )
  278. gcopt->setClassNo ( cn.numClasses() );
  279. LabeledSet::Permutation perm;
  280. train.getPermutation ( perm );
  281. learnHighLevel ( perm );
  282. //FIXME:Moosmann
  283. int imgcounter = 0;
  284. vector<vector<FeatureType> > feats;
  285. // loop over all training images
  286. for ( LabeledSet::Permutation::const_iterator i = perm.begin();
  287. i != perm.end(); i++, imgcounter++ )
  288. {
  289. const string fn = i->second->img();
  290. Globals::setCurrentImgFN ( fn );
  291. cout << fn << endl;
  292. NICE::ColorImage cimg ( fn );
  293. NICE::Matrix mask;
  294. RegionGraph rg;
  295. rsm->getGraphRepresentation ( cimg, mask, rg );
  296. #ifdef DEMO
  297. rsm->visualizeGraphRepresentation ( cimg, mask );
  298. #endif
  299. // get label
  300. const LocalizationResult *locResult = i->second->localization();
  301. NICE::Image pixelLabels ( cimg.width(), cimg.height() );
  302. pixelLabels.set ( 0 );
  303. locResult->calcLabeledImage ( pixelLabels, ( *classNames ).getBackgroundClass() );
  304. getRegionLabel ( mask, rg, pixelLabels );
  305. getFeats ( cimg, mask, rg, feats );
  306. //#pragma omp critical
  307. for ( int i = 0; i < rg.size(); i++ )
  308. {
  309. int classno = rg[i]->getLabel();
  310. Example example;
  311. example.position = imgcounter;
  312. examples.push_back ( pair<int, Example> ( classno, example ) );
  313. }
  314. //#pragma omp critical
  315. if ( gcopt != NULL )
  316. gcopt->trainImage ( rg );
  317. }
  318. cout << "train classifier starts" << endl;
  319. trainClassifier ( feats, examples );
  320. cout << "train classifier finished" << endl;
  321. if ( gcopt != NULL )
  322. gcopt->finishPP ( cn );
  323. // clean up
  324. /*for(int i = 0; i < (int) examples.size(); i++)
  325. {
  326. examples[i].second.clean();
  327. }*/
  328. #ifdef DEBUG_PRINTS
  329. cout << "SemSegRegionBased::train finished" << endl;
  330. #endif
  331. }
  332. void SemSegRegionBased::getRegionLabel ( NICE::Matrix &mask, RegionGraph &rg, NICE::Image &pixelLabels )
  333. {
  334. #ifdef DEBUG_PRINTS
  335. cout << "SemSegRegionBased::getRegionLabel starts" << endl;
  336. #endif
  337. vector<vector<int> > hists;
  338. int regionsize = rg.size();
  339. int xsize = pixelLabels.width();
  340. int ysize = pixelLabels.height();
  341. for ( int i = 0; i < regionsize; i++ )
  342. {
  343. vector<int> hist ( cn.numClasses(), 0 );
  344. hists.push_back ( hist );
  345. }
  346. for ( int x = 0; x < xsize; x++ )
  347. {
  348. for ( int y = 0; y < ysize; y++ )
  349. {
  350. int numb = mask ( x, y );
  351. hists[numb][pixelLabels.getPixel ( x,y ) ]++;
  352. }
  353. }
  354. for ( int i = 0; i < regionsize; i++ )
  355. {
  356. int maxval = -numeric_limits<int>::max();
  357. int smaxval = -numeric_limits<int>::max();
  358. int maxpos = -1;
  359. int secondpos = -1;
  360. for ( int k = 0; k < ( int ) hists[i].size(); k++ )
  361. {
  362. if ( maxval < hists[i][k] )
  363. {
  364. secondpos = maxpos;
  365. smaxval = maxval;
  366. maxval = hists[i][k];
  367. maxpos = k;
  368. }
  369. else
  370. {
  371. if ( smaxval < hists[i][k] )
  372. {
  373. smaxval = hists[i][k];
  374. secondpos = k;
  375. }
  376. }
  377. }
  378. // FIXME: das für alle verbotenen Klassen einbauen
  379. //if ( forbidden_classes.find ( classno ) != forbidden_classes.end() )
  380. if ( cn.text ( maxpos ) == "various" && smaxval > 0 )
  381. rg[i]->setLabel ( secondpos );
  382. else
  383. rg[i]->setLabel ( maxpos );
  384. }
  385. #ifdef DEBUG_PRINTS
  386. cout << "SemSegRegionBased::getRegionLabel finished" << endl;
  387. #endif
  388. }
  389. void SemSegRegionBased::getExample ( const vector<vector<FeatureType> > &feats, Examples &examples )
  390. {
  391. #ifdef DEBUG_PRINTS
  392. cout << "SemSegRegionBased::getExample starts" << endl;
  393. #endif
  394. for ( int j = 0; j < ( int ) feats.size(); j++ )
  395. {
  396. int counter = 0;
  397. for ( int i = 0; i < ( int ) feats[0].size(); i++, counter++ )
  398. {
  399. if ( examples[counter].second.vec == NULL )
  400. {
  401. NICE::Vector *vec = new NICE::Vector ( feats[j][i].getVec() );
  402. examples[counter].second.vec = vec;
  403. }
  404. else
  405. {
  406. examples[counter].second.vec->append ( feats[j][i].getVec() );
  407. }
  408. }
  409. }
  410. #ifdef DEBUG_PRINTS
  411. cout << "SemSegRegionBased::getExample finished" << endl;
  412. #endif
  413. }
  414. void SemSegRegionBased::getFeaturePool ( const vector<vector<FeatureType> > &feats, FeaturePool &fp )
  415. {
  416. #ifdef DEBUG_PRINTS
  417. cout << "SemSegRegionBased::getFeaturePool starts" << endl;
  418. #endif
  419. int olddim = 0;
  420. int fulldim = 0;
  421. for ( int j = 0; j < ( int ) feats.size(); j++ )
  422. {
  423. fulldim += feats[j][0].getDim();
  424. }
  425. for ( int j = 0; j < ( int ) feats.size(); j++ )
  426. {
  427. int dimension = feats[j][0].getDim();
  428. for ( int i = olddim ; i < olddim + dimension ; i++ )
  429. {
  430. VectorFeature *f = new VectorFeature ( fulldim );
  431. f->feature_index = i;
  432. fp.addFeature ( f, 1.0 / dimension );
  433. }
  434. olddim += dimension;
  435. }
  436. #ifdef DEBUG_PRINTS
  437. cout << "SemSegRegionBased::getFeaturePool finished" << endl;
  438. #endif
  439. }
  440. void SemSegRegionBased::trainClassifier ( vector<vector<FeatureType> > &feats, Examples & examples )
  441. {
  442. #ifdef DEBUG_PRINTS
  443. cout << "SemSegRegionBased::trainClassifier starts" << endl;
  444. #endif
  445. assert ( feats.size() > 0 );
  446. assert ( feats[0].size() > 0 );
  447. // delete nonrelevant features
  448. for ( int i = ( int ) examples.size() - 1; i >= 0; i-- )
  449. {
  450. if ( cn.text ( examples[i].first ) == "various" )
  451. {
  452. examples.erase ( examples.begin() + i );
  453. for ( int k = 0; k < ( int ) feats.size(); k++ )
  454. {
  455. feats[k].erase ( feats[k].begin() + i );
  456. }
  457. }
  458. }
  459. #ifdef WRITEFEATS
  460. // mermale in datei schreiben
  461. ofstream fout ( "trainfeats", ios_base::out );
  462. //vector<int> ccounter(cn.getMaxClassno(),0);
  463. //int maxv = 100;
  464. for ( int i = 0; i < ( int ) examples.size(); i++ )
  465. {
  466. //if(ccounter[examples[i].first]++ < maxv)
  467. //{
  468. fout << examples[i].first << " ";
  469. for ( int j = 0; j < ( int ) feats.size(); j++ )
  470. {
  471. for ( int k = 0; k < feats[j][i].getDim(); k++ )
  472. {
  473. fout << feats[j][i].get ( k ) << " ";
  474. }
  475. }
  476. fout << endl;
  477. //}
  478. }
  479. #endif
  480. if ( fpc != NULL )
  481. {
  482. FeaturePool fp;
  483. getExample ( feats, examples );
  484. getFeaturePool ( feats, fp );
  485. fpc->train ( fp, examples );
  486. fp.destroy();
  487. if ( save_cache )
  488. {
  489. fpc->save ( classifiercache );
  490. }
  491. //#pragma omp parallel for
  492. for ( int i = 0; i < ( int ) examples.size(); i++ )
  493. {
  494. if ( examples[i].second.vec != NULL )
  495. {
  496. delete examples[i].second.vec;
  497. examples[i].second.vec = NULL;
  498. }
  499. }
  500. }
  501. else if ( vclassifier != NULL )
  502. {
  503. LabeledSetVector lsv;
  504. //#pragma omp parallel for
  505. for ( int i = 0; i < ( int ) feats[0].size(); i++ )
  506. {
  507. NICE::Vector *v = new NICE::Vector ( feats[0][i].getVec() );
  508. for ( int j = 1; j < ( int ) feats.size(); j++ )
  509. {
  510. v->append ( feats[j][i].getVec() );
  511. }
  512. //#pragma omp critical
  513. lsv.add_reference ( examples[i].first, v );
  514. }
  515. vclassifier->teach ( lsv );
  516. vclassifier->finishTeaching();
  517. lsv.clear();
  518. if ( save_cache )
  519. {
  520. vclassifier->save ( classifiercache );
  521. }
  522. }
  523. #ifdef DEBUG_PRINTS
  524. cout << "SemSegRegionBased::trainClassifier finished" << endl;
  525. #endif
  526. }
  527. void SemSegRegionBased::classify ( const vector<vector<FeatureType> > &feats, Examples &examples, vector<vector<double> > &probs )
  528. {
  529. #ifdef DEBUG_PRINTS
  530. cout << "SemSegRegionBased::classify starts" << endl;
  531. #endif
  532. for ( int i = 0; i < ( int ) feats[0].size(); i++ )
  533. {
  534. Example example;
  535. examples.push_back ( pair<int, Example> ( -1, example ) );
  536. }
  537. getExample ( feats, examples );
  538. int nbcl = classNames->getMaxClassno() + 1;
  539. for ( int i = 0; i < ( int ) examples.size(); i++ )
  540. {
  541. vector<double> p;
  542. ClassificationResult r;
  543. if ( fpc != NULL )
  544. {
  545. r = fpc->classify ( examples[i].second );
  546. }
  547. else if ( vclassifier != NULL )
  548. {
  549. r = vclassifier->classify ( * ( examples[i].second.vec ) );
  550. }
  551. for ( int j = 0 ; j < nbcl; j++ )
  552. {
  553. p.push_back ( r.scores[j] );
  554. }
  555. probs.push_back ( p );
  556. }
  557. #ifdef DEBUG_PRINTS
  558. cout << "SemSegRegionBased::classify finished" << endl;
  559. #endif
  560. }
  561. void SemSegRegionBased::semanticseg ( CachedExample *ce, NICE::Image & segresult, NICE::MultiChannelImageT<double> & probabilities )
  562. {
  563. #ifdef DEBUG_PRINTS
  564. cout << "SemSegRegionBased::semanticseg starts" << endl;
  565. #endif
  566. int xsize, ysize;
  567. ce->getImageSize ( xsize, ysize );
  568. probabilities.reInit ( xsize, ysize, classNames->getMaxClassno() + 1);
  569. std::string currentFile = Globals::getCurrentImgFN();
  570. NICE::ColorImage cimg ( currentFile );
  571. NICE::Matrix mask;
  572. RegionGraph rg;
  573. rsm->getGraphRepresentation ( cimg, mask, rg );
  574. #ifdef DEMO
  575. rsm->visualizeGraphRepresentation ( cimg, mask );
  576. #endif
  577. vector<vector<FeatureType> > feats;
  578. getFeats ( cimg, mask, rg, feats );
  579. #ifdef WRITEFEATS
  580. getRegionLabel ( mask, rg, segresult );
  581. ofstream fout ( "testfeats", ios_base::app );
  582. for ( int i = 0; i < ( int ) rg.size(); i++ )
  583. {
  584. fout << rg[i]->getLabel() << " ";
  585. for ( int j = 0; j < ( int ) feats.size(); j++ )
  586. {
  587. for ( int k = 0; k < feats[j][i].getDim(); k++ )
  588. {
  589. fout << feats[j][i].get ( k ) << " ";
  590. }
  591. }
  592. fout << endl;
  593. }
  594. #endif
  595. segresult = NICE::Image ( xsize, ysize );
  596. segresult.set ( 0 );
  597. Examples examples;
  598. vector<vector<double> > probs;
  599. classify ( feats, examples, probs );
  600. labelRegions ( rg, probs );
  601. if ( gcopt != NULL )
  602. gcopt->optimizeImage ( rg, probs );
  603. labelImage ( segresult, mask, rg );
  604. #ifdef DEBUG_PRINTS
  605. cout << "SemSegRegionBased::semanticseg finished" << endl;
  606. #endif
  607. }
  608. void SemSegRegionBased::labelRegions ( RegionGraph &rg, vector<vector<double> > &probs )
  609. {
  610. #ifdef DEBUG_PRINTS
  611. cout << "SemSegRegionBased::labelRegions starts" << endl;
  612. #endif
  613. for ( int i = 0; i < rg.size(); i++ )
  614. {
  615. int bestclass = -1;
  616. double bestval = -numeric_limits<int>::max();
  617. for ( int j = 0; j < ( int ) probs[i].size(); j++ )
  618. {
  619. if ( bestval < probs[i][j] )
  620. {
  621. bestval = probs[i][j];
  622. bestclass = j;
  623. }
  624. }
  625. rg[i]->setLabel ( bestclass );
  626. }
  627. #ifdef DEBUG_PRINTS
  628. cout << "SemSegRegionBased::labelRegions finished" << endl;
  629. #endif
  630. }
  631. void SemSegRegionBased::labelImage ( NICE::Image &segresult, NICE::Matrix &mask, RegionGraph &rg )
  632. {
  633. #ifdef DEBUG_PRINTS
  634. cout << "SemSegRegionBased::labelImage starts" << endl;
  635. #endif
  636. for ( int y = 0; y < segresult.height(); y++ )
  637. {
  638. for ( int x = 0; x < segresult.width(); x++ )
  639. {
  640. int r = ( int ) mask ( x, y );
  641. segresult.setPixel ( x, y, rg[r]->getLabel() );
  642. }
  643. }
  644. #ifdef DEBUG_PRINTS
  645. cout << "SemSegRegionBased::labelImage finished" << endl;
  646. #endif
  647. }
  648. void SemSegRegionBased::getFeats ( const NICE::ColorImage &cimg, const NICE::Matrix &mask, const RegionGraph &rg, vector<vector< FeatureType> > &feats ) const
  649. {
  650. #ifdef DEBUG_PRINTS
  651. cout << "SemSegRegionBased::getFeats starts" << endl;
  652. #endif
  653. string fn = Globals::getCurrentImgFN();
  654. NICE::Image img ( fn );
  655. int featnb = 0;
  656. const int rgcount = rg.size();
  657. if ( rfc != NULL )
  658. {
  659. if ( ( int ) feats.size() <= featnb )
  660. {
  661. vector<FeatureType> ftv;
  662. feats.push_back ( ftv );
  663. }
  664. VVector features;
  665. rfc->extractRGB ( cimg, rg, mask, features );
  666. assert ( ( int ) features.size() == rgcount );
  667. for ( int j = 0; j < ( int ) features.size(); j++ )
  668. {
  669. feats[featnb].push_back ( FeatureType ( features[j] ) );
  670. }
  671. #ifdef DEMO
  672. LFColorWeijer lfc ( conf );
  673. lfc.visualizeFeatures ( cimg );
  674. #endif
  675. featnb++;
  676. }
  677. if ( rfbov != NULL )
  678. {
  679. if ( ( int ) feats.size() <= featnb )
  680. {
  681. vector<FeatureType> ftv;
  682. feats.push_back ( ftv );
  683. }
  684. VVector features;
  685. rfbov->extractRGB ( cimg, rg, mask, features );
  686. assert ( ( int ) features.size() == rgcount );
  687. for ( int j = 0; j < ( int ) features.size(); j++ )
  688. {
  689. feats[featnb].push_back ( FeatureType ( features[j] ) );
  690. }
  691. featnb++;
  692. }
  693. if ( rfhog != NULL )
  694. {
  695. if ( ( int ) feats.size() <= featnb )
  696. {
  697. vector<FeatureType> ftv;
  698. feats.push_back ( ftv );
  699. }
  700. VVector features;
  701. rfhog->extractRGB ( cimg, rg, mask, features );
  702. assert ( ( int ) features.size() == rgcount );
  703. for ( int j = 0; j < ( int ) features.size(); j++ )
  704. {
  705. feats[featnb].push_back ( FeatureType ( features[j] ) );
  706. }
  707. featnb++;
  708. }
  709. if ( rfstruct != NULL )
  710. {
  711. if ( ( int ) feats.size() <= featnb )
  712. {
  713. vector<FeatureType> ftv;
  714. feats.push_back ( ftv );
  715. }
  716. VVector features;
  717. rfstruct->extractRGB ( cimg, rg, mask, features );
  718. for ( int j = 0; j < ( int ) features.size(); j++ )
  719. {
  720. feats[featnb].push_back ( FeatureType ( features[j] ) );
  721. }
  722. featnb++;
  723. }
  724. if ( rfbovcrdf != NULL )
  725. {
  726. if ( ( int ) feats.size() <= featnb )
  727. {
  728. vector<FeatureType> ftv;
  729. feats.push_back ( ftv );
  730. }
  731. VVector features;
  732. rfbovcrdf->extractRGB ( cimg, rg, mask, features );
  733. assert ( ( int ) features.size() == rgcount );
  734. for ( int j = 0; j < ( int ) features.size(); j++ )
  735. {
  736. feats[featnb].push_back ( FeatureType ( features[j] ) );
  737. }
  738. featnb++;
  739. }
  740. if ( rfCsurka != NULL )
  741. {
  742. if ( ( int ) feats.size() <= featnb )
  743. {
  744. vector<FeatureType> ftv;
  745. feats.push_back ( ftv );
  746. }
  747. VVector features;
  748. rfCsurka->extractRGB ( cimg, rg, mask, features );
  749. assert ( ( int ) features.size() == rgcount );
  750. for ( int j = 0; j < ( int ) features.size(); j++ )
  751. {
  752. feats[featnb].push_back ( FeatureType ( features[j] ) );
  753. }
  754. featnb++;
  755. }
  756. /* Dummy for new features:
  757. if(siftFeats != NULL)
  758. {
  759. if((int)feats.size() <= featnb)
  760. {
  761. vector<FeatureType> ftv;
  762. feats.push_back(ftv);
  763. }
  764. featnb++;
  765. }
  766. */
  767. #ifdef DEBUG_PRINTS
  768. cout << "SemSegRegionBased::getFeats finished" << endl;
  769. #endif
  770. }
  771. void SemSegRegionBased::computeLF ( LabeledSet::Permutation perm, VVector &feats, vector<int> &label, Examples &examples, int mode )
  772. {
  773. #ifdef DEBUG_PRINTS
  774. cout << "SemSegRegionBased::computeLF starts" << endl;
  775. #endif
  776. string sscales = conf->gS ( "SIFTTrain", "scales", "1+2.0+3.0" );
  777. int grid = conf->gI ( "SIFTTrain", "grid", 20 );
  778. double fraction = conf->gD ( "SIFTTrain", "fraction", 1.0 );
  779. set<int> forbidden_classes;
  780. std::string forbidden_classes_s = conf->gS ( "analysis", "donttrain", "" );
  781. if ( forbidden_classes_s == "" )
  782. {
  783. forbidden_classes_s = conf->gS ( "analysis", "forbidden_classes", "" );
  784. }
  785. cn.getSelection ( forbidden_classes_s, forbidden_classes );
  786. cerr << "forbidden: " << forbidden_classes_s << endl;
  787. vector<double> scales;
  788. string::size_type pos = 0;
  789. string::size_type oldpos = 0;
  790. while ( pos != string::npos )
  791. {
  792. pos = sscales.find ( "+", oldpos );
  793. string val;
  794. if ( pos == string::npos )
  795. val = sscales.substr ( oldpos );
  796. else
  797. val = sscales.substr ( oldpos, pos - oldpos );
  798. double d = atof ( val.c_str() );
  799. scales.push_back ( d );
  800. oldpos = pos + 1;
  801. }
  802. int fsize = 0;
  803. string save = cache + "/siftTRAIN.dat";
  804. string savep = cache + "/siftPostions.dat";
  805. if ( !FileMgt::fileExists ( save ) || !FileMgt::fileExists ( savep ) )
  806. {
  807. //FIXME: entfernen
  808. // vector<int> counter(9,0);
  809. for ( LabeledSet::Permutation::const_iterator i = perm.begin();
  810. i != perm.end(); i++ )
  811. {
  812. const string fn = i->second->img();
  813. Globals::setCurrentImgFN ( fn );
  814. NICE::Image img ( fn );
  815. NICE::ColorImage cimg ( fn );
  816. VVector features;
  817. VVector positions;
  818. int x0 = grid / 2;
  819. for ( int y = 0; y < ( int ) img.height(); y += grid )
  820. {
  821. for ( int x = x0; x < ( int ) img.width(); x += grid )
  822. {
  823. for ( int s = 0; s < ( int ) scales.size(); s++ )
  824. {
  825. double r = ( double ) rand() / ( double ) RAND_MAX;
  826. if ( r < fraction )
  827. {
  828. fsize++;
  829. NICE::Vector vec ( 3 );
  830. vec[0] = x;
  831. vec[1] = y;
  832. vec[2] = scales[s];
  833. positions.push_back ( vec );
  834. }
  835. }
  836. }
  837. if ( x0 == 0 )
  838. {
  839. x0 = grid / 2;
  840. }
  841. else
  842. {
  843. x0 = 0;
  844. }
  845. }
  846. siftFeats->getDescriptors ( cimg, positions, features );
  847. assert ( positions.size() == features.size() );
  848. const LocalizationResult *locResult = i->second->localization();
  849. NICE::Image pixelLabels ( cimg.width(), cimg.height() );
  850. pixelLabels.set ( 0 );
  851. locResult->calcLabeledImage ( pixelLabels, ( *classNames ).getBackgroundClass() );
  852. for ( int i = 0; i < ( int ) features.size(); i++ )
  853. {
  854. int classno = pixelLabels ( positions[i][0], positions[i][1] );
  855. // if ( cn.text ( classno ) == "various")
  856. // continue;
  857. if ( forbidden_classes.find ( classno ) != forbidden_classes.end() )
  858. continue;
  859. // counter[classno]++;
  860. label.push_back ( classno );
  861. feats.push_back ( features[i] );
  862. }
  863. assert ( label.size() == feats.size() );
  864. }
  865. /* cout << "samples for class: " << endl;
  866. for(int i = 0; i < 9; i++)
  867. {
  868. cout << i << ": " << counter[i] << endl;
  869. }
  870. */
  871. feats.save ( save, 1 );
  872. ofstream lout ( savep.c_str(), ios_base::out );
  873. for ( uint i = 0; i < label.size(); i++ )
  874. {
  875. lout << label[i] << " ";
  876. }
  877. lout.close();
  878. }
  879. else
  880. {
  881. feats.read ( save, 1 );
  882. ifstream lin ( savep.c_str(), ios_base::in );
  883. label.clear();
  884. for ( int i = 0; i < ( int ) feats.size(); i++ )
  885. {
  886. int l;
  887. lin >> l;
  888. label.push_back ( l );
  889. }
  890. }
  891. if ( mode == 1 )
  892. {
  893. convertVVectorToExamples ( feats, examples, label );
  894. }
  895. #ifdef DEBUG_PRINTS
  896. cout << "SemSegRegionBased::computeLF finished" << endl;
  897. #endif
  898. }
  899. void SemSegRegionBased::learnHighLevel ( LabeledSet::Permutation perm )
  900. {
  901. #ifdef DEBUG_PRINTS
  902. cout << "SemSegRegionBased::learnHighLevel starts" << endl;
  903. #endif
  904. srand ( time ( NULL ) );
  905. if ( rfbov != NULL || rfbovcrdf != NULL || rfCsurka != NULL )
  906. {
  907. if ( rfbov != NULL )
  908. {
  909. RFBoV *rfbovdyn = dynamic_cast< RFBoV * > ( rfbov );
  910. int gaussians = conf->gI ( "SIFTTrain", "gaussians", 512 );
  911. int pcadim = conf->gI ( "SIFTTrain", "pcadim", 50 );
  912. GMM *g = new GMM ( conf, gaussians );
  913. PCA *pca = new PCA ( pcadim );
  914. string pcadst = cache + "/pca.txt";
  915. if ( !g->loadData ( cache + "/gmmSIFT" ) || !FileMgt::fileExists ( pcadst ) )
  916. {
  917. VVector feats;
  918. vector<int> label;
  919. Examples ex;
  920. computeLF ( perm, feats, label, ex, 0 );
  921. assert ( feats.size() > 0 );
  922. initializePCA ( feats, *pca, pcadim, pcadst );
  923. transformFeats ( feats, *pca );
  924. cout << "nb of feats for learning gmm: " << feats.size() << endl;
  925. g->computeMixture ( feats );
  926. if ( save_cache )
  927. g->saveData ( cache + "/gmmSIFT" );
  928. }
  929. else
  930. {
  931. pca->read ( pcadst );
  932. }
  933. rfbovdyn->setPCA ( pca );
  934. rfbovdyn->setGMM ( g );
  935. }
  936. if ( rfbovcrdf != NULL || rfCsurka != NULL )
  937. {
  938. Examples examples;
  939. VVector feats;
  940. vector<int> label;
  941. computeLF ( perm, feats, label, examples , 1 );
  942. FeaturePool fp;
  943. FeaturePool fpsparse;
  944. int dimension = examples[0].second.vec->size();
  945. for ( int i = 0 ; i < dimension ; i++ )
  946. {
  947. VectorFeature *f = new VectorFeature ( dimension, i );
  948. fp.addFeature ( f, 1.0 / dimension );
  949. SparseVectorFeature *fs = new SparseVectorFeature ( dimension, i );
  950. //fs->feature_index = i;
  951. fpsparse.addFeature ( fs, 1.0 / dimension );
  952. }
  953. if ( rfbovcrdf != NULL )
  954. {
  955. RFBoVCodebook *rfbovdyn = dynamic_cast< RFBoVCodebook * > ( rfbovcrdf );
  956. int maxDepth = conf->gI ( "BoVMoosmann", "maxdepth", 10 );
  957. int csize = conf->gI ( "BoVMoosmann", "codebooksize", 1024 );
  958. CodebookRandomForest *crdf = new CodebookRandomForest ( maxDepth, csize );
  959. //RF anlernen
  960. FPCRandomForests *fpcrfmoos = new FPCRandomForests ( conf, "MoosForest" );
  961. fpcrfmoos->train ( fp, examples );
  962. crdf->setClusterForest ( fpcrfmoos );
  963. for ( int i = 0; i < ( int ) examples.size(); i++ )
  964. {
  965. if ( examples[i].second.vec != NULL )
  966. {
  967. delete examples[i].second.vec;
  968. examples[i].second.vec = NULL;
  969. }
  970. }
  971. rfbovdyn->setCodebook ( crdf );
  972. }
  973. if ( rfCsurka != NULL )
  974. {
  975. bool usegmm = conf->gB ( "Csurka", "usegmm", false );
  976. bool usepca = conf->gB ( "Csurka", "usepca", false );
  977. PCA *pca = NULL;
  978. GMM *g = NULL;
  979. string classifierdst = cache + "/csurka.";
  980. if ( usepca || usegmm )
  981. {
  982. RFCsurka *_rfcsurka = dynamic_cast< RFCsurka * > ( rfCsurka );
  983. bool create = false;
  984. string gmmdst = cache + "/csurka.gmm";
  985. string pcadst = cache + "/csurka.pca";
  986. int pcadim = conf->gI ( "Csurka", "pcadim", 100 );
  987. if ( usepca )
  988. {
  989. pca = new PCA ( pcadim );
  990. if ( !FileMgt::fileExists ( pcadst ) )
  991. {
  992. create = true;
  993. }
  994. else
  995. {
  996. pca->read ( pcadst );
  997. }
  998. }
  999. if ( usegmm )
  1000. {
  1001. int gaussians = conf->gI ( "Csurka", "gaussians", 1024 );
  1002. g = new GMM ( conf, gaussians );
  1003. if ( !g->loadData ( gmmdst ) )
  1004. {
  1005. create = true;
  1006. }
  1007. }
  1008. if ( create )
  1009. {
  1010. if ( usepca )
  1011. {
  1012. convertExamplesToVVector ( feats, examples, label );
  1013. initializePCA ( feats, *pca, pcadim, pcadst );
  1014. transformFeats ( feats, *pca );
  1015. convertVVectorToExamples ( feats, examples, label );
  1016. }
  1017. if ( usegmm )
  1018. {
  1019. g->computeMixture ( examples );
  1020. if ( save_cache )
  1021. g->saveData ( gmmdst );
  1022. }
  1023. }
  1024. if ( usepca )
  1025. _rfcsurka->setPCA ( pca );
  1026. if ( usegmm )
  1027. _rfcsurka->setGMM ( g );
  1028. }
  1029. string classifiertype = conf->gS ( "Csurka", "classifier", "SMLR" );
  1030. FeaturePoolClassifier *fpcrfCs = NULL;
  1031. VecClassifier *vecClassifier = NULL;
  1032. if ( classifiertype == "SMLR" )
  1033. {
  1034. fpcrfCs = new FPCSMLR ( conf, "CsurkaSMLR" );
  1035. classifierdst += "smlr";
  1036. }
  1037. else if ( classifiertype == "RF" )
  1038. {
  1039. fpcrfCs = new FPCRandomForests ( conf, "CsurkaForest" );
  1040. classifierdst += "rf";
  1041. }
  1042. else
  1043. {
  1044. vecClassifier = GenericClassifierSelection::selectVecClassifier ( conf, classifiertype );
  1045. classifierdst += "other";
  1046. }
  1047. RFCsurka *rfcsurka = dynamic_cast< RFCsurka * > ( rfCsurka );
  1048. if ( usepca )
  1049. {
  1050. assert ( examples.size() > 0 );
  1051. if ( ( int ) examples[0].second.vec->size() != pca->getTargetDim() )
  1052. {
  1053. for ( int i = 0; i < ( int ) examples.size(); ++i )
  1054. {
  1055. *examples[i].second.vec = pca->getFeatureVector ( *examples[i].second.vec, true );
  1056. }
  1057. }
  1058. }
  1059. if ( !FileMgt::fileExists ( classifierdst ) )
  1060. {
  1061. if ( usegmm )
  1062. {
  1063. if ( classifiertype == "SMLR" )
  1064. {
  1065. for ( int i = 0; i < ( int ) examples.size(); ++i )
  1066. {
  1067. examples[i].second.svec = new SparseVector();
  1068. g->getProbs ( *examples[i].second.vec, *examples[i].second.svec );
  1069. delete examples[i].second.vec;
  1070. examples[i].second.vec = NULL;
  1071. }
  1072. }
  1073. else
  1074. {
  1075. for ( int i = 0; i < ( int ) examples.size(); ++i )
  1076. {
  1077. g->getProbs ( *examples[i].second.vec, *examples[i].second.vec );
  1078. }
  1079. }
  1080. if ( fpcrfCs != NULL )
  1081. {
  1082. fpcrfCs->train ( fpsparse, examples );
  1083. }
  1084. else
  1085. {
  1086. LabeledSetVector lvec;
  1087. convertExamplesToLSet ( examples, lvec );
  1088. vecClassifier->teach ( lvec );
  1089. convertLSetToExamples ( examples, lvec );
  1090. vecClassifier->finishTeaching();
  1091. }
  1092. }
  1093. else
  1094. {
  1095. if ( fpcrfCs != NULL )
  1096. {
  1097. fpcrfCs->train ( fp, examples );
  1098. }
  1099. else
  1100. {
  1101. LabeledSetVector lvec;
  1102. convertExamplesToLSet ( examples, lvec );
  1103. vecClassifier->teach ( lvec );
  1104. convertLSetToExamples ( examples, lvec );
  1105. vecClassifier->finishTeaching();
  1106. }
  1107. }
  1108. if ( fpcrfCs != NULL )
  1109. {
  1110. fpcrfCs->setMaxClassNo ( classNames->getMaxClassno() );
  1111. fpcrfCs->save ( classifierdst );
  1112. }
  1113. else
  1114. {
  1115. vecClassifier->setMaxClassNo ( classNames->getMaxClassno() );
  1116. vecClassifier->save ( classifierdst );
  1117. }
  1118. }
  1119. else
  1120. {
  1121. if ( fpcrfCs != NULL )
  1122. {
  1123. fpcrfCs->setMaxClassNo ( classNames->getMaxClassno() );
  1124. fpcrfCs->read ( classifierdst );
  1125. }
  1126. else
  1127. {
  1128. vecClassifier->setMaxClassNo ( classNames->getMaxClassno() );
  1129. vecClassifier->read ( classifierdst );
  1130. }
  1131. }
  1132. if ( fpcrfCs != NULL )
  1133. {
  1134. rfcsurka->setClassifier ( fpcrfCs );
  1135. }
  1136. else
  1137. {
  1138. rfcsurka->setClassifier ( vecClassifier );
  1139. }
  1140. }
  1141. fp.destroy();
  1142. for ( int i = 0; i < ( int ) examples.size(); i++ )
  1143. {
  1144. if ( examples[i].second.vec != NULL )
  1145. {
  1146. delete examples[i].second.vec;
  1147. examples[i].second.vec = NULL;
  1148. }
  1149. }
  1150. }
  1151. }
  1152. #ifdef DEBUG_PRINTS
  1153. cerr << "SemSegRegionBased::learnHighLevel finished" << endl;
  1154. #endif
  1155. }
  1156. void SemSegRegionBased::transformFeats ( VVector &feats, PCA &pca )
  1157. {
  1158. #ifdef DEBUG_PRINTS
  1159. cerr << "SemSegRegionBased::transformFeats starts" << endl;
  1160. #endif
  1161. for ( int i = 0; i < ( int ) feats.size(); i++ )
  1162. {
  1163. feats[i] = pca.getFeatureVector ( feats[i], true );
  1164. }
  1165. #ifdef DEBUG_PRINTS
  1166. cerr << "SemSegRegionBased::transformFeats finished" << endl;
  1167. #endif
  1168. }
  1169. void SemSegRegionBased::initializePCA ( const VVector &feats, PCA &pca, int dim, string &fn )
  1170. {
  1171. #ifdef DEBUG_PRINTS
  1172. cerr << "SemSegRegionBased::initializePCA starts" << endl;
  1173. #endif
  1174. pca = PCA ( dim );
  1175. if ( !FileMgt::fileExists ( fn ) )
  1176. {
  1177. srand ( time ( NULL ) );
  1178. int featsize = ( int ) feats.size();
  1179. int maxfeatures = std::min ( dim * 20, featsize );
  1180. NICE::Matrix features ( maxfeatures, ( int ) feats[0].size() );
  1181. for ( int i = 0; i < maxfeatures; i++ )
  1182. {
  1183. int k = rand() % featsize;
  1184. int vsize = ( int ) feats[k].size();
  1185. for ( int j = 0; j < vsize; j++ )
  1186. {
  1187. features ( i, j ) = feats[k][j];
  1188. }
  1189. }
  1190. pca.calculateBasis ( features, dim );
  1191. if ( save_cache )
  1192. pca.save ( fn );
  1193. }
  1194. else
  1195. {
  1196. pca.read ( fn );
  1197. }
  1198. #ifdef DEBUG_PRINTS
  1199. cerr << "SemSegRegionBased::initializePCA finished" << endl;
  1200. #endif
  1201. }