123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708 |
- // (C) 2002 Marshall Tappen, MIT AI Lab mtappen@mit.edu
- #include <limits>
- #include <stdio.h>
- #include "MaxProdBP.h"
- using namespace OBJREC;
- int numIterRun;
- // Some of the GBP code has been disabled here
- #define mexPrintf printf
- #define mexErrMsgTxt printf
- #define UP 0
- #define DOWN 1
- #define LEFT 2
- #define RIGHT 3
- OneNodeCluster::OneNodeCluster()
- {
- }
- int OneNodeCluster::numStates;
- namespace OBJREC {
- FLOATTYPE vec_min(FLOATTYPE *vec, int length)
- {
- FLOATTYPE min = vec[0];
- for (int i = 0; i < length; i++)
- if (vec[i] < min)
- min = vec[i];
- return min;
- }
- FLOATTYPE vec_max(FLOATTYPE *vec, int length)
- {
- FLOATTYPE max = vec[0];
- for (int i = 0; i < length; i++)
- if (vec[i] > max)
- max = vec[i];
- return max;
- }
- void getPsiMat(OneNodeCluster &/*cluster*/, FLOATTYPE *&destMatrix,
- int r, int c, MaxProdBP *mrf, int direction, FLOATTYPE &var_weight)
- {
- int mrfHeight = mrf->getHeight();
- int mrfWidth = mrf->getWidth();
- int numLabels = mrf->getNLabels();
- int x = c;
- int y = r;
- int i;
- FLOATTYPE *currMatrix = mrf->getScratchMatrix();
- if (mrf->getSmoothType() != MRF::FUNCTION)
- {
- if (((direction == UP) && (r == 0)) ||
- ((direction == DOWN) && (r == (mrfHeight - 1))) ||
- ((direction == LEFT) && (c == 0)) ||
- ((direction == RIGHT) && (c == (mrfWidth - 1))))
- {
- for ( i = 0; i < numLabels * numLabels; i++)
- {
- currMatrix[i] = 0;
- }
- }
- else
- {
- MRF::CostVal weight_mod = 1;
- if (mrf->varWeights())
- {
- if (direction == LEFT)
- weight_mod = mrf->getHorizWeight(r, c - 1);
- else if (direction == RIGHT)
- weight_mod = mrf->getHorizWeight(r, c);
- else if (direction == UP)
- weight_mod = mrf->getVertWeight(r - 1, c);
- else if (direction == DOWN)
- weight_mod = mrf->getVertWeight(r, c);
- }
- for ( i = 0; i < numLabels*numLabels; i++)
- {
- if (weight_mod != 1)
- {
- currMatrix[i] = FLOATTYPE(mrf->m_V[i] * weight_mod);
- }
- else
- currMatrix[i] = FLOATTYPE(mrf->m_V[i]);
- }
- destMatrix = currMatrix;
- var_weight = (float)weight_mod;
- }
- }
- else
- {
- if (((direction == UP) && (r == 0)) ||
- ((direction == DOWN) && (r == (mrfHeight - 1))) ||
- ((direction == LEFT) && (c == 0)) ||
- ((direction == RIGHT) && (c == (mrfWidth - 1))))
- {
- for ( i = 0; i < numLabels * numLabels; i++)
- {
- currMatrix[i] = 0;
- }
- }
- else
- {
- for ( i = 0; i < numLabels; i++)
- {
- for (int j = 0; j < numLabels; j++)
- {
- MRF::CostVal cCost;
- if (direction == LEFT)
- cCost = mrf->m_smoothFn(x + y * mrf->m_width,
- x + y * mrf->m_width - 1 , j, i);
- else if (direction == RIGHT)
- cCost = mrf->m_smoothFn(x + y * mrf->m_width,
- x + y * mrf->m_width + 1 , i, j);
- else if (direction == UP)
- cCost = mrf->m_smoothFn(x + y * mrf->m_width,
- x + (y - 1) * mrf->m_width , j, i);
- else if (direction == DOWN)
- cCost = mrf->m_smoothFn(x + y * mrf->m_width,
- x + (y + 1) * mrf->m_width , i, j);
- else
- {
- cCost = mrf->m_smoothFn(x + y * mrf->m_width,
- x + (y + 1) * mrf->m_width - 1 , j, i);
- assert(0);
- }
- currMatrix[i*numLabels+j] = (float)cCost;
- }
- }
- }
- }
- destMatrix = currMatrix;
- }
- void getVarWeight(OneNodeCluster &/*cluster*/, int r, int c, MaxProdBP *mrf, int direction, FLOATTYPE &var_weight)
- {
- MRF::CostVal weight_mod = 1;
- if (mrf->varWeights())
- {
- if (direction == LEFT)
- weight_mod = mrf->getHorizWeight(r, c - 1);
- else if (direction == RIGHT)
- weight_mod = mrf->getHorizWeight(r, c);
- else if (direction == UP)
- weight_mod = mrf->getVertWeight(r - 1, c);
- else if (direction == DOWN)
- weight_mod = mrf->getVertWeight(r, c);
- }
- var_weight = (FLOATTYPE) weight_mod;
- // printf("%d\n",weight_mod);
- }
- void initOneNodeMsgMem(OneNodeCluster *nodeArray, FLOATTYPE *memChunk,
- const int numNodes, const int msgChunkSize)
- {
- FLOATTYPE *currPtr = memChunk;
- OneNodeCluster *currNode = nodeArray;
- FLOATTYPE *nextRoundChunk = new FLOATTYPE[nodeArray[1].numStates];
- // MEMORY LEAK? where does this ever get deleted??
- for (int i = 0; i < numNodes; i++)
- {
- currNode->receivedMsgs[0] = currPtr;
- currPtr += msgChunkSize;
- currNode->receivedMsgs[1] = currPtr;
- currPtr += msgChunkSize;
- currNode->receivedMsgs[2] = currPtr;
- currPtr += msgChunkSize;
- currNode->receivedMsgs[3] = currPtr;
- currPtr += msgChunkSize;
- currNode->nextRoundReceivedMsgs[0] = nextRoundChunk;
- currNode->nextRoundReceivedMsgs[1] = nextRoundChunk;
- currNode->nextRoundReceivedMsgs[2] = nextRoundChunk;
- currNode->nextRoundReceivedMsgs[3] = nextRoundChunk;
- currNode++;
- }
- }
- inline void l1_dist_trans_comp(FLOATTYPE smoothMax, FLOATTYPE c, FLOATTYPE* tmpMsgDest, FLOATTYPE * msgProd, int numStates)
- {
- int q;
- for (int i = 0; i < numStates; i++)
- tmpMsgDest[i] = msgProd[i];
- for (q = 1; q <= numStates - 1; q++)
- {
- if (tmpMsgDest[q] > tmpMsgDest[q-1] + c)
- tmpMsgDest[q] = tmpMsgDest[q-1] + c;
- }
- for (q = numStates - 2; q >= 0; q--)
- {
- if (tmpMsgDest[q] > tmpMsgDest[q+1] + c)
- tmpMsgDest[q] = tmpMsgDest[q+1] + c;
- }
- FLOATTYPE minPotts = msgProd[0] + smoothMax;
- for (q = 0; q <= numStates - 1; q++)
- {
- if ((msgProd[q] + smoothMax) < minPotts)
- minPotts = msgProd[q] + smoothMax;
- }
- for (q = 0; q <= numStates - 1; q++)
- {
- if ((tmpMsgDest[q]) > minPotts)
- tmpMsgDest[q] = minPotts;
- tmpMsgDest[q] = -tmpMsgDest[q];
- }
- // printf("%f %f %f\n",smoothMax,c,minPotts);
- }
- inline void l2_dist_trans_comp(FLOATTYPE smoothMax, FLOATTYPE c, FLOATTYPE* tmpMsgDest, FLOATTYPE * msgProd, int numStates)
- {
- FLOATTYPE *z = new FLOATTYPE[numStates];
- int *v = new int[numStates];
- int j = 0;
- FLOATTYPE INFINITY = std::numeric_limits<float>::infinity();
- z[0] = -1 * INFINITY;
- z[1] = INFINITY;
- v[0] = 0;
- int q;
- if (c == 0)
- {
- FLOATTYPE minVal = msgProd[0];
- for (q = 0; q < numStates; q++)
- {
- if (msgProd[q] < minVal)
- minVal = msgProd[q];
- }
- for (q = 0; q < numStates; q++)
- {
- tmpMsgDest[q] = -minVal;
- }
- delete [] z;
- delete [] v;
- return;
- }
- for (q = 1; q <= numStates - 1; q++)
- {
- FLOATTYPE s;
- while ( (s = ((msgProd[q] + c * q * q) - (msgProd[v[j]] + c * v[j] * v[j])) /
- (2 * c * q - 2 * c * v[j])) <= z[j])
- {
- j -= 1;
- }
- j += 1;
- v[j] = q;
- z[j] = s;
- z[j+1] = INFINITY;
- }
- j = 0;
- FLOATTYPE minPotts = msgProd[0] + smoothMax;
- for (q = 0; q <= numStates - 1; q++)
- {
- while (z[j+1] < q)
- {
- j += 1;
- }
- tmpMsgDest[q] = c * (q - v[j]) * (q - v[j]) + msgProd[v[j]];
- if ((msgProd[q] + smoothMax) < minPotts)
- minPotts = msgProd[q] + smoothMax;
- }
- for (q = 0; q <= numStates - 1; q++)
- {
- if ((tmpMsgDest[q]) > minPotts)
- tmpMsgDest[q] = minPotts;
- tmpMsgDest[q] = -tmpMsgDest[q];
- }
- delete [] z;
- delete [] v;
- }
- }
- void OneNodeCluster::ComputeMsgRight(FLOATTYPE *msgDest, int r, int c, MaxProdBP *mrf)
- {
- FLOATTYPE *nodeLeftMsg = receivedMsgs[LEFT],
- *nodeDownMsg = receivedMsgs[DOWN],
- *nodeUpMsg = receivedMsgs[UP];
- FLOATTYPE weight_mod;
- getVarWeight(*this, r, c, mrf, RIGHT, weight_mod);
- FLOATTYPE *tmpMsgDest = msgDest;
- if (mrf->m_type == MaxProdBP::L1 || mrf->m_type == MaxProdBP::L2)
- {
- FLOATTYPE *msgProd = new FLOATTYPE[numStates];
- const FLOATTYPE lambda = (FLOATTYPE)mrf->m_lambda;
- const FLOATTYPE smoothMax = (FLOATTYPE)mrf->m_smoothMax;
- for (int leftNodeInd = 0; leftNodeInd < numStates; leftNodeInd++)
- {
- msgProd[leftNodeInd] = -nodeLeftMsg[leftNodeInd] +
- -nodeUpMsg[leftNodeInd] +
- -nodeDownMsg[leftNodeInd] + localEv[leftNodeInd];
- }
- if (mrf->m_type == MaxProdBP::L1)
- {
- l1_dist_trans_comp( weight_mod*smoothMax*lambda, lambda*weight_mod, tmpMsgDest, msgProd, numStates);
- }
- else
- {
- l2_dist_trans_comp( weight_mod*smoothMax*lambda, lambda*weight_mod, tmpMsgDest, msgProd, numStates);
- }
- delete [] msgProd;
- }
- else if ((mrf->getSmoothType() == MRF::FUNCTION) || (mrf->getSmoothType() == MRF::ARRAY))
- {
- FLOATTYPE *psiMat, var_weight;
- getPsiMat(*this, psiMat, r, c, mrf, RIGHT, var_weight);
- FLOATTYPE *cmessage = msgDest;
- for (int rightNodeInd = 0; rightNodeInd < numStates; rightNodeInd++)
- {
- *cmessage = 0;
- for (int leftNodeInd = 0; leftNodeInd < numStates; leftNodeInd++)
- {
- FLOATTYPE tmp = nodeLeftMsg[leftNodeInd] +
- nodeUpMsg[leftNodeInd] +
- nodeDownMsg[leftNodeInd]
- - localEv[leftNodeInd]
- - psiMat[leftNodeInd * numStates + rightNodeInd];
- if ((tmp > *cmessage) || (leftNodeInd == 0))
- *cmessage = tmp;
- }
- cmessage++;
- }
- }
- else {
- fprintf(stderr, "not implemented!\n");
- exit(1);
- }
- FLOATTYPE max = msgDest[0];
- for (int i = 0; i < numStates; i++)
- msgDest[i] -= max;
- }
- // This means, "Compute the message to send left."
- void OneNodeCluster::ComputeMsgLeft(FLOATTYPE *msgDest, int r, int c, MaxProdBP *mrf)
- {
- FLOATTYPE *nodeRightMsg = receivedMsgs[RIGHT],
- *nodeDownMsg = receivedMsgs[DOWN],
- *nodeUpMsg = receivedMsgs[UP];
- int do_dist = (int)(mrf->getSmoothType() == MRF::THREE_PARAM);
- FLOATTYPE *tmpMsgDest = msgDest;
- if (do_dist)
- {
- FLOATTYPE weight_mod;
- getVarWeight(*this, r, c, mrf, LEFT, weight_mod);
- FLOATTYPE *msgProd = new FLOATTYPE[numStates];
- const FLOATTYPE lambda = (FLOATTYPE)mrf->m_lambda;
- const FLOATTYPE smoothMax = (FLOATTYPE)mrf->m_smoothMax;
- for (int rightNodeInd = 0; rightNodeInd < numStates; rightNodeInd++)
- {
- msgProd[rightNodeInd] = -nodeRightMsg[rightNodeInd] +
- -nodeUpMsg[rightNodeInd] +
- -nodeDownMsg[rightNodeInd]
- + localEv[rightNodeInd] ;
- }
- if (mrf->m_smoothExp == 1)
- {
- l1_dist_trans_comp( weight_mod*smoothMax*lambda, lambda*weight_mod, tmpMsgDest, msgProd, numStates);
- }
- else
- l2_dist_trans_comp( weight_mod*smoothMax*lambda, lambda*weight_mod, tmpMsgDest, msgProd, numStates);
- delete [] msgProd;
- }
- else if ((mrf->getSmoothType() == MRF::FUNCTION) || (mrf->getSmoothType() == MRF::ARRAY))
- {
- FLOATTYPE *psiMat, var_weight;
- getPsiMat(*this, psiMat, r, c, mrf, LEFT, var_weight);
- FLOATTYPE *cmessage = msgDest;
- for (int leftNodeInd = 0; leftNodeInd < numStates; leftNodeInd++)
- {
- *cmessage = 0;
- for (int rightNodeInd = 0; rightNodeInd < numStates; rightNodeInd++)
- {
- FLOATTYPE tmp = nodeRightMsg[rightNodeInd] +
- nodeUpMsg[rightNodeInd] +
- nodeDownMsg[rightNodeInd]
- - localEv[rightNodeInd]
- - psiMat[leftNodeInd * numStates + rightNodeInd] ;
- if ((tmp > *cmessage) || (rightNodeInd == 0))
- *cmessage = tmp;
- }
- cmessage++;
- }
- }
- else
- assert(0);
- // FLOATTYPE max = vec_max(msgDest,numStates);
- FLOATTYPE max = msgDest[0];
- for (int i = 0; i < numStates; i++)
- msgDest[i] -= max;
- }
- void OneNodeCluster::ComputeMsgUp(FLOATTYPE *msgDest, int r, int c, MaxProdBP *mrf)
- {
- FLOATTYPE *nodeRightMsg = receivedMsgs[RIGHT],
- *nodeDownMsg = receivedMsgs[DOWN],
- *nodeLeftMsg = receivedMsgs[LEFT];
- int do_dist = (int)(mrf->getSmoothType() == MRF::THREE_PARAM);
- if (do_dist)
- {
- FLOATTYPE weight_mod;
- getVarWeight(*this, r, c, mrf, UP, weight_mod);
- FLOATTYPE *tmpMsgDest = msgDest;
- FLOATTYPE *msgProd = new FLOATTYPE[numStates];
- const FLOATTYPE lambda = (FLOATTYPE)mrf->m_lambda;
- const FLOATTYPE smoothMax = (FLOATTYPE)mrf->m_smoothMax;
- for (int downNodeInd = 0; downNodeInd < numStates; downNodeInd++)
- {
- msgProd[downNodeInd] = -nodeRightMsg[downNodeInd] +
- -nodeLeftMsg[downNodeInd] +
- -nodeDownMsg[downNodeInd] +
- + localEv[downNodeInd] ;
- }
- // printf("%f %f %f %f\n",nodeLeftMsg[leftNodeInd] ,
- // nodeUpMsg[leftNodeInd] ,
- // nodeDownMsg[leftNodeInd] ,localEv[leftNodeInd]);
- if (mrf->m_smoothExp == 1)
- {
- l1_dist_trans_comp( weight_mod*smoothMax*lambda, lambda*weight_mod, tmpMsgDest, msgProd, numStates);
- }
- else
- l2_dist_trans_comp( weight_mod*smoothMax*lambda, lambda*weight_mod, tmpMsgDest, msgProd, numStates);
- delete [] msgProd;
- }
- else if ((mrf->getSmoothType() == MRF::FUNCTION) || (mrf->getSmoothType() == MRF::ARRAY))
- {
- FLOATTYPE *psiMat, var_weight;
- getPsiMat(*this, psiMat, r, c, mrf, UP, var_weight);
- FLOATTYPE *cmessage = msgDest;
- for (int upNodeInd = 0; upNodeInd < numStates; upNodeInd++)
- {
- *cmessage = 0;
- for (int downNodeInd = 0; downNodeInd < numStates; downNodeInd++)
- {
- FLOATTYPE tmp = nodeRightMsg[downNodeInd] +
- nodeLeftMsg[downNodeInd] +
- nodeDownMsg[downNodeInd] +
- -localEv[downNodeInd]
- - psiMat[upNodeInd * numStates + downNodeInd] ;
- if ((tmp > *cmessage) || (downNodeInd == 0))
- *cmessage = tmp;
- }
- cmessage++;
- }
- }
- else
- assert(0);
- FLOATTYPE max = msgDest[0];
- // FLOATTYPE max = vec_max(msgDest,numStates);
- for (int i = 0; i < numStates; i++)
- msgDest[i] -= max;
- }
- void OneNodeCluster::ComputeMsgDown(FLOATTYPE *msgDest, int r, int c, MaxProdBP *mrf)
- {
- FLOATTYPE *nodeRightMsg = receivedMsgs[RIGHT],
- *nodeUpMsg = receivedMsgs[UP],
- *nodeLeftMsg = receivedMsgs[LEFT];
- int do_dist = (int)(mrf->getSmoothType() == MRF::THREE_PARAM);
- if (do_dist)
- {
- FLOATTYPE weight_mod;
- getVarWeight(*this, r, c, mrf, DOWN, weight_mod);
- FLOATTYPE *tmpMsgDest = msgDest;
- FLOATTYPE *msgProd = new FLOATTYPE[numStates];
- const FLOATTYPE lambda = (FLOATTYPE)mrf->m_lambda;
- const FLOATTYPE smoothMax = (FLOATTYPE)mrf->m_smoothMax;
- for (int upNodeInd = 0; upNodeInd < numStates; upNodeInd++)
- {
- msgProd[upNodeInd] = -nodeRightMsg[upNodeInd] +
- -nodeLeftMsg[upNodeInd] +
- -nodeUpMsg[upNodeInd] +
- + localEv[upNodeInd] ;
- }
- if (mrf->m_smoothExp == 1)
- {
- l1_dist_trans_comp( weight_mod*smoothMax*lambda, lambda*weight_mod, tmpMsgDest, msgProd, numStates);
- }
- else
- l2_dist_trans_comp( weight_mod*smoothMax*lambda, lambda*weight_mod, tmpMsgDest, msgProd, numStates);
- delete [] msgProd;
- }
- else if ((mrf->getSmoothType() == MRF::FUNCTION) || (mrf->getSmoothType() == MRF::ARRAY))
- {
- FLOATTYPE *psiMat, var_weight;
- getPsiMat(*this, psiMat, r, c, mrf, DOWN, var_weight);
- FLOATTYPE *cmessage = msgDest;
- for (int downNodeInd = 0; downNodeInd < numStates; downNodeInd++)
- {
- *cmessage = 0;
- for (int upNodeInd = 0; upNodeInd < numStates; upNodeInd++)
- {
- FLOATTYPE tmp = nodeRightMsg[upNodeInd] +
- nodeLeftMsg[upNodeInd] +
- nodeUpMsg[upNodeInd] +
- -localEv[upNodeInd]
- - psiMat[upNodeInd * numStates + downNodeInd] ;
- if ((tmp > *cmessage) || (upNodeInd == 0))
- *cmessage = tmp;
- }
- cmessage++;
- }
- }
- else
- assert(0);
- FLOATTYPE max = msgDest[0];
- // FLOATTYPE max = vec_max(msgDest,numStates);
- for (int i = 0; i < numStates; i++)
- msgDest[i] -= max;
- }
- void OneNodeCluster::getBelief(FLOATTYPE *beliefVec)
- {
- for (int i = 0; i < numStates; i++)
- {
- beliefVec[i] = receivedMsgs[UP][i] + receivedMsgs[DOWN][i] +
- receivedMsgs[LEFT][i] + receivedMsgs[RIGHT][i] - localEv[i];
- }
- }
- int OneNodeCluster::getBeliefMaxInd()
- {
- FLOATTYPE currBelief, bestBelief;
- int bestInd = 0;
- {
- int i = 0;
- bestBelief = receivedMsgs[UP][i] + receivedMsgs[DOWN][i] +
- receivedMsgs[LEFT][i] + receivedMsgs[RIGHT][i] - localEv[i];
- }
- for (int i = 1; i < numStates; i++)
- {
- currBelief = receivedMsgs[UP][i] + receivedMsgs[DOWN][i] +
- receivedMsgs[LEFT][i] + receivedMsgs[RIGHT][i] - localEv[i];
- if (currBelief > bestBelief)
- {
- bestInd = i;
- bestBelief = currBelief;
- }
- }
- return bestInd;
- }
- namespace OBJREC {
- void computeMessagesLeftRight(OneNodeCluster *nodeArray, const int numCols, const int /*numRows*/, const int currRow, const FLOATTYPE alpha, MaxProdBP *mrf)
- {
- const int numStates = OneNodeCluster::numStates;
- const FLOATTYPE omalpha = 1.0f - alpha;
- int i;
- int col;
- for ( col = 0; col < numCols - 1; col++)
- {
- nodeArray[currRow * numCols + col].ComputeMsgRight(nodeArray[currRow * numCols + col+1].nextRoundReceivedMsgs[LEFT], currRow, col, mrf);
- for (i = 0; i < numStates; i++)
- {
- nodeArray[currRow * numCols + col+1].receivedMsgs[LEFT][i] =
- omalpha * nodeArray[currRow * numCols + col+1].receivedMsgs[LEFT][i] +
- alpha * nodeArray[currRow * numCols + col+1].nextRoundReceivedMsgs[LEFT][i];
- }
- }
- for ( col = numCols - 1; col > 0; col--)
- {
- nodeArray[currRow * numCols + col].ComputeMsgLeft(nodeArray[currRow * numCols + col-1].nextRoundReceivedMsgs[RIGHT], currRow, col, mrf);
- for (i = 0; i < numStates; i++)
- {
- nodeArray[currRow * numCols + col-1].receivedMsgs[RIGHT][i] =
- omalpha * nodeArray[currRow * numCols + col-1].receivedMsgs[RIGHT][i] +
- alpha * nodeArray[currRow * numCols + col-1].nextRoundReceivedMsgs[RIGHT][i];
- }
- }
- }
- void computeMessagesUpDown(OneNodeCluster *nodeArray, const int numCols, const int numRows, const int currCol, const FLOATTYPE alpha, MaxProdBP *mrf)
- {
- const int numStates = OneNodeCluster::numStates;
- const FLOATTYPE omalpha = 1.0f - alpha;
- int i;
- int row;
- for (row = 0; row < numRows - 1; row++)
- {
- nodeArray[row * numCols + currCol].ComputeMsgDown(nodeArray[(row+1) * numCols + currCol].nextRoundReceivedMsgs[UP], row, currCol, mrf);
- for (i = 0; i < numStates; i++)
- {
- nodeArray[(row+1) * numCols + currCol].receivedMsgs[UP][i] =
- omalpha * nodeArray[(row+1) * numCols + currCol].receivedMsgs[UP][i] +
- alpha * nodeArray[(row+1) * numCols + currCol].nextRoundReceivedMsgs[UP][i];
- }
- }
- for ( row = numRows - 1; row > 0; row--)
- {
- nodeArray[row * numCols + currCol].ComputeMsgUp(nodeArray[(row-1) * numCols + currCol].nextRoundReceivedMsgs[DOWN], row, currCol, mrf);
- for (i = 0; i < numStates; i++)
- {
- nodeArray[(row-1) * numCols + currCol].receivedMsgs[DOWN][i] =
- omalpha * nodeArray[(row-1) * numCols + currCol].receivedMsgs[DOWN][i] +
- alpha * nodeArray[(row-1) * numCols + currCol].nextRoundReceivedMsgs[DOWN][i];
- }
- }
- }
- }
|