#include #include #include #include #include #include "TRW-S.h" #define private public #include "typeTruncatedQuadratic2D.h" using namespace OBJREC; #undef private #define m_D(pix,l) m_D[(pix)*m_nLabels+(l)] #define m_V(l1,l2) m_V[(l1)*m_nLabels+(l2)] #define MIN(a,b) (((a) < (b)) ? (a) : (b)) #define MAX(a,b) (((a) > (b)) ? (a) : (b)) #define TRUNCATE_MIN(a,b) { if ((a) > (b)) (a) = (b); } #define TRUNCATE_MAX(a,b) { if ((a) < (b)) (a) = (b); } #define TRUNCATE TRUNCATE_MIN ///////////////////////////////////////////////////////////////////////////// // Operations on vectors (arrays of size K) // ///////////////////////////////////////////////////////////////////////////// inline void CopyVector(TRWS::REAL* to, MRF::CostVal* from, int K) { TRWS::REAL* to_finish = to + K; do { *to ++ = *from ++; } while (to < to_finish); } inline void AddVector(TRWS::REAL* to, TRWS::REAL* from, int K) { TRWS::REAL* to_finish = to + K; do { *to ++ += *from ++; } while (to < to_finish); } inline TRWS::REAL SubtractMin(TRWS::REAL *D, int K) { int k; TRWS::REAL delta; delta = D[0]; for (k = 1; k < K; k++) TRUNCATE(delta, D[k]); for (k = 0; k < K; k++) D[k] -= delta; return delta; } // Functions UpdateMessageTYPE (see the paper for details): // // - Set Di[ki] := gamma*Di_hat[ki] - M[ki] // - Set M[kj] := min_{ki} (Di[ki] + V[ki,kj]) // - Normalize message: // delta := min_{kj} M[kj] // M[kj] := M[kj] - delta // return delta // // If dir = 1, then the meaning of i and j is swapped. /////////////////////////////////////////// // L1 // /////////////////////////////////////////// inline TRWS::REAL UpdateMessageL1(TRWS::REAL* M, TRWS::REAL* Di_hat, int K, TRWS::REAL gamma, MRF::CostVal lambda, MRF::CostVal smoothMax) { int k; TRWS::REAL delta; delta = M[0] = gamma * Di_hat[0] - M[0]; for (k = 1; k < K; k++) { M[k] = gamma * Di_hat[k] - M[k]; TRUNCATE(delta, M[k]); TRUNCATE(M[k], M[k-1] + lambda); } M[--k] -= delta; TRUNCATE(M[k], lambda*smoothMax); for (k--; k >= 0; k--) { M[k] -= delta; TRUNCATE(M[k], M[k+1] + lambda); TRUNCATE(M[k], lambda*smoothMax); } return delta; } //////////////////////////////////////// // L2 // //////////////////////////////////////// inline TRWS::REAL UpdateMessageL2(TRWS::REAL* M, TRWS::REAL* Di_hat, int K, TRWS::REAL gamma, MRF::CostVal lambda, MRF::CostVal smoothMax, void *buf) { TRWS::REAL* Di = (TRWS::REAL*) buf; int* parabolas = (int*) ((char*)buf + K * sizeof(TRWS::REAL)); int* intersections = parabolas + K; TypeTruncatedQuadratic2D::Edge* tmp = NULL; int k; TRWS::REAL delta; assert(lambda >= 0); Di[0] = gamma * Di_hat[0] - M[0]; delta = Di[0]; for (k = 1; k < K; k++) { Di[k] = gamma * Di_hat[k] - M[k]; TRUNCATE(delta, Di[k]); } if (lambda == 0) { for (k = 0; k < K; k++) M[k] = 0; return delta; } tmp->DistanceTransformL2(K, 1, lambda, Di, M, parabolas, intersections); for (k = 0; k < K; k++) { M[k] -= delta; TRUNCATE(M[k], lambda*smoothMax); } return delta; } ////////////////////////////////////////////////// // FIXED_MATRIX // ////////////////////////////////////////////////// inline TRWS::REAL UpdateMessageFIXED_MATRIX(TRWS::REAL* M, TRWS::REAL* Di_hat, int K, TRWS::REAL gamma, MRF::CostVal lambda, MRF::CostVal* V, void* buf) { TRWS::REAL* Di = (TRWS::REAL*) buf; int ki, kj; TRWS::REAL delta; if (lambda == 0) { delta = gamma * Di_hat[0] - M[0]; M[0] = 0; for (ki = 1; ki < K; ki++) { TRUNCATE(delta, gamma*Di_hat[ki] - M[ki]); M[ki] = 0; } return delta; } for (ki = 0; ki < K; ki++) { Di[ki] = (gamma * Di_hat[ki] - M[ki]) * (1 / (TRWS::REAL)lambda); } if (lambda > 0) { for (kj = 0; kj < K; kj++) { M[kj] = Di[0] + V[0]; V ++; for (ki = 1; ki < K; ki++) { TRUNCATE(M[kj], Di[ki] + V[0]); V ++; } M[kj] *= lambda; } } else { for (kj = 0; kj < K; kj++) { M[kj] = Di[0] + V[0]; V ++; for (ki = 1; ki < K; ki++) { TRUNCATE_MAX(M[kj], Di[ki] + V[0]); V ++; } M[kj] *= lambda; } } delta = M[0]; for (kj = 1; kj < K; kj++) TRUNCATE(delta, M[kj]); for (kj = 0; kj < K; kj++) M[kj] -= delta; return delta; } ///////////////////////////////////////////// // GENERAL // ///////////////////////////////////////////// inline TRWS::REAL UpdateMessageGENERAL(TRWS::REAL* M, TRWS::REAL* Di_hat, int K, TRWS::REAL gamma, int dir, MRF::CostVal* V, void* buf) { TRWS::REAL* Di = (TRWS::REAL*) buf; int ki, kj; TRWS::REAL delta; for (ki = 0; ki < K; ki++) { Di[ki] = (gamma * Di_hat[ki] - M[ki]); } if (dir == 0) { for (kj = 0; kj < K; kj++) { M[kj] = Di[0] + V[0]; V ++; for (ki = 1; ki < K; ki++) { TRUNCATE(M[kj], Di[ki] + V[0]); V ++; } } } else { for (kj = 0; kj < K; kj++) { M[kj] = Di[0] + V[0]; V += K; for (ki = 1; ki < K; ki++) { TRUNCATE(M[kj], Di[ki] + V[0]); V += K; } V -= K * K - 1; } } delta = M[0]; for (kj = 1; kj < K; kj++) TRUNCATE(delta, M[kj]); for (kj = 0; kj < K; kj++) M[kj] -= delta; return delta; } inline TRWS::REAL UpdateMessageGENERAL(TRWS::REAL* M, TRWS::REAL* Di_hat, int K, TRWS::REAL gamma, TRWS::SmoothCostGeneralFn fn, int i, int j, void* buf) { TRWS::REAL* Di = (TRWS::REAL*) buf; int ki, kj; TRWS::REAL delta; for (ki = 0; ki < K; ki++) { Di[ki] = (gamma * Di_hat[ki] - M[ki]); } for (kj = 0; kj < K; kj++) { M[kj] = Di[0] + fn(i, j, 0, kj); for (ki = 1; ki < K; ki++) { delta = Di[ki] + fn(i, j, ki, kj); TRUNCATE(M[kj], delta); } } delta = M[0]; for (kj = 1; kj < K; kj++) TRUNCATE(delta, M[kj]); for (kj = 0; kj < K; kj++) M[kj] -= delta; return delta; } TRWS::TRWS(int width, int height, int nLabels, EnergyFunction *eng): MRF(width, height, nLabels, eng) { Allocate(); } TRWS::TRWS(int nPixels, int nLabels, EnergyFunction *eng): MRF(nPixels, nLabels, eng) { Allocate(); } TRWS::~TRWS() { delete[] m_answer; if ( m_needToFreeD ) delete [] m_D; if ( m_needToFreeV ) delete [] m_V; if ( m_messages ) delete [] m_messages; if ( m_DBinary ) delete [] m_DBinary; if ( m_horzWeightsBinary ) delete [] m_horzWeightsBinary; if ( m_vertWeightsBinary ) delete [] m_vertWeightsBinary; } void TRWS::Allocate() { m_type = NONE; m_needToFreeV = false; m_needToFreeD = false; m_D = NULL; m_V = NULL; m_horzWeights = NULL; m_vertWeights = NULL; m_horzWeightsBinary = NULL; m_vertWeightsBinary = NULL; m_DBinary = NULL; m_messages = NULL; m_messageArraySizeInBytes = 0; m_answer = new Label[m_nPixels]; } void TRWS::clearAnswer() { memset(m_answer, 0, m_nPixels*sizeof(Label)); if (m_messages) { memset(m_messages, 0, m_messageArraySizeInBytes); } } MRF::EnergyVal TRWS::smoothnessEnergy() { EnergyVal eng = (EnergyVal) 0; EnergyVal weight; int x, y, pix; if ( m_grid_graph ) { if ( m_smoothType != FUNCTION ) { for ( y = 0; y < m_height; y++ ) for ( x = 1; x < m_width; x++ ) { pix = x + y * m_width; weight = m_varWeights ? m_horzWeights[pix-1] : 1; eng = eng + m_V(m_answer[pix], m_answer[pix-1]) * weight; } for ( y = 1; y < m_height; y++ ) for ( x = 0; x < m_width; x++ ) { pix = x + y * m_width; weight = m_varWeights ? m_vertWeights[pix-m_width] : 1; eng = eng + m_V(m_answer[pix], m_answer[pix-m_width]) * weight; } } else { for ( y = 0; y < m_height; y++ ) for ( x = 1; x < m_width; x++ ) { pix = x + y * m_width; eng = eng + m_smoothFn(pix, pix - 1, m_answer[pix], m_answer[pix-1]); } for ( y = 1; y < m_height; y++ ) for ( x = 0; x < m_width; x++ ) { pix = x + y * m_width; eng = eng + m_smoothFn(pix, pix - m_width, m_answer[pix], m_answer[pix-m_width]); } } } else { // not implemented } return(eng); } MRF::EnergyVal TRWS::dataEnergy() { EnergyVal eng = (EnergyVal) 0; if ( m_dataType == ARRAY) { for ( int i = 0; i < m_nPixels; i++ ) eng = eng + m_D(i, m_answer[i]); } else { for ( int i = 0; i < m_nPixels; i++ ) eng = eng + m_dataFn(i, m_answer[i]); } return(eng); } void TRWS::setData(DataCostFn dcost) { int i, k; m_dataFn = dcost; CostVal* ptr; m_D = new CostVal[m_nPixels*m_nLabels]; for (ptr = m_D, i = 0; i < m_nPixels; i++) for (k = 0; k < m_nLabels; k++, ptr++) { *ptr = m_dataFn(i, k); } m_needToFreeD = true; } void TRWS::setData(CostVal* data) { m_D = data; m_needToFreeD = false; } void TRWS::setSmoothness(SmoothCostGeneralFn cost) { assert(m_horzWeights == NULL && m_vertWeights == NULL && m_V == NULL); int x, y, i, ki, kj; CostVal* ptr; m_smoothFn = cost; m_type = GENERAL; if (!m_allocateArrayForSmoothnessCostFn) return; // try to cache all the function values in an array for efficiency m_V = new(std::nothrow) CostVal[2*m_nPixels*m_nLabels*m_nLabels]; if (!m_V) return; // if not enough space, just call the function directly m_needToFreeV = true; for (ptr = m_V, i = 0, y = 0; y < m_height; y++) for (x = 0; x < m_width; x++, i++) { if (x < m_width - 1) { for (kj = 0; kj < m_nLabels; kj++) for (ki = 0; ki < m_nLabels; ki++) { *ptr++ = cost(i, i + 1, ki, kj); } } else ptr += m_nLabels * m_nLabels; if (y < m_height - 1) { for (kj = 0; kj < m_nLabels; kj++) for (ki = 0; ki < m_nLabels; ki++) { *ptr++ = cost(i, i + m_width, ki, kj); } } else ptr += m_nLabels * m_nLabels; } } void TRWS::setSmoothness(CostVal* V) { m_type = FIXED_MATRIX; m_V = V; } void TRWS::setSmoothness(int smoothExp, CostVal smoothMax, CostVal lambda) { assert(smoothExp == 1 || smoothExp == 2); assert(lambda >= 0); m_type = (smoothExp == 1) ? L1 : L2; int ki, kj; CostVal cost; m_needToFreeV = true; m_V = new CostVal[m_nLabels*m_nLabels]; for (ki = 0; ki < m_nLabels; ki++) for (kj = ki; kj < m_nLabels; kj++) { cost = (CostVal) ((smoothExp == 1) ? kj - ki : (kj - ki) * (kj - ki)); if (cost > smoothMax) cost = smoothMax; m_V[ki*m_nLabels + kj] = m_V[kj*m_nLabels + ki] = cost * lambda; } m_smoothMax = smoothMax; m_lambda = lambda; } void TRWS::setCues(CostVal* hCue, CostVal* vCue) { m_horzWeights = hCue; m_vertWeights = vCue; } void TRWS::initializeAlg() { assert(m_type != NONE); int i; // determine type if (m_type == L1 && m_nLabels == 2) { m_type = BINARY; } // allocate messages int messageNum = (m_type == BINARY) ? 4 * m_nPixels : 4 * m_nPixels * m_nLabels; m_messageArraySizeInBytes = messageNum * sizeof(REAL); m_messages = new REAL[messageNum]; memset(m_messages, 0, messageNum*sizeof(REAL)); if (m_type == BINARY) { assert(m_DBinary == NULL && m_horzWeightsBinary == NULL && m_horzWeightsBinary == NULL); m_DBinary = new CostVal[m_nPixels]; m_horzWeightsBinary = new CostVal[m_nPixels]; m_vertWeightsBinary = new CostVal[m_nPixels]; if ( m_dataType == ARRAY) { for (i = 0; i < m_nPixels; i++) { m_DBinary[i] = m_D[2*i+1] - m_D[2*i]; } } else { for (i = 0; i < m_nPixels; i++) { m_DBinary[i] = m_dataFn(i, 1) - m_dataFn(i, 0); } } assert(m_V[0] == 0 && m_V[1] == m_V[2] && m_V[3] == 0); for (i = 0; i < m_nPixels; i++) { m_horzWeightsBinary[i] = (m_varWeights) ? m_V[1] * m_horzWeights[i] : m_V[1]; m_vertWeightsBinary[i] = (m_varWeights) ? m_V[1] * m_vertWeights[i] : m_V[1]; } } } void TRWS::optimizeAlg(int nIterations) { assert(m_type != NONE); if (m_grid_graph) { switch (m_type) { case L1: optimize_GRID_L1(nIterations); break; case L2: optimize_GRID_L2(nIterations); break; case FIXED_MATRIX: optimize_GRID_FIXED_MATRIX(nIterations); break; case GENERAL: optimize_GRID_GENERAL(nIterations); break; case BINARY: optimize_GRID_BINARY(nIterations); break; default: assert(0); exit(1); } } else { printf("\nNot implemented for general graphs yet, exiting!"); exit(1); } // printf("lower bound = %f\n", m_lowerBound); //////////////////////////////////////////////// // computing solution // //////////////////////////////////////////////// if (m_type != BINARY) { int x, y, n, K = m_nLabels; CostVal* D_ptr; REAL* M_ptr; REAL* Di; REAL delta; int ki, kj; Di = new REAL[K]; n = 0; D_ptr = m_D; M_ptr = m_messages; for (y = 0; y < m_height; y++) for (x = 0; x < m_width; x++, D_ptr += K, M_ptr += 2 * K, n++) { CopyVector(Di, D_ptr, K); if (m_type == GENERAL) { if (m_V) { CostVal* ptr = m_V + 2 * (x + y * m_width - 1) * K * K; if (x > 0) { kj = m_answer[n-1]; for (ki = 0; ki < K; ki++) { Di[ki] += ptr[kj + ki*K]; } } ptr -= (2 * m_width - 3) * K * K; if (y > 0) { kj = m_answer[n-m_width]; for (ki = 0; ki < K; ki++) { Di[ki] += ptr[kj + ki*K]; } } } else { if (x > 0) { kj = m_answer[n-1]; for (ki = 0; ki < K; ki++) { Di[ki] += m_smoothFn(n, n - 1, ki, kj); } } if (y > 0) { kj = m_answer[n-m_width]; for (ki = 0; ki < K; ki++) { Di[ki] += m_smoothFn(n, n - m_width, ki, kj); } } } } else // m_type == L1, L2 or FIXED_MATRIX { if (x > 0) { kj = m_answer[n-1]; CostVal lambda = (m_varWeights) ? m_horzWeights[n-1] : 1; for (ki = 0; ki < K; ki++) { Di[ki] += lambda * m_V[kj*K + ki]; } } if (y > 0) { kj = m_answer[n-m_width]; CostVal lambda = (m_varWeights) ? m_vertWeights[n-m_width] : 1; for (ki = 0; ki < K; ki++) { Di[ki] += lambda * m_V[kj*K + ki]; } } } if (x < m_width - 1) AddVector(Di, M_ptr, K); // message (x+1,y)->(x,y) if (y < m_height - 1) AddVector(Di, M_ptr + K, K); // message (x,y+1)->(x,y) // compute min delta = Di[0]; m_answer[n] = 0; for (ki = 1; ki < K; ki++) { if (delta > Di[ki]) { delta = Di[ki]; m_answer[n] = ki; } } } delete [] Di; } else // m_type == BINARY { int x, y, n; REAL* M_ptr; REAL Di; n = 0; M_ptr = m_messages; for (y = 0; y < m_height; y++) for (x = 0; x < m_width; x++, M_ptr += 2, n++) { Di = m_DBinary[n]; if (x > 0) Di += (m_answer[n-1] == 0) ? m_horzWeightsBinary[n-1] : -m_horzWeightsBinary[n-1]; if (y > 0) Di += (m_answer[n-m_width] == 0) ? m_vertWeightsBinary[n-m_width] : -m_vertWeightsBinary[n-m_width]; if (x < m_width - 1) Di += M_ptr[0]; // message (x+1,y)->(x,y) if (y < m_height - 1) Di += M_ptr[1]; // message (x,y+1)->(x,y) // compute min m_answer[n] = (Di >= 0) ? 0 : 1; } } } void TRWS::optimize_GRID_L1(int nIterations) { int x, y, n, K = m_nLabels; CostVal* D_ptr; REAL* M_ptr; REAL* Di; Di = new REAL[K]; for ( ; nIterations > 0; nIterations --) { // forward pass n = 0; D_ptr = m_D; M_ptr = m_messages; for (y = 0; y < m_height; y++) for (x = 0; x < m_width; x++, D_ptr += K, M_ptr += 2 * K, n++) { CopyVector(Di, D_ptr, K); if (x > 0) AddVector(Di, M_ptr - 2*K, K); // message (x-1,y)->(x,y) if (y > 0) AddVector(Di, M_ptr - (2*m_width - 1)*K, K); // message (x,y-1)->(x,y) if (x < m_width - 1) AddVector(Di, M_ptr, K); // message (x+1,y)->(x,y) if (y < m_height - 1) AddVector(Di, M_ptr + K, K); // message (x,y+1)->(x,y) if (x < m_width - 1) { CostVal lambda = (m_varWeights) ? m_lambda * m_horzWeights[n] : m_lambda; UpdateMessageL1(M_ptr, Di, K, 0.5, lambda, m_smoothMax); } if (y < m_height - 1) { CostVal lambda = (m_varWeights) ? m_lambda * m_vertWeights[n] : m_lambda; UpdateMessageL1(M_ptr + K, Di, K, 0.5, lambda, m_smoothMax); } } // backward pass m_lowerBound = 0; n --; D_ptr -= K; M_ptr -= 2 * K; for (y = m_height - 1; y >= 0; y--) for (x = m_width - 1; x >= 0; x--, D_ptr -= K, M_ptr -= 2 * K, n--) { CopyVector(Di, D_ptr, K); if (x > 0) AddVector(Di, M_ptr - 2*K, K); // message (x-1,y)->(x,y) if (y > 0) AddVector(Di, M_ptr - (2*m_width - 1)*K, K); // message (x,y-1)->(x,y) if (x < m_width - 1) AddVector(Di, M_ptr, K); // message (x+1,y)->(x,y) if (y < m_height - 1) AddVector(Di, M_ptr + K, K); // message (x,y+1)->(x,y) m_lowerBound += SubtractMin(Di, K); if (x > 0) { CostVal lambda = (m_varWeights) ? m_lambda * m_horzWeights[n-1] : m_lambda; m_lowerBound += UpdateMessageL1(M_ptr - 2 * K, Di, K, 0.5, lambda, m_smoothMax); } if (y > 0) { CostVal lambda = (m_varWeights) ? m_lambda * m_vertWeights[n-m_width] : m_lambda; m_lowerBound += UpdateMessageL1(M_ptr - (2 * m_width - 1) * K, Di, K, 0.5, lambda, m_smoothMax); } } } delete [] Di; } void TRWS::optimize_GRID_L2(int nIterations) { int x, y, n, K = m_nLabels; CostVal* D_ptr; REAL* M_ptr; REAL* Di; void* buf; Di = new REAL[K]; buf = new char[(2*K+1)*sizeof(int) + K*sizeof(REAL)]; for ( ; nIterations > 0; nIterations --) { // forward pass n = 0; D_ptr = m_D; M_ptr = m_messages; for (y = 0; y < m_height; y++) for (x = 0; x < m_width; x++, D_ptr += K, M_ptr += 2 * K, n++) { CopyVector(Di, D_ptr, K); if (x > 0) AddVector(Di, M_ptr - 2*K, K); // message (x-1,y)->(x,y) if (y > 0) AddVector(Di, M_ptr - (2*m_width - 1)*K, K); // message (x,y-1)->(x,y) if (x < m_width - 1) AddVector(Di, M_ptr, K); // message (x+1,y)->(x,y) if (y < m_height - 1) AddVector(Di, M_ptr + K, K); // message (x,y+1)->(x,y) if (x < m_width - 1) { CostVal lambda = (m_varWeights) ? m_lambda * m_horzWeights[n] : m_lambda; UpdateMessageL2(M_ptr, Di, K, 0.5, lambda, m_smoothMax, buf); } if (y < m_height - 1) { CostVal lambda = (m_varWeights) ? m_lambda * m_vertWeights[n] : m_lambda; UpdateMessageL2(M_ptr + K, Di, K, 0.5, lambda, m_smoothMax, buf); } } // backward pass m_lowerBound = 0; n --; D_ptr -= K; M_ptr -= 2 * K; for (y = m_height - 1; y >= 0; y--) for (x = m_width - 1; x >= 0; x--, D_ptr -= K, M_ptr -= 2 * K, n--) { CopyVector(Di, D_ptr, K); if (x > 0) AddVector(Di, M_ptr - 2*K, K); // message (x-1,y)->(x,y) if (y > 0) AddVector(Di, M_ptr - (2*m_width - 1)*K, K); // message (x,y-1)->(x,y) if (x < m_width - 1) AddVector(Di, M_ptr, K); // message (x+1,y)->(x,y) if (y < m_height - 1) AddVector(Di, M_ptr + K, K); // message (x,y+1)->(x,y) m_lowerBound += SubtractMin(Di, K); if (x > 0) { CostVal lambda = (m_varWeights) ? m_lambda * m_horzWeights[n-1] : m_lambda; m_lowerBound += UpdateMessageL2(M_ptr - 2 * K, Di, K, 0.5, lambda, m_smoothMax, buf); } if (y > 0) { CostVal lambda = (m_varWeights) ? m_lambda * m_vertWeights[n-m_width] : m_lambda; m_lowerBound += UpdateMessageL2(M_ptr - (2 * m_width - 1) * K, Di, K, 0.5, lambda, m_smoothMax, buf); } } } delete [] Di; delete [] (char *)buf; } void TRWS::optimize_GRID_BINARY(int nIterations) { int x, y, n; REAL* M_ptr; REAL Di; for ( ; nIterations > 0; nIterations --) { // forward pass n = 0; M_ptr = m_messages; for (y = 0; y < m_height; y++) for (x = 0; x < m_width; x++, M_ptr += 2, n++) { Di = m_DBinary[n]; if (x > 0) Di += M_ptr[-2]; // message (x-1,y)->(x,y) if (y > 0) Di += M_ptr[-2*m_width+1]; // message (x,y-1)->(x,y) if (x < m_width - 1) Di += M_ptr[0]; // message (x+1,y)->(x,y) if (y < m_height - 1) Di += M_ptr[1]; // message (x,y+1)->(x,y) REAL DiScaled = Di * 0.5; if (x < m_width - 1) { Di = DiScaled - M_ptr[0]; CostVal lambda = m_horzWeightsBinary[n]; if (lambda < 0) { Di = -Di; lambda = -lambda; } if (Di > lambda) M_ptr[0] = lambda; else M_ptr[0] = (Di < -lambda) ? -lambda : Di; } if (y < m_height - 1) { Di = DiScaled - M_ptr[1]; CostVal lambda = m_vertWeightsBinary[n]; if (lambda < 0) { Di = -Di; lambda = -lambda; } if (Di > lambda) M_ptr[1] = lambda; else M_ptr[1] = (Di < -lambda) ? -lambda : Di; } } // backward pass n --; M_ptr -= 2; for (y = m_height - 1; y >= 0; y--) for (x = m_width - 1; x >= 0; x--, M_ptr -= 2, n--) { Di = m_DBinary[n]; if (x > 0) Di += M_ptr[-2]; // message (x-1,y)->(x,y) if (y > 0) Di += M_ptr[-2*m_width+1]; // message (x,y-1)->(x,y) if (x < m_width - 1) Di += M_ptr[0]; // message (x+1,y)->(x,y) if (y < m_height - 1) Di += M_ptr[1]; // message (x,y+1)->(x,y) REAL DiScaled = Di * 0.5; if (x > 0) { Di = DiScaled - M_ptr[-2]; CostVal lambda = m_horzWeightsBinary[n-1]; if (lambda < 0) { Di = -Di; lambda = -lambda; } if (Di > lambda) M_ptr[-2] = lambda; else M_ptr[-2] = (Di < -lambda) ? -lambda : Di; } if (y > 0) { Di = DiScaled - M_ptr[-2*m_width+1]; CostVal lambda = m_vertWeightsBinary[n-m_width]; if (lambda < 0) { Di = -Di; lambda = -lambda; } if (Di > lambda) M_ptr[-2*m_width+1] = lambda; else M_ptr[-2*m_width+1] = (Di < -lambda) ? -lambda : Di; } } } m_lowerBound = 0; } void TRWS::optimize_GRID_FIXED_MATRIX(int nIterations) { int x, y, n, K = m_nLabels; CostVal* D_ptr; REAL* M_ptr; REAL* Di; void* buf; Di = new REAL[K]; buf = new REAL[K]; for ( ; nIterations > 0; nIterations --) { // forward pass n = 0; D_ptr = m_D; M_ptr = m_messages; for (y = 0; y < m_height; y++) for (x = 0; x < m_width; x++, D_ptr += K, M_ptr += 2 * K, n++) { CopyVector(Di, D_ptr, K); if (x > 0) AddVector(Di, M_ptr - 2*K, K); // message (x-1,y)->(x,y) if (y > 0) AddVector(Di, M_ptr - (2*m_width - 1)*K, K); // message (x,y-1)->(x,y) if (x < m_width - 1) AddVector(Di, M_ptr, K); // message (x+1,y)->(x,y) if (y < m_height - 1) AddVector(Di, M_ptr + K, K); // message (x,y+1)->(x,y) if (x < m_width - 1) { CostVal lambda = (m_varWeights) ? m_horzWeights[n] : 1; UpdateMessageFIXED_MATRIX(M_ptr, Di, K, 0.5, lambda, m_V, buf); } if (y < m_height - 1) { CostVal lambda = (m_varWeights) ? m_vertWeights[n] : 1; UpdateMessageFIXED_MATRIX(M_ptr + K, Di, K, 0.5, lambda, m_V, buf); } } // backward pass m_lowerBound = 0; n --; D_ptr -= K; M_ptr -= 2 * K; for (y = m_height - 1; y >= 0; y--) for (x = m_width - 1; x >= 0; x--, D_ptr -= K, M_ptr -= 2 * K, n--) { CopyVector(Di, D_ptr, K); if (x > 0) AddVector(Di, M_ptr - 2*K, K); // message (x-1,y)->(x,y) if (y > 0) AddVector(Di, M_ptr - (2*m_width - 1)*K, K); // message (x,y-1)->(x,y) if (x < m_width - 1) AddVector(Di, M_ptr, K); // message (x+1,y)->(x,y) if (y < m_height - 1) AddVector(Di, M_ptr + K, K); // message (x,y+1)->(x,y) m_lowerBound += SubtractMin(Di, K); if (x > 0) { CostVal lambda = (m_varWeights) ? m_horzWeights[n-1] : 1; m_lowerBound += UpdateMessageFIXED_MATRIX(M_ptr - 2 * K, Di, K, 0.5, lambda, m_V, buf); } if (y > 0) { CostVal lambda = (m_varWeights) ? m_vertWeights[n-m_width] : 1; m_lowerBound += UpdateMessageFIXED_MATRIX(M_ptr - (2 * m_width - 1) * K, Di, K, 0.5, lambda, m_V, buf); } } } delete [] Di; delete [] (REAL *)buf; } void TRWS::optimize_GRID_GENERAL(int nIterations) { int x, y, n, K = m_nLabels; CostVal* D_ptr; REAL* M_ptr; REAL* Di; void* buf; Di = new REAL[K]; buf = new REAL[K]; for ( ; nIterations > 0; nIterations --) { // forward pass n = 0; D_ptr = m_D; M_ptr = m_messages; CostVal* V_ptr = m_V; for (y = 0; y < m_height; y++) for (x = 0; x < m_width; x++, D_ptr += K, M_ptr += 2 * K, V_ptr += 2 * K * K, n++) { CopyVector(Di, D_ptr, K); if (x > 0) AddVector(Di, M_ptr - 2*K, K); // message (x-1,y)->(x,y) if (y > 0) AddVector(Di, M_ptr - (2*m_width - 1)*K, K); // message (x,y-1)->(x,y) if (x < m_width - 1) AddVector(Di, M_ptr, K); // message (x+1,y)->(x,y) if (y < m_height - 1) AddVector(Di, M_ptr + K, K); // message (x,y+1)->(x,y) if (x < m_width - 1) { if (m_V) UpdateMessageGENERAL(M_ptr, Di, K, 0.5, /* forward dir*/ 0, V_ptr, buf); else UpdateMessageGENERAL(M_ptr, Di, K, 0.5, m_smoothFn, n, n + 1, buf); } if (y < m_height - 1) { if (m_V) UpdateMessageGENERAL(M_ptr + K, Di, K, 0.5, /* forward dir*/ 0, V_ptr + K*K, buf); else UpdateMessageGENERAL(M_ptr + K, Di, K, 0.5, m_smoothFn, n, n + m_width, buf); } } // backward pass m_lowerBound = 0; n --; D_ptr -= K; M_ptr -= 2 * K; V_ptr -= 2 * K * K; for (y = m_height - 1; y >= 0; y--) for (x = m_width - 1; x >= 0; x--, D_ptr -= K, M_ptr -= 2 * K, V_ptr -= 2 * K * K, n--) { CopyVector(Di, D_ptr, K); if (x > 0) AddVector(Di, M_ptr - 2*K, K); // message (x-1,y)->(x,y) if (y > 0) AddVector(Di, M_ptr - (2*m_width - 1)*K, K); // message (x,y-1)->(x,y) if (x < m_width - 1) AddVector(Di, M_ptr, K); // message (x+1,y)->(x,y) if (y < m_height - 1) AddVector(Di, M_ptr + K, K); // message (x,y+1)->(x,y) // normalize Di, update lower bound m_lowerBound += SubtractMin(Di, K); if (x > 0) { if (m_V) m_lowerBound += UpdateMessageGENERAL(M_ptr - 2 * K, Di, K, 0.5, /* backward dir */ 1, V_ptr - 2 * K * K, buf); else m_lowerBound += UpdateMessageGENERAL(M_ptr - 2 * K, Di, K, 0.5, m_smoothFn, n, n - 1, buf); } if (y > 0) { if (m_V) m_lowerBound += UpdateMessageGENERAL(M_ptr - (2 * m_width - 1) * K, Di, K, 0.5, /* backward dir */ 1, V_ptr - (2 * m_width - 1) * K * K, buf); else m_lowerBound += UpdateMessageGENERAL(M_ptr - (2 * m_width - 1) * K, Di, K, 0.5, m_smoothFn, n, n - m_width, buf); } } } delete [] Di; delete [] (REAL *)buf; }