SemSegContextTree.cpp 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856
  1. #include "SemSegContextTree.h"
  2. #include "vislearning/baselib/Globals.h"
  3. #include "vislearning/baselib/ProgressBar.h"
  4. #include "core/basics/StringTools.h"
  5. #include "vislearning/cbaselib/CachedExample.h"
  6. #include "vislearning/cbaselib/PascalResults.h"
  7. #include <omp.h>
  8. #include <iostream>
  9. #define BOUND(x,min,max) (((x)<(min))?(min):((x)>(max)?(max):(x)))
  10. #define LOCALFEATS
  11. using namespace OBJREC;
  12. using namespace std;
  13. using namespace NICE;
  14. class Minus:public Operation
  15. {
  16. public:
  17. virtual double getVal(const NICE::MultiChannelImageT<double> &feats, const std::vector<std::vector<int> > &cfeats, const std::vector<TreeNode> &tree, const int &x, const int &y)
  18. {
  19. int xsize = feats.width();
  20. int ysize = feats.height();
  21. double v1 = feats.get(BOUND(x+x1,0,xsize-1),BOUND(y+y1,0,ysize-1),channel1);
  22. double v2 = feats.get(BOUND(x+x2,0,xsize-1),BOUND(y+y2,0,ysize-1),channel2);
  23. return v1-v2;
  24. }
  25. virtual Operation* clone()
  26. {
  27. return new Minus();
  28. };
  29. virtual void writeInfos()
  30. {
  31. cout << "Minus: " << endl;
  32. }
  33. };
  34. class MinusAbs:public Operation
  35. {
  36. public:
  37. virtual double getVal(const NICE::MultiChannelImageT<double> &feats, const std::vector<std::vector<int> > &cfeats, const std::vector<TreeNode> &tree, const int &x, const int &y)
  38. {
  39. int xsize = feats.width();
  40. int ysize = feats.height();
  41. double v1 = feats.get(BOUND(x+x1,0,xsize-1),BOUND(y+y1,0,ysize-1),channel1);
  42. double v2 = feats.get(BOUND(x+x2,0,xsize-1),BOUND(y+y2,0,ysize-1),channel2);
  43. return abs(v1-v2);
  44. }
  45. virtual Operation* clone()
  46. {
  47. return new MinusAbs();
  48. };
  49. virtual void writeInfos()
  50. {
  51. cout << "MinusAbs: " << endl;
  52. }
  53. };
  54. class Addition:public Operation
  55. {
  56. public:
  57. virtual double getVal(const NICE::MultiChannelImageT<double> &feats, const std::vector<std::vector<int> > &cfeats, const std::vector<TreeNode> &tree, const int &x, const int &y)
  58. {
  59. int xsize = feats.width();
  60. int ysize = feats.height();
  61. double v1 = feats.get(BOUND(x+x1,0,xsize-1),BOUND(y+y1,0,ysize-1),channel1);
  62. double v2 = feats.get(BOUND(x+x2,0,xsize-1),BOUND(y+y2,0,ysize-1),channel2);
  63. return v1+v2;
  64. }
  65. virtual Operation* clone()
  66. {
  67. return new Addition();
  68. };
  69. virtual void writeInfos()
  70. {
  71. cout << "Addition: " << endl;
  72. }
  73. };
  74. class Only1:public Operation
  75. {
  76. public:
  77. virtual double getVal(const NICE::MultiChannelImageT<double> &feats, const std::vector<std::vector<int> > &cfeats, const std::vector<TreeNode> &tree, const int &x, const int &y)
  78. {
  79. int xsize = feats.width();
  80. int ysize = feats.height();
  81. double v1 = feats.get(BOUND(x+x1,0,xsize-1),BOUND(y+y1,0,ysize-1),channel1);
  82. return v1;
  83. }
  84. virtual Operation* clone()
  85. {
  86. return new Only1();
  87. };
  88. virtual void writeInfos()
  89. {
  90. cout << "Only1: " << endl;
  91. }
  92. };
  93. class ContextMinus:public Operation
  94. {
  95. public:
  96. virtual double getVal(const NICE::MultiChannelImageT<double> &feats, const std::vector<std::vector<int> > &cfeats, const std::vector<TreeNode> &tree, const int &x, const int &y)
  97. {
  98. int xsize = feats.width();
  99. int ysize = feats.height();
  100. double v1 = tree[cfeats[BOUND(x+x1,0,xsize-1)][BOUND(y+y1,0,ysize-1)]].dist[channel1];
  101. double v2 = tree[cfeats[BOUND(x+x2,0,xsize-1)][BOUND(y+y2,0,ysize-1)]].dist[channel2];
  102. return v1-v2;
  103. }
  104. virtual Operation* clone()
  105. {
  106. return new ContextMinus();
  107. };
  108. virtual void writeInfos()
  109. {
  110. cout << "ContextMinus: " << endl;
  111. }
  112. };
  113. class ContextMinusAbs:public Operation
  114. {
  115. public:
  116. virtual double getVal(const NICE::MultiChannelImageT<double> &feats, const std::vector<std::vector<int> > &cfeats, const std::vector<TreeNode> &tree, const int &x, const int &y)
  117. {
  118. int xsize = feats.width();
  119. int ysize = feats.height();
  120. double v1 = tree[cfeats[BOUND(x+x1,0,xsize-1)][BOUND(y+y1,0,ysize-1)]].dist[channel1];
  121. double v2 = tree[cfeats[BOUND(x+x2,0,xsize-1)][BOUND(y+y2,0,ysize-1)]].dist[channel2];
  122. return abs(v1-v2);
  123. }
  124. virtual Operation* clone()
  125. {
  126. return new ContextMinusAbs();
  127. };
  128. virtual void writeInfos()
  129. {
  130. cout << "ContextMinusAbs: " << endl;
  131. }
  132. };
  133. class ContextAddition:public Operation
  134. {
  135. public:
  136. virtual double getVal(const NICE::MultiChannelImageT<double> &feats, const std::vector<std::vector<int> > &cfeats, const std::vector<TreeNode> &tree, const int &x, const int &y)
  137. {
  138. int xsize = feats.width();
  139. int ysize = feats.height();
  140. double v1 = tree[cfeats[BOUND(x+x1,0,xsize-1)][BOUND(y+y1,0,ysize-1)]].dist[channel1];
  141. double v2 = tree[cfeats[BOUND(x+x2,0,xsize-1)][BOUND(y+y2,0,ysize-1)]].dist[channel2];
  142. return v1+v2;
  143. }
  144. virtual Operation* clone()
  145. {
  146. return new ContextAddition();
  147. };
  148. virtual void writeInfos()
  149. {
  150. cout << "ContextAddition: " << endl;
  151. }
  152. };
  153. class ContextOnly1:public Operation
  154. {
  155. public:
  156. virtual double getVal(const NICE::MultiChannelImageT<double> &feats, const std::vector<std::vector<int> > &cfeats, const std::vector<TreeNode> &tree, const int &x, const int &y)
  157. {
  158. int xsize = feats.width();
  159. int ysize = feats.height();
  160. double v1 = tree[cfeats[BOUND(x+x1,0,xsize-1)][BOUND(y+y1,0,ysize-1)]].dist[channel1];
  161. return v1;
  162. }
  163. virtual Operation* clone()
  164. {
  165. return new ContextOnly1();
  166. };
  167. virtual void writeInfos()
  168. {
  169. cout << "ContextOnly1: " << endl;
  170. }
  171. };
  172. SemSegContextTree::SemSegContextTree( const Config *conf, const MultiDataset *md )
  173. : SemanticSegmentation ( conf, &(md->getClassNames("train")) )
  174. {
  175. this->conf = conf;
  176. string section = "SSContextTree";
  177. lfcw = new LFColorWeijer(conf);
  178. grid = conf->gI(section, "grid", 10 );
  179. maxSamples = conf->gI(section, "max_samples", 2000);
  180. minFeats = conf->gI(section, "min_feats", 50 );
  181. maxDepth = conf->gI(section, "max_depth", 10 );
  182. windowSize = conf->gI(section, "window_size", 16);
  183. featsPerSplit = conf->gI(section, "feats_per_split", 200);
  184. useShannonEntropy = conf->gB(section, "use_shannon_entropy", true);
  185. ftypes = conf->gI(section, "features", 2);;
  186. ops.push_back(new Minus());
  187. ops.push_back(new MinusAbs());
  188. ops.push_back(new Addition());
  189. ops.push_back(new Only1());
  190. cops.push_back(new ContextMinus());
  191. cops.push_back(new ContextMinusAbs());
  192. cops.push_back(new ContextAddition());
  193. cops.push_back(new ContextOnly1());
  194. classnames = md->getClassNames ( "train" );
  195. ///////////////////////////////////
  196. // Train Segmentation Context Trees
  197. ///////////////////////////////////
  198. train ( md );
  199. }
  200. SemSegContextTree::~SemSegContextTree()
  201. {
  202. }
  203. void SemSegContextTree::getBestSplit(const vector<MultiChannelImageT<double> > &feats, vector<vector<vector<int> > > &currentfeats,const vector<vector<vector<int> > > &labels, int node, Operation *&splitop, double &splitval)
  204. {
  205. int imgCount, featdim;
  206. try
  207. {
  208. imgCount = (int)feats.size();
  209. featdim = feats[0].channels();
  210. }
  211. catch(Exception)
  212. {
  213. cerr << "no features computed?" << endl;
  214. }
  215. double bestig = -numeric_limits< double >::max();
  216. splitop = NULL;
  217. splitval = -1.0;
  218. set<vector<int> >selFeats;
  219. map<int,int> e;
  220. int featcounter = 0;
  221. for(int iCounter = 0; iCounter < imgCount; iCounter++)
  222. {
  223. int xsize = (int)currentfeats[iCounter].size();
  224. int ysize = (int)currentfeats[iCounter][0].size();
  225. for(int x = 0; x < xsize; x++)
  226. {
  227. for(int y = 0; y < ysize; y++)
  228. {
  229. if(currentfeats[iCounter][x][y] == node)
  230. {
  231. featcounter++;
  232. }
  233. }
  234. }
  235. }
  236. if(featcounter < minFeats)
  237. {
  238. cout << "only " << featcounter << " feats in current node -> it's a leaf" << endl;
  239. return;
  240. }
  241. vector<double> fraction(a.size(),0.0);
  242. for(uint i = 0; i < fraction.size(); i++)
  243. {
  244. if ( forbidden_classes.find ( labelmapback[i] ) != forbidden_classes.end() )
  245. fraction[i] = 0;
  246. else
  247. fraction[i] = ((double)maxSamples)/((double)featcounter*a[i]*a.size());
  248. //cout << "fraction["<<i<<"]: "<< fraction[i] << " a[" << i << "]: " << a[i] << endl;
  249. }
  250. //cout << "a.size(): " << a.size() << endl;
  251. //getchar();
  252. featcounter = 0;
  253. for(int iCounter = 0; iCounter < imgCount; iCounter++)
  254. {
  255. int xsize = (int)currentfeats[iCounter].size();
  256. int ysize = (int)currentfeats[iCounter][0].size();
  257. for(int x = 0; x < xsize; x++)
  258. {
  259. for(int y = 0; y < ysize; y++)
  260. {
  261. if(currentfeats[iCounter][x][y] == node)
  262. {
  263. int cn = labels[iCounter][x][y];
  264. double randD = (double)rand()/(double)RAND_MAX;
  265. if(randD < fraction[labelmap[cn]])
  266. {
  267. vector<int> tmp(3,0);
  268. tmp[0] = iCounter;
  269. tmp[1] = x;
  270. tmp[2] = y;
  271. featcounter++;
  272. selFeats.insert(tmp);
  273. e[cn]++;
  274. }
  275. }
  276. }
  277. }
  278. }
  279. //cout << "size: " << selFeats.size() << endl;
  280. //getchar();
  281. map<int,int>::iterator mapit;
  282. double globent = 0.0;
  283. for ( mapit=e.begin() ; mapit != e.end(); mapit++ )
  284. {
  285. //cout << "class: " << mapit->first << ": " << mapit->second << endl;
  286. double p = (double)(*mapit).second/(double)featcounter;
  287. globent += p*log2(p);
  288. }
  289. globent = -globent;
  290. if(globent < 0.5)
  291. {
  292. cout << "globent to small: " << globent << endl;
  293. return;
  294. }
  295. int classes = (int)labelmap.size();
  296. featsel.clear();
  297. for(int i = 0; i < featsPerSplit; i++)
  298. {
  299. int x1 = (int)((double)rand()/(double)RAND_MAX*(double)windowSize)-windowSize/2;
  300. int x2 = (int)((double)rand()/(double)RAND_MAX*(double)windowSize)-windowSize/2;
  301. int y1 = (int)((double)rand()/(double)RAND_MAX*(double)windowSize)-windowSize/2;
  302. int y2 = (int)((double)rand()/(double)RAND_MAX*(double)windowSize)-windowSize/2;
  303. int ft = (int)((double)rand()/(double)RAND_MAX*(double)ftypes);
  304. if(ft == 0)
  305. {
  306. int f1 = (int)((double)rand()/(double)RAND_MAX*(double)featdim);
  307. int f2 = (int)((double)rand()/(double)RAND_MAX*(double)featdim);
  308. int o = (int)((double)rand()/(double)RAND_MAX*(double)ops.size());
  309. Operation *op = ops[o]->clone();
  310. op->set(x1,y1,x2,y2,f1,f2);
  311. featsel.push_back(op);
  312. }
  313. else if(ft == 1)
  314. {
  315. int f1 = (int)((double)rand()/(double)RAND_MAX*(double)classes);
  316. int f2 = (int)((double)rand()/(double)RAND_MAX*(double)classes);
  317. int o = (int)((double)rand()/(double)RAND_MAX*(double)cops.size());
  318. Operation *op = cops[o]->clone();
  319. op->set(x1,y1,x2,y2,f1,f2);
  320. featsel.push_back(op);
  321. }
  322. }
  323. #pragma omp parallel for private(mapit)
  324. for(int f = 0; f < featsPerSplit; f++)
  325. {
  326. double l_bestig = -numeric_limits< double >::max();
  327. double l_splitval = -1.0;
  328. set<vector<int> >::iterator it;
  329. vector<double> vals;
  330. for ( it=selFeats.begin() ; it != selFeats.end(); it++ )
  331. {
  332. vals.push_back(featsel[f]->getVal(feats[(*it)[0]],currentfeats[(*it)[0]],tree,(*it)[1], (*it)[2]));
  333. }
  334. int counter = 0;
  335. for ( it=selFeats.begin() ; it != selFeats.end(); it++ , counter++)
  336. {
  337. set<vector<int> >::iterator it2;
  338. double val = vals[counter];
  339. map<int,int> eL, eR;
  340. int counterL = 0, counterR = 0;
  341. int counter2 = 0;
  342. for ( it2=selFeats.begin() ; it2 != selFeats.end(); it2++, counter2++ )
  343. {
  344. int cn = labels[(*it2)[0]][(*it2)[1]][(*it2)[2]];
  345. //cout << "vals[counter2] " << vals[counter2] << " val: " << val << endl;
  346. if(vals[counter2] < val)
  347. {
  348. //left entropie:
  349. eL[cn] = eL[cn]+1;
  350. counterL++;
  351. }
  352. else
  353. {
  354. //right entropie:
  355. eR[cn] = eR[cn]+1;
  356. counterR++;
  357. }
  358. }
  359. double leftent = 0.0;
  360. for ( mapit=eL.begin() ; mapit != eL.end(); mapit++ )
  361. {
  362. double p = (double)(*mapit).second/(double)counterL;
  363. leftent -= p*log2(p);
  364. }
  365. double rightent = 0.0;
  366. for ( mapit=eR.begin() ; mapit != eR.end(); mapit++ )
  367. {
  368. double p = (double)(*mapit).second/(double)counterR;
  369. rightent -= p*log2(p);
  370. }
  371. //cout << "rightent: " << rightent << " leftent: " << leftent << endl;
  372. double pl = (double)counterL/(double)(counterL+counterR);
  373. double ig = globent - (1.0-pl) * rightent - pl*leftent;
  374. //double ig = globent - rightent - leftent;
  375. if(useShannonEntropy)
  376. {
  377. double esplit = - ( pl*log(pl) + (1-pl)*log(1-pl) );
  378. ig = 2*ig / ( globent + esplit );
  379. }
  380. if(ig > l_bestig)
  381. {
  382. l_bestig = ig;
  383. l_splitval = val;
  384. }
  385. }
  386. #pragma omp critical
  387. {
  388. //cout << "globent: " << globent << " bestig " << bestig << " splitfeat: " << splitfeat << " splitval: " << splitval << endl;
  389. //cout << "globent: " << globent << " l_bestig " << l_bestig << " f: " << p << " l_splitval: " << l_splitval << endl;
  390. //cout << "p: " << featsubset[f] << endl;
  391. if(l_bestig > bestig)
  392. {
  393. bestig = l_bestig;
  394. splitop = featsel[f];
  395. splitval = l_splitval;
  396. }
  397. }
  398. }
  399. splitop->writeInfos();
  400. cout<< "ig: " << bestig << endl;
  401. /*for(int i = 0; i < featsPerSplit; i++)
  402. {
  403. if(featsel[i] != splitop)
  404. delete featsel[i];
  405. }*/
  406. #ifdef debug
  407. cout << "globent: " << globent << " bestig " << bestig << " splitval: " << splitval << endl;
  408. #endif
  409. }
  410. void SemSegContextTree::train ( const MultiDataset *md )
  411. {
  412. const LabeledSet train = * ( *md ) ["train"];
  413. const LabeledSet *trainp = &train;
  414. ProgressBar pb ( "compute feats" );
  415. pb.show();
  416. //TODO: Speichefresser!, lohnt sich sparse?
  417. vector<MultiChannelImageT<double> > allfeats;
  418. vector<vector<vector<int> > > currentfeats;
  419. vector<vector<vector<int> > > labels;
  420. std::string forbidden_classes_s = conf->gS ( "analysis", "donttrain", "" );
  421. if ( forbidden_classes_s == "" )
  422. {
  423. forbidden_classes_s = conf->gS ( "analysis", "forbidden_classes", "" );
  424. }
  425. classnames.getSelection ( forbidden_classes_s, forbidden_classes );
  426. int imgcounter = 0;
  427. LOOP_ALL_S ( *trainp )
  428. {
  429. EACH_INFO ( classno,info );
  430. NICE::ColorImage img;
  431. std::string currentFile = info.img();
  432. CachedExample *ce = new CachedExample ( currentFile );
  433. const LocalizationResult *locResult = info.localization();
  434. if ( locResult->size() <= 0 )
  435. {
  436. fprintf ( stderr, "WARNING: NO ground truth polygons found for %s !\n",
  437. currentFile.c_str() );
  438. continue;
  439. }
  440. fprintf ( stderr, "SemSegCsurka: Collecting pixel examples from localization info: %s\n", currentFile.c_str() );
  441. int xsize, ysize;
  442. ce->getImageSize ( xsize, ysize );
  443. vector<vector<int> > tmp = vector<vector<int> >(xsize, vector<int>(ysize,0));
  444. currentfeats.push_back(tmp);
  445. labels.push_back(tmp);
  446. try {
  447. img = ColorImage(currentFile);
  448. } catch (Exception) {
  449. cerr << "SemSeg: error opening image file <" << currentFile << ">" << endl;
  450. continue;
  451. }
  452. Globals::setCurrentImgFN ( currentFile );
  453. //TODO: resize image?!
  454. MultiChannelImageT<double> feats;
  455. allfeats.push_back(feats);
  456. #ifdef LOCALFEATS
  457. lfcw->getFeats(img, allfeats[imgcounter]);
  458. #else
  459. allfeats[imgcounter].reInit(xsize, ysize, 3, true);
  460. for(int x = 0; x < xsize; x++)
  461. {
  462. for(int y = 0; y < ysize; y++)
  463. {
  464. for(int r = 0; r < 3; r++)
  465. {
  466. allfeats[imgcounter].set(x,y,img.getPixel(x,y,r),r);
  467. }
  468. }
  469. }
  470. #endif
  471. // getting groundtruth
  472. NICE::Image pixelLabels (xsize, ysize);
  473. pixelLabels.set(0);
  474. locResult->calcLabeledImage ( pixelLabels, ( *classNames ).getBackgroundClass() );
  475. for(int x = 0; x < xsize; x++)
  476. {
  477. for(int y = 0; y < ysize; y++)
  478. {
  479. classno = pixelLabels.getPixel(x, y);
  480. labels[imgcounter][x][y] = classno;
  481. if ( forbidden_classes.find ( classno ) != forbidden_classes.end() )
  482. continue;
  483. labelcounter[classno]++;
  484. }
  485. }
  486. imgcounter++;
  487. pb.update ( trainp->count());
  488. delete ce;
  489. }
  490. pb.hide();
  491. map<int,int>::iterator mapit;
  492. int classes = 0;
  493. for(mapit = labelcounter.begin(); mapit != labelcounter.end(); mapit++)
  494. {
  495. labelmap[mapit->first] = classes;
  496. labelmapback[classes] = mapit->first;
  497. classes++;
  498. }
  499. //balancing
  500. int featcounter = 0;
  501. a = vector<double>(classes,0.0);
  502. for(int iCounter = 0; iCounter < imgcounter; iCounter++)
  503. {
  504. int xsize = (int)currentfeats[iCounter].size();
  505. int ysize = (int)currentfeats[iCounter][0].size();
  506. for(int x = 0; x < xsize; x++)
  507. {
  508. for(int y = 0; y < ysize; y++)
  509. {
  510. featcounter++;
  511. int cn = labels[iCounter][x][y];
  512. a[labelmap[cn]] ++;
  513. }
  514. }
  515. }
  516. for(int i = 0; i < (int)a.size(); i++)
  517. {
  518. a[i] /= (double)featcounter;
  519. }
  520. #ifdef DEBUG
  521. for(int i = 0; i < (int)a.size(); i++)
  522. {
  523. cout << "a["<<i<<"]: " << a[i] << endl;
  524. }
  525. cout << "a.size: " << a.size() << endl;
  526. #endif
  527. tree.push_back(TreeNode());
  528. tree[0].dist = vector<double>(classes,0.0);
  529. int depth = 0;
  530. tree[0].depth = depth;
  531. int startnode = 0;
  532. bool allleaf = false;
  533. //int baseFeatSize = allfeats[0].size();
  534. while(!allleaf && depth < maxDepth)
  535. {
  536. allleaf = true;
  537. //TODO vielleicht parallel wenn nächste schleife trotzdem noch parallelsiert würde, die hat mehr gewicht
  538. int t = (int) tree.size();
  539. int s = startnode;
  540. startnode = t;
  541. //#pragma omp parallel for
  542. for(int i = s; i < t; i++)
  543. {
  544. if(!tree[i].isleaf && tree[i].left < 0)
  545. {
  546. Operation *splitfeat = NULL;
  547. double splitval;
  548. getBestSplit(allfeats, currentfeats,labels, i, splitfeat, splitval);
  549. tree[i].feat = splitfeat;
  550. tree[i].decision = splitval;
  551. if(splitfeat != NULL)
  552. {
  553. allleaf = false;
  554. int left = tree.size();
  555. tree.push_back(TreeNode());
  556. tree.push_back(TreeNode());
  557. int right = left+1;
  558. tree[i].left = left;
  559. tree[i].right = right;
  560. tree[left].dist = vector<double>(classes, 0.0);
  561. tree[right].dist = vector<double>(classes, 0.0);
  562. tree[left].depth = depth+1;
  563. tree[right].depth = depth+1;
  564. #pragma omp parallel for
  565. for(int iCounter = 0; iCounter < imgcounter; iCounter++)
  566. {
  567. int xsize = currentfeats[iCounter].size();
  568. int ysize = currentfeats[iCounter][0].size();
  569. for(int x = 0; x < xsize; x++)
  570. {
  571. for(int y = 0; y < ysize; y++)
  572. {
  573. if(currentfeats[iCounter][x][y] == i)
  574. {
  575. double val = splitfeat->getVal(allfeats[iCounter],currentfeats[iCounter],tree,x,y);
  576. if(val < splitval)
  577. {
  578. currentfeats[iCounter][x][y] = left;
  579. tree[left].dist[labelmap[labels[iCounter][x][y]]]++;
  580. }
  581. else
  582. {
  583. currentfeats[iCounter][x][y] = right;
  584. tree[right].dist[labelmap[labels[iCounter][x][y]]]++;
  585. }
  586. }
  587. }
  588. }
  589. }
  590. double lcounter = 0.0, rcounter = 0.0;
  591. for(uint d = 0; d < tree[left].dist.size(); d++)
  592. {
  593. if ( forbidden_classes.find ( labelmapback[d] ) != forbidden_classes.end() )
  594. {
  595. tree[left].dist[d] = 0;
  596. tree[right].dist[d] = 0;
  597. }
  598. else
  599. {
  600. tree[left].dist[d]/=a[d];
  601. lcounter +=tree[left].dist[d];
  602. tree[right].dist[d]/=a[d];
  603. rcounter +=tree[right].dist[d];
  604. }
  605. }
  606. if(lcounter <= 0 || rcounter <= 0)
  607. {
  608. cout << "lcounter : " << lcounter << " rcounter: " << rcounter << endl;
  609. cout << "splitval: " << splitval << endl;
  610. assert(lcounter > 0 && rcounter > 0);
  611. }
  612. for(uint d = 0; d < tree[left].dist.size(); d++)
  613. {
  614. tree[left].dist[d]/=lcounter;
  615. tree[right].dist[d]/=rcounter;
  616. }
  617. }
  618. else
  619. {
  620. tree[i].isleaf = true;
  621. }
  622. }
  623. }
  624. //TODO: features neu berechnen!
  625. depth++;
  626. #ifdef DEBUG
  627. cout << "depth: " << depth << endl;
  628. #endif
  629. }
  630. #ifdef DEBUG
  631. int t = (int) tree.size();
  632. for(int i = 0; i < t; i++)
  633. {
  634. printf("tree[%i]: left: %i, right: %i ", i, tree[i].left, tree[i].right);
  635. for(int d = 0; d < (int)tree[i].dist.size(); d++)
  636. {
  637. cout << " " << tree[i].dist[d];
  638. }
  639. cout << endl;
  640. }
  641. #endif
  642. }
  643. void SemSegContextTree::semanticseg ( CachedExample *ce, NICE::Image & segresult,NICE::MultiChannelImageT<double> & probabilities )
  644. {
  645. int xsize;
  646. int ysize;
  647. ce->getImageSize ( xsize, ysize );
  648. int numClasses = classNames->numClasses();
  649. fprintf (stderr, "ContextTree classification !\n");
  650. probabilities.reInit ( xsize, ysize, numClasses, true );
  651. probabilities.setAll ( 0 );
  652. NICE::ColorImage img;
  653. std::string currentFile = Globals::getCurrentImgFN();
  654. try {
  655. img = ColorImage(currentFile);
  656. } catch (Exception) {
  657. cerr << "SemSeg: error opening image file <" << currentFile << ">" << endl;
  658. return;
  659. }
  660. //TODO: resize image?!
  661. MultiChannelImageT<double> feats;
  662. #ifdef LOCALFEATS
  663. lfcw->getFeats(img, feats);
  664. #else
  665. feats.reInit (xsize, ysize, 3, true);
  666. for(int x = 0; x < xsize; x++)
  667. {
  668. for(int y = 0; y < ysize; y++)
  669. {
  670. for(int r = 0; r < 3; r++)
  671. {
  672. feats.set(x,y,img.getPixel(x,y,r),r);
  673. }
  674. }
  675. }
  676. #endif
  677. bool allleaf = false;
  678. vector<vector<int> > currentfeats = vector<vector<int> >(xsize, vector<int>(ysize,0));
  679. int depth = 0;
  680. while(!allleaf)
  681. {
  682. allleaf = true;
  683. //TODO vielleicht parallel wenn nächste schleife auch noch parallelsiert würde, die hat mehr gewicht
  684. //#pragma omp parallel for
  685. int t = (int) tree.size();
  686. for(int i = 0; i < t; i++)
  687. {
  688. for(int x = 0; x < xsize; x++)
  689. {
  690. for(int y = 0; y < ysize; y++)
  691. {
  692. int t = currentfeats[x][y];
  693. if(tree[t].left > 0)
  694. {
  695. allleaf = false;
  696. double val = tree[t].feat->getVal(feats,currentfeats,tree,x,y);
  697. if(val < tree[t].decision)
  698. {
  699. currentfeats[x][y] = tree[t].left;
  700. }
  701. else
  702. {
  703. currentfeats[x][y] = tree[t].right;
  704. }
  705. }
  706. }
  707. }
  708. }
  709. //TODO: features neu berechnen! analog zum training
  710. depth++;
  711. }
  712. //finales labeln:
  713. long int offset = 0;
  714. for(int x = 0; x < xsize; x++)
  715. {
  716. for(int y = 0; y < ysize; y++,offset++)
  717. {
  718. int t = currentfeats[x][y];
  719. double maxvalue = - numeric_limits<double>::max(); //TODO: das muss nur pro knoten gemacht werden, nicht pro pixel
  720. int maxindex = 0;
  721. for(uint i = 0; i < tree[i].dist.size(); i++)
  722. {
  723. probabilities.data[labelmapback[i]][offset] = tree[t].dist[i];
  724. if(tree[t].dist[i] > maxvalue)
  725. {
  726. maxvalue = tree[t].dist[i];
  727. maxindex = labelmapback[i];
  728. }
  729. segresult.setPixel(x,y,maxindex);
  730. }
  731. }
  732. }
  733. }