SemSegContextTree.cpp 32 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297
  1. #include "SemSegContextTree.h"
  2. #include "vislearning/baselib/Globals.h"
  3. #include "vislearning/baselib/ProgressBar.h"
  4. #include "core/basics/StringTools.h"
  5. #include "vislearning/cbaselib/CachedExample.h"
  6. #include "vislearning/cbaselib/PascalResults.h"
  7. #include "objrec/segmentation/RSMeanShift.h"
  8. #include "objrec/segmentation/RSGraphBased.h"
  9. #include "core/basics/numerictools.h"
  10. #include "core/basics/Timer.h"
  11. #include <omp.h>
  12. #include <iostream>
  13. #define BOUND(x,min,max) (((x)<(min))?(min):((x)>(max)?(max):(x)))
  14. #undef LOCALFEATS
  15. //#define LOCALFEATS
  16. using namespace OBJREC;
  17. using namespace std;
  18. using namespace NICE;
  19. class MCImageAccess:public ValueAccess
  20. {
  21. public:
  22. virtual double getVal(const Features &feats, const int &x, const int &y, const int &channel)
  23. {
  24. return feats.feats->get(x,y,channel);
  25. }
  26. virtual string writeInfos()
  27. {
  28. return "raw";
  29. }
  30. };
  31. class ClassificationResultAcess:public ValueAccess
  32. {
  33. public:
  34. virtual double getVal(const Features &feats, const int &x, const int &y, const int &channel)
  35. {
  36. return (*feats.tree)[feats.cfeats->get(x,y,feats.cTree)].dist[channel];
  37. }
  38. virtual string writeInfos()
  39. {
  40. return "context";
  41. }
  42. };
  43. class Minus:public Operation
  44. {
  45. public:
  46. virtual double getVal(const Features &feats, const int &x, const int &y)
  47. {
  48. int xsize, ysize;
  49. getXY(feats, xsize, ysize);
  50. double v1 = values->getVal(feats, BOUND(x+x1,0,xsize-1),BOUND(y+y1,0,ysize-1),channel1);
  51. double v2 = values->getVal(feats, BOUND(x+x2,0,xsize-1),BOUND(y+y2,0,ysize-1),channel2);
  52. return v1-v2;
  53. }
  54. virtual Operation* clone()
  55. {
  56. return new Minus();
  57. }
  58. virtual string writeInfos()
  59. {
  60. return "Minus"+values->writeInfos();
  61. }
  62. };
  63. class MinusAbs:public Operation
  64. {
  65. public:
  66. virtual double getVal(const Features &feats, const int &x, const int &y)
  67. {
  68. int xsize, ysize;
  69. getXY(feats, xsize, ysize);
  70. double v1 = values->getVal(feats, BOUND(x+x1,0,xsize-1),BOUND(y+y1,0,ysize-1),channel1);
  71. double v2 = values->getVal(feats, BOUND(x+x2,0,xsize-1),BOUND(y+y2,0,ysize-1),channel2);
  72. return abs(v1-v2);
  73. }
  74. virtual Operation* clone()
  75. {
  76. return new MinusAbs();
  77. };
  78. virtual string writeInfos()
  79. {
  80. return "MinusAbs"+values->writeInfos();
  81. }
  82. };
  83. class Addition:public Operation
  84. {
  85. public:
  86. virtual double getVal(const Features &feats, const int &x, const int &y)
  87. {
  88. int xsize, ysize;
  89. getXY(feats, xsize, ysize);
  90. double v1 = values->getVal(feats, BOUND(x+x1,0,xsize-1),BOUND(y+y1,0,ysize-1),channel1);
  91. double v2 = values->getVal(feats, BOUND(x+x2,0,xsize-1),BOUND(y+y2,0,ysize-1),channel2);
  92. return v1+v2;
  93. }
  94. virtual Operation* clone()
  95. {
  96. return new Addition();
  97. }
  98. virtual string writeInfos()
  99. {
  100. return "Addition"+values->writeInfos();
  101. }
  102. };
  103. class Only1:public Operation
  104. {
  105. public:
  106. virtual double getVal(const Features &feats, const int &x, const int &y)
  107. {
  108. int xsize, ysize;
  109. getXY(feats, xsize, ysize);
  110. double v1 = values->getVal(feats, BOUND(x+x1,0,xsize-1),BOUND(y+y1,0,ysize-1),channel1);
  111. return v1;
  112. }
  113. virtual Operation* clone()
  114. {
  115. return new Only1();
  116. }
  117. virtual string writeInfos()
  118. {
  119. return "Only1"+values->writeInfos();
  120. }
  121. };
  122. // uses mean of classification in window given by (x1,y1) (x2,y2)
  123. class IntegralOps:public Operation
  124. {
  125. public:
  126. virtual void set(int _x1, int _y1, int _x2, int _y2, int _channel1, int _channel2, ValueAccess *_values)
  127. {
  128. x1 = min(_x1,_x2);
  129. y1 = min(_y1,_y2);
  130. x2 = max(_x1,_x2);
  131. y2 = max(_y1,_y2);
  132. channel1 = _channel1;
  133. channel2 = _channel2;
  134. values = _values;
  135. }
  136. virtual double getVal(const Features &feats, const int &x, const int &y)
  137. {
  138. int xsize, ysize;
  139. getXY(feats, xsize, ysize);
  140. return computeMean(*feats.integralImg,BOUND(x+x1,0,xsize-1),BOUND(y+y1,0,ysize-1),BOUND(x+x2,0,xsize-1),BOUND(y+y2,0,ysize-1),channel1);
  141. }
  142. inline double computeMean(const NICE::MultiChannelImageT<double> &intImg, const int &uLx, const int &uLy, const int &lRx, const int &lRy, const int &chan)
  143. {
  144. double val1 = intImg.get(uLx,uLy, chan);
  145. double val2 = intImg.get(lRx,uLy, chan);
  146. double val3 = intImg.get(uLx,lRy, chan);
  147. double val4 = intImg.get(lRx,lRy, chan);
  148. double area = (lRx-uLx)*(lRy-uLy);
  149. return (val1+val4-val2-val3)/area;
  150. }
  151. virtual Operation* clone()
  152. {
  153. return new IntegralOps();
  154. }
  155. virtual string writeInfos()
  156. {
  157. return "IntegralOps";
  158. }
  159. };
  160. //uses mean of Integral image given by x1, y1 with current pixel as center
  161. class IntegralCenteredOps:public IntegralOps
  162. {
  163. public:
  164. virtual void set(int _x1, int _y1, int _x2, int _y2, int _channel1, int _channel2)
  165. {
  166. x1 = abs(_x1);
  167. y1 = abs(_y1);
  168. x2 = abs(_x2);
  169. y2 = abs(_y2);
  170. channel1 = _channel1;
  171. channel2 = _channel2;
  172. }
  173. virtual double getVal(const Features &feats, const int &x, const int &y)
  174. {
  175. int xsize, ysize;
  176. getXY(feats, xsize, ysize);
  177. return computeMean(*feats.integralImg,BOUND(x-x1,0,xsize-1),BOUND(y-y1,0,ysize-1),BOUND(x+x1,0,xsize-1),BOUND(y+y1,0,ysize-1),channel1);
  178. }
  179. virtual Operation* clone()
  180. {
  181. return new IntegralCenteredOps();
  182. }
  183. virtual string writeInfos()
  184. {
  185. return "IntegralCenteredOps";
  186. }
  187. };
  188. //uses different of mean of Integral image given by two windows, where (x1,y1) is the width and height of window1 and (x2,y2) of window 2
  189. class BiIntegralCenteredOps:public IntegralCenteredOps
  190. {
  191. public:
  192. virtual void set(int _x1, int _y1, int _x2, int _y2, int _channel1, int _channel2)
  193. {
  194. x1 = min(abs(_x1),abs(_x2));
  195. y1 = min(abs(_y1),abs(_y2));
  196. x2 = max(abs(_x1),abs(_x2));
  197. y2 = max(abs(_y1),abs(_y2));
  198. channel1 = _channel1;
  199. channel2 = _channel2;
  200. }
  201. virtual double getVal(const Features &feats, const int &x, const int &y)
  202. {
  203. int xsize, ysize;
  204. getXY(feats, xsize, ysize);
  205. return computeMean(*feats.integralImg,BOUND(x-x1,0,xsize-1),BOUND(y-y1,0,ysize-1),BOUND(x+x1,0,xsize-1),BOUND(y+y1,0,ysize-1),channel1) - computeMean(*feats.integralImg,BOUND(x-x2,0,xsize-1),BOUND(y-y2,0,ysize-1),BOUND(x+x2,0,xsize-1),BOUND(y+y2,0,ysize-1),channel1);
  206. }
  207. virtual Operation* clone()
  208. {
  209. return new BiIntegralCenteredOps();
  210. }
  211. virtual string writeInfos()
  212. {
  213. return "BiIntegralCenteredOps";
  214. }
  215. };
  216. /** horizontal Haar features
  217. * ++
  218. * --
  219. */
  220. class HaarHorizontal:public IntegralCenteredOps
  221. {
  222. virtual double getVal(const Features &feats, const int &x, const int &y)
  223. {
  224. int xsize, ysize;
  225. getXY(feats, xsize, ysize);
  226. int tlx = BOUND(x-x1,0,xsize-1);
  227. int tly = BOUND(y-y1,0,ysize-1);
  228. int lrx = BOUND(x+x1,0,xsize-1);
  229. int lry = BOUND(y+y1,0,ysize-1);
  230. return computeMean(*feats.integralImg,tlx,tly,lrx, y,channel1)-computeMean(*feats.integralImg,tlx,y,lrx, lry,channel1);
  231. }
  232. virtual string writeInfos()
  233. {
  234. return "HaarHorizontal";
  235. }
  236. };
  237. /** vertical Haar features
  238. * +-
  239. * +-
  240. */
  241. class HaarVertical:public IntegralCenteredOps
  242. {
  243. virtual double getVal(const Features &feats, const int &x, const int &y)
  244. {
  245. int xsize, ysize;
  246. getXY(feats, xsize, ysize);
  247. int tlx = BOUND(x-x1,0,xsize-1);
  248. int tly = BOUND(y-y1,0,ysize-1);
  249. int lrx = BOUND(x+x1,0,xsize-1);
  250. int lry = BOUND(y+y1,0,ysize-1);
  251. return computeMean(*feats.integralImg,tlx,tly,x, lry,channel1)-computeMean(*feats.integralImg,x,tly,lrx, lry,channel1);
  252. }
  253. virtual string writeInfos()
  254. {
  255. return "HaarVertical";
  256. }
  257. };
  258. /** vertical Haar features
  259. * +-
  260. * -+
  261. */
  262. class HaarDiag:public IntegralCenteredOps
  263. {
  264. virtual double getVal(const Features &feats, const int &x, const int &y)
  265. {
  266. int xsize, ysize;
  267. getXY(feats, xsize, ysize);
  268. int tlx = BOUND(x-x1,0,xsize-1);
  269. int tly = BOUND(y-y1,0,ysize-1);
  270. int lrx = BOUND(x+x1,0,xsize-1);
  271. int lry = BOUND(y+y1,0,ysize-1);
  272. return computeMean(*feats.integralImg,tlx,tly,x, y,channel1)+computeMean(*feats.integralImg,x,y,lrx, lry,channel1) - computeMean(*feats.integralImg,tlx,y,x, lry,channel1)-computeMean(*feats.integralImg,x,tly,lrx, y,channel1);
  273. }
  274. virtual string writeInfos()
  275. {
  276. return "HaarDiag";
  277. }
  278. };
  279. /** horizontal Haar features
  280. * +++
  281. * ---
  282. * +++
  283. */
  284. class Haar3Horiz:public BiIntegralCenteredOps
  285. {
  286. virtual double getVal(const Features &feats, const int &x, const int &y)
  287. {
  288. int xsize, ysize;
  289. getXY(feats, xsize, ysize);
  290. int tlx = BOUND(x-x2,0,xsize-1);
  291. int tly = BOUND(y-y2,0,ysize-1);
  292. int mtly = BOUND(y-y1,0,ysize-1);
  293. int mlry = BOUND(y-y1,0,ysize-1);
  294. int lrx = BOUND(x+x2,0,xsize-1);
  295. int lry = BOUND(y+y2,0,ysize-1);
  296. return computeMean(*feats.integralImg,tlx,tly,lrx, mtly,channel1) -computeMean(*feats.integralImg,tlx,mtly,lrx, mlry,channel1) + computeMean(*feats.integralImg,tlx,mlry,lrx, lry,channel1);
  297. }
  298. virtual string writeInfos()
  299. {
  300. return "Haar3Horiz";
  301. }
  302. };
  303. /** vertical Haar features
  304. * +-+
  305. * +-+
  306. * +-+
  307. */
  308. class Haar3Vert:public BiIntegralCenteredOps
  309. {
  310. virtual double getVal(const Features &feats, const int &x, const int &y)
  311. {
  312. int xsize, ysize;
  313. getXY(feats, xsize, ysize);
  314. int tlx = BOUND(x-x2,0,xsize-1);
  315. int tly = BOUND(y-y2,0,ysize-1);
  316. int mtlx = BOUND(x-x1,0,xsize-1);
  317. int mlrx = BOUND(x-x1,0,xsize-1);
  318. int lrx = BOUND(x+x2,0,xsize-1);
  319. int lry = BOUND(y+y2,0,ysize-1);
  320. return computeMean(*feats.integralImg,tlx,tly,mtlx, lry,channel1) -computeMean(*feats.integralImg,mtlx,tly,mlrx, lry,channel1) + computeMean(*feats.integralImg,mlrx,tly,lrx, lry,channel1);
  321. }
  322. virtual string writeInfos()
  323. {
  324. return "Haar3Vert";
  325. }
  326. };
  327. SemSegContextTree::SemSegContextTree( const Config *conf, const MultiDataset *md )
  328. : SemanticSegmentation ( conf, &(md->getClassNames("train")) )
  329. {
  330. this->conf = conf;
  331. string section = "SSContextTree";
  332. lfcw = new LFColorWeijer(conf);
  333. grid = conf->gI(section, "grid", 10 );
  334. maxSamples = conf->gI(section, "max_samples", 2000);
  335. minFeats = conf->gI(section, "min_feats", 50 );
  336. maxDepth = conf->gI(section, "max_depth", 10 );
  337. windowSize = conf->gI(section, "window_size", 16);
  338. featsPerSplit = conf->gI(section, "feats_per_split", 200);
  339. useShannonEntropy = conf->gB(section, "use_shannon_entropy", true);
  340. nbTrees = conf->gI(section, "amount_trees", 1);
  341. string segmentationtype = conf->gS(section, "segmentation_type", "meanshift");
  342. useGaussian = conf->gB(section, "use_gaussian", true);
  343. if(useGaussian)
  344. throw("there something wrong with using gaussian! first fix it!");
  345. pixelWiseLabeling = false;
  346. if(segmentationtype == "meanshift")
  347. segmentation = new RSMeanShift(conf);
  348. else if (segmentationtype == "none")
  349. {
  350. segmentation = NULL;
  351. pixelWiseLabeling = true;
  352. }
  353. else if (segmentationtype == "felzenszwalb")
  354. segmentation = new RSGraphBased(conf);
  355. else
  356. throw("no valid segmenation_type\n please choose between none, meanshift and felzenszwalb\n");
  357. ftypes = conf->gI(section, "features", 2);;
  358. ops.push_back(new Minus());
  359. ops.push_back(new MinusAbs());
  360. ops.push_back(new Addition());
  361. ops.push_back(new Only1());
  362. cops.push_back(new BiIntegralCenteredOps());
  363. cops.push_back(new IntegralCenteredOps());
  364. cops.push_back(new IntegralOps());
  365. cops.push_back(new HaarHorizontal());
  366. cops.push_back(new HaarVertical());
  367. cops.push_back(new HaarDiag());
  368. cops.push_back(new Haar3Horiz());
  369. cops.push_back(new Haar3Vert());
  370. calcVal.push_back(new MCImageAccess());
  371. calcVal.push_back(new ClassificationResultAcess());
  372. classnames = md->getClassNames ( "train" );
  373. ///////////////////////////////////
  374. // Train Segmentation Context Trees
  375. ///////////////////////////////////
  376. train ( md );
  377. }
  378. SemSegContextTree::~SemSegContextTree()
  379. {
  380. }
  381. double SemSegContextTree::getBestSplit(std::vector<NICE::MultiChannelImageT<double> > &feats, std::vector<NICE::MultiChannelImageT<int> > &currentfeats, std::vector<NICE::MultiChannelImageT<double> > &integralImgs, const std::vector<NICE::MatrixT<int> > &labels, int node, Operation *&splitop, double &splitval, const int &tree)
  382. {
  383. int imgCount = 0, featdim = 0;
  384. try
  385. {
  386. imgCount = (int)feats.size();
  387. featdim = feats[0].channels();
  388. }
  389. catch(Exception)
  390. {
  391. cerr << "no features computed?" << endl;
  392. }
  393. double bestig = -numeric_limits< double >::max();
  394. splitop = NULL;
  395. splitval = -1.0;
  396. set<vector<int> >selFeats;
  397. map<int,int> e;
  398. int featcounter = 0;
  399. for(int iCounter = 0; iCounter < imgCount; iCounter++)
  400. {
  401. int xsize = (int)currentfeats[iCounter].width();
  402. int ysize = (int)currentfeats[iCounter].height();
  403. for(int x = 0; x < xsize; x++)
  404. {
  405. for(int y = 0; y < ysize; y++)
  406. {
  407. if(currentfeats[iCounter].get(x,y,tree) == node)
  408. {
  409. featcounter++;
  410. }
  411. }
  412. }
  413. }
  414. if(featcounter < minFeats)
  415. {
  416. cout << "only " << featcounter << " feats in current node -> it's a leaf" << endl;
  417. return 0.0;
  418. }
  419. vector<double> fraction(a.size(),0.0);
  420. for(uint i = 0; i < fraction.size(); i++)
  421. {
  422. if ( forbidden_classes.find ( labelmapback[i] ) != forbidden_classes.end() )
  423. fraction[i] = 0;
  424. else
  425. fraction[i] = ((double)maxSamples)/((double)featcounter*a[i]*a.size());
  426. //cout << "fraction["<<i<<"]: "<< fraction[i] << " a[" << i << "]: " << a[i] << endl;
  427. }
  428. //cout << "a.size(): " << a.size() << endl;
  429. //getchar();
  430. featcounter = 0;
  431. for(int iCounter = 0; iCounter < imgCount; iCounter++)
  432. {
  433. int xsize = (int)currentfeats[iCounter].width();
  434. int ysize = (int)currentfeats[iCounter].height();
  435. for(int x = 0; x < xsize; x++)
  436. {
  437. for(int y = 0; y < ysize; y++)
  438. {
  439. if(currentfeats[iCounter].get(x,y,tree) == node)
  440. {
  441. int cn = labels[iCounter](x,y);
  442. double randD = (double)rand()/(double)RAND_MAX;
  443. if(randD < fraction[labelmap[cn]])
  444. {
  445. vector<int> tmp(3,0);
  446. tmp[0] = iCounter;
  447. tmp[1] = x;
  448. tmp[2] = y;
  449. featcounter++;
  450. selFeats.insert(tmp);
  451. e[cn]++;
  452. }
  453. }
  454. }
  455. }
  456. }
  457. //cout << "size: " << selFeats.size() << endl;
  458. //getchar();
  459. map<int,int>::iterator mapit;
  460. double globent = 0.0;
  461. for ( mapit=e.begin() ; mapit != e.end(); mapit++ )
  462. {
  463. //cout << "class: " << mapit->first << ": " << mapit->second << endl;
  464. double p = (double)(*mapit).second/(double)featcounter;
  465. globent += p*log2(p);
  466. }
  467. globent = -globent;
  468. if(globent < 0.5)
  469. {
  470. cout << "globent to small: " << globent << endl;
  471. return 0.0;
  472. }
  473. int classes = (int)forest[tree][0].dist.size();
  474. featsel.clear();
  475. for(int i = 0; i < featsPerSplit; i++)
  476. {
  477. int x1, x2, y1, y2;
  478. int ft = (int)((double)rand()/(double)RAND_MAX*(double)ftypes);
  479. int tmpws = windowSize;
  480. if(integralImgs[0].width() == 0)
  481. ft = 0;
  482. if(ft > 0)
  483. {
  484. tmpws *= 2;
  485. }
  486. if(useGaussian)
  487. {
  488. double sigma = (double)tmpws/2.0;
  489. x1 = randGaussDouble(sigma)*(double)tmpws;
  490. x2 = randGaussDouble(sigma)*(double)tmpws;
  491. y1 = randGaussDouble(sigma)*(double)tmpws;
  492. y2 = randGaussDouble(sigma)*(double)tmpws;
  493. }
  494. else
  495. {
  496. x1 = (int)((double)rand()/(double)RAND_MAX*(double)tmpws)-tmpws/2;
  497. x2 = (int)((double)rand()/(double)RAND_MAX*(double)tmpws)-tmpws/2;
  498. y1 = (int)((double)rand()/(double)RAND_MAX*(double)tmpws)-tmpws/2;
  499. y2 = (int)((double)rand()/(double)RAND_MAX*(double)tmpws)-tmpws/2;
  500. }
  501. if(ft == 0)
  502. {
  503. int f1 = (int)((double)rand()/(double)RAND_MAX*(double)featdim);
  504. int f2 = (int)((double)rand()/(double)RAND_MAX*(double)featdim);
  505. int o = (int)((double)rand()/(double)RAND_MAX*(double)ops.size());
  506. Operation *op = ops[o]->clone();
  507. op->set(x1,y1,x2,y2,f1,f2, calcVal[ft]);
  508. featsel.push_back(op);
  509. }
  510. else if(ft == 1)
  511. {
  512. int chans = integralImgs[0].channels;
  513. int opssize = (int)ops.size();
  514. int f1 = (int)((double)rand()/(double)RAND_MAX*(double)chans);
  515. int f2 = (int)((double)rand()/(double)RAND_MAX*(double)chans);
  516. int o = (int)((double)rand()/(double)RAND_MAX*((double)cops.size())+(double)opssize);
  517. Operation *op;
  518. if(o < opssize)
  519. {
  520. op = ops[o]->clone();
  521. op->set(x1,y1,x2,y2,f1,f2, calcVal[ft]);
  522. }
  523. else
  524. {
  525. o -= opssize;
  526. op = cops[o]->clone();
  527. op->set(x1,y1,x2,y2,f1,f2, calcVal[ft]);
  528. }
  529. featsel.push_back(op);
  530. }
  531. }
  532. #pragma omp parallel for private(mapit)
  533. for(int f = 0; f < featsPerSplit; f++)
  534. {
  535. double l_bestig = -numeric_limits< double >::max();
  536. double l_splitval = -1.0;
  537. set<vector<int> >::iterator it;
  538. vector<double> vals;
  539. for ( it=selFeats.begin() ; it != selFeats.end(); it++ )
  540. {
  541. Features feat;
  542. feat.feats = &feats[(*it)[0]];
  543. feat.cfeats = &currentfeats[(*it)[0]];
  544. feat.cTree = tree;
  545. feat.tree = &forest[tree];
  546. feat.integralImg = &integralImgs[(*it)[0]];
  547. vals.push_back(featsel[f]->getVal(feat, (*it)[1], (*it)[2]));
  548. }
  549. int counter = 0;
  550. for ( it=selFeats.begin() ; it != selFeats.end(); it++ , counter++)
  551. {
  552. set<vector<int> >::iterator it2;
  553. double val = vals[counter];
  554. map<int,int> eL, eR;
  555. int counterL = 0, counterR = 0;
  556. int counter2 = 0;
  557. for ( it2=selFeats.begin() ; it2 != selFeats.end(); it2++, counter2++ )
  558. {
  559. int cn = labels[(*it2)[0]]((*it2)[1], (*it2)[2]);
  560. //cout << "vals[counter2] " << vals[counter2] << " val: " << val << endl;
  561. if(vals[counter2] < val)
  562. {
  563. //left entropie:
  564. eL[cn] = eL[cn]+1;
  565. counterL++;
  566. }
  567. else
  568. {
  569. //right entropie:
  570. eR[cn] = eR[cn]+1;
  571. counterR++;
  572. }
  573. }
  574. double leftent = 0.0;
  575. for ( mapit=eL.begin() ; mapit != eL.end(); mapit++ )
  576. {
  577. double p = (double)(*mapit).second/(double)counterL;
  578. leftent -= p*log2(p);
  579. }
  580. double rightent = 0.0;
  581. for ( mapit=eR.begin() ; mapit != eR.end(); mapit++ )
  582. {
  583. double p = (double)(*mapit).second/(double)counterR;
  584. rightent -= p*log2(p);
  585. }
  586. //cout << "rightent: " << rightent << " leftent: " << leftent << endl;
  587. double pl = (double)counterL/(double)(counterL+counterR);
  588. double ig = globent - (1.0-pl) * rightent - pl*leftent;
  589. //double ig = globent - rightent - leftent;
  590. if(useShannonEntropy)
  591. {
  592. double esplit = - ( pl*log(pl) + (1-pl)*log(1-pl) );
  593. ig = 2*ig / ( globent + esplit );
  594. }
  595. if(ig > l_bestig)
  596. {
  597. l_bestig = ig;
  598. l_splitval = val;
  599. }
  600. }
  601. #pragma omp critical
  602. {
  603. //cout << "globent: " << globent << " bestig " << bestig << " splitfeat: " << splitfeat << " splitval: " << splitval << endl;
  604. //cout << "globent: " << globent << " l_bestig " << l_bestig << " f: " << p << " l_splitval: " << l_splitval << endl;
  605. //cout << "p: " << featsubset[f] << endl;
  606. if(l_bestig > bestig)
  607. {
  608. bestig = l_bestig;
  609. splitop = featsel[f];
  610. splitval = l_splitval;
  611. }
  612. }
  613. }
  614. //splitop->writeInfos();
  615. //cout<< "ig: " << bestig << endl;
  616. /*for(int i = 0; i < featsPerSplit; i++)
  617. {
  618. if(featsel[i] != splitop)
  619. delete featsel[i];
  620. }*/
  621. #ifdef debug
  622. cout << "globent: " << globent << " bestig " << bestig << " splitval: " << splitval << endl;
  623. #endif
  624. return bestig;
  625. }
  626. inline double SemSegContextTree::getMeanProb(const int &x,const int &y,const int &channel, const MultiChannelImageT<int> &currentfeats)
  627. {
  628. double val = 0.0;
  629. for(int tree = 0; tree < nbTrees; tree++)
  630. {
  631. val += forest[tree][currentfeats.get(x,y,tree)].dist[channel];
  632. }
  633. return val / (double)nbTrees;
  634. }
  635. void SemSegContextTree::computeIntegralImage(const NICE::MultiChannelImageT<int> &currentfeats, const NICE::MultiChannelImageT<int> &lfeats, NICE::MultiChannelImageT<double> &integralImage)
  636. {
  637. int xsize = currentfeats.width();
  638. int ysize = currentfeats.height();
  639. int channels = (int)forest[0][0].dist.size();
  640. #pragma omp parallel for
  641. for(int c = 0; c < channels; c++)
  642. {
  643. integralImage.set(0,0,getMeanProb(0,0,c, currentfeats), c);
  644. //first column
  645. for(int y = 1; y < ysize; y++)
  646. {
  647. integralImage.set(0,y,getMeanProb(0,y,c, currentfeats)+integralImage.get(0,y,c), c);
  648. }
  649. //first row
  650. for(int x = 1; x < xsize; x++)
  651. {
  652. integralImage.set(x,0,getMeanProb(x,0,c, currentfeats)+integralImage.get(x,0,c), c);
  653. }
  654. //rest
  655. for(int y = 1; y < ysize; y++)
  656. {
  657. for(int x = 1; x < xsize; x++)
  658. {
  659. double val = getMeanProb(x,y,c,currentfeats)+integralImage.get(x,y-1,c)+integralImage.get(x-1,y,c)-integralImage.get(x-1,y-1,c);
  660. integralImage.set(x, y, val, c);
  661. }
  662. }
  663. }
  664. int channels2 = (int)lfeats.size();
  665. if(lfeats.get(xsize-1,ysize-1,0) == 0)
  666. {
  667. #pragma omp parallel for
  668. for(int c = channels, int c1 = 0; c1 < channels2; c++, c1++)
  669. {
  670. integralImage.set(0,0,lfeats.get(0,0,c1), c);
  671. //first column
  672. for(int y = 1; y < ysize; y++)
  673. {
  674. integralImage.set(0,y,lfeats.get(0,y,c1)+integralImage.get(0,y,c), c);
  675. }
  676. //first row
  677. for(int x = 1; x < xsize; x++)
  678. {
  679. integralImage.set(x,0,lfeats.get(x,0,c1)+integralImage.get(x,0,c), c);
  680. }
  681. //rest
  682. for(int y = 1; y < ysize; y++)
  683. {
  684. for(int x = 1; x < xsize; x++)
  685. {
  686. double val = lfeats.get(x,y,c1)+integralImage.get(x,y-1,c)+integralImage.get(x-1,y,c)-integralImage.get(x-1,y-1,c);
  687. integralImage.set(x, y, val, c);
  688. }
  689. }
  690. }
  691. }
  692. }
  693. void SemSegContextTree::train ( const MultiDataset *md )
  694. {
  695. const LabeledSet train = * ( *md ) ["train"];
  696. const LabeledSet *trainp = &train;
  697. ProgressBar pb ( "compute feats" );
  698. pb.show();
  699. //TODO: Speichefresser!, lohnt sich sparse?
  700. vector<MultiChannelImageT<double> > allfeats;
  701. vector<MultiChannelImageT<int> > currentfeats;
  702. vector<MatrixT<int> > labels;
  703. std::string forbidden_classes_s = conf->gS ( "analysis", "donttrain", "" );
  704. if ( forbidden_classes_s == "" )
  705. {
  706. forbidden_classes_s = conf->gS ( "analysis", "forbidden_classes", "" );
  707. }
  708. classnames.getSelection ( forbidden_classes_s, forbidden_classes );
  709. int imgcounter = 0;
  710. LOOP_ALL_S ( *trainp )
  711. {
  712. EACH_INFO ( classno,info );
  713. NICE::ColorImage img;
  714. std::string currentFile = info.img();
  715. CachedExample *ce = new CachedExample ( currentFile );
  716. const LocalizationResult *locResult = info.localization();
  717. if ( locResult->size() <= 0 )
  718. {
  719. fprintf ( stderr, "WARNING: NO ground truth polygons found for %s !\n",
  720. currentFile.c_str() );
  721. continue;
  722. }
  723. fprintf ( stderr, "SemSegCsurka: Collecting pixel examples from localization info: %s\n", currentFile.c_str() );
  724. int xsize, ysize;
  725. ce->getImageSize ( xsize, ysize );
  726. MatrixT<int> tmpMat(xsize,ysize);
  727. currentfeats.push_back(MultiChannelImageT<int>(xsize,ysize,nbTrees));
  728. currentfeats[imgcounter].setAll(0);
  729. labels.push_back(tmpMat);
  730. try {
  731. img = ColorImage(currentFile);
  732. } catch (Exception) {
  733. cerr << "SemSeg: error opening image file <" << currentFile << ">" << endl;
  734. continue;
  735. }
  736. Globals::setCurrentImgFN ( currentFile );
  737. //TODO: resize image?!
  738. MultiChannelImageT<double> feats;
  739. allfeats.push_back(feats);
  740. #ifdef LOCALFEATS
  741. lfcw->getFeats(img, allfeats[imgcounter]);
  742. #else
  743. allfeats[imgcounter].reInit(xsize, ysize, 3, true);
  744. for(int x = 0; x < xsize; x++)
  745. {
  746. for(int y = 0; y < ysize; y++)
  747. {
  748. for(int r = 0; r < 3; r++)
  749. {
  750. allfeats[imgcounter].set(x,y,img.getPixel(x,y,r),r);
  751. }
  752. }
  753. }
  754. #endif
  755. // getting groundtruth
  756. NICE::Image pixelLabels (xsize, ysize);
  757. pixelLabels.set(0);
  758. locResult->calcLabeledImage ( pixelLabels, ( *classNames ).getBackgroundClass() );
  759. for(int x = 0; x < xsize; x++)
  760. {
  761. for(int y = 0; y < ysize; y++)
  762. {
  763. classno = pixelLabels.getPixel(x, y);
  764. labels[imgcounter](x,y) = classno;
  765. if ( forbidden_classes.find ( classno ) != forbidden_classes.end() )
  766. continue;
  767. labelcounter[classno]++;
  768. }
  769. }
  770. imgcounter++;
  771. pb.update ( trainp->count());
  772. delete ce;
  773. }
  774. pb.hide();
  775. map<int,int>::iterator mapit;
  776. int classes = 0;
  777. for(mapit = labelcounter.begin(); mapit != labelcounter.end(); mapit++)
  778. {
  779. labelmap[mapit->first] = classes;
  780. labelmapback[classes] = mapit->first;
  781. classes++;
  782. }
  783. //balancing
  784. int featcounter = 0;
  785. a = vector<double>(classes,0.0);
  786. for(int iCounter = 0; iCounter < imgcounter; iCounter++)
  787. {
  788. int xsize = (int)currentfeats[iCounter].width();
  789. int ysize = (int)currentfeats[iCounter].height();
  790. for(int x = 0; x < xsize; x++)
  791. {
  792. for(int y = 0; y < ysize; y++)
  793. {
  794. featcounter++;
  795. int cn = labels[iCounter](x,y);
  796. a[labelmap[cn]] ++;
  797. }
  798. }
  799. }
  800. for(int i = 0; i < (int)a.size(); i++)
  801. {
  802. a[i] /= (double)featcounter;
  803. }
  804. #ifdef DEBUG
  805. for(int i = 0; i < (int)a.size(); i++)
  806. {
  807. cout << "a["<<i<<"]: " << a[i] << endl;
  808. }
  809. cout << "a.size: " << a.size() << endl;
  810. #endif
  811. int depth = 0;
  812. for(int t = 0; t < nbTrees; t++)
  813. {
  814. vector<TreeNode> tree;
  815. tree.push_back(TreeNode());
  816. tree[0].dist = vector<double>(classes,0.0);
  817. tree[0].depth = depth;
  818. forest.push_back(tree);
  819. }
  820. vector<int> startnode(nbTrees,0);
  821. bool allleaf = false;
  822. //int baseFeatSize = allfeats[0].size();
  823. vector<MultiChannelImageT<double> > integralImgs(imgcounter,MultiChannelImageT<double>());
  824. while(!allleaf && depth < maxDepth)
  825. {
  826. allleaf = true;
  827. vector<MultiChannelImageT<int> > lastfeats = currentfeats;
  828. #if 1
  829. Timer timer;
  830. timer.start();
  831. #endif
  832. for(int tree = 0; tree < nbTrees; tree++)
  833. {
  834. int t = (int) forest[tree].size();
  835. int s = startnode[tree];
  836. startnode[tree] = t;
  837. //TODO vielleicht parallel wenn nächste schleife trotzdem noch parallelsiert würde, die hat mehr gewicht
  838. //#pragma omp parallel for
  839. for(int i = s; i < t; i++)
  840. {
  841. if(!forest[tree][i].isleaf && forest[tree][i].left < 0)
  842. {
  843. Operation *splitfeat = NULL;
  844. double splitval;
  845. double bestig = getBestSplit(allfeats, lastfeats, integralImgs, labels, i, splitfeat, splitval, tree);
  846. forest[tree][i].feat = splitfeat;
  847. forest[tree][i].decision = splitval;
  848. if(splitfeat != NULL)
  849. {
  850. allleaf = false;
  851. int left = forest[tree].size();
  852. forest[tree].push_back(TreeNode());
  853. forest[tree].push_back(TreeNode());
  854. int right = left+1;
  855. forest[tree][i].left = left;
  856. forest[tree][i].right = right;
  857. forest[tree][left].dist = vector<double>(classes, 0.0);
  858. forest[tree][right].dist = vector<double>(classes, 0.0);
  859. forest[tree][left].depth = depth+1;
  860. forest[tree][right].depth = depth+1;
  861. #pragma omp parallel for
  862. for(int iCounter = 0; iCounter < imgcounter; iCounter++)
  863. {
  864. int xsize = currentfeats[iCounter].width();
  865. int ysize = currentfeats[iCounter].height();
  866. for(int x = 0; x < xsize; x++)
  867. {
  868. for(int y = 0; y < ysize; y++)
  869. {
  870. if(currentfeats[iCounter].get(x, y, tree) == i)
  871. {
  872. Features feat;
  873. feat.feats = &allfeats[iCounter];
  874. feat.cfeats = &lastfeats[iCounter];
  875. feat.cTree = tree;
  876. feat.tree = &forest[tree];
  877. feat.integralImg = &integralImgs[iCounter];
  878. double val = splitfeat->getVal(feat,x,y);
  879. if(val < splitval)
  880. {
  881. currentfeats[iCounter].set(x,y,left,tree);
  882. forest[tree][left].dist[labelmap[labels[iCounter](x,y)]]++;
  883. }
  884. else
  885. {
  886. currentfeats[iCounter].set(x,y,right,tree);
  887. forest[tree][right].dist[labelmap[labels[iCounter](x,y)]]++;
  888. }
  889. }
  890. }
  891. }
  892. }
  893. double lcounter = 0.0, rcounter = 0.0;
  894. for(uint d = 0; d < forest[tree][left].dist.size(); d++)
  895. {
  896. if ( forbidden_classes.find ( labelmapback[d] ) != forbidden_classes.end() )
  897. {
  898. forest[tree][left].dist[d] = 0;
  899. forest[tree][right].dist[d] = 0;
  900. }
  901. else
  902. {
  903. forest[tree][left].dist[d]/=a[d];
  904. lcounter +=forest[tree][left].dist[d];
  905. forest[tree][right].dist[d]/=a[d];
  906. rcounter +=forest[tree][right].dist[d];
  907. }
  908. }
  909. assert(lcounter > 0 && rcounter > 0);
  910. for(uint d = 0; d < forest[tree][left].dist.size(); d++)
  911. {
  912. forest[tree][left].dist[d]/=lcounter;
  913. forest[tree][right].dist[d]/=rcounter;
  914. }
  915. }
  916. else
  917. {
  918. forest[tree][i].isleaf = true;
  919. }
  920. }
  921. }
  922. }
  923. //TODO: features neu berechnen!
  924. //compute integral image
  925. int channels = classes+allfeats.size();
  926. if(integralImgs[0].width() == 0)
  927. {
  928. for(int i = 0; i < imgcounter; i++)
  929. {
  930. int xsize = allfeats[i].width();
  931. int ysize = allfeats[i].height();
  932. integralImgs[i].reInit(xsize, ysize, channels);
  933. }
  934. }
  935. for(int i = 0; i < imgcounter; i++)
  936. {
  937. computeIntegralImage(currentfeats[i],allfeats[i], integralImgs[i]);
  938. }
  939. #if 1
  940. timer.stop();
  941. cout << "time for depth " << depth << ": " << timer.getLast() << endl;
  942. #endif
  943. depth++;
  944. #ifdef DEBUG
  945. cout << "depth: " << depth << endl;
  946. #endif
  947. }
  948. #ifdef DEBUG
  949. for(int tree = 0; tree < nbTrees; tree++)
  950. {
  951. int t = (int) forest[tree].size();
  952. for(int i = 0; i < t; i++)
  953. {
  954. printf("tree[%i]: left: %i, right: %i", i, forest[tree][i].left, forest[tree][i].right);
  955. if(!forest[tree][i].isleaf && forest[tree][i].left != -1)
  956. cout << ", feat: " << forest[tree][i].feat->writeInfos() << " ";
  957. for(int d = 0; d < (int)forest[tree][i].dist.size(); d++)
  958. {
  959. cout << " " << forest[tree][i].dist[d];
  960. }
  961. cout << endl;
  962. }
  963. }
  964. #endif
  965. }
  966. void SemSegContextTree::semanticseg ( CachedExample *ce, NICE::Image & segresult,NICE::MultiChannelImageT<double> & probabilities )
  967. {
  968. int xsize;
  969. int ysize;
  970. ce->getImageSize ( xsize, ysize );
  971. int numClasses = classNames->numClasses();
  972. fprintf (stderr, "ContextTree classification !\n");
  973. probabilities.reInit ( xsize, ysize, numClasses, true );
  974. probabilities.setAll ( 0 );
  975. NICE::ColorImage img;
  976. std::string currentFile = Globals::getCurrentImgFN();
  977. try {
  978. img = ColorImage(currentFile);
  979. } catch (Exception) {
  980. cerr << "SemSeg: error opening image file <" << currentFile << ">" << endl;
  981. return;
  982. }
  983. //TODO: resize image?!
  984. MultiChannelImageT<double> feats;
  985. #ifdef LOCALFEATS
  986. lfcw->getFeats(img, feats);
  987. #else
  988. feats.reInit (xsize, ysize, 3, true);
  989. for(int x = 0; x < xsize; x++)
  990. {
  991. for(int y = 0; y < ysize; y++)
  992. {
  993. for(int r = 0; r < 3; r++)
  994. {
  995. feats.set(x,y,img.getPixel(x,y,r),r);
  996. }
  997. }
  998. }
  999. #endif
  1000. bool allleaf = false;
  1001. MultiChannelImageT<double> integralImg;
  1002. MultiChannelImageT<int> currentfeats(xsize, ysize, nbTrees);
  1003. currentfeats.setAll(0);
  1004. int depth = 0;
  1005. while(!allleaf)
  1006. {
  1007. allleaf = true;
  1008. //TODO vielleicht parallel wenn nächste schleife auch noch parallelsiert würde, die hat mehr gewicht
  1009. //#pragma omp parallel for
  1010. MultiChannelImageT<int> lastfeats = currentfeats;
  1011. for(int tree = 0; tree < nbTrees; tree++)
  1012. {
  1013. for(int x = 0; x < xsize; x++)
  1014. {
  1015. for(int y = 0; y < ysize; y++)
  1016. {
  1017. int t = currentfeats.get(x,y,tree);
  1018. if(forest[tree][t].left > 0)
  1019. {
  1020. allleaf = false;
  1021. Features feat;
  1022. feat.feats = &feats;
  1023. feat.cfeats = &lastfeats;
  1024. feat.cTree = tree;
  1025. feat.tree = &forest[tree];
  1026. feat.integralImg = &integralImg;
  1027. double val = forest[tree][t].feat->getVal(feat,x,y);
  1028. if(val < forest[tree][t].decision)
  1029. {
  1030. currentfeats.set(x, y, forest[tree][t].left, tree);
  1031. }
  1032. else
  1033. {
  1034. currentfeats.set(x, y, forest[tree][t].right, tree);
  1035. }
  1036. }
  1037. }
  1038. }
  1039. //compute integral image
  1040. int channels = (int)labelmap.size()+feats.size();
  1041. if(integralImg.width() == 0)
  1042. {
  1043. int xsize = feats.width();
  1044. int ysize = feats.height();
  1045. integralImg.reInit(xsize, ysize, channels);
  1046. }
  1047. }
  1048. computeIntegralImage(currentfeats,feats, integralImg);
  1049. depth++;
  1050. }
  1051. if(pixelWiseLabeling)
  1052. {
  1053. //finales labeln:
  1054. long int offset = 0;
  1055. for(int x = 0; x < xsize; x++)
  1056. {
  1057. for(int y = 0; y < ysize; y++,offset++)
  1058. {
  1059. double maxvalue = - numeric_limits<double>::max(); //TODO: das muss nur pro knoten gemacht werden, nicht pro pixel
  1060. int maxindex = 0;
  1061. uint s = forest[0][0].dist.size();
  1062. for(uint i = 0; i < s; i++)
  1063. {
  1064. probabilities.data[labelmapback[i]][offset] = getMeanProb(x,y,i,currentfeats);
  1065. if(probabilities.data[labelmapback[i]][offset] > maxvalue)
  1066. {
  1067. maxvalue = probabilities.data[labelmapback[i]][offset];
  1068. maxindex = labelmapback[i];
  1069. }
  1070. segresult.setPixel(x,y,maxindex);
  1071. }
  1072. }
  1073. }
  1074. }
  1075. else
  1076. {
  1077. //final labeling using segmentation
  1078. //TODO: segmentation
  1079. Matrix regions;
  1080. int regionNumber = segmentation->segRegions(img,regions);
  1081. cout << "regions: " << regionNumber << endl;
  1082. int dSize = (int)labelmap.size();
  1083. vector<vector<double> > regionProbs(regionNumber, vector<double>(dSize,0.0));
  1084. vector<int> bestlabels(regionNumber, 0);
  1085. for(int y = 0; y < img.height(); y++)
  1086. {
  1087. for(int x = 0; x < img.width(); x++)
  1088. {
  1089. int cregion = regions(x,y);
  1090. for(int d = 0; d < dSize; d++)
  1091. {
  1092. regionProbs[cregion][d]+=getMeanProb(x,y,d,currentfeats);
  1093. }
  1094. }
  1095. }
  1096. for(int r = 0; r < regionNumber; r++)
  1097. {
  1098. double maxval = regionProbs[r][0];
  1099. for(int d = 1; d < dSize; d++)
  1100. {
  1101. if(maxval < regionProbs[r][d])
  1102. {
  1103. maxval = regionProbs[r][d];
  1104. bestlabels[r] = d;
  1105. }
  1106. }
  1107. bestlabels[r] = labelmapback[bestlabels[r]];
  1108. }
  1109. for(int y = 0; y < img.height(); y++)
  1110. {
  1111. for(int x = 0; x < img.width(); x++)
  1112. {
  1113. segresult.setPixel(x,y,bestlabels[regions(x,y)]);
  1114. }
  1115. }
  1116. }
  1117. }