SemSegContextTree.cpp 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521
  1. #include "SemSegContextTree.h"
  2. #include "vislearning/baselib/Globals.h"
  3. #include "vislearning/baselib/ProgressBar.h"
  4. #include "core/basics/StringTools.h"
  5. #include "vislearning/cbaselib/CachedExample.h"
  6. #include "vislearning/cbaselib/PascalResults.h"
  7. #include "objrec/segmentation/RSMeanShift.h"
  8. #include "objrec/segmentation/RSGraphBased.h"
  9. #include "core/basics/numerictools.h"
  10. #include "core/basics/Timer.h"
  11. #include <omp.h>
  12. #include <iostream>
  13. #define BOUND(x,min,max) (((x)<(min))?(min):((x)>(max)?(max):(x)))
  14. #undef LOCALFEATS
  15. //#define LOCALFEATS
  16. using namespace OBJREC;
  17. using namespace std;
  18. using namespace NICE;
  19. class MCImageAccess:public ValueAccess
  20. {
  21. public:
  22. virtual double getVal(const Features &feats, const int &x, const int &y, const int &channel)
  23. {
  24. return feats.feats->get(x,y,channel);
  25. }
  26. virtual string writeInfos()
  27. {
  28. return "raw";
  29. }
  30. };
  31. class ClassificationResultAcess:public ValueAccess
  32. {
  33. public:
  34. virtual double getVal(const Features &feats, const int &x, const int &y, const int &channel)
  35. {
  36. return (*feats.tree)[feats.cfeats->get(x,y,feats.cTree)].dist[channel];
  37. }
  38. virtual string writeInfos()
  39. {
  40. return "context";
  41. }
  42. };
  43. class Minus:public Operation
  44. {
  45. public:
  46. virtual double getVal(const Features &feats, const int &x, const int &y)
  47. {
  48. int xsize, ysize;
  49. getXY(feats, xsize, ysize);
  50. double v1 = values->getVal(feats, BOUND(x+x1,0,xsize-1),BOUND(y+y1,0,ysize-1),channel1);
  51. double v2 = values->getVal(feats, BOUND(x+x2,0,xsize-1),BOUND(y+y2,0,ysize-1),channel2);
  52. return v1-v2;
  53. }
  54. virtual Operation* clone()
  55. {
  56. return new Minus();
  57. }
  58. virtual string writeInfos()
  59. {
  60. string out = "Minus";
  61. if(values !=NULL)
  62. out+=values->writeInfos();
  63. return out;
  64. }
  65. virtual OperationTypes getOps()
  66. {
  67. return MINUS;
  68. }
  69. };
  70. class MinusAbs:public Operation
  71. {
  72. public:
  73. virtual double getVal(const Features &feats, const int &x, const int &y)
  74. {
  75. int xsize, ysize;
  76. getXY(feats, xsize, ysize);
  77. double v1 = values->getVal(feats, BOUND(x+x1,0,xsize-1),BOUND(y+y1,0,ysize-1),channel1);
  78. double v2 = values->getVal(feats, BOUND(x+x2,0,xsize-1),BOUND(y+y2,0,ysize-1),channel2);
  79. return abs(v1-v2);
  80. }
  81. virtual Operation* clone()
  82. {
  83. return new MinusAbs();
  84. };
  85. virtual string writeInfos()
  86. {
  87. string out = "MinusAbs";
  88. if(values !=NULL)
  89. out+=values->writeInfos();
  90. return out;
  91. }
  92. virtual OperationTypes getOps()
  93. {
  94. return MINUSABS;
  95. }
  96. };
  97. class Addition:public Operation
  98. {
  99. public:
  100. virtual double getVal(const Features &feats, const int &x, const int &y)
  101. {
  102. int xsize, ysize;
  103. getXY(feats, xsize, ysize);
  104. double v1 = values->getVal(feats, BOUND(x+x1,0,xsize-1),BOUND(y+y1,0,ysize-1),channel1);
  105. double v2 = values->getVal(feats, BOUND(x+x2,0,xsize-1),BOUND(y+y2,0,ysize-1),channel2);
  106. return v1+v2;
  107. }
  108. virtual Operation* clone()
  109. {
  110. return new Addition();
  111. }
  112. virtual string writeInfos()
  113. {
  114. string out = "Addition";
  115. if(values !=NULL)
  116. out+=values->writeInfos();
  117. return out;
  118. }
  119. virtual OperationTypes getOps()
  120. {
  121. return ADDITION;
  122. }
  123. };
  124. class Only1:public Operation
  125. {
  126. public:
  127. virtual double getVal(const Features &feats, const int &x, const int &y)
  128. {
  129. int xsize, ysize;
  130. getXY(feats, xsize, ysize);
  131. double v1 = values->getVal(feats, BOUND(x+x1,0,xsize-1),BOUND(y+y1,0,ysize-1),channel1);
  132. return v1;
  133. }
  134. virtual Operation* clone()
  135. {
  136. return new Only1();
  137. }
  138. virtual string writeInfos()
  139. {
  140. string out = "Only1";
  141. if(values !=NULL)
  142. out+=values->writeInfos();
  143. return out;
  144. }
  145. virtual OperationTypes getOps()
  146. {
  147. return ONLY1;
  148. }
  149. };
  150. // uses mean of classification in window given by (x1,y1) (x2,y2)
  151. class IntegralOps:public Operation
  152. {
  153. public:
  154. virtual void set(int _x1, int _y1, int _x2, int _y2, int _channel1, int _channel2, ValueAccess *_values)
  155. {
  156. x1 = min(_x1,_x2);
  157. y1 = min(_y1,_y2);
  158. x2 = max(_x1,_x2);
  159. y2 = max(_y1,_y2);
  160. channel1 = _channel1;
  161. channel2 = _channel2;
  162. values = _values;
  163. }
  164. virtual double getVal(const Features &feats, const int &x, const int &y)
  165. {
  166. int xsize, ysize;
  167. getXY(feats, xsize, ysize);
  168. return computeMean(*feats.integralImg,BOUND(x+x1,0,xsize-1),BOUND(y+y1,0,ysize-1),BOUND(x+x2,0,xsize-1),BOUND(y+y2,0,ysize-1),channel1);
  169. }
  170. inline double computeMean(const NICE::MultiChannelImageT<double> &intImg, const int &uLx, const int &uLy, const int &lRx, const int &lRy, const int &chan)
  171. {
  172. double val1 = intImg.get(uLx,uLy, chan);
  173. double val2 = intImg.get(lRx,uLy, chan);
  174. double val3 = intImg.get(uLx,lRy, chan);
  175. double val4 = intImg.get(lRx,lRy, chan);
  176. double area = (lRx-uLx)*(lRy-uLy);
  177. if(area == 0)
  178. return 0.0;
  179. return (val1+val4-val2-val3)/area;
  180. }
  181. virtual Operation* clone()
  182. {
  183. return new IntegralOps();
  184. }
  185. virtual string writeInfos()
  186. {
  187. return "IntegralOps";
  188. }
  189. virtual OperationTypes getOps()
  190. {
  191. return INTEGRAL;
  192. }
  193. };
  194. //uses mean of Integral image given by x1, y1 with current pixel as center
  195. class IntegralCenteredOps:public IntegralOps
  196. {
  197. public:
  198. virtual void set(int _x1, int _y1, int _x2, int _y2, int _channel1, int _channel2, ValueAccess *_values)
  199. {
  200. x1 = abs(_x1);
  201. y1 = abs(_y1);
  202. x2 = abs(_x2);
  203. y2 = abs(_y2);
  204. channel1 = _channel1;
  205. channel2 = _channel2;
  206. values = _values;
  207. }
  208. virtual double getVal(const Features &feats, const int &x, const int &y)
  209. {
  210. int xsize, ysize;
  211. getXY(feats, xsize, ysize);
  212. return computeMean(*feats.integralImg,BOUND(x-x1,0,xsize-1),BOUND(y-y1,0,ysize-1),BOUND(x+x1,0,xsize-1),BOUND(y+y1,0,ysize-1),channel1);
  213. }
  214. virtual Operation* clone()
  215. {
  216. return new IntegralCenteredOps();
  217. }
  218. virtual string writeInfos()
  219. {
  220. return "IntegralCenteredOps";
  221. }
  222. virtual OperationTypes getOps()
  223. {
  224. return INTEGRALCENT;
  225. }
  226. };
  227. //uses different of mean of Integral image given by two windows, where (x1,y1) is the width and height of window1 and (x2,y2) of window 2
  228. class BiIntegralCenteredOps:public IntegralCenteredOps
  229. {
  230. public:
  231. virtual void set(int _x1, int _y1, int _x2, int _y2, int _channel1, int _channel2, ValueAccess *_values)
  232. {
  233. x1 = min(abs(_x1),abs(_x2));
  234. y1 = min(abs(_y1),abs(_y2));
  235. x2 = max(abs(_x1),abs(_x2));
  236. y2 = max(abs(_y1),abs(_y2));
  237. channel1 = _channel1;
  238. channel2 = _channel2;
  239. values = _values;
  240. }
  241. virtual double getVal(const Features &feats, const int &x, const int &y)
  242. {
  243. int xsize, ysize;
  244. getXY(feats, xsize, ysize);
  245. return computeMean(*feats.integralImg,BOUND(x-x1,0,xsize-1),BOUND(y-y1,0,ysize-1),BOUND(x+x1,0,xsize-1),BOUND(y+y1,0,ysize-1),channel1) - computeMean(*feats.integralImg,BOUND(x-x2,0,xsize-1),BOUND(y-y2,0,ysize-1),BOUND(x+x2,0,xsize-1),BOUND(y+y2,0,ysize-1),channel1);
  246. }
  247. virtual Operation* clone()
  248. {
  249. return new BiIntegralCenteredOps();
  250. }
  251. virtual string writeInfos()
  252. {
  253. return "BiIntegralCenteredOps";
  254. }
  255. virtual OperationTypes getOps()
  256. {
  257. return BIINTEGRALCENT;
  258. }
  259. };
  260. /** horizontal Haar features
  261. * ++
  262. * --
  263. */
  264. class HaarHorizontal:public IntegralCenteredOps
  265. {
  266. virtual double getVal(const Features &feats, const int &x, const int &y)
  267. {
  268. int xsize, ysize;
  269. getXY(feats, xsize, ysize);
  270. int tlx = BOUND(x-x1,0,xsize-1);
  271. int tly = BOUND(y-y1,0,ysize-1);
  272. int lrx = BOUND(x+x1,0,xsize-1);
  273. int lry = BOUND(y+y1,0,ysize-1);
  274. return computeMean(*feats.integralImg,tlx,tly,lrx, y,channel1)-computeMean(*feats.integralImg,tlx,y,lrx, lry,channel1);
  275. }
  276. virtual Operation* clone()
  277. {
  278. return new HaarHorizontal();
  279. }
  280. virtual string writeInfos()
  281. {
  282. return "HaarHorizontal";
  283. }
  284. virtual OperationTypes getOps()
  285. {
  286. return HAARHORIZ;
  287. }
  288. };
  289. /** vertical Haar features
  290. * +-
  291. * +-
  292. */
  293. class HaarVertical:public IntegralCenteredOps
  294. {
  295. virtual double getVal(const Features &feats, const int &x, const int &y)
  296. {
  297. int xsize, ysize;
  298. getXY(feats, xsize, ysize);
  299. int tlx = BOUND(x-x1,0,xsize-1);
  300. int tly = BOUND(y-y1,0,ysize-1);
  301. int lrx = BOUND(x+x1,0,xsize-1);
  302. int lry = BOUND(y+y1,0,ysize-1);
  303. return computeMean(*feats.integralImg,tlx,tly,x, lry,channel1)-computeMean(*feats.integralImg,x,tly,lrx, lry,channel1);
  304. }
  305. virtual Operation* clone()
  306. {
  307. return new HaarVertical();
  308. }
  309. virtual string writeInfos()
  310. {
  311. return "HaarVertical";
  312. }
  313. virtual OperationTypes getOps()
  314. {
  315. return HAARVERT;
  316. }
  317. };
  318. /** vertical Haar features
  319. * +-
  320. * -+
  321. */
  322. class HaarDiag:public IntegralCenteredOps
  323. {
  324. virtual double getVal(const Features &feats, const int &x, const int &y)
  325. {
  326. int xsize, ysize;
  327. getXY(feats, xsize, ysize);
  328. int tlx = BOUND(x-x1,0,xsize-1);
  329. int tly = BOUND(y-y1,0,ysize-1);
  330. int lrx = BOUND(x+x1,0,xsize-1);
  331. int lry = BOUND(y+y1,0,ysize-1);
  332. return computeMean(*feats.integralImg,tlx,tly,x, y,channel1)+computeMean(*feats.integralImg,x,y,lrx, lry,channel1) - computeMean(*feats.integralImg,tlx,y,x, lry,channel1)-computeMean(*feats.integralImg,x,tly,lrx, y,channel1);
  333. }
  334. virtual Operation* clone()
  335. {
  336. return new HaarDiag();
  337. }
  338. virtual string writeInfos()
  339. {
  340. return "HaarDiag";
  341. }
  342. virtual OperationTypes getOps()
  343. {
  344. return HAARDIAG;
  345. }
  346. };
  347. /** horizontal Haar features
  348. * +++
  349. * ---
  350. * +++
  351. */
  352. class Haar3Horiz:public BiIntegralCenteredOps
  353. {
  354. virtual double getVal(const Features &feats, const int &x, const int &y)
  355. {
  356. int xsize, ysize;
  357. getXY(feats, xsize, ysize);
  358. int tlx = BOUND(x-x2,0,xsize-1);
  359. int tly = BOUND(y-y2,0,ysize-1);
  360. int mtly = BOUND(y-y1,0,ysize-1);
  361. int mlry = BOUND(y+y1,0,ysize-1);
  362. int lrx = BOUND(x+x2,0,xsize-1);
  363. int lry = BOUND(y+y2,0,ysize-1);
  364. return computeMean(*feats.integralImg,tlx,tly,lrx, mtly,channel1) -computeMean(*feats.integralImg,tlx,mtly,lrx, mlry,channel1) + computeMean(*feats.integralImg,tlx,mlry,lrx, lry,channel1);
  365. }
  366. virtual Operation* clone()
  367. {
  368. return new Haar3Horiz();
  369. }
  370. virtual string writeInfos()
  371. {
  372. return "Haar3Horiz";
  373. }
  374. virtual OperationTypes getOps()
  375. {
  376. return HAAR3HORIZ;
  377. }
  378. };
  379. /** vertical Haar features
  380. * +-+
  381. * +-+
  382. * +-+
  383. */
  384. class Haar3Vert:public BiIntegralCenteredOps
  385. {
  386. virtual double getVal(const Features &feats, const int &x, const int &y)
  387. {
  388. int xsize, ysize;
  389. getXY(feats, xsize, ysize);
  390. int tlx = BOUND(x-x2,0,xsize-1);
  391. int tly = BOUND(y-y2,0,ysize-1);
  392. int mtlx = BOUND(x-x1,0,xsize-1);
  393. int mlrx = BOUND(x+x1,0,xsize-1);
  394. int lrx = BOUND(x+x2,0,xsize-1);
  395. int lry = BOUND(y+y2,0,ysize-1);
  396. return computeMean(*feats.integralImg,tlx,tly,mtlx, lry,channel1) -computeMean(*feats.integralImg,mtlx,tly,mlrx, lry,channel1) + computeMean(*feats.integralImg,mlrx,tly,lrx, lry,channel1);
  397. }
  398. virtual Operation* clone()
  399. {
  400. return new Haar3Vert();
  401. }
  402. virtual string writeInfos()
  403. {
  404. return "Haar3Vert";
  405. }
  406. virtual OperationTypes getOps()
  407. {
  408. return HAAR3VERT;
  409. }
  410. };
  411. SemSegContextTree::SemSegContextTree( const Config *conf, const MultiDataset *md )
  412. : SemanticSegmentation ( conf, &(md->getClassNames("train")) )
  413. {
  414. this->conf = conf;
  415. string section = "SSContextTree";
  416. lfcw = new LFColorWeijer(conf);
  417. grid = conf->gI(section, "grid", 10 );
  418. maxSamples = conf->gI(section, "max_samples", 2000);
  419. minFeats = conf->gI(section, "min_feats", 50 );
  420. maxDepth = conf->gI(section, "max_depth", 10 );
  421. windowSize = conf->gI(section, "window_size", 16);
  422. featsPerSplit = conf->gI(section, "feats_per_split", 200);
  423. useShannonEntropy = conf->gB(section, "use_shannon_entropy", true);
  424. nbTrees = conf->gI(section, "amount_trees", 1);
  425. string segmentationtype = conf->gS(section, "segmentation_type", "meanshift");
  426. useGaussian = conf->gB(section, "use_gaussian", true);
  427. if(useGaussian)
  428. throw("there something wrong with using gaussian! first fix it!");
  429. pixelWiseLabeling = false;
  430. if(segmentationtype == "meanshift")
  431. segmentation = new RSMeanShift(conf);
  432. else if (segmentationtype == "none")
  433. {
  434. segmentation = NULL;
  435. pixelWiseLabeling = true;
  436. }
  437. else if (segmentationtype == "felzenszwalb")
  438. segmentation = new RSGraphBased(conf);
  439. else
  440. throw("no valid segmenation_type\n please choose between none, meanshift and felzenszwalb\n");
  441. ftypes = conf->gI(section, "features", 2);;
  442. ops.push_back(new Minus());
  443. ops.push_back(new MinusAbs());
  444. ops.push_back(new Addition());
  445. ops.push_back(new Only1());
  446. cops.push_back(new BiIntegralCenteredOps());
  447. cops.push_back(new IntegralCenteredOps());
  448. cops.push_back(new IntegralOps());
  449. cops.push_back(new HaarHorizontal());
  450. cops.push_back(new HaarVertical());
  451. cops.push_back(new HaarDiag());
  452. cops.push_back(new Haar3Horiz());
  453. cops.push_back(new Haar3Vert());
  454. opOverview = vector<int>(NBOPERATIONS, 0);
  455. calcVal.push_back(new MCImageAccess());
  456. calcVal.push_back(new ClassificationResultAcess());
  457. classnames = md->getClassNames ( "train" );
  458. ///////////////////////////////////
  459. // Train Segmentation Context Trees
  460. ///////////////////////////////////
  461. train ( md );
  462. }
  463. SemSegContextTree::~SemSegContextTree()
  464. {
  465. }
  466. double SemSegContextTree::getBestSplit(std::vector<NICE::MultiChannelImageT<double> > &feats, std::vector<NICE::MultiChannelImageT<int> > &currentfeats, std::vector<NICE::MultiChannelImageT<double> > &integralImgs, const std::vector<NICE::MatrixT<int> > &labels, int node, Operation *&splitop, double &splitval, const int &tree)
  467. {
  468. int imgCount = 0, featdim = 0;
  469. try
  470. {
  471. imgCount = (int)feats.size();
  472. featdim = feats[0].channels();
  473. }
  474. catch(Exception)
  475. {
  476. cerr << "no features computed?" << endl;
  477. }
  478. double bestig = -numeric_limits< double >::max();
  479. splitop = NULL;
  480. splitval = -1.0;
  481. set<vector<int> >selFeats;
  482. map<int,int> e;
  483. int featcounter = 0;
  484. for(int iCounter = 0; iCounter < imgCount; iCounter++)
  485. {
  486. int xsize = (int)currentfeats[iCounter].width();
  487. int ysize = (int)currentfeats[iCounter].height();
  488. for(int x = 0; x < xsize; x++)
  489. {
  490. for(int y = 0; y < ysize; y++)
  491. {
  492. if(currentfeats[iCounter].get(x,y,tree) == node)
  493. {
  494. featcounter++;
  495. }
  496. }
  497. }
  498. }
  499. if(featcounter < minFeats)
  500. {
  501. cout << "only " << featcounter << " feats in current node -> it's a leaf" << endl;
  502. return 0.0;
  503. }
  504. vector<double> fraction(a.size(),0.0);
  505. for(uint i = 0; i < fraction.size(); i++)
  506. {
  507. if ( forbidden_classes.find ( labelmapback[i] ) != forbidden_classes.end() )
  508. fraction[i] = 0;
  509. else
  510. fraction[i] = ((double)maxSamples)/((double)featcounter*a[i]*a.size());
  511. //cout << "fraction["<<i<<"]: "<< fraction[i] << " a[" << i << "]: " << a[i] << endl;
  512. }
  513. //cout << "a.size(): " << a.size() << endl;
  514. //getchar();
  515. featcounter = 0;
  516. for(int iCounter = 0; iCounter < imgCount; iCounter++)
  517. {
  518. int xsize = (int)currentfeats[iCounter].width();
  519. int ysize = (int)currentfeats[iCounter].height();
  520. for(int x = 0; x < xsize; x++)
  521. {
  522. for(int y = 0; y < ysize; y++)
  523. {
  524. if(currentfeats[iCounter].get(x,y,tree) == node)
  525. {
  526. int cn = labels[iCounter](x,y);
  527. double randD = (double)rand()/(double)RAND_MAX;
  528. if(randD < fraction[labelmap[cn]])
  529. {
  530. vector<int> tmp(3,0);
  531. tmp[0] = iCounter;
  532. tmp[1] = x;
  533. tmp[2] = y;
  534. featcounter++;
  535. selFeats.insert(tmp);
  536. e[cn]++;
  537. }
  538. }
  539. }
  540. }
  541. }
  542. //cout << "size: " << selFeats.size() << endl;
  543. //getchar();
  544. map<int,int>::iterator mapit;
  545. double globent = 0.0;
  546. for ( mapit=e.begin() ; mapit != e.end(); mapit++ )
  547. {
  548. //cout << "class: " << mapit->first << ": " << mapit->second << endl;
  549. double p = (double)(*mapit).second/(double)featcounter;
  550. globent += p*log2(p);
  551. }
  552. globent = -globent;
  553. if(globent < 0.5)
  554. {
  555. cout << "globent to small: " << globent << endl;
  556. return 0.0;
  557. }
  558. int classes = (int)forest[tree][0].dist.size();
  559. featsel.clear();
  560. for(int i = 0; i < featsPerSplit; i++)
  561. {
  562. int x1, x2, y1, y2;
  563. int ft = (int)((double)rand()/(double)RAND_MAX*(double)ftypes);
  564. int tmpws = windowSize;
  565. if(integralImgs[0].width() == 0)
  566. ft = 0;
  567. if(ft > 0)
  568. {
  569. tmpws *= 4;
  570. }
  571. if(useGaussian)
  572. {
  573. double sigma = (double)tmpws/2.0;
  574. x1 = randGaussDouble(sigma)*(double)tmpws;
  575. x2 = randGaussDouble(sigma)*(double)tmpws;
  576. y1 = randGaussDouble(sigma)*(double)tmpws;
  577. y2 = randGaussDouble(sigma)*(double)tmpws;
  578. }
  579. else
  580. {
  581. x1 = (int)((double)rand()/(double)RAND_MAX*(double)tmpws)-tmpws/2;
  582. x2 = (int)((double)rand()/(double)RAND_MAX*(double)tmpws)-tmpws/2;
  583. y1 = (int)((double)rand()/(double)RAND_MAX*(double)tmpws)-tmpws/2;
  584. y2 = (int)((double)rand()/(double)RAND_MAX*(double)tmpws)-tmpws/2;
  585. }
  586. if(ft == 0)
  587. {
  588. int f1 = (int)((double)rand()/(double)RAND_MAX*(double)featdim);
  589. int f2 = (int)((double)rand()/(double)RAND_MAX*(double)featdim);
  590. int o = (int)((double)rand()/(double)RAND_MAX*(double)ops.size());
  591. Operation *op = ops[o]->clone();
  592. op->set(x1,y1,x2,y2,f1,f2, calcVal[ft]);
  593. featsel.push_back(op);
  594. }
  595. else if(ft == 1)
  596. {
  597. int opssize = (int)ops.size();
  598. //opssize = 0;
  599. int o = (int)((double)rand()/(double)RAND_MAX*(((double)cops.size())+(double)opssize));
  600. Operation *op;
  601. if(o < opssize)
  602. {
  603. int chans = (int)forest[0][0].dist.size();
  604. int f1 = (int)((double)rand()/(double)RAND_MAX*(double)chans);
  605. int f2 = (int)((double)rand()/(double)RAND_MAX*(double)chans);
  606. op = ops[o]->clone();
  607. op->set(x1,y1,x2,y2,f1,f2, calcVal[ft]);
  608. }
  609. else
  610. {
  611. int chans = integralImgs[0].channels();
  612. int f1 = (int)((double)rand()/(double)RAND_MAX*(double)chans);
  613. int f2 = (int)((double)rand()/(double)RAND_MAX*(double)chans);
  614. o -= opssize;
  615. op = cops[o]->clone();
  616. op->set(x1,y1,x2,y2,f1,f2, calcVal[ft]);
  617. }
  618. featsel.push_back(op);
  619. }
  620. }
  621. #pragma omp parallel for private(mapit)
  622. for(int f = 0; f < featsPerSplit; f++)
  623. {
  624. double l_bestig = -numeric_limits< double >::max();
  625. double l_splitval = -1.0;
  626. set<vector<int> >::iterator it;
  627. vector<double> vals;
  628. for ( it=selFeats.begin() ; it != selFeats.end(); it++ )
  629. {
  630. Features feat;
  631. feat.feats = &feats[(*it)[0]];
  632. feat.cfeats = &currentfeats[(*it)[0]];
  633. feat.cTree = tree;
  634. feat.tree = &forest[tree];
  635. feat.integralImg = &integralImgs[(*it)[0]];
  636. vals.push_back(featsel[f]->getVal(feat, (*it)[1], (*it)[2]));
  637. }
  638. int counter = 0;
  639. for ( it=selFeats.begin() ; it != selFeats.end(); it++ , counter++)
  640. {
  641. set<vector<int> >::iterator it2;
  642. double val = vals[counter];
  643. map<int,int> eL, eR;
  644. int counterL = 0, counterR = 0;
  645. int counter2 = 0;
  646. for ( it2=selFeats.begin() ; it2 != selFeats.end(); it2++, counter2++ )
  647. {
  648. int cn = labels[(*it2)[0]]((*it2)[1], (*it2)[2]);
  649. //cout << "vals[counter2] " << vals[counter2] << " val: " << val << endl;
  650. if(vals[counter2] < val)
  651. {
  652. //left entropie:
  653. eL[cn] = eL[cn]+1;
  654. counterL++;
  655. }
  656. else
  657. {
  658. //right entropie:
  659. eR[cn] = eR[cn]+1;
  660. counterR++;
  661. }
  662. }
  663. double leftent = 0.0;
  664. for ( mapit=eL.begin() ; mapit != eL.end(); mapit++ )
  665. {
  666. double p = (double)(*mapit).second/(double)counterL;
  667. leftent -= p*log2(p);
  668. }
  669. double rightent = 0.0;
  670. for ( mapit=eR.begin() ; mapit != eR.end(); mapit++ )
  671. {
  672. double p = (double)(*mapit).second/(double)counterR;
  673. rightent -= p*log2(p);
  674. }
  675. //cout << "rightent: " << rightent << " leftent: " << leftent << endl;
  676. double pl = (double)counterL/(double)(counterL+counterR);
  677. double ig = globent - (1.0-pl) * rightent - pl*leftent;
  678. //double ig = globent - rightent - leftent;
  679. if(useShannonEntropy)
  680. {
  681. double esplit = - ( pl*log(pl) + (1-pl)*log(1-pl) );
  682. ig = 2*ig / ( globent + esplit );
  683. }
  684. if(ig > l_bestig)
  685. {
  686. l_bestig = ig;
  687. l_splitval = val;
  688. }
  689. }
  690. #pragma omp critical
  691. {
  692. //cout << "globent: " << globent << " bestig " << bestig << " splitfeat: " << splitfeat << " splitval: " << splitval << endl;
  693. //cout << "globent: " << globent << " l_bestig " << l_bestig << " f: " << p << " l_splitval: " << l_splitval << endl;
  694. //cout << "p: " << featsubset[f] << endl;
  695. if(l_bestig > bestig)
  696. {
  697. bestig = l_bestig;
  698. splitop = featsel[f];
  699. splitval = l_splitval;
  700. }
  701. }
  702. }
  703. //splitop->writeInfos();
  704. //cout<< "ig: " << bestig << endl;
  705. /*for(int i = 0; i < featsPerSplit; i++)
  706. {
  707. if(featsel[i] != splitop)
  708. delete featsel[i];
  709. }*/
  710. #ifdef debug
  711. cout << "globent: " << globent << " bestig " << bestig << " splitval: " << splitval << endl;
  712. #endif
  713. return bestig;
  714. }
  715. inline double SemSegContextTree::getMeanProb(const int &x,const int &y,const int &channel, const MultiChannelImageT<int> &currentfeats)
  716. {
  717. double val = 0.0;
  718. for(int tree = 0; tree < nbTrees; tree++)
  719. {
  720. val += forest[tree][currentfeats.get(x,y,tree)].dist[channel];
  721. }
  722. return val / (double)nbTrees;
  723. }
  724. void SemSegContextTree::computeIntegralImage(const NICE::MultiChannelImageT<int> &currentfeats, const NICE::MultiChannelImageT<double> &lfeats, NICE::MultiChannelImageT<double> &integralImage)
  725. {
  726. int xsize = currentfeats.width();
  727. int ysize = currentfeats.height();
  728. int channels = (int)forest[0][0].dist.size();
  729. #pragma omp parallel for
  730. for(int c = 0; c < channels; c++)
  731. {
  732. integralImage.set(0,0,getMeanProb(0,0,c, currentfeats), c);
  733. //first column
  734. for(int y = 1; y < ysize; y++)
  735. {
  736. integralImage.set(0,y,getMeanProb(0,y,c, currentfeats)+integralImage.get(0,y,c), c);
  737. }
  738. //first row
  739. for(int x = 1; x < xsize; x++)
  740. {
  741. integralImage.set(x,0,getMeanProb(x,0,c, currentfeats)+integralImage.get(x,0,c), c);
  742. }
  743. //rest
  744. for(int y = 1; y < ysize; y++)
  745. {
  746. for(int x = 1; x < xsize; x++)
  747. {
  748. double val = getMeanProb(x,y,c,currentfeats)+integralImage.get(x,y-1,c)+integralImage.get(x-1,y,c)-integralImage.get(x-1,y-1,c);
  749. integralImage.set(x, y, val, c);
  750. }
  751. }
  752. }
  753. int channels2 = (int)lfeats.channels();
  754. xsize = lfeats.width();
  755. ysize = lfeats.height();
  756. if(integralImage.get(xsize-1,ysize-1,channels) == 0.0)
  757. {
  758. #pragma omp parallel for
  759. for(int c1 = 0; c1 < channels2; c1++)
  760. {
  761. int c = channels+c1;
  762. integralImage.set(0,0,lfeats.get(0,0,c1), c);
  763. //first column
  764. for(int y = 1; y < ysize; y++)
  765. {
  766. integralImage.set(0,y,lfeats.get(0,y,c1)+integralImage.get(0,y,c), c);
  767. }
  768. //first row
  769. for(int x = 1; x < xsize; x++)
  770. {
  771. integralImage.set(x,0,lfeats.get(x,0,c1)+integralImage.get(x,0,c), c);
  772. }
  773. //rest
  774. for(int y = 1; y < ysize; y++)
  775. {
  776. for(int x = 1; x < xsize; x++)
  777. {
  778. double val = lfeats.get(x,y,c1)+integralImage.get(x,y-1,c)+integralImage.get(x-1,y,c)-integralImage.get(x-1,y-1,c);
  779. integralImage.set(x, y, val, c);
  780. }
  781. }
  782. }
  783. }
  784. }
  785. void SemSegContextTree::train ( const MultiDataset *md )
  786. {
  787. const LabeledSet train = * ( *md ) ["train"];
  788. const LabeledSet *trainp = &train;
  789. ProgressBar pb ( "compute feats" );
  790. pb.show();
  791. //TODO: Speichefresser!, lohnt sich sparse?
  792. vector<MultiChannelImageT<double> > allfeats;
  793. vector<MultiChannelImageT<int> > currentfeats;
  794. vector<MatrixT<int> > labels;
  795. std::string forbidden_classes_s = conf->gS ( "analysis", "donttrain", "" );
  796. if ( forbidden_classes_s == "" )
  797. {
  798. forbidden_classes_s = conf->gS ( "analysis", "forbidden_classes", "" );
  799. }
  800. classnames.getSelection ( forbidden_classes_s, forbidden_classes );
  801. int imgcounter = 0;
  802. /*MultiChannelImageT<int> ttmp2(0,0,0);
  803. MultiChannelImageT<double> ttmp1(100,100,1);
  804. MultiChannelImageT<double> tint(100,100,1);
  805. ttmp1.setAll(1.0);
  806. tint.setAll(0.0);
  807. computeIntegralImage(ttmp2,ttmp1,tint);
  808. for(int i = 0; i < cops.size(); i++)
  809. {
  810. Features feats;
  811. feats.feats = &tint;
  812. feats.cfeats = &ttmp2;
  813. feats.cTree = 0;
  814. feats.tree = new vector<TreeNode>;
  815. feats.integralImg = &tint;
  816. cops[i]->set(-10, -6, 8, 9, 0, 0, new MCImageAccess());
  817. cout << "for: " << cops[i]->writeInfos() << endl;
  818. int y = 50;
  819. for(int x = 40; x < 44; x++)
  820. {
  821. cout << "x: " << x << " val: " << cops[i]->getVal(feats, x, y) << endl;
  822. }
  823. }
  824. getchar();*/
  825. LOOP_ALL_S ( *trainp )
  826. {
  827. EACH_INFO ( classno,info );
  828. NICE::ColorImage img;
  829. std::string currentFile = info.img();
  830. CachedExample *ce = new CachedExample ( currentFile );
  831. const LocalizationResult *locResult = info.localization();
  832. if ( locResult->size() <= 0 )
  833. {
  834. fprintf ( stderr, "WARNING: NO ground truth polygons found for %s !\n",
  835. currentFile.c_str() );
  836. continue;
  837. }
  838. fprintf ( stderr, "SemSegCsurka: Collecting pixel examples from localization info: %s\n", currentFile.c_str() );
  839. int xsize, ysize;
  840. ce->getImageSize ( xsize, ysize );
  841. MatrixT<int> tmpMat(xsize,ysize);
  842. currentfeats.push_back(MultiChannelImageT<int>(xsize,ysize,nbTrees));
  843. currentfeats[imgcounter].setAll(0);
  844. labels.push_back(tmpMat);
  845. try {
  846. img = ColorImage(currentFile);
  847. } catch (Exception) {
  848. cerr << "SemSeg: error opening image file <" << currentFile << ">" << endl;
  849. continue;
  850. }
  851. Globals::setCurrentImgFN ( currentFile );
  852. //TODO: resize image?!
  853. MultiChannelImageT<double> feats;
  854. allfeats.push_back(feats);
  855. #ifdef LOCALFEATS
  856. lfcw->getFeats(img, allfeats[imgcounter]);
  857. #else
  858. allfeats[imgcounter].reInit(xsize, ysize, 3, true);
  859. for(int x = 0; x < xsize; x++)
  860. {
  861. for(int y = 0; y < ysize; y++)
  862. {
  863. for(int r = 0; r < 3; r++)
  864. {
  865. allfeats[imgcounter].set(x,y,img.getPixel(x,y,r),r);
  866. }
  867. }
  868. }
  869. #endif
  870. // getting groundtruth
  871. NICE::Image pixelLabels (xsize, ysize);
  872. pixelLabels.set(0);
  873. locResult->calcLabeledImage ( pixelLabels, ( *classNames ).getBackgroundClass() );
  874. for(int x = 0; x < xsize; x++)
  875. {
  876. for(int y = 0; y < ysize; y++)
  877. {
  878. classno = pixelLabels.getPixel(x, y);
  879. labels[imgcounter](x,y) = classno;
  880. if ( forbidden_classes.find ( classno ) != forbidden_classes.end() )
  881. continue;
  882. labelcounter[classno]++;
  883. }
  884. }
  885. imgcounter++;
  886. pb.update ( trainp->count());
  887. delete ce;
  888. }
  889. pb.hide();
  890. map<int,int>::iterator mapit;
  891. int classes = 0;
  892. for(mapit = labelcounter.begin(); mapit != labelcounter.end(); mapit++)
  893. {
  894. labelmap[mapit->first] = classes;
  895. labelmapback[classes] = mapit->first;
  896. classes++;
  897. }
  898. //balancing
  899. int featcounter = 0;
  900. a = vector<double>(classes,0.0);
  901. for(int iCounter = 0; iCounter < imgcounter; iCounter++)
  902. {
  903. int xsize = (int)currentfeats[iCounter].width();
  904. int ysize = (int)currentfeats[iCounter].height();
  905. for(int x = 0; x < xsize; x++)
  906. {
  907. for(int y = 0; y < ysize; y++)
  908. {
  909. featcounter++;
  910. int cn = labels[iCounter](x,y);
  911. a[labelmap[cn]] ++;
  912. }
  913. }
  914. }
  915. for(int i = 0; i < (int)a.size(); i++)
  916. {
  917. a[i] /= (double)featcounter;
  918. }
  919. #ifdef DEBUG
  920. for(int i = 0; i < (int)a.size(); i++)
  921. {
  922. cout << "a["<<i<<"]: " << a[i] << endl;
  923. }
  924. cout << "a.size: " << a.size() << endl;
  925. #endif
  926. int depth = 0;
  927. for(int t = 0; t < nbTrees; t++)
  928. {
  929. vector<TreeNode> tree;
  930. tree.push_back(TreeNode());
  931. tree[0].dist = vector<double>(classes,0.0);
  932. tree[0].depth = depth;
  933. forest.push_back(tree);
  934. }
  935. vector<int> startnode(nbTrees,0);
  936. bool allleaf = false;
  937. //int baseFeatSize = allfeats[0].size();
  938. vector<MultiChannelImageT<double> > integralImgs(imgcounter,MultiChannelImageT<double>());
  939. while(!allleaf && depth < maxDepth)
  940. {
  941. allleaf = true;
  942. vector<MultiChannelImageT<int> > lastfeats = currentfeats;
  943. #if 1
  944. Timer timer;
  945. timer.start();
  946. #endif
  947. for(int tree = 0; tree < nbTrees; tree++)
  948. {
  949. int t = (int) forest[tree].size();
  950. int s = startnode[tree];
  951. startnode[tree] = t;
  952. //TODO vielleicht parallel wenn nächste schleife trotzdem noch parallelsiert würde, die hat mehr gewicht
  953. //#pragma omp parallel for
  954. for(int i = s; i < t; i++)
  955. {
  956. if(!forest[tree][i].isleaf && forest[tree][i].left < 0)
  957. {
  958. Operation *splitfeat = NULL;
  959. double splitval;
  960. double bestig = getBestSplit(allfeats, lastfeats, integralImgs, labels, i, splitfeat, splitval, tree);
  961. forest[tree][i].feat = splitfeat;
  962. forest[tree][i].decision = splitval;
  963. if(splitfeat != NULL)
  964. {
  965. allleaf = false;
  966. int left = forest[tree].size();
  967. forest[tree].push_back(TreeNode());
  968. forest[tree].push_back(TreeNode());
  969. int right = left+1;
  970. forest[tree][i].left = left;
  971. forest[tree][i].right = right;
  972. forest[tree][left].dist = vector<double>(classes, 0.0);
  973. forest[tree][right].dist = vector<double>(classes, 0.0);
  974. forest[tree][left].depth = depth+1;
  975. forest[tree][right].depth = depth+1;
  976. #pragma omp parallel for
  977. for(int iCounter = 0; iCounter < imgcounter; iCounter++)
  978. {
  979. int xsize = currentfeats[iCounter].width();
  980. int ysize = currentfeats[iCounter].height();
  981. for(int x = 0; x < xsize; x++)
  982. {
  983. for(int y = 0; y < ysize; y++)
  984. {
  985. if(currentfeats[iCounter].get(x, y, tree) == i)
  986. {
  987. Features feat;
  988. feat.feats = &allfeats[iCounter];
  989. feat.cfeats = &lastfeats[iCounter];
  990. feat.cTree = tree;
  991. feat.tree = &forest[tree];
  992. feat.integralImg = &integralImgs[iCounter];
  993. double val = splitfeat->getVal(feat,x,y);
  994. if(val < splitval)
  995. {
  996. currentfeats[iCounter].set(x,y,left,tree);
  997. forest[tree][left].dist[labelmap[labels[iCounter](x,y)]]++;
  998. }
  999. else
  1000. {
  1001. currentfeats[iCounter].set(x,y,right,tree);
  1002. forest[tree][right].dist[labelmap[labels[iCounter](x,y)]]++;
  1003. }
  1004. }
  1005. }
  1006. }
  1007. }
  1008. double lcounter = 0.0, rcounter = 0.0;
  1009. for(uint d = 0; d < forest[tree][left].dist.size(); d++)
  1010. {
  1011. if ( forbidden_classes.find ( labelmapback[d] ) != forbidden_classes.end() )
  1012. {
  1013. forest[tree][left].dist[d] = 0;
  1014. forest[tree][right].dist[d] = 0;
  1015. }
  1016. else
  1017. {
  1018. forest[tree][left].dist[d]/=a[d];
  1019. lcounter +=forest[tree][left].dist[d];
  1020. forest[tree][right].dist[d]/=a[d];
  1021. rcounter +=forest[tree][right].dist[d];
  1022. }
  1023. }
  1024. if(lcounter <= 0 || rcounter <= 0)
  1025. {
  1026. cout << "lcounter : " << lcounter << " rcounter: " << rcounter << endl;
  1027. cout << "splitval: " << splitval << " splittype: " << splitfeat->writeInfos() << endl;
  1028. cout << "bestig: " << bestig << endl;
  1029. for(int iCounter = 0; iCounter < imgcounter; iCounter++)
  1030. {
  1031. int xsize = currentfeats[iCounter].width();
  1032. int ysize = currentfeats[iCounter].height();
  1033. int counter = 0;
  1034. for(int x = 0; x < xsize; x++)
  1035. {
  1036. for(int y = 0; y < ysize; y++)
  1037. {
  1038. if(lastfeats[iCounter].get(x,y,tree) == i)
  1039. {
  1040. if(++counter > 30)
  1041. break;
  1042. Features feat;
  1043. feat.feats = &allfeats[iCounter];
  1044. feat.cfeats = &lastfeats[iCounter];
  1045. feat.cTree = tree;
  1046. feat.tree = &forest[tree];
  1047. feat.integralImg = &integralImgs[iCounter];
  1048. double val = splitfeat->getVal(feat,x,y);
  1049. cout << "splitval: " << splitval << " val: " << val << endl;
  1050. }
  1051. }
  1052. }
  1053. }
  1054. assert(lcounter > 0 && rcounter > 0);
  1055. }
  1056. for(uint d = 0; d < forest[tree][left].dist.size(); d++)
  1057. {
  1058. forest[tree][left].dist[d]/=lcounter;
  1059. forest[tree][right].dist[d]/=rcounter;
  1060. }
  1061. }
  1062. else
  1063. {
  1064. forest[tree][i].isleaf = true;
  1065. }
  1066. }
  1067. }
  1068. }
  1069. //TODO: features neu berechnen!
  1070. //compute integral image
  1071. int channels = classes+allfeats[0].channels();
  1072. if(integralImgs[0].width() == 0)
  1073. {
  1074. for(int i = 0; i < imgcounter; i++)
  1075. {
  1076. int xsize = allfeats[i].width();
  1077. int ysize = allfeats[i].height();
  1078. integralImgs[i].reInit(xsize, ysize, channels);
  1079. integralImgs[i].setAll(0.0);
  1080. }
  1081. }
  1082. for(int i = 0; i < imgcounter; i++)
  1083. {
  1084. computeIntegralImage(currentfeats[i],allfeats[i], integralImgs[i]);
  1085. }
  1086. #if 1
  1087. timer.stop();
  1088. cout << "time for depth " << depth << ": " << timer.getLast() << endl;
  1089. #endif
  1090. depth++;
  1091. #ifdef DEBUG
  1092. cout << "depth: " << depth << endl;
  1093. #endif
  1094. }
  1095. #ifdef DEBUG
  1096. for(int tree = 0; tree < nbTrees; tree++)
  1097. {
  1098. int t = (int) forest[tree].size();
  1099. for(int i = 0; i < t; i++)
  1100. {
  1101. printf("tree[%i]: left: %i, right: %i", i, forest[tree][i].left, forest[tree][i].right);
  1102. if(!forest[tree][i].isleaf && forest[tree][i].left != -1)
  1103. {
  1104. cout << ", feat: " << forest[tree][i].feat->writeInfos() << " ";
  1105. opOverview[forest[tree][i].feat->getOps()]++;
  1106. }
  1107. for(int d = 0; d < (int)forest[tree][i].dist.size(); d++)
  1108. {
  1109. cout << " " << forest[tree][i].dist[d];
  1110. }
  1111. cout << endl;
  1112. }
  1113. }
  1114. for(uint c = 0; c < ops.size(); c++)
  1115. {
  1116. cout << ops[c]->writeInfos() << ": " << opOverview[ops[c]->getOps()] << endl;
  1117. }
  1118. for(uint c = 0; c < cops.size(); c++)
  1119. {
  1120. cout << cops[c]->writeInfos() << ": " << opOverview[cops[c]->getOps()] << endl;
  1121. }
  1122. #endif
  1123. }
  1124. void SemSegContextTree::semanticseg ( CachedExample *ce, NICE::Image & segresult,NICE::MultiChannelImageT<double> & probabilities )
  1125. {
  1126. int xsize;
  1127. int ysize;
  1128. ce->getImageSize ( xsize, ysize );
  1129. int numClasses = classNames->numClasses();
  1130. fprintf (stderr, "ContextTree classification !\n");
  1131. probabilities.reInit ( xsize, ysize, numClasses, true );
  1132. probabilities.setAll ( 0 );
  1133. NICE::ColorImage img;
  1134. std::string currentFile = Globals::getCurrentImgFN();
  1135. try {
  1136. img = ColorImage(currentFile);
  1137. } catch (Exception) {
  1138. cerr << "SemSeg: error opening image file <" << currentFile << ">" << endl;
  1139. return;
  1140. }
  1141. //TODO: resize image?!
  1142. MultiChannelImageT<double> feats;
  1143. #ifdef LOCALFEATS
  1144. lfcw->getFeats(img, feats);
  1145. #else
  1146. feats.reInit (xsize, ysize, 3, true);
  1147. for(int x = 0; x < xsize; x++)
  1148. {
  1149. for(int y = 0; y < ysize; y++)
  1150. {
  1151. for(int r = 0; r < 3; r++)
  1152. {
  1153. feats.set(x,y,img.getPixel(x,y,r),r);
  1154. }
  1155. }
  1156. }
  1157. #endif
  1158. bool allleaf = false;
  1159. MultiChannelImageT<double> integralImg;
  1160. MultiChannelImageT<int> currentfeats(xsize, ysize, nbTrees);
  1161. currentfeats.setAll(0);
  1162. int depth = 0;
  1163. while(!allleaf)
  1164. {
  1165. allleaf = true;
  1166. //TODO vielleicht parallel wenn nächste schleife auch noch parallelsiert würde, die hat mehr gewicht
  1167. //#pragma omp parallel for
  1168. MultiChannelImageT<int> lastfeats = currentfeats;
  1169. for(int tree = 0; tree < nbTrees; tree++)
  1170. {
  1171. for(int x = 0; x < xsize; x++)
  1172. {
  1173. for(int y = 0; y < ysize; y++)
  1174. {
  1175. int t = currentfeats.get(x,y,tree);
  1176. if(forest[tree][t].left > 0)
  1177. {
  1178. allleaf = false;
  1179. Features feat;
  1180. feat.feats = &feats;
  1181. feat.cfeats = &lastfeats;
  1182. feat.cTree = tree;
  1183. feat.tree = &forest[tree];
  1184. feat.integralImg = &integralImg;
  1185. double val = forest[tree][t].feat->getVal(feat,x,y);
  1186. if(val < forest[tree][t].decision)
  1187. {
  1188. currentfeats.set(x, y, forest[tree][t].left, tree);
  1189. }
  1190. else
  1191. {
  1192. currentfeats.set(x, y, forest[tree][t].right, tree);
  1193. }
  1194. }
  1195. }
  1196. }
  1197. //compute integral image
  1198. int channels = (int)labelmap.size()+feats.channels();
  1199. if(integralImg.width() == 0)
  1200. {
  1201. int xsize = feats.width();
  1202. int ysize = feats.height();
  1203. integralImg.reInit(xsize, ysize, channels);
  1204. }
  1205. }
  1206. computeIntegralImage(currentfeats,feats, integralImg);
  1207. depth++;
  1208. }
  1209. if(pixelWiseLabeling)
  1210. {
  1211. //finales labeln:
  1212. long int offset = 0;
  1213. for(int x = 0; x < xsize; x++)
  1214. {
  1215. for(int y = 0; y < ysize; y++,offset++)
  1216. {
  1217. double maxvalue = - numeric_limits<double>::max(); //TODO: das muss nur pro knoten gemacht werden, nicht pro pixel
  1218. int maxindex = 0;
  1219. uint s = forest[0][0].dist.size();
  1220. for(uint i = 0; i < s; i++)
  1221. {
  1222. probabilities.data[labelmapback[i]][offset] = getMeanProb(x,y,i,currentfeats);
  1223. if(probabilities.data[labelmapback[i]][offset] > maxvalue)
  1224. {
  1225. maxvalue = probabilities.data[labelmapback[i]][offset];
  1226. maxindex = labelmapback[i];
  1227. }
  1228. segresult.setPixel(x,y,maxindex);
  1229. }
  1230. if(maxvalue > 1)
  1231. cout << "maxvalue: " << maxvalue << endl;
  1232. }
  1233. }
  1234. }
  1235. else
  1236. {
  1237. //final labeling using segmentation
  1238. //TODO: segmentation
  1239. Matrix regions;
  1240. int regionNumber = segmentation->segRegions(img,regions);
  1241. cout << "regions: " << regionNumber << endl;
  1242. int dSize = forest[0][0].dist.size();
  1243. vector<vector<double> > regionProbs(regionNumber, vector<double>(dSize,0.0));
  1244. vector<int> bestlabels(regionNumber, 0);
  1245. /*
  1246. for(int r = 0; r < regionNumber; r++)
  1247. {
  1248. Image over(img.width(), img.height());
  1249. for(int y = 0; y < img.height(); y++)
  1250. {
  1251. for(int x = 0; x < img.width(); x++)
  1252. {
  1253. if(((int)regions(x,y)) == r)
  1254. over.setPixel(x,y,1);
  1255. else
  1256. over.setPixel(x,y,0);
  1257. }
  1258. }
  1259. cout << "r: " << r << endl;
  1260. showImageOverlay(img, over);
  1261. }
  1262. */
  1263. for(int y = 0; y < img.height(); y++)
  1264. {
  1265. for(int x = 0; x < img.width(); x++)
  1266. {
  1267. int cregion = regions(x,y);
  1268. for(int d = 0; d < dSize; d++)
  1269. {
  1270. regionProbs[cregion][d]+=getMeanProb(x,y,d,currentfeats);
  1271. }
  1272. }
  1273. }
  1274. int roi = 38;
  1275. for(int r = 0; r < regionNumber; r++)
  1276. {
  1277. double maxval = regionProbs[r][0];
  1278. bestlabels[r] = 0;
  1279. if(roi == r)
  1280. {
  1281. cout << "r: " << r << endl;
  1282. cout << "0: " << regionProbs[r][0] << endl;
  1283. }
  1284. for(int d = 1; d < dSize; d++)
  1285. {
  1286. if(maxval < regionProbs[r][d])
  1287. {
  1288. maxval = regionProbs[r][d];
  1289. bestlabels[r] = d;
  1290. }
  1291. if(roi == r)
  1292. {
  1293. cout << d << ": " << regionProbs[r][d] << endl;
  1294. }
  1295. }
  1296. if(roi == r)
  1297. {
  1298. cout << "bestlabel: " << bestlabels[r] << " danach: " << labelmapback[bestlabels[r]] << endl;
  1299. }
  1300. bestlabels[r] = labelmapback[bestlabels[r]];
  1301. }
  1302. for(int y = 0; y < img.height(); y++)
  1303. {
  1304. for(int x = 0; x < img.width(); x++)
  1305. {
  1306. segresult.setPixel(x,y,bestlabels[regions(x,y)]);
  1307. }
  1308. }
  1309. }
  1310. }