regions-maxprod.cpp 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708
  1. // (C) 2002 Marshall Tappen, MIT AI Lab mtappen@mit.edu
  2. #include <limits>
  3. #include <stdio.h>
  4. #include "MaxProdBP.h"
  5. using namespace OBJREC;
  6. int numIterRun;
  7. // Some of the GBP code has been disabled here
  8. #define mexPrintf printf
  9. #define mexErrMsgTxt printf
  10. #define UP 0
  11. #define DOWN 1
  12. #define LEFT 2
  13. #define RIGHT 3
  14. OneNodeCluster::OneNodeCluster()
  15. {
  16. }
  17. int OneNodeCluster::numStates;
  18. namespace OBJREC {
  19. FLOATTYPE vec_min(FLOATTYPE *vec, int length)
  20. {
  21. FLOATTYPE min = vec[0];
  22. for (int i = 0; i < length; i++)
  23. if (vec[i] < min)
  24. min = vec[i];
  25. return min;
  26. }
  27. FLOATTYPE vec_max(FLOATTYPE *vec, int length)
  28. {
  29. FLOATTYPE max = vec[0];
  30. for (int i = 0; i < length; i++)
  31. if (vec[i] > max)
  32. max = vec[i];
  33. return max;
  34. }
  35. void getPsiMat(OneNodeCluster &/*cluster*/, FLOATTYPE *&destMatrix,
  36. int r, int c, MaxProdBP *mrf, int direction, FLOATTYPE &var_weight)
  37. {
  38. int mrfHeight = mrf->getHeight();
  39. int mrfWidth = mrf->getWidth();
  40. int numLabels = mrf->getNLabels();
  41. int x = c;
  42. int y = r;
  43. int i;
  44. FLOATTYPE *currMatrix = mrf->getScratchMatrix();
  45. if (mrf->getSmoothType() != MRF::FUNCTION)
  46. {
  47. if (((direction == UP) && (r == 0)) ||
  48. ((direction == DOWN) && (r == (mrfHeight - 1))) ||
  49. ((direction == LEFT) && (c == 0)) ||
  50. ((direction == RIGHT) && (c == (mrfWidth - 1))))
  51. {
  52. for ( i = 0; i < numLabels * numLabels; i++)
  53. {
  54. currMatrix[i] = 0;
  55. }
  56. }
  57. else
  58. {
  59. MRF::CostVal weight_mod = 1;
  60. if (mrf->varWeights())
  61. {
  62. if (direction == LEFT)
  63. weight_mod = mrf->getHorizWeight(r, c - 1);
  64. else if (direction == RIGHT)
  65. weight_mod = mrf->getHorizWeight(r, c);
  66. else if (direction == UP)
  67. weight_mod = mrf->getVertWeight(r - 1, c);
  68. else if (direction == DOWN)
  69. weight_mod = mrf->getVertWeight(r, c);
  70. }
  71. for ( i = 0; i < numLabels*numLabels; i++)
  72. {
  73. if (weight_mod != 1)
  74. {
  75. currMatrix[i] = FLOATTYPE(mrf->m_V[i] * weight_mod);
  76. }
  77. else
  78. currMatrix[i] = FLOATTYPE(mrf->m_V[i]);
  79. }
  80. destMatrix = currMatrix;
  81. var_weight = (float)weight_mod;
  82. }
  83. }
  84. else
  85. {
  86. if (((direction == UP) && (r == 0)) ||
  87. ((direction == DOWN) && (r == (mrfHeight - 1))) ||
  88. ((direction == LEFT) && (c == 0)) ||
  89. ((direction == RIGHT) && (c == (mrfWidth - 1))))
  90. {
  91. for ( i = 0; i < numLabels * numLabels; i++)
  92. {
  93. currMatrix[i] = 0;
  94. }
  95. }
  96. else
  97. {
  98. for ( i = 0; i < numLabels; i++)
  99. {
  100. for (int j = 0; j < numLabels; j++)
  101. {
  102. MRF::CostVal cCost;
  103. if (direction == LEFT)
  104. cCost = mrf->m_smoothFn(x + y * mrf->m_width,
  105. x + y * mrf->m_width - 1 , j, i);
  106. else if (direction == RIGHT)
  107. cCost = mrf->m_smoothFn(x + y * mrf->m_width,
  108. x + y * mrf->m_width + 1 , i, j);
  109. else if (direction == UP)
  110. cCost = mrf->m_smoothFn(x + y * mrf->m_width,
  111. x + (y - 1) * mrf->m_width , j, i);
  112. else if (direction == DOWN)
  113. cCost = mrf->m_smoothFn(x + y * mrf->m_width,
  114. x + (y + 1) * mrf->m_width , i, j);
  115. else
  116. {
  117. cCost = mrf->m_smoothFn(x + y * mrf->m_width,
  118. x + (y + 1) * mrf->m_width - 1 , j, i);
  119. assert(0);
  120. }
  121. currMatrix[i*numLabels+j] = (float)cCost;
  122. }
  123. }
  124. }
  125. }
  126. destMatrix = currMatrix;
  127. }
  128. void getVarWeight(OneNodeCluster &/*cluster*/, int r, int c, MaxProdBP *mrf, int direction, FLOATTYPE &var_weight)
  129. {
  130. MRF::CostVal weight_mod = 1;
  131. if (mrf->varWeights())
  132. {
  133. if (direction == LEFT)
  134. weight_mod = mrf->getHorizWeight(r, c - 1);
  135. else if (direction == RIGHT)
  136. weight_mod = mrf->getHorizWeight(r, c);
  137. else if (direction == UP)
  138. weight_mod = mrf->getVertWeight(r - 1, c);
  139. else if (direction == DOWN)
  140. weight_mod = mrf->getVertWeight(r, c);
  141. }
  142. var_weight = (FLOATTYPE) weight_mod;
  143. // printf("%d\n",weight_mod);
  144. }
  145. void initOneNodeMsgMem(OneNodeCluster *nodeArray, FLOATTYPE *memChunk,
  146. const int numNodes, const int msgChunkSize)
  147. {
  148. FLOATTYPE *currPtr = memChunk;
  149. OneNodeCluster *currNode = nodeArray;
  150. FLOATTYPE *nextRoundChunk = new FLOATTYPE[nodeArray[1].numStates];
  151. // MEMORY LEAK? where does this ever get deleted??
  152. for (int i = 0; i < numNodes; i++)
  153. {
  154. currNode->receivedMsgs[0] = currPtr;
  155. currPtr += msgChunkSize;
  156. currNode->receivedMsgs[1] = currPtr;
  157. currPtr += msgChunkSize;
  158. currNode->receivedMsgs[2] = currPtr;
  159. currPtr += msgChunkSize;
  160. currNode->receivedMsgs[3] = currPtr;
  161. currPtr += msgChunkSize;
  162. currNode->nextRoundReceivedMsgs[0] = nextRoundChunk;
  163. currNode->nextRoundReceivedMsgs[1] = nextRoundChunk;
  164. currNode->nextRoundReceivedMsgs[2] = nextRoundChunk;
  165. currNode->nextRoundReceivedMsgs[3] = nextRoundChunk;
  166. currNode++;
  167. }
  168. }
  169. inline void l1_dist_trans_comp(FLOATTYPE smoothMax, FLOATTYPE c, FLOATTYPE* tmpMsgDest, FLOATTYPE * msgProd, int numStates)
  170. {
  171. int q;
  172. for (int i = 0; i < numStates; i++)
  173. tmpMsgDest[i] = msgProd[i];
  174. for (q = 1; q <= numStates - 1; q++)
  175. {
  176. if (tmpMsgDest[q] > tmpMsgDest[q-1] + c)
  177. tmpMsgDest[q] = tmpMsgDest[q-1] + c;
  178. }
  179. for (q = numStates - 2; q >= 0; q--)
  180. {
  181. if (tmpMsgDest[q] > tmpMsgDest[q+1] + c)
  182. tmpMsgDest[q] = tmpMsgDest[q+1] + c;
  183. }
  184. FLOATTYPE minPotts = msgProd[0] + smoothMax;
  185. for (q = 0; q <= numStates - 1; q++)
  186. {
  187. if ((msgProd[q] + smoothMax) < minPotts)
  188. minPotts = msgProd[q] + smoothMax;
  189. }
  190. for (q = 0; q <= numStates - 1; q++)
  191. {
  192. if ((tmpMsgDest[q]) > minPotts)
  193. tmpMsgDest[q] = minPotts;
  194. tmpMsgDest[q] = -tmpMsgDest[q];
  195. }
  196. // printf("%f %f %f\n",smoothMax,c,minPotts);
  197. }
  198. inline void l2_dist_trans_comp(FLOATTYPE smoothMax, FLOATTYPE c, FLOATTYPE* tmpMsgDest, FLOATTYPE * msgProd, int numStates)
  199. {
  200. FLOATTYPE *z = new FLOATTYPE[numStates];
  201. int *v = new int[numStates];
  202. int j = 0;
  203. FLOATTYPE INFINITY = std::numeric_limits<float>::infinity();
  204. z[0] = -1 * INFINITY;
  205. z[1] = INFINITY;
  206. v[0] = 0;
  207. int q;
  208. if (c == 0)
  209. {
  210. FLOATTYPE minVal = msgProd[0];
  211. for (q = 0; q < numStates; q++)
  212. {
  213. if (msgProd[q] < minVal)
  214. minVal = msgProd[q];
  215. }
  216. for (q = 0; q < numStates; q++)
  217. {
  218. tmpMsgDest[q] = -minVal;
  219. }
  220. delete [] z;
  221. delete [] v;
  222. return;
  223. }
  224. for (q = 1; q <= numStates - 1; q++)
  225. {
  226. FLOATTYPE s;
  227. while ( (s = ((msgProd[q] + c * q * q) - (msgProd[v[j]] + c * v[j] * v[j])) /
  228. (2 * c * q - 2 * c * v[j])) <= z[j])
  229. {
  230. j -= 1;
  231. }
  232. j += 1;
  233. v[j] = q;
  234. z[j] = s;
  235. z[j+1] = INFINITY;
  236. }
  237. j = 0;
  238. FLOATTYPE minPotts = msgProd[0] + smoothMax;
  239. for (q = 0; q <= numStates - 1; q++)
  240. {
  241. while (z[j+1] < q)
  242. {
  243. j += 1;
  244. }
  245. tmpMsgDest[q] = c * (q - v[j]) * (q - v[j]) + msgProd[v[j]];
  246. if ((msgProd[q] + smoothMax) < minPotts)
  247. minPotts = msgProd[q] + smoothMax;
  248. }
  249. for (q = 0; q <= numStates - 1; q++)
  250. {
  251. if ((tmpMsgDest[q]) > minPotts)
  252. tmpMsgDest[q] = minPotts;
  253. tmpMsgDest[q] = -tmpMsgDest[q];
  254. }
  255. delete [] z;
  256. delete [] v;
  257. }
  258. }
  259. void OneNodeCluster::ComputeMsgRight(FLOATTYPE *msgDest, int r, int c, MaxProdBP *mrf)
  260. {
  261. FLOATTYPE *nodeLeftMsg = receivedMsgs[LEFT],
  262. *nodeDownMsg = receivedMsgs[DOWN],
  263. *nodeUpMsg = receivedMsgs[UP];
  264. FLOATTYPE weight_mod;
  265. getVarWeight(*this, r, c, mrf, RIGHT, weight_mod);
  266. FLOATTYPE *tmpMsgDest = msgDest;
  267. if (mrf->m_type == MaxProdBP::L1 || mrf->m_type == MaxProdBP::L2)
  268. {
  269. FLOATTYPE *msgProd = new FLOATTYPE[numStates];
  270. const FLOATTYPE lambda = (FLOATTYPE)mrf->m_lambda;
  271. const FLOATTYPE smoothMax = (FLOATTYPE)mrf->m_smoothMax;
  272. for (int leftNodeInd = 0; leftNodeInd < numStates; leftNodeInd++)
  273. {
  274. msgProd[leftNodeInd] = -nodeLeftMsg[leftNodeInd] +
  275. -nodeUpMsg[leftNodeInd] +
  276. -nodeDownMsg[leftNodeInd] + localEv[leftNodeInd];
  277. }
  278. if (mrf->m_type == MaxProdBP::L1)
  279. {
  280. l1_dist_trans_comp( weight_mod*smoothMax*lambda, lambda*weight_mod, tmpMsgDest, msgProd, numStates);
  281. }
  282. else
  283. {
  284. l2_dist_trans_comp( weight_mod*smoothMax*lambda, lambda*weight_mod, tmpMsgDest, msgProd, numStates);
  285. }
  286. delete [] msgProd;
  287. }
  288. else if ((mrf->getSmoothType() == MRF::FUNCTION) || (mrf->getSmoothType() == MRF::ARRAY))
  289. {
  290. FLOATTYPE *psiMat, var_weight;
  291. getPsiMat(*this, psiMat, r, c, mrf, RIGHT, var_weight);
  292. FLOATTYPE *cmessage = msgDest;
  293. for (int rightNodeInd = 0; rightNodeInd < numStates; rightNodeInd++)
  294. {
  295. *cmessage = 0;
  296. for (int leftNodeInd = 0; leftNodeInd < numStates; leftNodeInd++)
  297. {
  298. FLOATTYPE tmp = nodeLeftMsg[leftNodeInd] +
  299. nodeUpMsg[leftNodeInd] +
  300. nodeDownMsg[leftNodeInd]
  301. - localEv[leftNodeInd]
  302. - psiMat[leftNodeInd * numStates + rightNodeInd];
  303. if ((tmp > *cmessage) || (leftNodeInd == 0))
  304. *cmessage = tmp;
  305. }
  306. cmessage++;
  307. }
  308. }
  309. else {
  310. fprintf(stderr, "not implemented!\n");
  311. exit(1);
  312. }
  313. FLOATTYPE max = msgDest[0];
  314. for (int i = 0; i < numStates; i++)
  315. msgDest[i] -= max;
  316. }
  317. // This means, "Compute the message to send left."
  318. void OneNodeCluster::ComputeMsgLeft(FLOATTYPE *msgDest, int r, int c, MaxProdBP *mrf)
  319. {
  320. FLOATTYPE *nodeRightMsg = receivedMsgs[RIGHT],
  321. *nodeDownMsg = receivedMsgs[DOWN],
  322. *nodeUpMsg = receivedMsgs[UP];
  323. int do_dist = (int)(mrf->getSmoothType() == MRF::THREE_PARAM);
  324. FLOATTYPE *tmpMsgDest = msgDest;
  325. if (do_dist)
  326. {
  327. FLOATTYPE weight_mod;
  328. getVarWeight(*this, r, c, mrf, LEFT, weight_mod);
  329. FLOATTYPE *msgProd = new FLOATTYPE[numStates];
  330. const FLOATTYPE lambda = (FLOATTYPE)mrf->m_lambda;
  331. const FLOATTYPE smoothMax = (FLOATTYPE)mrf->m_smoothMax;
  332. for (int rightNodeInd = 0; rightNodeInd < numStates; rightNodeInd++)
  333. {
  334. msgProd[rightNodeInd] = -nodeRightMsg[rightNodeInd] +
  335. -nodeUpMsg[rightNodeInd] +
  336. -nodeDownMsg[rightNodeInd]
  337. + localEv[rightNodeInd] ;
  338. }
  339. if (mrf->m_smoothExp == 1)
  340. {
  341. l1_dist_trans_comp( weight_mod*smoothMax*lambda, lambda*weight_mod, tmpMsgDest, msgProd, numStates);
  342. }
  343. else
  344. l2_dist_trans_comp( weight_mod*smoothMax*lambda, lambda*weight_mod, tmpMsgDest, msgProd, numStates);
  345. delete [] msgProd;
  346. }
  347. else if ((mrf->getSmoothType() == MRF::FUNCTION) || (mrf->getSmoothType() == MRF::ARRAY))
  348. {
  349. FLOATTYPE *psiMat, var_weight;
  350. getPsiMat(*this, psiMat, r, c, mrf, LEFT, var_weight);
  351. FLOATTYPE *cmessage = msgDest;
  352. for (int leftNodeInd = 0; leftNodeInd < numStates; leftNodeInd++)
  353. {
  354. *cmessage = 0;
  355. for (int rightNodeInd = 0; rightNodeInd < numStates; rightNodeInd++)
  356. {
  357. FLOATTYPE tmp = nodeRightMsg[rightNodeInd] +
  358. nodeUpMsg[rightNodeInd] +
  359. nodeDownMsg[rightNodeInd]
  360. - localEv[rightNodeInd]
  361. - psiMat[leftNodeInd * numStates + rightNodeInd] ;
  362. if ((tmp > *cmessage) || (rightNodeInd == 0))
  363. *cmessage = tmp;
  364. }
  365. cmessage++;
  366. }
  367. }
  368. else
  369. assert(0);
  370. // FLOATTYPE max = vec_max(msgDest,numStates);
  371. FLOATTYPE max = msgDest[0];
  372. for (int i = 0; i < numStates; i++)
  373. msgDest[i] -= max;
  374. }
  375. void OneNodeCluster::ComputeMsgUp(FLOATTYPE *msgDest, int r, int c, MaxProdBP *mrf)
  376. {
  377. FLOATTYPE *nodeRightMsg = receivedMsgs[RIGHT],
  378. *nodeDownMsg = receivedMsgs[DOWN],
  379. *nodeLeftMsg = receivedMsgs[LEFT];
  380. int do_dist = (int)(mrf->getSmoothType() == MRF::THREE_PARAM);
  381. if (do_dist)
  382. {
  383. FLOATTYPE weight_mod;
  384. getVarWeight(*this, r, c, mrf, UP, weight_mod);
  385. FLOATTYPE *tmpMsgDest = msgDest;
  386. FLOATTYPE *msgProd = new FLOATTYPE[numStates];
  387. const FLOATTYPE lambda = (FLOATTYPE)mrf->m_lambda;
  388. const FLOATTYPE smoothMax = (FLOATTYPE)mrf->m_smoothMax;
  389. for (int downNodeInd = 0; downNodeInd < numStates; downNodeInd++)
  390. {
  391. msgProd[downNodeInd] = -nodeRightMsg[downNodeInd] +
  392. -nodeLeftMsg[downNodeInd] +
  393. -nodeDownMsg[downNodeInd] +
  394. + localEv[downNodeInd] ;
  395. }
  396. // printf("%f %f %f %f\n",nodeLeftMsg[leftNodeInd] ,
  397. // nodeUpMsg[leftNodeInd] ,
  398. // nodeDownMsg[leftNodeInd] ,localEv[leftNodeInd]);
  399. if (mrf->m_smoothExp == 1)
  400. {
  401. l1_dist_trans_comp( weight_mod*smoothMax*lambda, lambda*weight_mod, tmpMsgDest, msgProd, numStates);
  402. }
  403. else
  404. l2_dist_trans_comp( weight_mod*smoothMax*lambda, lambda*weight_mod, tmpMsgDest, msgProd, numStates);
  405. delete [] msgProd;
  406. }
  407. else if ((mrf->getSmoothType() == MRF::FUNCTION) || (mrf->getSmoothType() == MRF::ARRAY))
  408. {
  409. FLOATTYPE *psiMat, var_weight;
  410. getPsiMat(*this, psiMat, r, c, mrf, UP, var_weight);
  411. FLOATTYPE *cmessage = msgDest;
  412. for (int upNodeInd = 0; upNodeInd < numStates; upNodeInd++)
  413. {
  414. *cmessage = 0;
  415. for (int downNodeInd = 0; downNodeInd < numStates; downNodeInd++)
  416. {
  417. FLOATTYPE tmp = nodeRightMsg[downNodeInd] +
  418. nodeLeftMsg[downNodeInd] +
  419. nodeDownMsg[downNodeInd] +
  420. -localEv[downNodeInd]
  421. - psiMat[upNodeInd * numStates + downNodeInd] ;
  422. if ((tmp > *cmessage) || (downNodeInd == 0))
  423. *cmessage = tmp;
  424. }
  425. cmessage++;
  426. }
  427. }
  428. else
  429. assert(0);
  430. FLOATTYPE max = msgDest[0];
  431. // FLOATTYPE max = vec_max(msgDest,numStates);
  432. for (int i = 0; i < numStates; i++)
  433. msgDest[i] -= max;
  434. }
  435. void OneNodeCluster::ComputeMsgDown(FLOATTYPE *msgDest, int r, int c, MaxProdBP *mrf)
  436. {
  437. FLOATTYPE *nodeRightMsg = receivedMsgs[RIGHT],
  438. *nodeUpMsg = receivedMsgs[UP],
  439. *nodeLeftMsg = receivedMsgs[LEFT];
  440. int do_dist = (int)(mrf->getSmoothType() == MRF::THREE_PARAM);
  441. if (do_dist)
  442. {
  443. FLOATTYPE weight_mod;
  444. getVarWeight(*this, r, c, mrf, DOWN, weight_mod);
  445. FLOATTYPE *tmpMsgDest = msgDest;
  446. FLOATTYPE *msgProd = new FLOATTYPE[numStates];
  447. const FLOATTYPE lambda = (FLOATTYPE)mrf->m_lambda;
  448. const FLOATTYPE smoothMax = (FLOATTYPE)mrf->m_smoothMax;
  449. for (int upNodeInd = 0; upNodeInd < numStates; upNodeInd++)
  450. {
  451. msgProd[upNodeInd] = -nodeRightMsg[upNodeInd] +
  452. -nodeLeftMsg[upNodeInd] +
  453. -nodeUpMsg[upNodeInd] +
  454. + localEv[upNodeInd] ;
  455. }
  456. if (mrf->m_smoothExp == 1)
  457. {
  458. l1_dist_trans_comp( weight_mod*smoothMax*lambda, lambda*weight_mod, tmpMsgDest, msgProd, numStates);
  459. }
  460. else
  461. l2_dist_trans_comp( weight_mod*smoothMax*lambda, lambda*weight_mod, tmpMsgDest, msgProd, numStates);
  462. delete [] msgProd;
  463. }
  464. else if ((mrf->getSmoothType() == MRF::FUNCTION) || (mrf->getSmoothType() == MRF::ARRAY))
  465. {
  466. FLOATTYPE *psiMat, var_weight;
  467. getPsiMat(*this, psiMat, r, c, mrf, DOWN, var_weight);
  468. FLOATTYPE *cmessage = msgDest;
  469. for (int downNodeInd = 0; downNodeInd < numStates; downNodeInd++)
  470. {
  471. *cmessage = 0;
  472. for (int upNodeInd = 0; upNodeInd < numStates; upNodeInd++)
  473. {
  474. FLOATTYPE tmp = nodeRightMsg[upNodeInd] +
  475. nodeLeftMsg[upNodeInd] +
  476. nodeUpMsg[upNodeInd] +
  477. -localEv[upNodeInd]
  478. - psiMat[upNodeInd * numStates + downNodeInd] ;
  479. if ((tmp > *cmessage) || (upNodeInd == 0))
  480. *cmessage = tmp;
  481. }
  482. cmessage++;
  483. }
  484. }
  485. else
  486. assert(0);
  487. FLOATTYPE max = msgDest[0];
  488. // FLOATTYPE max = vec_max(msgDest,numStates);
  489. for (int i = 0; i < numStates; i++)
  490. msgDest[i] -= max;
  491. }
  492. void OneNodeCluster::getBelief(FLOATTYPE *beliefVec)
  493. {
  494. for (int i = 0; i < numStates; i++)
  495. {
  496. beliefVec[i] = receivedMsgs[UP][i] + receivedMsgs[DOWN][i] +
  497. receivedMsgs[LEFT][i] + receivedMsgs[RIGHT][i] - localEv[i];
  498. }
  499. }
  500. int OneNodeCluster::getBeliefMaxInd()
  501. {
  502. FLOATTYPE currBelief, bestBelief;
  503. int bestInd = 0;
  504. {
  505. int i = 0;
  506. bestBelief = receivedMsgs[UP][i] + receivedMsgs[DOWN][i] +
  507. receivedMsgs[LEFT][i] + receivedMsgs[RIGHT][i] - localEv[i];
  508. }
  509. for (int i = 1; i < numStates; i++)
  510. {
  511. currBelief = receivedMsgs[UP][i] + receivedMsgs[DOWN][i] +
  512. receivedMsgs[LEFT][i] + receivedMsgs[RIGHT][i] - localEv[i];
  513. if (currBelief > bestBelief)
  514. {
  515. bestInd = i;
  516. bestBelief = currBelief;
  517. }
  518. }
  519. return bestInd;
  520. }
  521. namespace OBJREC {
  522. void computeMessagesLeftRight(OneNodeCluster *nodeArray, const int numCols, const int /*numRows*/, const int currRow, const FLOATTYPE alpha, MaxProdBP *mrf)
  523. {
  524. const int numStates = OneNodeCluster::numStates;
  525. const FLOATTYPE omalpha = 1.0f - alpha;
  526. int i;
  527. int col;
  528. for ( col = 0; col < numCols - 1; col++)
  529. {
  530. nodeArray[currRow * numCols + col].ComputeMsgRight(nodeArray[currRow * numCols + col+1].nextRoundReceivedMsgs[LEFT], currRow, col, mrf);
  531. for (i = 0; i < numStates; i++)
  532. {
  533. nodeArray[currRow * numCols + col+1].receivedMsgs[LEFT][i] =
  534. omalpha * nodeArray[currRow * numCols + col+1].receivedMsgs[LEFT][i] +
  535. alpha * nodeArray[currRow * numCols + col+1].nextRoundReceivedMsgs[LEFT][i];
  536. }
  537. }
  538. for ( col = numCols - 1; col > 0; col--)
  539. {
  540. nodeArray[currRow * numCols + col].ComputeMsgLeft(nodeArray[currRow * numCols + col-1].nextRoundReceivedMsgs[RIGHT], currRow, col, mrf);
  541. for (i = 0; i < numStates; i++)
  542. {
  543. nodeArray[currRow * numCols + col-1].receivedMsgs[RIGHT][i] =
  544. omalpha * nodeArray[currRow * numCols + col-1].receivedMsgs[RIGHT][i] +
  545. alpha * nodeArray[currRow * numCols + col-1].nextRoundReceivedMsgs[RIGHT][i];
  546. }
  547. }
  548. }
  549. void computeMessagesUpDown(OneNodeCluster *nodeArray, const int numCols, const int numRows, const int currCol, const FLOATTYPE alpha, MaxProdBP *mrf)
  550. {
  551. const int numStates = OneNodeCluster::numStates;
  552. const FLOATTYPE omalpha = 1.0f - alpha;
  553. int i;
  554. int row;
  555. for (row = 0; row < numRows - 1; row++)
  556. {
  557. nodeArray[row * numCols + currCol].ComputeMsgDown(nodeArray[(row+1) * numCols + currCol].nextRoundReceivedMsgs[UP], row, currCol, mrf);
  558. for (i = 0; i < numStates; i++)
  559. {
  560. nodeArray[(row+1) * numCols + currCol].receivedMsgs[UP][i] =
  561. omalpha * nodeArray[(row+1) * numCols + currCol].receivedMsgs[UP][i] +
  562. alpha * nodeArray[(row+1) * numCols + currCol].nextRoundReceivedMsgs[UP][i];
  563. }
  564. }
  565. for ( row = numRows - 1; row > 0; row--)
  566. {
  567. nodeArray[row * numCols + currCol].ComputeMsgUp(nodeArray[(row-1) * numCols + currCol].nextRoundReceivedMsgs[DOWN], row, currCol, mrf);
  568. for (i = 0; i < numStates; i++)
  569. {
  570. nodeArray[(row-1) * numCols + currCol].receivedMsgs[DOWN][i] =
  571. omalpha * nodeArray[(row-1) * numCols + currCol].receivedMsgs[DOWN][i] +
  572. alpha * nodeArray[(row-1) * numCols + currCol].nextRoundReceivedMsgs[DOWN][i];
  573. }
  574. }
  575. }
  576. }