BP-S.cpp 25 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057
  1. #include <stdio.h>
  2. #include <stdlib.h>
  3. #include <string.h>
  4. #include <assert.h>
  5. #include <new>
  6. #include "BP-S.h"
  7. #define private public
  8. #include "typeTruncatedQuadratic2D.h"
  9. using namespace OBJREC;
  10. #undef private
  11. #define m_D(pix,l) m_D[(pix)*m_nLabels+(l)]
  12. #define m_V(l1,l2) m_V[(l1)*m_nLabels+(l2)]
  13. #define MIN(a,b) (((a) < (b)) ? (a) : (b))
  14. #define MAX(a,b) (((a) > (b)) ? (a) : (b))
  15. #define TRUNCATE_MIN(a,b) { if ((a) > (b)) (a) = (b); }
  16. #define TRUNCATE_MAX(a,b) { if ((a) < (b)) (a) = (b); }
  17. #define TRUNCATE TRUNCATE_MIN
  18. /////////////////////////////////////////////////////////////////////////////
  19. // Operations on vectors (arrays of size K) //
  20. /////////////////////////////////////////////////////////////////////////////
  21. inline void CopyVector(BPS::REAL* to, MRF::CostVal* from, int K)
  22. {
  23. BPS::REAL* to_finish = to + K;
  24. do
  25. {
  26. *to ++ = *from ++;
  27. } while (to < to_finish);
  28. }
  29. inline void AddVector(BPS::REAL* to, BPS::REAL* from, int K)
  30. {
  31. BPS::REAL* to_finish = to + K;
  32. do
  33. {
  34. *to ++ += *from ++;
  35. } while (to < to_finish);
  36. }
  37. inline BPS::REAL SubtractMin(BPS::REAL *D, int K)
  38. {
  39. int k;
  40. BPS::REAL delta;
  41. delta = D[0];
  42. for (k=1; k<K; k++) TRUNCATE(delta, D[k]);
  43. for (k=0; k<K; k++) D[k] -= delta;
  44. return delta;
  45. }
  46. // Functions UpdateMessageTYPE (see the paper for details):
  47. //
  48. // - Set Di[ki] := gamma*Di_hat[ki] - M[ki]
  49. // - Set M[kj] := min_{ki} (Di[ki] + V[ki,kj])
  50. // - Normalize message:
  51. // delta := min_{kj} M[kj]
  52. // M[kj] := M[kj] - delta
  53. // return delta
  54. //
  55. // If dir = 1, then the meaning of i and j is swapped.
  56. ///////////////////////////////////////////
  57. // L1 //
  58. ///////////////////////////////////////////
  59. inline BPS::REAL UpdateMessageL1(BPS::REAL* M, BPS::REAL* Di_hat, int K, BPS::REAL gamma, MRF::CostVal lambda, MRF::CostVal smoothMax)
  60. {
  61. int k;
  62. BPS::REAL delta;
  63. delta = M[0] = gamma*Di_hat[0] - M[0];
  64. for (k=1; k<K; k++)
  65. {
  66. M[k] = gamma*Di_hat[k] - M[k];
  67. TRUNCATE(delta, M[k]);
  68. TRUNCATE(M[k], M[k-1] + lambda);
  69. }
  70. M[--k] -= delta;
  71. TRUNCATE(M[k], lambda*smoothMax);
  72. for (k--; k>=0; k--)
  73. {
  74. M[k] -= delta;
  75. TRUNCATE(M[k], M[k+1] + lambda);
  76. TRUNCATE(M[k], lambda*smoothMax);
  77. }
  78. return delta;
  79. }
  80. ////////////////////////////////////////
  81. // L2 //
  82. ////////////////////////////////////////
  83. inline BPS::REAL UpdateMessageL2(BPS::REAL* M, BPS::REAL* Di_hat, int K, BPS::REAL gamma, MRF::CostVal lambda, MRF::CostVal smoothMax, void *buf)
  84. {
  85. BPS::REAL* Di = (BPS::REAL*) buf;
  86. int* parabolas = (int*) ((char*)buf + K*sizeof(BPS::REAL));
  87. int* intersections = parabolas + K;
  88. TypeTruncatedQuadratic2D::REAL* Di_tmp = (TypeTruncatedQuadratic2D::REAL*) (intersections + K + 1);
  89. TypeTruncatedQuadratic2D::REAL* M_tmp = Di_tmp + K;
  90. TypeTruncatedQuadratic2D::Edge* tmp = NULL;
  91. int k;
  92. BPS::REAL delta;
  93. assert(lambda >= 0);
  94. Di[0] = gamma*Di_hat[0] - M[0];
  95. delta = Di[0];
  96. for (k=1; k<K; k++)
  97. {
  98. Di[k] = gamma*Di_hat[k] - M[k];
  99. TRUNCATE(delta, Di[k]);
  100. }
  101. if (lambda == 0)
  102. {
  103. for (k=0; k<K; k++) M[k] = 0;
  104. return delta;
  105. }
  106. for (k=0; k<K; k++) Di_tmp[k] = Di[k];
  107. tmp->DistanceTransformL2(K, 1, lambda, Di_tmp, M_tmp, parabolas, intersections);
  108. for (k=0; k<K; k++) M[k] = (BPS::REAL) M_tmp[k];
  109. for (k=0; k<K; k++)
  110. {
  111. M[k] -= delta;
  112. TRUNCATE(M[k], lambda*smoothMax);
  113. }
  114. return delta;
  115. }
  116. //////////////////////////////////////////////////
  117. // FIXED_MATRIX //
  118. //////////////////////////////////////////////////
  119. inline BPS::REAL UpdateMessageFIXED_MATRIX(BPS::REAL* M, BPS::REAL* Di_hat, int K, BPS::REAL gamma, MRF::CostVal lambda, MRF::CostVal* V, void* buf)
  120. {
  121. BPS::REAL* Di = (BPS::REAL*) buf;
  122. int ki, kj;
  123. BPS::REAL delta;
  124. for (ki=0; ki<K; ki++)
  125. {
  126. Di[ki] = gamma*Di_hat[ki] - M[ki];
  127. }
  128. for (kj=0; kj<K; kj++)
  129. {
  130. M[kj] = Di[0] + lambda*V[0];
  131. V ++;
  132. for (ki=1; ki<K; ki++)
  133. {
  134. TRUNCATE(M[kj], Di[ki] + lambda*V[0]);
  135. V ++;
  136. }
  137. }
  138. delta = M[0];
  139. for (kj=1; kj<K; kj++) TRUNCATE(delta, M[kj]);
  140. for (kj=0; kj<K; kj++) M[kj] -= delta;
  141. return delta;
  142. }
  143. /////////////////////////////////////////////
  144. // GENERAL //
  145. /////////////////////////////////////////////
  146. inline BPS::REAL UpdateMessageGENERAL(BPS::REAL* M, BPS::REAL* Di_hat, int K, BPS::REAL gamma, int dir, MRF::CostVal* V, void* buf)
  147. {
  148. BPS::REAL* Di = (BPS::REAL*) buf;
  149. int ki, kj;
  150. BPS::REAL delta;
  151. for (ki=0; ki<K; ki++)
  152. {
  153. Di[ki] = (gamma*Di_hat[ki] - M[ki]);
  154. }
  155. if (dir == 0)
  156. {
  157. for (kj=0; kj<K; kj++)
  158. {
  159. M[kj] = Di[0] + V[0];
  160. V ++;
  161. for (ki=1; ki<K; ki++)
  162. {
  163. TRUNCATE(M[kj], Di[ki] + V[0]);
  164. V ++;
  165. }
  166. }
  167. }
  168. else
  169. {
  170. for (kj=0; kj<K; kj++)
  171. {
  172. M[kj] = Di[0] + V[0];
  173. V += K;
  174. for (ki=1; ki<K; ki++)
  175. {
  176. TRUNCATE(M[kj], Di[ki] + V[0]);
  177. V += K;
  178. }
  179. V -= K*K - 1;
  180. }
  181. }
  182. delta = M[0];
  183. for (kj=1; kj<K; kj++) TRUNCATE(delta, M[kj]);
  184. for (kj=0; kj<K; kj++) M[kj] -= delta;
  185. return delta;
  186. }
  187. inline BPS::REAL UpdateMessageGENERAL(BPS::REAL* M, BPS::REAL* Di_hat, int K, BPS::REAL gamma, BPS::SmoothCostGeneralFn fn, int i, int j, void* buf)
  188. {
  189. BPS::REAL* Di = (BPS::REAL*) buf;
  190. int ki, kj;
  191. BPS::REAL delta;
  192. for (ki=0; ki<K; ki++)
  193. {
  194. Di[ki] = (gamma*Di_hat[ki] - M[ki]);
  195. }
  196. for (kj=0; kj<K; kj++)
  197. {
  198. M[kj] = Di[0] + fn(i, j, 0, kj);
  199. for (ki=1; ki<K; ki++)
  200. {
  201. delta = Di[ki] + fn(i, j, ki, kj);
  202. TRUNCATE(M[kj], delta);
  203. }
  204. }
  205. delta = M[0];
  206. for (kj=1; kj<K; kj++) TRUNCATE(delta, M[kj]);
  207. for (kj=0; kj<K; kj++) M[kj] -= delta;
  208. return delta;
  209. }
  210. BPS::BPS(int width, int height, int nLabels,EnergyFunction *eng):MRF(width,height,nLabels,eng)
  211. {
  212. Allocate();
  213. }
  214. BPS::BPS(int nPixels, int nLabels,EnergyFunction *eng):MRF(nPixels,nLabels,eng)
  215. {
  216. Allocate();
  217. }
  218. BPS::~BPS()
  219. {
  220. delete[] m_answer;
  221. if ( m_needToFreeD ) delete [] m_D;
  222. if ( m_needToFreeV ) delete [] m_V;
  223. if ( m_messages ) delete [] m_messages;
  224. if ( m_DBinary ) delete [] m_DBinary;
  225. if ( m_horzWeightsBinary ) delete [] m_horzWeightsBinary;
  226. if ( m_vertWeightsBinary ) delete [] m_vertWeightsBinary;
  227. }
  228. void BPS::Allocate()
  229. {
  230. m_type = NONE;
  231. m_needToFreeV = false;
  232. m_needToFreeD = false;
  233. m_D = NULL;
  234. m_V = NULL;
  235. m_horzWeights = NULL;
  236. m_vertWeights = NULL;
  237. m_horzWeightsBinary = NULL;
  238. m_vertWeightsBinary = NULL;
  239. m_DBinary = NULL;
  240. m_messages = NULL;
  241. m_messageArraySizeInBytes = 0;
  242. m_answer = new Label[m_nPixels];
  243. }
  244. void BPS::clearAnswer()
  245. {
  246. memset(m_answer, 0, m_nPixels*sizeof(Label));
  247. if (m_messages)
  248. {
  249. memset(m_messages, 0, m_messageArraySizeInBytes);
  250. }
  251. }
  252. MRF::EnergyVal BPS::smoothnessEnergy()
  253. {
  254. EnergyVal eng = (EnergyVal) 0;
  255. EnergyVal weight;
  256. int x,y,pix;
  257. if ( m_grid_graph )
  258. {
  259. if ( m_smoothType != FUNCTION )
  260. {
  261. for ( y = 0; y < m_height; y++ )
  262. for ( x = 1; x < m_width; x++ )
  263. {
  264. pix = x+y*m_width;
  265. weight = m_varWeights ? m_horzWeights[pix-1] : 1;
  266. eng = eng + m_V(m_answer[pix],m_answer[pix-1])*weight;
  267. }
  268. for ( y = 1; y < m_height; y++ )
  269. for ( x = 0; x < m_width; x++ )
  270. {
  271. pix = x+y*m_width;
  272. weight = m_varWeights ? m_vertWeights[pix-m_width] : 1;
  273. eng = eng + m_V(m_answer[pix],m_answer[pix-m_width])*weight;
  274. }
  275. }
  276. else
  277. {
  278. for ( y = 0; y < m_height; y++ )
  279. for ( x = 1; x < m_width; x++ )
  280. {
  281. pix = x+y*m_width;
  282. eng = eng + m_smoothFn(pix,pix-1,m_answer[pix],m_answer[pix-1]);
  283. }
  284. for ( y = 1; y < m_height; y++ )
  285. for ( x = 0; x < m_width; x++ )
  286. {
  287. pix = x+y*m_width;
  288. eng = eng + m_smoothFn(pix,pix-m_width,m_answer[pix],m_answer[pix-m_width]);
  289. }
  290. }
  291. }
  292. else
  293. {
  294. // not implemented
  295. }
  296. return(eng);
  297. }
  298. MRF::EnergyVal BPS::dataEnergy()
  299. {
  300. EnergyVal eng = (EnergyVal) 0;
  301. if ( m_dataType == ARRAY)
  302. {
  303. for ( int i = 0; i < m_nPixels; i++ )
  304. eng = eng + m_D(i,m_answer[i]);
  305. }
  306. else
  307. {
  308. for ( int i = 0; i < m_nPixels; i++ )
  309. eng = eng + m_dataFn(i,m_answer[i]);
  310. }
  311. return(eng);
  312. }
  313. void BPS::setData(DataCostFn dcost)
  314. {
  315. int i, k;
  316. m_dataFn = dcost;
  317. CostVal* ptr;
  318. m_D = new CostVal[m_nPixels*m_nLabels];
  319. for (ptr=m_D, i=0; i<m_nPixels; i++)
  320. for (k=0; k<m_nLabels; k++, ptr++)
  321. {
  322. *ptr = m_dataFn(i,k);
  323. }
  324. m_needToFreeD = true;
  325. }
  326. void BPS::setData(CostVal* data)
  327. {
  328. m_D = data;
  329. m_needToFreeD = false;
  330. }
  331. void BPS::setSmoothness(SmoothCostGeneralFn cost)
  332. {
  333. assert(m_horzWeights == NULL && m_vertWeights == NULL && m_V == NULL);
  334. int x, y, i, ki, kj;
  335. CostVal* ptr;
  336. m_smoothFn = cost;
  337. m_type = GENERAL;
  338. if (!m_allocateArrayForSmoothnessCostFn) return;
  339. // try to cache all the function values in an array for efficiency
  340. m_V = new(std::nothrow) CostVal[2*m_nPixels*m_nLabels*m_nLabels];
  341. if (!m_V) {
  342. fprintf(stderr, "not caching smoothness cost values (not enough memory)\n");
  343. return; // if not enough space, just call the function directly
  344. }
  345. m_needToFreeV = true;
  346. for (ptr=m_V,i=0,y=0; y<m_height; y++)
  347. for (x=0; x<m_width; x++, i++)
  348. {
  349. if (x < m_width-1)
  350. {
  351. for (kj=0; kj<m_nLabels; kj++)
  352. for (ki=0; ki<m_nLabels; ki++)
  353. {
  354. *ptr++ = cost(i,i+1,ki,kj);
  355. }
  356. }
  357. else ptr += m_nLabels*m_nLabels;
  358. if (y < m_height-1)
  359. {
  360. for (kj=0; kj<m_nLabels; kj++)
  361. for (ki=0; ki<m_nLabels; ki++)
  362. {
  363. *ptr++ = cost(i,i+m_width,ki,kj);
  364. }
  365. }
  366. else ptr += m_nLabels*m_nLabels;
  367. }
  368. }
  369. void BPS::setSmoothness(CostVal* V)
  370. {
  371. m_type = FIXED_MATRIX;
  372. m_V = V;
  373. }
  374. void BPS::setSmoothness(int smoothExp,CostVal smoothMax, CostVal lambda)
  375. {
  376. assert(smoothExp == 1 || smoothExp == 2);
  377. assert(lambda >= 0);
  378. m_type = (smoothExp == 1) ? L1 : L2;
  379. int ki, kj;
  380. CostVal cost;
  381. m_needToFreeV = true;
  382. m_V = new CostVal[m_nLabels*m_nLabels];
  383. for (ki=0; ki<m_nLabels; ki++)
  384. for (kj=ki; kj<m_nLabels; kj++)
  385. {
  386. cost = (CostVal) ((smoothExp == 1) ? kj - ki : (kj - ki)*(kj - ki));
  387. if (cost > smoothMax) cost = smoothMax;
  388. m_V[ki*m_nLabels + kj] = m_V[kj*m_nLabels + ki] = cost*lambda;
  389. }
  390. m_smoothMax = smoothMax;
  391. m_lambda = lambda;
  392. }
  393. void BPS::setCues(CostVal* hCue, CostVal* vCue)
  394. {
  395. m_horzWeights = hCue;
  396. m_vertWeights = vCue;
  397. }
  398. void BPS::initializeAlg()
  399. {
  400. assert(m_type != NONE);
  401. int i;
  402. // determine type
  403. if (m_type == L1 && m_nLabels == 2)
  404. {
  405. m_type = BINARY;
  406. }
  407. // allocate messages
  408. int messageNum = (m_type == BINARY) ? 4*m_nPixels : 4*m_nPixels*m_nLabels;
  409. m_messageArraySizeInBytes = messageNum*sizeof(REAL);
  410. m_messages = new REAL[messageNum];
  411. memset(m_messages, 0, messageNum*sizeof(REAL));
  412. if (m_type == BINARY)
  413. {
  414. assert(m_DBinary == NULL && m_horzWeightsBinary == NULL && m_horzWeightsBinary == NULL);
  415. m_DBinary = new CostVal[m_nPixels];
  416. m_horzWeightsBinary = new CostVal[m_nPixels];
  417. m_vertWeightsBinary = new CostVal[m_nPixels];
  418. if ( m_dataType == ARRAY)
  419. {
  420. for (i=0; i<m_nPixels; i++)
  421. {
  422. m_DBinary[i] = m_D[2*i+1] - m_D[2*i];
  423. }
  424. }
  425. else
  426. {
  427. for (i=0; i<m_nPixels; i++)
  428. {
  429. m_DBinary[i] = m_dataFn(i,1) - m_dataFn(i,0);
  430. }
  431. }
  432. assert(m_V[0] == 0 && m_V[1] == m_V[2] && m_V[3] == 0);
  433. for (i=0; i<m_nPixels; i++)
  434. {
  435. m_horzWeightsBinary[i] = (m_varWeights) ? m_V[1]*m_horzWeights[i] : m_V[1];
  436. m_vertWeightsBinary[i] = (m_varWeights) ? m_V[1]*m_vertWeights[i] : m_V[1];
  437. }
  438. }
  439. }
  440. void BPS::optimizeAlg(int nIterations)
  441. {
  442. assert(m_type != NONE);
  443. if (m_grid_graph)
  444. {
  445. switch (m_type)
  446. {
  447. case L1: optimize_GRID_L1(nIterations); break;
  448. case L2: optimize_GRID_L2(nIterations); break;
  449. case FIXED_MATRIX: optimize_GRID_FIXED_MATRIX(nIterations); break;
  450. case GENERAL: optimize_GRID_GENERAL(nIterations); break;
  451. case BINARY: optimize_GRID_BINARY(nIterations); break;
  452. default: assert(0); exit(1);
  453. }
  454. }
  455. else {printf("\nNot implemented for general graphs yet, exiting!");exit(1);}
  456. // printf("lower bound = %f\n", m_lowerBound);
  457. ////////////////////////////////////////////////
  458. // computing solution //
  459. ////////////////////////////////////////////////
  460. if (m_type != BINARY)
  461. {
  462. int x, y, n, K = m_nLabels;
  463. CostVal* D_ptr;
  464. REAL* M_ptr;
  465. REAL* Di;
  466. REAL delta;
  467. int ki, kj;
  468. Di = new REAL[K];
  469. n = 0;
  470. D_ptr = m_D;
  471. M_ptr = m_messages;
  472. for (y=0; y<m_height; y++)
  473. for (x=0; x<m_width; x++, D_ptr+=K, M_ptr+=2*K, n++)
  474. {
  475. CopyVector(Di, D_ptr, K);
  476. if (m_type == GENERAL)
  477. {
  478. if (m_V)
  479. {
  480. CostVal* ptr = m_V + 2*(x+y*m_width-1)*K*K;
  481. if (x > 0)
  482. {
  483. kj = m_answer[n-1];
  484. for (ki=0; ki<K; ki++)
  485. {
  486. Di[ki] += ptr[kj + ki*K];
  487. }
  488. }
  489. ptr -= (2*m_width-3)*K*K;
  490. if (y > 0)
  491. {
  492. kj = m_answer[n-m_width];
  493. for (ki=0; ki<K; ki++)
  494. {
  495. Di[ki] += ptr[kj + ki*K];
  496. }
  497. }
  498. }
  499. else
  500. {
  501. if (x > 0)
  502. {
  503. kj = m_answer[n-1];
  504. for (ki=0; ki<K; ki++)
  505. {
  506. Di[ki] += m_smoothFn(n, n-1, ki, kj);
  507. }
  508. }
  509. if (y > 0)
  510. {
  511. kj = m_answer[n-m_width];
  512. for (ki=0; ki<K; ki++)
  513. {
  514. Di[ki] += m_smoothFn(n, n-m_width, ki, kj);
  515. }
  516. }
  517. }
  518. }
  519. else // m_type == L1, L2 or FIXED_MATRIX
  520. {
  521. if (x > 0)
  522. {
  523. kj = m_answer[n-1];
  524. CostVal lambda = (m_varWeights) ? m_horzWeights[n-1] : 1;
  525. for (ki=0; ki<K; ki++)
  526. {
  527. Di[ki] += lambda*m_V[kj*K + ki];
  528. }
  529. }
  530. if (y > 0)
  531. {
  532. kj = m_answer[n-m_width];
  533. CostVal lambda = (m_varWeights) ? m_vertWeights[n-m_width] : 1;
  534. for (ki=0; ki<K; ki++)
  535. {
  536. Di[ki] += lambda*m_V[kj*K + ki];
  537. }
  538. }
  539. }
  540. if (x < m_width-1) AddVector(Di, M_ptr, K); // message (x+1,y)->(x,y)
  541. if (y < m_height-1) AddVector(Di, M_ptr+K, K); // message (x,y+1)->(x,y)
  542. // compute min
  543. delta = Di[0];
  544. m_answer[n] = 0;
  545. for (ki=1; ki<K; ki++)
  546. {
  547. if (delta > Di[ki])
  548. {
  549. delta = Di[ki];
  550. m_answer[n] = ki;
  551. }
  552. }
  553. }
  554. delete [] Di;
  555. }
  556. else // m_type == BINARY
  557. {
  558. int x, y, n;
  559. REAL* M_ptr;
  560. REAL Di;
  561. n = 0;
  562. M_ptr = m_messages;
  563. for (y=0; y<m_height; y++)
  564. for (x=0; x<m_width; x++, M_ptr+=2, n++)
  565. {
  566. Di = m_DBinary[n];
  567. if (x > 0) Di += (m_answer[n-1] == 0) ? m_horzWeightsBinary[n-1] : -m_horzWeightsBinary[n-1];
  568. if (y > 0) Di += (m_answer[n-m_width] == 0) ? m_vertWeightsBinary[n-m_width] : -m_vertWeightsBinary[n-m_width];
  569. if (x < m_width-1) Di += M_ptr[0]; // message (x+1,y)->(x,y)
  570. if (y < m_height-1) Di += M_ptr[1]; // message (x,y+1)->(x,y)
  571. // compute min
  572. m_answer[n] = (Di >= 0) ? 0 : 1;
  573. }
  574. }
  575. }
  576. void BPS::optimize_GRID_L1(int nIterations)
  577. {
  578. int x, y, n, K = m_nLabels;
  579. CostVal* D_ptr;
  580. REAL* M_ptr;
  581. REAL* Di;
  582. Di = new REAL[K];
  583. for ( ; nIterations > 0; nIterations --)
  584. {
  585. // forward pass
  586. n = 0;
  587. D_ptr = m_D;
  588. M_ptr = m_messages;
  589. for (y=0; y<m_height; y++)
  590. for (x=0; x<m_width; x++, D_ptr+=K, M_ptr+=2*K, n++)
  591. {
  592. CopyVector(Di, D_ptr, K);
  593. if (x > 0) AddVector(Di, M_ptr-2*K, K); // message (x-1,y)->(x,y)
  594. if (y > 0) AddVector(Di, M_ptr-(2*m_width-1)*K, K); // message (x,y-1)->(x,y)
  595. if (x < m_width-1) AddVector(Di, M_ptr, K); // message (x+1,y)->(x,y)
  596. if (y < m_height-1) AddVector(Di, M_ptr+K, K); // message (x,y+1)->(x,y)
  597. if (x < m_width-1)
  598. {
  599. CostVal lambda = (m_varWeights) ? m_lambda*m_horzWeights[n] : m_lambda;
  600. UpdateMessageL1(M_ptr, Di, K, 1, lambda, m_smoothMax);
  601. }
  602. if (y < m_height-1)
  603. {
  604. CostVal lambda = (m_varWeights) ? m_lambda*m_vertWeights[n] : m_lambda;
  605. UpdateMessageL1(M_ptr+K, Di, K, 1, lambda, m_smoothMax);
  606. }
  607. }
  608. // backward pass
  609. n --;
  610. D_ptr -= K;
  611. M_ptr -= 2*K;
  612. for (y=m_height-1; y>=0; y--)
  613. for (x=m_width-1; x>=0; x--, D_ptr-=K, M_ptr-=2*K, n--)
  614. {
  615. CopyVector(Di, D_ptr, K);
  616. if (x > 0) AddVector(Di, M_ptr-2*K, K); // message (x-1,y)->(x,y)
  617. if (y > 0) AddVector(Di, M_ptr-(2*m_width-1)*K, K); // message (x,y-1)->(x,y)
  618. if (x < m_width-1) AddVector(Di, M_ptr, K); // message (x+1,y)->(x,y)
  619. if (y < m_height-1) AddVector(Di, M_ptr+K, K); // message (x,y+1)->(x,y)
  620. SubtractMin(Di, K);
  621. if (x > 0)
  622. {
  623. CostVal lambda = (m_varWeights) ? m_lambda*m_horzWeights[n-1] : m_lambda;
  624. UpdateMessageL1(M_ptr-2*K, Di, K, 1, lambda, m_smoothMax);
  625. }
  626. if (y > 0)
  627. {
  628. CostVal lambda = (m_varWeights) ? m_lambda*m_vertWeights[n-m_width] : m_lambda;
  629. UpdateMessageL1(M_ptr-(2*m_width-1)*K, Di, K, 1, lambda, m_smoothMax);
  630. }
  631. }
  632. }
  633. delete [] Di;
  634. }
  635. void BPS::optimize_GRID_L2(int nIterations)
  636. {
  637. int x, y, n, K = m_nLabels;
  638. CostVal* D_ptr;
  639. REAL* M_ptr;
  640. REAL* Di;
  641. void* buf;
  642. Di = new REAL[K];
  643. buf = new char[2*K*sizeof(TypeTruncatedQuadratic2D::REAL) + (2*K+1)*sizeof(int) + K*sizeof(REAL)];
  644. for ( ; nIterations > 0; nIterations --)
  645. {
  646. // forward pass
  647. n = 0;
  648. D_ptr = m_D;
  649. M_ptr = m_messages;
  650. for (y=0; y<m_height; y++)
  651. for (x=0; x<m_width; x++, D_ptr+=K, M_ptr+=2*K, n++)
  652. {
  653. CopyVector(Di, D_ptr, K);
  654. if (x > 0) AddVector(Di, M_ptr-2*K, K); // message (x-1,y)->(x,y)
  655. if (y > 0) AddVector(Di, M_ptr-(2*m_width-1)*K, K); // message (x,y-1)->(x,y)
  656. if (x < m_width-1) AddVector(Di, M_ptr, K); // message (x+1,y)->(x,y)
  657. if (y < m_height-1) AddVector(Di, M_ptr+K, K); // message (x,y+1)->(x,y)
  658. if (x < m_width-1)
  659. {
  660. CostVal lambda = (m_varWeights) ? m_lambda*m_horzWeights[n] : m_lambda;
  661. UpdateMessageL2(M_ptr, Di, K, 1, lambda, m_smoothMax, buf);
  662. }
  663. if (y < m_height-1)
  664. {
  665. CostVal lambda = (m_varWeights) ? m_lambda*m_vertWeights[n] : m_lambda;
  666. UpdateMessageL2(M_ptr+K, Di, K, 1, lambda, m_smoothMax, buf);
  667. }
  668. }
  669. // backward pass
  670. n --;
  671. D_ptr -= K;
  672. M_ptr -= 2*K;
  673. for (y=m_height-1; y>=0; y--)
  674. for (x=m_width-1; x>=0; x--, D_ptr-=K, M_ptr-=2*K, n--)
  675. {
  676. CopyVector(Di, D_ptr, K);
  677. if (x > 0) AddVector(Di, M_ptr-2*K, K); // message (x-1,y)->(x,y)
  678. if (y > 0) AddVector(Di, M_ptr-(2*m_width-1)*K, K); // message (x,y-1)->(x,y)
  679. if (x < m_width-1) AddVector(Di, M_ptr, K); // message (x+1,y)->(x,y)
  680. if (y < m_height-1) AddVector(Di, M_ptr+K, K); // message (x,y+1)->(x,y)
  681. SubtractMin(Di, K);
  682. if (x > 0)
  683. {
  684. CostVal lambda = (m_varWeights) ? m_lambda*m_horzWeights[n-1] : m_lambda;
  685. UpdateMessageL2(M_ptr-2*K, Di, K, 1, lambda, m_smoothMax, buf);
  686. }
  687. if (y > 0)
  688. {
  689. CostVal lambda = (m_varWeights) ? m_lambda*m_vertWeights[n-m_width] : m_lambda;
  690. UpdateMessageL2(M_ptr-(2*m_width-1)*K, Di, K, 1, lambda, m_smoothMax, buf);
  691. }
  692. }
  693. }
  694. delete [] Di;
  695. delete [] (REAL *)buf;
  696. }
  697. void BPS::optimize_GRID_BINARY(int nIterations)
  698. {
  699. int x, y, n;
  700. REAL* M_ptr;
  701. REAL Di;
  702. for ( ; nIterations > 0; nIterations --)
  703. {
  704. // forward pass
  705. n = 0;
  706. M_ptr = m_messages;
  707. for (y=0; y<m_height; y++)
  708. for (x=0; x<m_width; x++, M_ptr+=2, n++)
  709. {
  710. Di = m_DBinary[n];
  711. if (x > 0) Di += M_ptr[-2]; // message (x-1,y)->(x,y)
  712. if (y > 0) Di += M_ptr[-2*m_width+1]; // message (x,y-1)->(x,y)
  713. if (x < m_width-1) Di += M_ptr[0]; // message (x+1,y)->(x,y)
  714. if (y < m_height-1) Di += M_ptr[1]; // message (x,y+1)->(x,y)
  715. REAL DiScaled = Di * 1;
  716. if (x < m_width-1)
  717. {
  718. Di = DiScaled - M_ptr[0];
  719. CostVal lambda = m_horzWeightsBinary[n];
  720. if (lambda < 0) { Di = -Di; lambda = -lambda; }
  721. if (Di > lambda) M_ptr[0] = lambda;
  722. else M_ptr[0] = (Di < -lambda) ? -lambda : Di;
  723. }
  724. if (y < m_height-1)
  725. {
  726. Di = DiScaled - M_ptr[1];
  727. CostVal lambda = m_vertWeightsBinary[n];
  728. if (lambda < 0) { Di = -Di; lambda = -lambda; }
  729. if (Di > lambda) M_ptr[1] = lambda;
  730. else M_ptr[1] = (Di < -lambda) ? -lambda : Di;
  731. }
  732. }
  733. // backward pass
  734. n --;
  735. M_ptr -= 2;
  736. for (y=m_height-1; y>=0; y--)
  737. for (x=m_width-1; x>=0; x--, M_ptr-=2, n--)
  738. {
  739. Di = m_DBinary[n];
  740. if (x > 0) Di += M_ptr[-2]; // message (x-1,y)->(x,y)
  741. if (y > 0) Di += M_ptr[-2*m_width+1]; // message (x,y-1)->(x,y)
  742. if (x < m_width-1) Di += M_ptr[0]; // message (x+1,y)->(x,y)
  743. if (y < m_height-1) Di += M_ptr[1]; // message (x,y+1)->(x,y)
  744. REAL DiScaled = Di * 1;
  745. if (x > 0)
  746. {
  747. Di = DiScaled - M_ptr[-2];
  748. CostVal lambda = m_horzWeightsBinary[n-1];
  749. if (lambda < 0) { Di = -Di; lambda = -lambda; }
  750. if (Di > lambda) M_ptr[-2] = lambda;
  751. else M_ptr[-2] = (Di < -lambda) ? -lambda : Di;
  752. }
  753. if (y > 0)
  754. {
  755. Di = DiScaled - M_ptr[-2*m_width+1];
  756. CostVal lambda = m_vertWeightsBinary[n-m_width];
  757. if (lambda < 0) { Di = -Di; lambda = -lambda; }
  758. if (Di > lambda) M_ptr[-2*m_width+1] = lambda;
  759. else M_ptr[-2*m_width+1] = (Di < -lambda) ? -lambda : Di;
  760. }
  761. }
  762. }
  763. }
  764. void BPS::optimize_GRID_FIXED_MATRIX(int nIterations)
  765. {
  766. int x, y, n, K = m_nLabels;
  767. CostVal* D_ptr;
  768. REAL* M_ptr;
  769. REAL* Di;
  770. void* buf;
  771. Di = new REAL[K];
  772. buf = new REAL[K];
  773. for ( ; nIterations > 0; nIterations --)
  774. {
  775. // forward pass
  776. n = 0;
  777. D_ptr = m_D;
  778. M_ptr = m_messages;
  779. for (y=0; y<m_height; y++)
  780. for (x=0; x<m_width; x++, D_ptr+=K, M_ptr+=2*K, n++)
  781. {
  782. CopyVector(Di, D_ptr, K);
  783. if (x > 0) AddVector(Di, M_ptr-2*K, K); // message (x-1,y)->(x,y)
  784. if (y > 0) AddVector(Di, M_ptr-(2*m_width-1)*K, K); // message (x,y-1)->(x,y)
  785. if (x < m_width-1) AddVector(Di, M_ptr, K); // message (x+1,y)->(x,y)
  786. if (y < m_height-1) AddVector(Di, M_ptr+K, K); // message (x,y+1)->(x,y)
  787. if (x < m_width-1)
  788. {
  789. CostVal lambda = (m_varWeights) ? m_horzWeights[n] : 1;
  790. UpdateMessageFIXED_MATRIX(M_ptr, Di, K, 1, lambda, m_V, buf);
  791. }
  792. if (y < m_height-1)
  793. {
  794. CostVal lambda = (m_varWeights) ? m_vertWeights[n] : 1;
  795. UpdateMessageFIXED_MATRIX(M_ptr+K, Di, K, 1, lambda, m_V, buf);
  796. }
  797. }
  798. // backward pass
  799. n --;
  800. D_ptr -= K;
  801. M_ptr -= 2*K;
  802. for (y=m_height-1; y>=0; y--)
  803. for (x=m_width-1; x>=0; x--, D_ptr-=K, M_ptr-=2*K, n--)
  804. {
  805. CopyVector(Di, D_ptr, K);
  806. if (x > 0) AddVector(Di, M_ptr-2*K, K); // message (x-1,y)->(x,y)
  807. if (y > 0) AddVector(Di, M_ptr-(2*m_width-1)*K, K); // message (x,y-1)->(x,y)
  808. if (x < m_width-1) AddVector(Di, M_ptr, K); // message (x+1,y)->(x,y)
  809. if (y < m_height-1) AddVector(Di, M_ptr+K, K); // message (x,y+1)->(x,y)
  810. SubtractMin(Di, K);
  811. if (x > 0)
  812. {
  813. CostVal lambda = (m_varWeights) ? m_horzWeights[n-1] : 1;
  814. UpdateMessageFIXED_MATRIX(M_ptr-2*K, Di, K, 1, lambda, m_V, buf);
  815. }
  816. if (y > 0)
  817. {
  818. CostVal lambda = (m_varWeights) ? m_vertWeights[n-m_width] : 1;
  819. UpdateMessageFIXED_MATRIX(M_ptr-(2*m_width-1)*K, Di, K, 1, lambda, m_V, buf);
  820. }
  821. }
  822. }
  823. delete [] Di;
  824. delete [] (REAL *)buf;
  825. }
  826. void BPS::optimize_GRID_GENERAL(int nIterations)
  827. {
  828. int x, y, n, K = m_nLabels;
  829. CostVal* D_ptr;
  830. REAL* M_ptr;
  831. REAL* Di;
  832. void* buf;
  833. Di = new REAL[K];
  834. buf = new REAL[K];
  835. for ( ; nIterations > 0; nIterations --)
  836. {
  837. // forward pass
  838. n = 0;
  839. D_ptr = m_D;
  840. M_ptr = m_messages;
  841. CostVal* V_ptr = m_V;
  842. for (y=0; y<m_height; y++)
  843. for (x=0; x<m_width; x++, D_ptr+=K, M_ptr+=2*K, V_ptr+=2*K*K, n++)
  844. {
  845. CopyVector(Di, D_ptr, K);
  846. if (x > 0) AddVector(Di, M_ptr-2*K, K); // message (x-1,y)->(x,y)
  847. if (y > 0) AddVector(Di, M_ptr-(2*m_width-1)*K, K); // message (x,y-1)->(x,y)
  848. if (x < m_width-1) AddVector(Di, M_ptr, K); // message (x+1,y)->(x,y)
  849. if (y < m_height-1) AddVector(Di, M_ptr+K, K); // message (x,y+1)->(x,y)
  850. if (x < m_width-1)
  851. {
  852. if (m_V) UpdateMessageGENERAL(M_ptr, Di, K, 1, /* forward dir*/ 0, V_ptr, buf);
  853. else UpdateMessageGENERAL(M_ptr, Di, K, 1, m_smoothFn, n, n+1, buf);
  854. }
  855. if (y < m_height-1)
  856. {
  857. if (m_V) UpdateMessageGENERAL(M_ptr+K, Di, K, 1, /* forward dir*/ 0, V_ptr+K*K, buf);
  858. else UpdateMessageGENERAL(M_ptr+K, Di, K, 1, m_smoothFn, n, n+m_width, buf);
  859. }
  860. }
  861. // backward pass
  862. n --;
  863. D_ptr -= K;
  864. M_ptr -= 2*K;
  865. V_ptr -= 2*K*K;
  866. for (y=m_height-1; y>=0; y--)
  867. for (x=m_width-1; x>=0; x--, D_ptr-=K, M_ptr-=2*K, V_ptr-=2*K*K, n--)
  868. {
  869. CopyVector(Di, D_ptr, K);
  870. if (x > 0) AddVector(Di, M_ptr-2*K, K); // message (x-1,y)->(x,y)
  871. if (y > 0) AddVector(Di, M_ptr-(2*m_width-1)*K, K); // message (x,y-1)->(x,y)
  872. if (x < m_width-1) AddVector(Di, M_ptr, K); // message (x+1,y)->(x,y)
  873. if (y < m_height-1) AddVector(Di, M_ptr+K, K); // message (x,y+1)->(x,y)
  874. // normalize Di, update lower bound
  875. SubtractMin(Di, K);
  876. if (x > 0)
  877. {
  878. if (m_V) UpdateMessageGENERAL(M_ptr-2*K, Di, K, 1, /* backward dir */ 1, V_ptr-2*K*K, buf);
  879. else UpdateMessageGENERAL(M_ptr-2*K, Di, K, 1, m_smoothFn, n, n-1, buf);
  880. }
  881. if (y > 0)
  882. {
  883. if (m_V) UpdateMessageGENERAL(M_ptr-(2*m_width-1)*K, Di, K, 1, /* backward dir */ 1, V_ptr-(2*m_width-1)*K*K, buf);
  884. else UpdateMessageGENERAL(M_ptr-(2*m_width-1)*K, Di, K, 1, m_smoothFn, n, n-m_width, buf);
  885. }
  886. }
  887. }
  888. delete [] Di;
  889. delete [] (REAL *)buf;
  890. }