TRW-S.cpp 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116
  1. #include <stdio.h>
  2. #include <stdlib.h>
  3. #include <string.h>
  4. #include <assert.h>
  5. #include <new>
  6. #include "TRW-S.h"
  7. #define private public
  8. #include "typeTruncatedQuadratic2D.h"
  9. using namespace OBJREC;
  10. #undef private
  11. #define m_D(pix,l) m_D[(pix)*m_nLabels+(l)]
  12. #define m_V(l1,l2) m_V[(l1)*m_nLabels+(l2)]
  13. #define MIN(a,b) (((a) < (b)) ? (a) : (b))
  14. #define MAX(a,b) (((a) > (b)) ? (a) : (b))
  15. #define TRUNCATE_MIN(a,b) { if ((a) > (b)) (a) = (b); }
  16. #define TRUNCATE_MAX(a,b) { if ((a) < (b)) (a) = (b); }
  17. #define TRUNCATE TRUNCATE_MIN
  18. /////////////////////////////////////////////////////////////////////////////
  19. // Operations on vectors (arrays of size K) //
  20. /////////////////////////////////////////////////////////////////////////////
  21. inline void CopyVector(TRWS::REAL* to, MRF::CostVal* from, int K)
  22. {
  23. TRWS::REAL* to_finish = to + K;
  24. do
  25. {
  26. *to ++ = *from ++;
  27. } while (to < to_finish);
  28. }
  29. inline void AddVector(TRWS::REAL* to, TRWS::REAL* from, int K)
  30. {
  31. TRWS::REAL* to_finish = to + K;
  32. do
  33. {
  34. *to ++ += *from ++;
  35. } while (to < to_finish);
  36. }
  37. inline TRWS::REAL SubtractMin(TRWS::REAL *D, int K)
  38. {
  39. int k;
  40. TRWS::REAL delta;
  41. delta = D[0];
  42. for (k = 1; k < K; k++) TRUNCATE(delta, D[k]);
  43. for (k = 0; k < K; k++) D[k] -= delta;
  44. return delta;
  45. }
  46. // Functions UpdateMessageTYPE (see the paper for details):
  47. //
  48. // - Set Di[ki] := gamma*Di_hat[ki] - M[ki]
  49. // - Set M[kj] := min_{ki} (Di[ki] + V[ki,kj])
  50. // - Normalize message:
  51. // delta := min_{kj} M[kj]
  52. // M[kj] := M[kj] - delta
  53. // return delta
  54. //
  55. // If dir = 1, then the meaning of i and j is swapped.
  56. ///////////////////////////////////////////
  57. // L1 //
  58. ///////////////////////////////////////////
  59. inline TRWS::REAL UpdateMessageL1(TRWS::REAL* M, TRWS::REAL* Di_hat, int K, TRWS::REAL gamma, MRF::CostVal lambda, MRF::CostVal smoothMax)
  60. {
  61. int k;
  62. TRWS::REAL delta;
  63. delta = M[0] = gamma * Di_hat[0] - M[0];
  64. for (k = 1; k < K; k++)
  65. {
  66. M[k] = gamma * Di_hat[k] - M[k];
  67. TRUNCATE(delta, M[k]);
  68. TRUNCATE(M[k], M[k-1] + lambda);
  69. }
  70. M[--k] -= delta;
  71. TRUNCATE(M[k], lambda*smoothMax);
  72. for (k--; k >= 0; k--)
  73. {
  74. M[k] -= delta;
  75. TRUNCATE(M[k], M[k+1] + lambda);
  76. TRUNCATE(M[k], lambda*smoothMax);
  77. }
  78. return delta;
  79. }
  80. ////////////////////////////////////////
  81. // L2 //
  82. ////////////////////////////////////////
  83. inline TRWS::REAL UpdateMessageL2(TRWS::REAL* M, TRWS::REAL* Di_hat, int K, TRWS::REAL gamma, MRF::CostVal lambda, MRF::CostVal smoothMax, void *buf)
  84. {
  85. TRWS::REAL* Di = (TRWS::REAL*) buf;
  86. int* parabolas = (int*) ((char*)buf + K * sizeof(TRWS::REAL));
  87. int* intersections = parabolas + K;
  88. TypeTruncatedQuadratic2D::Edge* tmp = NULL;
  89. int k;
  90. TRWS::REAL delta;
  91. assert(lambda >= 0);
  92. Di[0] = gamma * Di_hat[0] - M[0];
  93. delta = Di[0];
  94. for (k = 1; k < K; k++)
  95. {
  96. Di[k] = gamma * Di_hat[k] - M[k];
  97. TRUNCATE(delta, Di[k]);
  98. }
  99. if (lambda == 0)
  100. {
  101. for (k = 0; k < K; k++) M[k] = 0;
  102. return delta;
  103. }
  104. tmp->DistanceTransformL2(K, 1, lambda, Di, M, parabolas, intersections);
  105. for (k = 0; k < K; k++)
  106. {
  107. M[k] -= delta;
  108. TRUNCATE(M[k], lambda*smoothMax);
  109. }
  110. return delta;
  111. }
  112. //////////////////////////////////////////////////
  113. // FIXED_MATRIX //
  114. //////////////////////////////////////////////////
  115. inline TRWS::REAL UpdateMessageFIXED_MATRIX(TRWS::REAL* M, TRWS::REAL* Di_hat, int K, TRWS::REAL gamma, MRF::CostVal lambda, MRF::CostVal* V, void* buf)
  116. {
  117. TRWS::REAL* Di = (TRWS::REAL*) buf;
  118. int ki, kj;
  119. TRWS::REAL delta;
  120. if (lambda == 0)
  121. {
  122. delta = gamma * Di_hat[0] - M[0];
  123. M[0] = 0;
  124. for (ki = 1; ki < K; ki++)
  125. {
  126. TRUNCATE(delta, gamma*Di_hat[ki] - M[ki]);
  127. M[ki] = 0;
  128. }
  129. return delta;
  130. }
  131. for (ki = 0; ki < K; ki++)
  132. {
  133. Di[ki] = (gamma * Di_hat[ki] - M[ki]) * (1 / (TRWS::REAL)lambda);
  134. }
  135. if (lambda > 0)
  136. {
  137. for (kj = 0; kj < K; kj++)
  138. {
  139. M[kj] = Di[0] + V[0];
  140. V ++;
  141. for (ki = 1; ki < K; ki++)
  142. {
  143. TRUNCATE(M[kj], Di[ki] + V[0]);
  144. V ++;
  145. }
  146. M[kj] *= lambda;
  147. }
  148. }
  149. else
  150. {
  151. for (kj = 0; kj < K; kj++)
  152. {
  153. M[kj] = Di[0] + V[0];
  154. V ++;
  155. for (ki = 1; ki < K; ki++)
  156. {
  157. TRUNCATE_MAX(M[kj], Di[ki] + V[0]);
  158. V ++;
  159. }
  160. M[kj] *= lambda;
  161. }
  162. }
  163. delta = M[0];
  164. for (kj = 1; kj < K; kj++) TRUNCATE(delta, M[kj]);
  165. for (kj = 0; kj < K; kj++) M[kj] -= delta;
  166. return delta;
  167. }
  168. /////////////////////////////////////////////
  169. // GENERAL //
  170. /////////////////////////////////////////////
  171. inline TRWS::REAL UpdateMessageGENERAL(TRWS::REAL* M, TRWS::REAL* Di_hat, int K, TRWS::REAL gamma, int dir, MRF::CostVal* V, void* buf)
  172. {
  173. TRWS::REAL* Di = (TRWS::REAL*) buf;
  174. int ki, kj;
  175. TRWS::REAL delta;
  176. for (ki = 0; ki < K; ki++)
  177. {
  178. Di[ki] = (gamma * Di_hat[ki] - M[ki]);
  179. }
  180. if (dir == 0)
  181. {
  182. for (kj = 0; kj < K; kj++)
  183. {
  184. M[kj] = Di[0] + V[0];
  185. V ++;
  186. for (ki = 1; ki < K; ki++)
  187. {
  188. TRUNCATE(M[kj], Di[ki] + V[0]);
  189. V ++;
  190. }
  191. }
  192. }
  193. else
  194. {
  195. for (kj = 0; kj < K; kj++)
  196. {
  197. M[kj] = Di[0] + V[0];
  198. V += K;
  199. for (ki = 1; ki < K; ki++)
  200. {
  201. TRUNCATE(M[kj], Di[ki] + V[0]);
  202. V += K;
  203. }
  204. V -= K * K - 1;
  205. }
  206. }
  207. delta = M[0];
  208. for (kj = 1; kj < K; kj++) TRUNCATE(delta, M[kj]);
  209. for (kj = 0; kj < K; kj++) M[kj] -= delta;
  210. return delta;
  211. }
  212. inline TRWS::REAL UpdateMessageGENERAL(TRWS::REAL* M, TRWS::REAL* Di_hat, int K, TRWS::REAL gamma, TRWS::SmoothCostGeneralFn fn, int i, int j, void* buf)
  213. {
  214. TRWS::REAL* Di = (TRWS::REAL*) buf;
  215. int ki, kj;
  216. TRWS::REAL delta;
  217. for (ki = 0; ki < K; ki++)
  218. {
  219. Di[ki] = (gamma * Di_hat[ki] - M[ki]);
  220. }
  221. for (kj = 0; kj < K; kj++)
  222. {
  223. M[kj] = Di[0] + fn(i, j, 0, kj);
  224. for (ki = 1; ki < K; ki++)
  225. {
  226. delta = Di[ki] + fn(i, j, ki, kj);
  227. TRUNCATE(M[kj], delta);
  228. }
  229. }
  230. delta = M[0];
  231. for (kj = 1; kj < K; kj++) TRUNCATE(delta, M[kj]);
  232. for (kj = 0; kj < K; kj++) M[kj] -= delta;
  233. return delta;
  234. }
  235. TRWS::TRWS(int width, int height, int nLabels, EnergyFunction *eng): MRF(width, height, nLabels, eng)
  236. {
  237. Allocate();
  238. }
  239. TRWS::TRWS(int nPixels, int nLabels, EnergyFunction *eng): MRF(nPixels, nLabels, eng)
  240. {
  241. Allocate();
  242. }
  243. TRWS::~TRWS()
  244. {
  245. delete[] m_answer;
  246. if ( m_needToFreeD ) delete [] m_D;
  247. if ( m_needToFreeV ) delete [] m_V;
  248. if ( m_messages ) delete [] m_messages;
  249. if ( m_DBinary ) delete [] m_DBinary;
  250. if ( m_horzWeightsBinary ) delete [] m_horzWeightsBinary;
  251. if ( m_vertWeightsBinary ) delete [] m_vertWeightsBinary;
  252. }
  253. void TRWS::Allocate()
  254. {
  255. m_type = NONE;
  256. m_needToFreeV = false;
  257. m_needToFreeD = false;
  258. m_D = NULL;
  259. m_V = NULL;
  260. m_horzWeights = NULL;
  261. m_vertWeights = NULL;
  262. m_horzWeightsBinary = NULL;
  263. m_vertWeightsBinary = NULL;
  264. m_DBinary = NULL;
  265. m_messages = NULL;
  266. m_messageArraySizeInBytes = 0;
  267. m_answer = new Label[m_nPixels];
  268. }
  269. void TRWS::clearAnswer()
  270. {
  271. memset(m_answer, 0, m_nPixels*sizeof(Label));
  272. if (m_messages)
  273. {
  274. memset(m_messages, 0, m_messageArraySizeInBytes);
  275. }
  276. }
  277. MRF::EnergyVal TRWS::smoothnessEnergy()
  278. {
  279. EnergyVal eng = (EnergyVal) 0;
  280. EnergyVal weight;
  281. int x, y, pix;
  282. if ( m_grid_graph )
  283. {
  284. if ( m_smoothType != FUNCTION )
  285. {
  286. for ( y = 0; y < m_height; y++ )
  287. for ( x = 1; x < m_width; x++ )
  288. {
  289. pix = x + y * m_width;
  290. weight = m_varWeights ? m_horzWeights[pix-1] : 1;
  291. eng = eng + m_V(m_answer[pix], m_answer[pix-1]) * weight;
  292. }
  293. for ( y = 1; y < m_height; y++ )
  294. for ( x = 0; x < m_width; x++ )
  295. {
  296. pix = x + y * m_width;
  297. weight = m_varWeights ? m_vertWeights[pix-m_width] : 1;
  298. eng = eng + m_V(m_answer[pix], m_answer[pix-m_width]) * weight;
  299. }
  300. }
  301. else
  302. {
  303. for ( y = 0; y < m_height; y++ )
  304. for ( x = 1; x < m_width; x++ )
  305. {
  306. pix = x + y * m_width;
  307. eng = eng + m_smoothFn(pix, pix - 1, m_answer[pix], m_answer[pix-1]);
  308. }
  309. for ( y = 1; y < m_height; y++ )
  310. for ( x = 0; x < m_width; x++ )
  311. {
  312. pix = x + y * m_width;
  313. eng = eng + m_smoothFn(pix, pix - m_width, m_answer[pix], m_answer[pix-m_width]);
  314. }
  315. }
  316. }
  317. else
  318. {
  319. // not implemented
  320. }
  321. return(eng);
  322. }
  323. MRF::EnergyVal TRWS::dataEnergy()
  324. {
  325. EnergyVal eng = (EnergyVal) 0;
  326. if ( m_dataType == ARRAY)
  327. {
  328. for ( int i = 0; i < m_nPixels; i++ )
  329. eng = eng + m_D(i, m_answer[i]);
  330. }
  331. else
  332. {
  333. for ( int i = 0; i < m_nPixels; i++ )
  334. eng = eng + m_dataFn(i, m_answer[i]);
  335. }
  336. return(eng);
  337. }
  338. void TRWS::setData(DataCostFn dcost)
  339. {
  340. int i, k;
  341. m_dataFn = dcost;
  342. CostVal* ptr;
  343. m_D = new CostVal[m_nPixels*m_nLabels];
  344. for (ptr = m_D, i = 0; i < m_nPixels; i++)
  345. for (k = 0; k < m_nLabels; k++, ptr++)
  346. {
  347. *ptr = m_dataFn(i, k);
  348. }
  349. m_needToFreeD = true;
  350. }
  351. void TRWS::setData(CostVal* data)
  352. {
  353. m_D = data;
  354. m_needToFreeD = false;
  355. }
  356. void TRWS::setSmoothness(SmoothCostGeneralFn cost)
  357. {
  358. assert(m_horzWeights == NULL && m_vertWeights == NULL && m_V == NULL);
  359. int x, y, i, ki, kj;
  360. CostVal* ptr;
  361. m_smoothFn = cost;
  362. m_type = GENERAL;
  363. if (!m_allocateArrayForSmoothnessCostFn) return;
  364. // try to cache all the function values in an array for efficiency
  365. m_V = new(std::nothrow) CostVal[2*m_nPixels*m_nLabels*m_nLabels];
  366. if (!m_V) return; // if not enough space, just call the function directly
  367. m_needToFreeV = true;
  368. for (ptr = m_V, i = 0, y = 0; y < m_height; y++)
  369. for (x = 0; x < m_width; x++, i++)
  370. {
  371. if (x < m_width - 1)
  372. {
  373. for (kj = 0; kj < m_nLabels; kj++)
  374. for (ki = 0; ki < m_nLabels; ki++)
  375. {
  376. *ptr++ = cost(i, i + 1, ki, kj);
  377. }
  378. }
  379. else ptr += m_nLabels * m_nLabels;
  380. if (y < m_height - 1)
  381. {
  382. for (kj = 0; kj < m_nLabels; kj++)
  383. for (ki = 0; ki < m_nLabels; ki++)
  384. {
  385. *ptr++ = cost(i, i + m_width, ki, kj);
  386. }
  387. }
  388. else ptr += m_nLabels * m_nLabels;
  389. }
  390. }
  391. void TRWS::setSmoothness(CostVal* V)
  392. {
  393. m_type = FIXED_MATRIX;
  394. m_V = V;
  395. }
  396. void TRWS::setSmoothness(int smoothExp, CostVal smoothMax, CostVal lambda)
  397. {
  398. assert(smoothExp == 1 || smoothExp == 2);
  399. assert(lambda >= 0);
  400. m_type = (smoothExp == 1) ? L1 : L2;
  401. int ki, kj;
  402. CostVal cost;
  403. m_needToFreeV = true;
  404. m_V = new CostVal[m_nLabels*m_nLabels];
  405. for (ki = 0; ki < m_nLabels; ki++)
  406. for (kj = ki; kj < m_nLabels; kj++)
  407. {
  408. cost = (CostVal) ((smoothExp == 1) ? kj - ki : (kj - ki) * (kj - ki));
  409. if (cost > smoothMax) cost = smoothMax;
  410. m_V[ki*m_nLabels + kj] = m_V[kj*m_nLabels + ki] = cost * lambda;
  411. }
  412. m_smoothMax = smoothMax;
  413. m_lambda = lambda;
  414. }
  415. void TRWS::setCues(CostVal* hCue, CostVal* vCue)
  416. {
  417. m_horzWeights = hCue;
  418. m_vertWeights = vCue;
  419. }
  420. void TRWS::initializeAlg()
  421. {
  422. assert(m_type != NONE);
  423. int i;
  424. // determine type
  425. if (m_type == L1 && m_nLabels == 2)
  426. {
  427. m_type = BINARY;
  428. }
  429. // allocate messages
  430. int messageNum = (m_type == BINARY) ? 4 * m_nPixels : 4 * m_nPixels * m_nLabels;
  431. m_messageArraySizeInBytes = messageNum * sizeof(REAL);
  432. m_messages = new REAL[messageNum];
  433. memset(m_messages, 0, messageNum*sizeof(REAL));
  434. if (m_type == BINARY)
  435. {
  436. assert(m_DBinary == NULL && m_horzWeightsBinary == NULL && m_horzWeightsBinary == NULL);
  437. m_DBinary = new CostVal[m_nPixels];
  438. m_horzWeightsBinary = new CostVal[m_nPixels];
  439. m_vertWeightsBinary = new CostVal[m_nPixels];
  440. if ( m_dataType == ARRAY)
  441. {
  442. for (i = 0; i < m_nPixels; i++)
  443. {
  444. m_DBinary[i] = m_D[2*i+1] - m_D[2*i];
  445. }
  446. }
  447. else
  448. {
  449. for (i = 0; i < m_nPixels; i++)
  450. {
  451. m_DBinary[i] = m_dataFn(i, 1) - m_dataFn(i, 0);
  452. }
  453. }
  454. assert(m_V[0] == 0 && m_V[1] == m_V[2] && m_V[3] == 0);
  455. for (i = 0; i < m_nPixels; i++)
  456. {
  457. m_horzWeightsBinary[i] = (m_varWeights) ? m_V[1] * m_horzWeights[i] : m_V[1];
  458. m_vertWeightsBinary[i] = (m_varWeights) ? m_V[1] * m_vertWeights[i] : m_V[1];
  459. }
  460. }
  461. }
  462. void TRWS::optimizeAlg(int nIterations)
  463. {
  464. assert(m_type != NONE);
  465. if (m_grid_graph)
  466. {
  467. switch (m_type)
  468. {
  469. case L1:
  470. optimize_GRID_L1(nIterations);
  471. break;
  472. case L2:
  473. optimize_GRID_L2(nIterations);
  474. break;
  475. case FIXED_MATRIX:
  476. optimize_GRID_FIXED_MATRIX(nIterations);
  477. break;
  478. case GENERAL:
  479. optimize_GRID_GENERAL(nIterations);
  480. break;
  481. case BINARY:
  482. optimize_GRID_BINARY(nIterations);
  483. break;
  484. default:
  485. assert(0);
  486. exit(1);
  487. }
  488. }
  489. else {
  490. printf("\nNot implemented for general graphs yet, exiting!");
  491. exit(1);
  492. }
  493. // printf("lower bound = %f\n", m_lowerBound);
  494. ////////////////////////////////////////////////
  495. // computing solution //
  496. ////////////////////////////////////////////////
  497. if (m_type != BINARY)
  498. {
  499. int x, y, n, K = m_nLabels;
  500. CostVal* D_ptr;
  501. REAL* M_ptr;
  502. REAL* Di;
  503. REAL delta;
  504. int ki, kj;
  505. Di = new REAL[K];
  506. n = 0;
  507. D_ptr = m_D;
  508. M_ptr = m_messages;
  509. for (y = 0; y < m_height; y++)
  510. for (x = 0; x < m_width; x++, D_ptr += K, M_ptr += 2 * K, n++)
  511. {
  512. CopyVector(Di, D_ptr, K);
  513. if (m_type == GENERAL)
  514. {
  515. if (m_V)
  516. {
  517. CostVal* ptr = m_V + 2 * (x + y * m_width - 1) * K * K;
  518. if (x > 0)
  519. {
  520. kj = m_answer[n-1];
  521. for (ki = 0; ki < K; ki++)
  522. {
  523. Di[ki] += ptr[kj + ki*K];
  524. }
  525. }
  526. ptr -= (2 * m_width - 3) * K * K;
  527. if (y > 0)
  528. {
  529. kj = m_answer[n-m_width];
  530. for (ki = 0; ki < K; ki++)
  531. {
  532. Di[ki] += ptr[kj + ki*K];
  533. }
  534. }
  535. }
  536. else
  537. {
  538. if (x > 0)
  539. {
  540. kj = m_answer[n-1];
  541. for (ki = 0; ki < K; ki++)
  542. {
  543. Di[ki] += m_smoothFn(n, n - 1, ki, kj);
  544. }
  545. }
  546. if (y > 0)
  547. {
  548. kj = m_answer[n-m_width];
  549. for (ki = 0; ki < K; ki++)
  550. {
  551. Di[ki] += m_smoothFn(n, n - m_width, ki, kj);
  552. }
  553. }
  554. }
  555. }
  556. else // m_type == L1, L2 or FIXED_MATRIX
  557. {
  558. if (x > 0)
  559. {
  560. kj = m_answer[n-1];
  561. CostVal lambda = (m_varWeights) ? m_horzWeights[n-1] : 1;
  562. for (ki = 0; ki < K; ki++)
  563. {
  564. Di[ki] += lambda * m_V[kj*K + ki];
  565. }
  566. }
  567. if (y > 0)
  568. {
  569. kj = m_answer[n-m_width];
  570. CostVal lambda = (m_varWeights) ? m_vertWeights[n-m_width] : 1;
  571. for (ki = 0; ki < K; ki++)
  572. {
  573. Di[ki] += lambda * m_V[kj*K + ki];
  574. }
  575. }
  576. }
  577. if (x < m_width - 1) AddVector(Di, M_ptr, K); // message (x+1,y)->(x,y)
  578. if (y < m_height - 1) AddVector(Di, M_ptr + K, K); // message (x,y+1)->(x,y)
  579. // compute min
  580. delta = Di[0];
  581. m_answer[n] = 0;
  582. for (ki = 1; ki < K; ki++)
  583. {
  584. if (delta > Di[ki])
  585. {
  586. delta = Di[ki];
  587. m_answer[n] = ki;
  588. }
  589. }
  590. }
  591. delete [] Di;
  592. }
  593. else // m_type == BINARY
  594. {
  595. int x, y, n;
  596. REAL* M_ptr;
  597. REAL Di;
  598. n = 0;
  599. M_ptr = m_messages;
  600. for (y = 0; y < m_height; y++)
  601. for (x = 0; x < m_width; x++, M_ptr += 2, n++)
  602. {
  603. Di = m_DBinary[n];
  604. if (x > 0) Di += (m_answer[n-1] == 0) ? m_horzWeightsBinary[n-1] : -m_horzWeightsBinary[n-1];
  605. if (y > 0) Di += (m_answer[n-m_width] == 0) ? m_vertWeightsBinary[n-m_width] : -m_vertWeightsBinary[n-m_width];
  606. if (x < m_width - 1) Di += M_ptr[0]; // message (x+1,y)->(x,y)
  607. if (y < m_height - 1) Di += M_ptr[1]; // message (x,y+1)->(x,y)
  608. // compute min
  609. m_answer[n] = (Di >= 0) ? 0 : 1;
  610. }
  611. }
  612. }
  613. void TRWS::optimize_GRID_L1(int nIterations)
  614. {
  615. int x, y, n, K = m_nLabels;
  616. CostVal* D_ptr;
  617. REAL* M_ptr;
  618. REAL* Di;
  619. Di = new REAL[K];
  620. for ( ; nIterations > 0; nIterations --)
  621. {
  622. // forward pass
  623. n = 0;
  624. D_ptr = m_D;
  625. M_ptr = m_messages;
  626. for (y = 0; y < m_height; y++)
  627. for (x = 0; x < m_width; x++, D_ptr += K, M_ptr += 2 * K, n++)
  628. {
  629. CopyVector(Di, D_ptr, K);
  630. if (x > 0) AddVector(Di, M_ptr - 2*K, K); // message (x-1,y)->(x,y)
  631. if (y > 0) AddVector(Di, M_ptr - (2*m_width - 1)*K, K); // message (x,y-1)->(x,y)
  632. if (x < m_width - 1) AddVector(Di, M_ptr, K); // message (x+1,y)->(x,y)
  633. if (y < m_height - 1) AddVector(Di, M_ptr + K, K); // message (x,y+1)->(x,y)
  634. if (x < m_width - 1)
  635. {
  636. CostVal lambda = (m_varWeights) ? m_lambda * m_horzWeights[n] : m_lambda;
  637. UpdateMessageL1(M_ptr, Di, K, 0.5, lambda, m_smoothMax);
  638. }
  639. if (y < m_height - 1)
  640. {
  641. CostVal lambda = (m_varWeights) ? m_lambda * m_vertWeights[n] : m_lambda;
  642. UpdateMessageL1(M_ptr + K, Di, K, 0.5, lambda, m_smoothMax);
  643. }
  644. }
  645. // backward pass
  646. m_lowerBound = 0;
  647. n --;
  648. D_ptr -= K;
  649. M_ptr -= 2 * K;
  650. for (y = m_height - 1; y >= 0; y--)
  651. for (x = m_width - 1; x >= 0; x--, D_ptr -= K, M_ptr -= 2 * K, n--)
  652. {
  653. CopyVector(Di, D_ptr, K);
  654. if (x > 0) AddVector(Di, M_ptr - 2*K, K); // message (x-1,y)->(x,y)
  655. if (y > 0) AddVector(Di, M_ptr - (2*m_width - 1)*K, K); // message (x,y-1)->(x,y)
  656. if (x < m_width - 1) AddVector(Di, M_ptr, K); // message (x+1,y)->(x,y)
  657. if (y < m_height - 1) AddVector(Di, M_ptr + K, K); // message (x,y+1)->(x,y)
  658. m_lowerBound += SubtractMin(Di, K);
  659. if (x > 0)
  660. {
  661. CostVal lambda = (m_varWeights) ? m_lambda * m_horzWeights[n-1] : m_lambda;
  662. m_lowerBound += UpdateMessageL1(M_ptr - 2 * K, Di, K, 0.5, lambda, m_smoothMax);
  663. }
  664. if (y > 0)
  665. {
  666. CostVal lambda = (m_varWeights) ? m_lambda * m_vertWeights[n-m_width] : m_lambda;
  667. m_lowerBound += UpdateMessageL1(M_ptr - (2 * m_width - 1) * K, Di, K, 0.5, lambda, m_smoothMax);
  668. }
  669. }
  670. }
  671. delete [] Di;
  672. }
  673. void TRWS::optimize_GRID_L2(int nIterations)
  674. {
  675. int x, y, n, K = m_nLabels;
  676. CostVal* D_ptr;
  677. REAL* M_ptr;
  678. REAL* Di;
  679. void* buf;
  680. Di = new REAL[K];
  681. buf = new char[(2*K+1)*sizeof(int) + K*sizeof(REAL)];
  682. for ( ; nIterations > 0; nIterations --)
  683. {
  684. // forward pass
  685. n = 0;
  686. D_ptr = m_D;
  687. M_ptr = m_messages;
  688. for (y = 0; y < m_height; y++)
  689. for (x = 0; x < m_width; x++, D_ptr += K, M_ptr += 2 * K, n++)
  690. {
  691. CopyVector(Di, D_ptr, K);
  692. if (x > 0) AddVector(Di, M_ptr - 2*K, K); // message (x-1,y)->(x,y)
  693. if (y > 0) AddVector(Di, M_ptr - (2*m_width - 1)*K, K); // message (x,y-1)->(x,y)
  694. if (x < m_width - 1) AddVector(Di, M_ptr, K); // message (x+1,y)->(x,y)
  695. if (y < m_height - 1) AddVector(Di, M_ptr + K, K); // message (x,y+1)->(x,y)
  696. if (x < m_width - 1)
  697. {
  698. CostVal lambda = (m_varWeights) ? m_lambda * m_horzWeights[n] : m_lambda;
  699. UpdateMessageL2(M_ptr, Di, K, 0.5, lambda, m_smoothMax, buf);
  700. }
  701. if (y < m_height - 1)
  702. {
  703. CostVal lambda = (m_varWeights) ? m_lambda * m_vertWeights[n] : m_lambda;
  704. UpdateMessageL2(M_ptr + K, Di, K, 0.5, lambda, m_smoothMax, buf);
  705. }
  706. }
  707. // backward pass
  708. m_lowerBound = 0;
  709. n --;
  710. D_ptr -= K;
  711. M_ptr -= 2 * K;
  712. for (y = m_height - 1; y >= 0; y--)
  713. for (x = m_width - 1; x >= 0; x--, D_ptr -= K, M_ptr -= 2 * K, n--)
  714. {
  715. CopyVector(Di, D_ptr, K);
  716. if (x > 0) AddVector(Di, M_ptr - 2*K, K); // message (x-1,y)->(x,y)
  717. if (y > 0) AddVector(Di, M_ptr - (2*m_width - 1)*K, K); // message (x,y-1)->(x,y)
  718. if (x < m_width - 1) AddVector(Di, M_ptr, K); // message (x+1,y)->(x,y)
  719. if (y < m_height - 1) AddVector(Di, M_ptr + K, K); // message (x,y+1)->(x,y)
  720. m_lowerBound += SubtractMin(Di, K);
  721. if (x > 0)
  722. {
  723. CostVal lambda = (m_varWeights) ? m_lambda * m_horzWeights[n-1] : m_lambda;
  724. m_lowerBound += UpdateMessageL2(M_ptr - 2 * K, Di, K, 0.5, lambda, m_smoothMax, buf);
  725. }
  726. if (y > 0)
  727. {
  728. CostVal lambda = (m_varWeights) ? m_lambda * m_vertWeights[n-m_width] : m_lambda;
  729. m_lowerBound += UpdateMessageL2(M_ptr - (2 * m_width - 1) * K, Di, K, 0.5, lambda, m_smoothMax, buf);
  730. }
  731. }
  732. }
  733. delete [] Di;
  734. delete [] (char *)buf;
  735. }
  736. void TRWS::optimize_GRID_BINARY(int nIterations)
  737. {
  738. int x, y, n;
  739. REAL* M_ptr;
  740. REAL Di;
  741. for ( ; nIterations > 0; nIterations --)
  742. {
  743. // forward pass
  744. n = 0;
  745. M_ptr = m_messages;
  746. for (y = 0; y < m_height; y++)
  747. for (x = 0; x < m_width; x++, M_ptr += 2, n++)
  748. {
  749. Di = m_DBinary[n];
  750. if (x > 0) Di += M_ptr[-2]; // message (x-1,y)->(x,y)
  751. if (y > 0) Di += M_ptr[-2*m_width+1]; // message (x,y-1)->(x,y)
  752. if (x < m_width - 1) Di += M_ptr[0]; // message (x+1,y)->(x,y)
  753. if (y < m_height - 1) Di += M_ptr[1]; // message (x,y+1)->(x,y)
  754. REAL DiScaled = Di * 0.5;
  755. if (x < m_width - 1)
  756. {
  757. Di = DiScaled - M_ptr[0];
  758. CostVal lambda = m_horzWeightsBinary[n];
  759. if (lambda < 0) {
  760. Di = -Di;
  761. lambda = -lambda;
  762. }
  763. if (Di > lambda) M_ptr[0] = lambda;
  764. else M_ptr[0] = (Di < -lambda) ? -lambda : Di;
  765. }
  766. if (y < m_height - 1)
  767. {
  768. Di = DiScaled - M_ptr[1];
  769. CostVal lambda = m_vertWeightsBinary[n];
  770. if (lambda < 0) {
  771. Di = -Di;
  772. lambda = -lambda;
  773. }
  774. if (Di > lambda) M_ptr[1] = lambda;
  775. else M_ptr[1] = (Di < -lambda) ? -lambda : Di;
  776. }
  777. }
  778. // backward pass
  779. n --;
  780. M_ptr -= 2;
  781. for (y = m_height - 1; y >= 0; y--)
  782. for (x = m_width - 1; x >= 0; x--, M_ptr -= 2, n--)
  783. {
  784. Di = m_DBinary[n];
  785. if (x > 0) Di += M_ptr[-2]; // message (x-1,y)->(x,y)
  786. if (y > 0) Di += M_ptr[-2*m_width+1]; // message (x,y-1)->(x,y)
  787. if (x < m_width - 1) Di += M_ptr[0]; // message (x+1,y)->(x,y)
  788. if (y < m_height - 1) Di += M_ptr[1]; // message (x,y+1)->(x,y)
  789. REAL DiScaled = Di * 0.5;
  790. if (x > 0)
  791. {
  792. Di = DiScaled - M_ptr[-2];
  793. CostVal lambda = m_horzWeightsBinary[n-1];
  794. if (lambda < 0) {
  795. Di = -Di;
  796. lambda = -lambda;
  797. }
  798. if (Di > lambda) M_ptr[-2] = lambda;
  799. else M_ptr[-2] = (Di < -lambda) ? -lambda : Di;
  800. }
  801. if (y > 0)
  802. {
  803. Di = DiScaled - M_ptr[-2*m_width+1];
  804. CostVal lambda = m_vertWeightsBinary[n-m_width];
  805. if (lambda < 0) {
  806. Di = -Di;
  807. lambda = -lambda;
  808. }
  809. if (Di > lambda) M_ptr[-2*m_width+1] = lambda;
  810. else M_ptr[-2*m_width+1] = (Di < -lambda) ? -lambda : Di;
  811. }
  812. }
  813. }
  814. m_lowerBound = 0;
  815. }
  816. void TRWS::optimize_GRID_FIXED_MATRIX(int nIterations)
  817. {
  818. int x, y, n, K = m_nLabels;
  819. CostVal* D_ptr;
  820. REAL* M_ptr;
  821. REAL* Di;
  822. void* buf;
  823. Di = new REAL[K];
  824. buf = new REAL[K];
  825. for ( ; nIterations > 0; nIterations --)
  826. {
  827. // forward pass
  828. n = 0;
  829. D_ptr = m_D;
  830. M_ptr = m_messages;
  831. for (y = 0; y < m_height; y++)
  832. for (x = 0; x < m_width; x++, D_ptr += K, M_ptr += 2 * K, n++)
  833. {
  834. CopyVector(Di, D_ptr, K);
  835. if (x > 0) AddVector(Di, M_ptr - 2*K, K); // message (x-1,y)->(x,y)
  836. if (y > 0) AddVector(Di, M_ptr - (2*m_width - 1)*K, K); // message (x,y-1)->(x,y)
  837. if (x < m_width - 1) AddVector(Di, M_ptr, K); // message (x+1,y)->(x,y)
  838. if (y < m_height - 1) AddVector(Di, M_ptr + K, K); // message (x,y+1)->(x,y)
  839. if (x < m_width - 1)
  840. {
  841. CostVal lambda = (m_varWeights) ? m_horzWeights[n] : 1;
  842. UpdateMessageFIXED_MATRIX(M_ptr, Di, K, 0.5, lambda, m_V, buf);
  843. }
  844. if (y < m_height - 1)
  845. {
  846. CostVal lambda = (m_varWeights) ? m_vertWeights[n] : 1;
  847. UpdateMessageFIXED_MATRIX(M_ptr + K, Di, K, 0.5, lambda, m_V, buf);
  848. }
  849. }
  850. // backward pass
  851. m_lowerBound = 0;
  852. n --;
  853. D_ptr -= K;
  854. M_ptr -= 2 * K;
  855. for (y = m_height - 1; y >= 0; y--)
  856. for (x = m_width - 1; x >= 0; x--, D_ptr -= K, M_ptr -= 2 * K, n--)
  857. {
  858. CopyVector(Di, D_ptr, K);
  859. if (x > 0) AddVector(Di, M_ptr - 2*K, K); // message (x-1,y)->(x,y)
  860. if (y > 0) AddVector(Di, M_ptr - (2*m_width - 1)*K, K); // message (x,y-1)->(x,y)
  861. if (x < m_width - 1) AddVector(Di, M_ptr, K); // message (x+1,y)->(x,y)
  862. if (y < m_height - 1) AddVector(Di, M_ptr + K, K); // message (x,y+1)->(x,y)
  863. m_lowerBound += SubtractMin(Di, K);
  864. if (x > 0)
  865. {
  866. CostVal lambda = (m_varWeights) ? m_horzWeights[n-1] : 1;
  867. m_lowerBound += UpdateMessageFIXED_MATRIX(M_ptr - 2 * K, Di, K, 0.5, lambda, m_V, buf);
  868. }
  869. if (y > 0)
  870. {
  871. CostVal lambda = (m_varWeights) ? m_vertWeights[n-m_width] : 1;
  872. m_lowerBound += UpdateMessageFIXED_MATRIX(M_ptr - (2 * m_width - 1) * K, Di, K, 0.5, lambda, m_V, buf);
  873. }
  874. }
  875. }
  876. delete [] Di;
  877. delete [] (REAL *)buf;
  878. }
  879. void TRWS::optimize_GRID_GENERAL(int nIterations)
  880. {
  881. int x, y, n, K = m_nLabels;
  882. CostVal* D_ptr;
  883. REAL* M_ptr;
  884. REAL* Di;
  885. void* buf;
  886. Di = new REAL[K];
  887. buf = new REAL[K];
  888. for ( ; nIterations > 0; nIterations --)
  889. {
  890. // forward pass
  891. n = 0;
  892. D_ptr = m_D;
  893. M_ptr = m_messages;
  894. CostVal* V_ptr = m_V;
  895. for (y = 0; y < m_height; y++)
  896. for (x = 0; x < m_width; x++, D_ptr += K, M_ptr += 2 * K, V_ptr += 2 * K * K, n++)
  897. {
  898. CopyVector(Di, D_ptr, K);
  899. if (x > 0) AddVector(Di, M_ptr - 2*K, K); // message (x-1,y)->(x,y)
  900. if (y > 0) AddVector(Di, M_ptr - (2*m_width - 1)*K, K); // message (x,y-1)->(x,y)
  901. if (x < m_width - 1) AddVector(Di, M_ptr, K); // message (x+1,y)->(x,y)
  902. if (y < m_height - 1) AddVector(Di, M_ptr + K, K); // message (x,y+1)->(x,y)
  903. if (x < m_width - 1)
  904. {
  905. if (m_V) UpdateMessageGENERAL(M_ptr, Di, K, 0.5, /* forward dir*/ 0, V_ptr, buf);
  906. else UpdateMessageGENERAL(M_ptr, Di, K, 0.5, m_smoothFn, n, n + 1, buf);
  907. }
  908. if (y < m_height - 1)
  909. {
  910. if (m_V) UpdateMessageGENERAL(M_ptr + K, Di, K, 0.5, /* forward dir*/ 0, V_ptr + K*K, buf);
  911. else UpdateMessageGENERAL(M_ptr + K, Di, K, 0.5, m_smoothFn, n, n + m_width, buf);
  912. }
  913. }
  914. // backward pass
  915. m_lowerBound = 0;
  916. n --;
  917. D_ptr -= K;
  918. M_ptr -= 2 * K;
  919. V_ptr -= 2 * K * K;
  920. for (y = m_height - 1; y >= 0; y--)
  921. for (x = m_width - 1; x >= 0; x--, D_ptr -= K, M_ptr -= 2 * K, V_ptr -= 2 * K * K, n--)
  922. {
  923. CopyVector(Di, D_ptr, K);
  924. if (x > 0) AddVector(Di, M_ptr - 2*K, K); // message (x-1,y)->(x,y)
  925. if (y > 0) AddVector(Di, M_ptr - (2*m_width - 1)*K, K); // message (x,y-1)->(x,y)
  926. if (x < m_width - 1) AddVector(Di, M_ptr, K); // message (x+1,y)->(x,y)
  927. if (y < m_height - 1) AddVector(Di, M_ptr + K, K); // message (x,y+1)->(x,y)
  928. // normalize Di, update lower bound
  929. m_lowerBound += SubtractMin(Di, K);
  930. if (x > 0)
  931. {
  932. if (m_V) m_lowerBound += UpdateMessageGENERAL(M_ptr - 2 * K, Di, K, 0.5, /* backward dir */ 1, V_ptr - 2 * K * K, buf);
  933. else m_lowerBound += UpdateMessageGENERAL(M_ptr - 2 * K, Di, K, 0.5, m_smoothFn, n, n - 1, buf);
  934. }
  935. if (y > 0)
  936. {
  937. if (m_V) m_lowerBound += UpdateMessageGENERAL(M_ptr - (2 * m_width - 1) * K, Di, K, 0.5, /* backward dir */ 1, V_ptr - (2 * m_width - 1) * K * K, buf);
  938. else m_lowerBound += UpdateMessageGENERAL(M_ptr - (2 * m_width - 1) * K, Di, K, 0.5, m_smoothFn, n, n - m_width, buf);
  939. }
  940. }
  941. }
  942. delete [] Di;
  943. delete [] (REAL *)buf;
  944. }