|
@@ -0,0 +1,1116 @@
|
|
|
+#include <stdio.h>
|
|
|
+#include <stdlib.h>
|
|
|
+#include <string.h>
|
|
|
+#include <assert.h>
|
|
|
+#include <new>
|
|
|
+#include "TRW-S.h"
|
|
|
+
|
|
|
+#define private public
|
|
|
+#include "typeTruncatedQuadratic2D.h"
|
|
|
+
|
|
|
+using namespace OBJREC;
|
|
|
+#undef private
|
|
|
+
|
|
|
+
|
|
|
+#define m_D(pix,l) m_D[(pix)*m_nLabels+(l)]
|
|
|
+#define m_V(l1,l2) m_V[(l1)*m_nLabels+(l2)]
|
|
|
+
|
|
|
+
|
|
|
+#define MIN(a,b) (((a) < (b)) ? (a) : (b))
|
|
|
+#define MAX(a,b) (((a) > (b)) ? (a) : (b))
|
|
|
+#define TRUNCATE_MIN(a,b) { if ((a) > (b)) (a) = (b); }
|
|
|
+#define TRUNCATE_MAX(a,b) { if ((a) < (b)) (a) = (b); }
|
|
|
+#define TRUNCATE TRUNCATE_MIN
|
|
|
+
|
|
|
+/////////////////////////////////////////////////////////////////////////////
|
|
|
+// Operations on vectors (arrays of size K) //
|
|
|
+/////////////////////////////////////////////////////////////////////////////
|
|
|
+
|
|
|
+inline void CopyVector(TRWS::REAL* to, MRF::CostVal* from, int K)
|
|
|
+{
|
|
|
+ TRWS::REAL* to_finish = to + K;
|
|
|
+ do
|
|
|
+ {
|
|
|
+ *to ++ = *from ++;
|
|
|
+ } while (to < to_finish);
|
|
|
+}
|
|
|
+
|
|
|
+inline void AddVector(TRWS::REAL* to, TRWS::REAL* from, int K)
|
|
|
+{
|
|
|
+ TRWS::REAL* to_finish = to + K;
|
|
|
+ do
|
|
|
+ {
|
|
|
+ *to ++ += *from ++;
|
|
|
+ } while (to < to_finish);
|
|
|
+}
|
|
|
+
|
|
|
+inline TRWS::REAL SubtractMin(TRWS::REAL *D, int K)
|
|
|
+{
|
|
|
+ int k;
|
|
|
+ TRWS::REAL delta;
|
|
|
+
|
|
|
+ delta = D[0];
|
|
|
+ for (k = 1; k < K; k++) TRUNCATE(delta, D[k]);
|
|
|
+ for (k = 0; k < K; k++) D[k] -= delta;
|
|
|
+
|
|
|
+ return delta;
|
|
|
+}
|
|
|
+
|
|
|
+// Functions UpdateMessageTYPE (see the paper for details):
|
|
|
+//
|
|
|
+// - Set Di[ki] := gamma*Di_hat[ki] - M[ki]
|
|
|
+// - Set M[kj] := min_{ki} (Di[ki] + V[ki,kj])
|
|
|
+// - Normalize message:
|
|
|
+// delta := min_{kj} M[kj]
|
|
|
+// M[kj] := M[kj] - delta
|
|
|
+// return delta
|
|
|
+//
|
|
|
+// If dir = 1, then the meaning of i and j is swapped.
|
|
|
+
|
|
|
+///////////////////////////////////////////
|
|
|
+// L1 //
|
|
|
+///////////////////////////////////////////
|
|
|
+
|
|
|
+inline TRWS::REAL UpdateMessageL1(TRWS::REAL* M, TRWS::REAL* Di_hat, int K, TRWS::REAL gamma, MRF::CostVal lambda, MRF::CostVal smoothMax)
|
|
|
+{
|
|
|
+ int k;
|
|
|
+ TRWS::REAL delta;
|
|
|
+
|
|
|
+ delta = M[0] = gamma * Di_hat[0] - M[0];
|
|
|
+ for (k = 1; k < K; k++)
|
|
|
+ {
|
|
|
+ M[k] = gamma * Di_hat[k] - M[k];
|
|
|
+ TRUNCATE(delta, M[k]);
|
|
|
+ TRUNCATE(M[k], M[k-1] + lambda);
|
|
|
+ }
|
|
|
+
|
|
|
+ M[--k] -= delta;
|
|
|
+ TRUNCATE(M[k], lambda*smoothMax);
|
|
|
+ for (k--; k >= 0; k--)
|
|
|
+ {
|
|
|
+ M[k] -= delta;
|
|
|
+ TRUNCATE(M[k], M[k+1] + lambda);
|
|
|
+ TRUNCATE(M[k], lambda*smoothMax);
|
|
|
+ }
|
|
|
+
|
|
|
+ return delta;
|
|
|
+}
|
|
|
+
|
|
|
+////////////////////////////////////////
|
|
|
+// L2 //
|
|
|
+////////////////////////////////////////
|
|
|
+
|
|
|
+inline TRWS::REAL UpdateMessageL2(TRWS::REAL* M, TRWS::REAL* Di_hat, int K, TRWS::REAL gamma, MRF::CostVal lambda, MRF::CostVal smoothMax, void *buf)
|
|
|
+{
|
|
|
+ TRWS::REAL* Di = (TRWS::REAL*) buf;
|
|
|
+ int* parabolas = (int*) ((char*)buf + K * sizeof(TRWS::REAL));
|
|
|
+ int* intersections = parabolas + K;
|
|
|
+ TypeTruncatedQuadratic2D::Edge* tmp = NULL;
|
|
|
+
|
|
|
+ int k;
|
|
|
+ TRWS::REAL delta;
|
|
|
+
|
|
|
+ assert(lambda >= 0);
|
|
|
+
|
|
|
+ Di[0] = gamma * Di_hat[0] - M[0];
|
|
|
+ delta = Di[0];
|
|
|
+ for (k = 1; k < K; k++)
|
|
|
+ {
|
|
|
+ Di[k] = gamma * Di_hat[k] - M[k];
|
|
|
+ TRUNCATE(delta, Di[k]);
|
|
|
+ }
|
|
|
+
|
|
|
+ if (lambda == 0)
|
|
|
+ {
|
|
|
+ for (k = 0; k < K; k++) M[k] = 0;
|
|
|
+ return delta;
|
|
|
+ }
|
|
|
+
|
|
|
+ tmp->DistanceTransformL2(K, 1, lambda, Di, M, parabolas, intersections);
|
|
|
+
|
|
|
+ for (k = 0; k < K; k++)
|
|
|
+ {
|
|
|
+ M[k] -= delta;
|
|
|
+ TRUNCATE(M[k], lambda*smoothMax);
|
|
|
+ }
|
|
|
+
|
|
|
+ return delta;
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+//////////////////////////////////////////////////
|
|
|
+// FIXED_MATRIX //
|
|
|
+//////////////////////////////////////////////////
|
|
|
+
|
|
|
+inline TRWS::REAL UpdateMessageFIXED_MATRIX(TRWS::REAL* M, TRWS::REAL* Di_hat, int K, TRWS::REAL gamma, MRF::CostVal lambda, MRF::CostVal* V, void* buf)
|
|
|
+{
|
|
|
+ TRWS::REAL* Di = (TRWS::REAL*) buf;
|
|
|
+ int ki, kj;
|
|
|
+ TRWS::REAL delta;
|
|
|
+
|
|
|
+ if (lambda == 0)
|
|
|
+ {
|
|
|
+ delta = gamma * Di_hat[0] - M[0];
|
|
|
+ M[0] = 0;
|
|
|
+ for (ki = 1; ki < K; ki++)
|
|
|
+ {
|
|
|
+ TRUNCATE(delta, gamma*Di_hat[ki] - M[ki]);
|
|
|
+ M[ki] = 0;
|
|
|
+ }
|
|
|
+ return delta;
|
|
|
+ }
|
|
|
+
|
|
|
+ for (ki = 0; ki < K; ki++)
|
|
|
+ {
|
|
|
+ Di[ki] = (gamma * Di_hat[ki] - M[ki]) * (1 / (TRWS::REAL)lambda);
|
|
|
+ }
|
|
|
+
|
|
|
+ if (lambda > 0)
|
|
|
+ {
|
|
|
+ for (kj = 0; kj < K; kj++)
|
|
|
+ {
|
|
|
+ M[kj] = Di[0] + V[0];
|
|
|
+ V ++;
|
|
|
+ for (ki = 1; ki < K; ki++)
|
|
|
+ {
|
|
|
+ TRUNCATE(M[kj], Di[ki] + V[0]);
|
|
|
+ V ++;
|
|
|
+ }
|
|
|
+ M[kj] *= lambda;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ else
|
|
|
+ {
|
|
|
+ for (kj = 0; kj < K; kj++)
|
|
|
+ {
|
|
|
+ M[kj] = Di[0] + V[0];
|
|
|
+ V ++;
|
|
|
+ for (ki = 1; ki < K; ki++)
|
|
|
+ {
|
|
|
+ TRUNCATE_MAX(M[kj], Di[ki] + V[0]);
|
|
|
+ V ++;
|
|
|
+ }
|
|
|
+ M[kj] *= lambda;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ delta = M[0];
|
|
|
+ for (kj = 1; kj < K; kj++) TRUNCATE(delta, M[kj]);
|
|
|
+ for (kj = 0; kj < K; kj++) M[kj] -= delta;
|
|
|
+
|
|
|
+ return delta;
|
|
|
+}
|
|
|
+
|
|
|
+/////////////////////////////////////////////
|
|
|
+// GENERAL //
|
|
|
+/////////////////////////////////////////////
|
|
|
+
|
|
|
+inline TRWS::REAL UpdateMessageGENERAL(TRWS::REAL* M, TRWS::REAL* Di_hat, int K, TRWS::REAL gamma, int dir, MRF::CostVal* V, void* buf)
|
|
|
+{
|
|
|
+ TRWS::REAL* Di = (TRWS::REAL*) buf;
|
|
|
+ int ki, kj;
|
|
|
+ TRWS::REAL delta;
|
|
|
+
|
|
|
+ for (ki = 0; ki < K; ki++)
|
|
|
+ {
|
|
|
+ Di[ki] = (gamma * Di_hat[ki] - M[ki]);
|
|
|
+ }
|
|
|
+
|
|
|
+ if (dir == 0)
|
|
|
+ {
|
|
|
+ for (kj = 0; kj < K; kj++)
|
|
|
+ {
|
|
|
+ M[kj] = Di[0] + V[0];
|
|
|
+ V ++;
|
|
|
+ for (ki = 1; ki < K; ki++)
|
|
|
+ {
|
|
|
+ TRUNCATE(M[kj], Di[ki] + V[0]);
|
|
|
+ V ++;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ else
|
|
|
+ {
|
|
|
+ for (kj = 0; kj < K; kj++)
|
|
|
+ {
|
|
|
+ M[kj] = Di[0] + V[0];
|
|
|
+ V += K;
|
|
|
+ for (ki = 1; ki < K; ki++)
|
|
|
+ {
|
|
|
+ TRUNCATE(M[kj], Di[ki] + V[0]);
|
|
|
+ V += K;
|
|
|
+ }
|
|
|
+ V -= K * K - 1;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ delta = M[0];
|
|
|
+ for (kj = 1; kj < K; kj++) TRUNCATE(delta, M[kj]);
|
|
|
+ for (kj = 0; kj < K; kj++) M[kj] -= delta;
|
|
|
+
|
|
|
+ return delta;
|
|
|
+}
|
|
|
+
|
|
|
+inline TRWS::REAL UpdateMessageGENERAL(TRWS::REAL* M, TRWS::REAL* Di_hat, int K, TRWS::REAL gamma, TRWS::SmoothCostGeneralFn fn, int i, int j, void* buf)
|
|
|
+{
|
|
|
+ TRWS::REAL* Di = (TRWS::REAL*) buf;
|
|
|
+ int ki, kj;
|
|
|
+ TRWS::REAL delta;
|
|
|
+
|
|
|
+ for (ki = 0; ki < K; ki++)
|
|
|
+ {
|
|
|
+ Di[ki] = (gamma * Di_hat[ki] - M[ki]);
|
|
|
+ }
|
|
|
+
|
|
|
+ for (kj = 0; kj < K; kj++)
|
|
|
+ {
|
|
|
+ M[kj] = Di[0] + fn(i, j, 0, kj);
|
|
|
+ for (ki = 1; ki < K; ki++)
|
|
|
+ {
|
|
|
+ delta = Di[ki] + fn(i, j, ki, kj);
|
|
|
+ TRUNCATE(M[kj], delta);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ delta = M[0];
|
|
|
+ for (kj = 1; kj < K; kj++) TRUNCATE(delta, M[kj]);
|
|
|
+ for (kj = 0; kj < K; kj++) M[kj] -= delta;
|
|
|
+
|
|
|
+ return delta;
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+TRWS::TRWS(int width, int height, int nLabels, EnergyFunction *eng): MRF(width, height, nLabels, eng)
|
|
|
+{
|
|
|
+ Allocate();
|
|
|
+}
|
|
|
+TRWS::TRWS(int nPixels, int nLabels, EnergyFunction *eng): MRF(nPixels, nLabels, eng)
|
|
|
+{
|
|
|
+ Allocate();
|
|
|
+}
|
|
|
+
|
|
|
+TRWS::~TRWS()
|
|
|
+{
|
|
|
+ delete[] m_answer;
|
|
|
+ if ( m_needToFreeD ) delete [] m_D;
|
|
|
+ if ( m_needToFreeV ) delete [] m_V;
|
|
|
+ if ( m_messages ) delete [] m_messages;
|
|
|
+ if ( m_DBinary ) delete [] m_DBinary;
|
|
|
+ if ( m_horzWeightsBinary ) delete [] m_horzWeightsBinary;
|
|
|
+ if ( m_vertWeightsBinary ) delete [] m_vertWeightsBinary;
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+void TRWS::Allocate()
|
|
|
+{
|
|
|
+ m_type = NONE;
|
|
|
+ m_needToFreeV = false;
|
|
|
+ m_needToFreeD = false;
|
|
|
+
|
|
|
+ m_D = NULL;
|
|
|
+ m_V = NULL;
|
|
|
+ m_horzWeights = NULL;
|
|
|
+ m_vertWeights = NULL;
|
|
|
+ m_horzWeightsBinary = NULL;
|
|
|
+ m_vertWeightsBinary = NULL;
|
|
|
+
|
|
|
+ m_DBinary = NULL;
|
|
|
+ m_messages = NULL;
|
|
|
+ m_messageArraySizeInBytes = 0;
|
|
|
+
|
|
|
+ m_answer = new Label[m_nPixels];
|
|
|
+}
|
|
|
+
|
|
|
+void TRWS::clearAnswer()
|
|
|
+{
|
|
|
+ memset(m_answer, 0, m_nPixels*sizeof(Label));
|
|
|
+ if (m_messages)
|
|
|
+ {
|
|
|
+ memset(m_messages, 0, m_messageArraySizeInBytes);
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+MRF::EnergyVal TRWS::smoothnessEnergy()
|
|
|
+{
|
|
|
+ EnergyVal eng = (EnergyVal) 0;
|
|
|
+ EnergyVal weight;
|
|
|
+ int x, y, pix;
|
|
|
+
|
|
|
+ if ( m_grid_graph )
|
|
|
+ {
|
|
|
+ if ( m_smoothType != FUNCTION )
|
|
|
+ {
|
|
|
+ for ( y = 0; y < m_height; y++ )
|
|
|
+ for ( x = 1; x < m_width; x++ )
|
|
|
+ {
|
|
|
+ pix = x + y * m_width;
|
|
|
+ weight = m_varWeights ? m_horzWeights[pix-1] : 1;
|
|
|
+ eng = eng + m_V(m_answer[pix], m_answer[pix-1]) * weight;
|
|
|
+ }
|
|
|
+
|
|
|
+ for ( y = 1; y < m_height; y++ )
|
|
|
+ for ( x = 0; x < m_width; x++ )
|
|
|
+ {
|
|
|
+ pix = x + y * m_width;
|
|
|
+ weight = m_varWeights ? m_vertWeights[pix-m_width] : 1;
|
|
|
+ eng = eng + m_V(m_answer[pix], m_answer[pix-m_width]) * weight;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ else
|
|
|
+ {
|
|
|
+ for ( y = 0; y < m_height; y++ )
|
|
|
+ for ( x = 1; x < m_width; x++ )
|
|
|
+ {
|
|
|
+ pix = x + y * m_width;
|
|
|
+ eng = eng + m_smoothFn(pix, pix - 1, m_answer[pix], m_answer[pix-1]);
|
|
|
+ }
|
|
|
+
|
|
|
+ for ( y = 1; y < m_height; y++ )
|
|
|
+ for ( x = 0; x < m_width; x++ )
|
|
|
+ {
|
|
|
+ pix = x + y * m_width;
|
|
|
+ eng = eng + m_smoothFn(pix, pix - m_width, m_answer[pix], m_answer[pix-m_width]);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ else
|
|
|
+ {
|
|
|
+ // not implemented
|
|
|
+ }
|
|
|
+
|
|
|
+ return(eng);
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+MRF::EnergyVal TRWS::dataEnergy()
|
|
|
+{
|
|
|
+ EnergyVal eng = (EnergyVal) 0;
|
|
|
+
|
|
|
+
|
|
|
+ if ( m_dataType == ARRAY)
|
|
|
+ {
|
|
|
+ for ( int i = 0; i < m_nPixels; i++ )
|
|
|
+ eng = eng + m_D(i, m_answer[i]);
|
|
|
+ }
|
|
|
+ else
|
|
|
+ {
|
|
|
+ for ( int i = 0; i < m_nPixels; i++ )
|
|
|
+ eng = eng + m_dataFn(i, m_answer[i]);
|
|
|
+ }
|
|
|
+ return(eng);
|
|
|
+
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+void TRWS::setData(DataCostFn dcost)
|
|
|
+{
|
|
|
+ int i, k;
|
|
|
+
|
|
|
+ m_dataFn = dcost;
|
|
|
+ CostVal* ptr;
|
|
|
+ m_D = new CostVal[m_nPixels*m_nLabels];
|
|
|
+
|
|
|
+ for (ptr = m_D, i = 0; i < m_nPixels; i++)
|
|
|
+ for (k = 0; k < m_nLabels; k++, ptr++)
|
|
|
+ {
|
|
|
+ *ptr = m_dataFn(i, k);
|
|
|
+ }
|
|
|
+ m_needToFreeD = true;
|
|
|
+}
|
|
|
+
|
|
|
+void TRWS::setData(CostVal* data)
|
|
|
+{
|
|
|
+ m_D = data;
|
|
|
+ m_needToFreeD = false;
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+void TRWS::setSmoothness(SmoothCostGeneralFn cost)
|
|
|
+{
|
|
|
+ assert(m_horzWeights == NULL && m_vertWeights == NULL && m_V == NULL);
|
|
|
+
|
|
|
+ int x, y, i, ki, kj;
|
|
|
+ CostVal* ptr;
|
|
|
+
|
|
|
+ m_smoothFn = cost;
|
|
|
+ m_type = GENERAL;
|
|
|
+
|
|
|
+ if (!m_allocateArrayForSmoothnessCostFn) return;
|
|
|
+
|
|
|
+ // try to cache all the function values in an array for efficiency
|
|
|
+ m_V = new(std::nothrow) CostVal[2*m_nPixels*m_nLabels*m_nLabels];
|
|
|
+ if (!m_V) return; // if not enough space, just call the function directly
|
|
|
+
|
|
|
+ m_needToFreeV = true;
|
|
|
+
|
|
|
+ for (ptr = m_V, i = 0, y = 0; y < m_height; y++)
|
|
|
+ for (x = 0; x < m_width; x++, i++)
|
|
|
+ {
|
|
|
+ if (x < m_width - 1)
|
|
|
+ {
|
|
|
+ for (kj = 0; kj < m_nLabels; kj++)
|
|
|
+ for (ki = 0; ki < m_nLabels; ki++)
|
|
|
+ {
|
|
|
+ *ptr++ = cost(i, i + 1, ki, kj);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ else ptr += m_nLabels * m_nLabels;
|
|
|
+
|
|
|
+ if (y < m_height - 1)
|
|
|
+ {
|
|
|
+ for (kj = 0; kj < m_nLabels; kj++)
|
|
|
+ for (ki = 0; ki < m_nLabels; ki++)
|
|
|
+ {
|
|
|
+ *ptr++ = cost(i, i + m_width, ki, kj);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ else ptr += m_nLabels * m_nLabels;
|
|
|
+ }
|
|
|
+}
|
|
|
+void TRWS::setSmoothness(CostVal* V)
|
|
|
+{
|
|
|
+ m_type = FIXED_MATRIX;
|
|
|
+ m_V = V;
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+void TRWS::setSmoothness(int smoothExp, CostVal smoothMax, CostVal lambda)
|
|
|
+{
|
|
|
+ assert(smoothExp == 1 || smoothExp == 2);
|
|
|
+ assert(lambda >= 0);
|
|
|
+
|
|
|
+ m_type = (smoothExp == 1) ? L1 : L2;
|
|
|
+
|
|
|
+ int ki, kj;
|
|
|
+ CostVal cost;
|
|
|
+
|
|
|
+ m_needToFreeV = true;
|
|
|
+
|
|
|
+ m_V = new CostVal[m_nLabels*m_nLabels];
|
|
|
+
|
|
|
+ for (ki = 0; ki < m_nLabels; ki++)
|
|
|
+ for (kj = ki; kj < m_nLabels; kj++)
|
|
|
+ {
|
|
|
+ cost = (CostVal) ((smoothExp == 1) ? kj - ki : (kj - ki) * (kj - ki));
|
|
|
+ if (cost > smoothMax) cost = smoothMax;
|
|
|
+ m_V[ki*m_nLabels + kj] = m_V[kj*m_nLabels + ki] = cost * lambda;
|
|
|
+ }
|
|
|
+
|
|
|
+ m_smoothMax = smoothMax;
|
|
|
+ m_lambda = lambda;
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+void TRWS::setCues(CostVal* hCue, CostVal* vCue)
|
|
|
+{
|
|
|
+ m_horzWeights = hCue;
|
|
|
+ m_vertWeights = vCue;
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+void TRWS::initializeAlg()
|
|
|
+{
|
|
|
+ assert(m_type != NONE);
|
|
|
+
|
|
|
+ int i;
|
|
|
+
|
|
|
+ // determine type
|
|
|
+ if (m_type == L1 && m_nLabels == 2)
|
|
|
+ {
|
|
|
+ m_type = BINARY;
|
|
|
+ }
|
|
|
+
|
|
|
+ // allocate messages
|
|
|
+ int messageNum = (m_type == BINARY) ? 4 * m_nPixels : 4 * m_nPixels * m_nLabels;
|
|
|
+ m_messageArraySizeInBytes = messageNum * sizeof(REAL);
|
|
|
+ m_messages = new REAL[messageNum];
|
|
|
+ memset(m_messages, 0, messageNum*sizeof(REAL));
|
|
|
+
|
|
|
+ if (m_type == BINARY)
|
|
|
+ {
|
|
|
+ assert(m_DBinary == NULL && m_horzWeightsBinary == NULL && m_horzWeightsBinary == NULL);
|
|
|
+ m_DBinary = new CostVal[m_nPixels];
|
|
|
+ m_horzWeightsBinary = new CostVal[m_nPixels];
|
|
|
+ m_vertWeightsBinary = new CostVal[m_nPixels];
|
|
|
+
|
|
|
+ if ( m_dataType == ARRAY)
|
|
|
+ {
|
|
|
+ for (i = 0; i < m_nPixels; i++)
|
|
|
+ {
|
|
|
+ m_DBinary[i] = m_D[2*i+1] - m_D[2*i];
|
|
|
+ }
|
|
|
+ }
|
|
|
+ else
|
|
|
+ {
|
|
|
+ for (i = 0; i < m_nPixels; i++)
|
|
|
+ {
|
|
|
+ m_DBinary[i] = m_dataFn(i, 1) - m_dataFn(i, 0);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ assert(m_V[0] == 0 && m_V[1] == m_V[2] && m_V[3] == 0);
|
|
|
+ for (i = 0; i < m_nPixels; i++)
|
|
|
+ {
|
|
|
+ m_horzWeightsBinary[i] = (m_varWeights) ? m_V[1] * m_horzWeights[i] : m_V[1];
|
|
|
+ m_vertWeightsBinary[i] = (m_varWeights) ? m_V[1] * m_vertWeights[i] : m_V[1];
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+void TRWS::optimizeAlg(int nIterations)
|
|
|
+{
|
|
|
+ assert(m_type != NONE);
|
|
|
+
|
|
|
+ if (m_grid_graph)
|
|
|
+ {
|
|
|
+ switch (m_type)
|
|
|
+ {
|
|
|
+ case L1:
|
|
|
+ optimize_GRID_L1(nIterations);
|
|
|
+ break;
|
|
|
+ case L2:
|
|
|
+ optimize_GRID_L2(nIterations);
|
|
|
+ break;
|
|
|
+ case FIXED_MATRIX:
|
|
|
+ optimize_GRID_FIXED_MATRIX(nIterations);
|
|
|
+ break;
|
|
|
+ case GENERAL:
|
|
|
+ optimize_GRID_GENERAL(nIterations);
|
|
|
+ break;
|
|
|
+ case BINARY:
|
|
|
+ optimize_GRID_BINARY(nIterations);
|
|
|
+ break;
|
|
|
+ default:
|
|
|
+ assert(0);
|
|
|
+ exit(1);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ else {
|
|
|
+ printf("\nNot implemented for general graphs yet, exiting!");
|
|
|
+ exit(1);
|
|
|
+ }
|
|
|
+
|
|
|
+ // printf("lower bound = %f\n", m_lowerBound);
|
|
|
+
|
|
|
+ ////////////////////////////////////////////////
|
|
|
+ // computing solution //
|
|
|
+ ////////////////////////////////////////////////
|
|
|
+
|
|
|
+ if (m_type != BINARY)
|
|
|
+ {
|
|
|
+ int x, y, n, K = m_nLabels;
|
|
|
+ CostVal* D_ptr;
|
|
|
+ REAL* M_ptr;
|
|
|
+ REAL* Di;
|
|
|
+ REAL delta;
|
|
|
+ int ki, kj;
|
|
|
+
|
|
|
+ Di = new REAL[K];
|
|
|
+
|
|
|
+ n = 0;
|
|
|
+ D_ptr = m_D;
|
|
|
+ M_ptr = m_messages;
|
|
|
+
|
|
|
+ for (y = 0; y < m_height; y++)
|
|
|
+ for (x = 0; x < m_width; x++, D_ptr += K, M_ptr += 2 * K, n++)
|
|
|
+ {
|
|
|
+ CopyVector(Di, D_ptr, K);
|
|
|
+
|
|
|
+ if (m_type == GENERAL)
|
|
|
+ {
|
|
|
+ if (m_V)
|
|
|
+ {
|
|
|
+ CostVal* ptr = m_V + 2 * (x + y * m_width - 1) * K * K;
|
|
|
+ if (x > 0)
|
|
|
+ {
|
|
|
+ kj = m_answer[n-1];
|
|
|
+ for (ki = 0; ki < K; ki++)
|
|
|
+ {
|
|
|
+ Di[ki] += ptr[kj + ki*K];
|
|
|
+ }
|
|
|
+ }
|
|
|
+ ptr -= (2 * m_width - 3) * K * K;
|
|
|
+ if (y > 0)
|
|
|
+ {
|
|
|
+ kj = m_answer[n-m_width];
|
|
|
+ for (ki = 0; ki < K; ki++)
|
|
|
+ {
|
|
|
+ Di[ki] += ptr[kj + ki*K];
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ else
|
|
|
+ {
|
|
|
+ if (x > 0)
|
|
|
+ {
|
|
|
+ kj = m_answer[n-1];
|
|
|
+ for (ki = 0; ki < K; ki++)
|
|
|
+ {
|
|
|
+ Di[ki] += m_smoothFn(n, n - 1, ki, kj);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (y > 0)
|
|
|
+ {
|
|
|
+ kj = m_answer[n-m_width];
|
|
|
+ for (ki = 0; ki < K; ki++)
|
|
|
+ {
|
|
|
+ Di[ki] += m_smoothFn(n, n - m_width, ki, kj);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ else // m_type == L1, L2 or FIXED_MATRIX
|
|
|
+ {
|
|
|
+ if (x > 0)
|
|
|
+ {
|
|
|
+ kj = m_answer[n-1];
|
|
|
+ CostVal lambda = (m_varWeights) ? m_horzWeights[n-1] : 1;
|
|
|
+ for (ki = 0; ki < K; ki++)
|
|
|
+ {
|
|
|
+ Di[ki] += lambda * m_V[kj*K + ki];
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (y > 0)
|
|
|
+ {
|
|
|
+ kj = m_answer[n-m_width];
|
|
|
+ CostVal lambda = (m_varWeights) ? m_vertWeights[n-m_width] : 1;
|
|
|
+ for (ki = 0; ki < K; ki++)
|
|
|
+ {
|
|
|
+ Di[ki] += lambda * m_V[kj*K + ki];
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if (x < m_width - 1) AddVector(Di, M_ptr, K); // message (x+1,y)->(x,y)
|
|
|
+ if (y < m_height - 1) AddVector(Di, M_ptr + K, K); // message (x,y+1)->(x,y)
|
|
|
+
|
|
|
+ // compute min
|
|
|
+ delta = Di[0];
|
|
|
+ m_answer[n] = 0;
|
|
|
+ for (ki = 1; ki < K; ki++)
|
|
|
+ {
|
|
|
+ if (delta > Di[ki])
|
|
|
+ {
|
|
|
+ delta = Di[ki];
|
|
|
+ m_answer[n] = ki;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ delete [] Di;
|
|
|
+ }
|
|
|
+ else // m_type == BINARY
|
|
|
+ {
|
|
|
+ int x, y, n;
|
|
|
+ REAL* M_ptr;
|
|
|
+ REAL Di;
|
|
|
+
|
|
|
+ n = 0;
|
|
|
+ M_ptr = m_messages;
|
|
|
+
|
|
|
+ for (y = 0; y < m_height; y++)
|
|
|
+ for (x = 0; x < m_width; x++, M_ptr += 2, n++)
|
|
|
+ {
|
|
|
+ Di = m_DBinary[n];
|
|
|
+ if (x > 0) Di += (m_answer[n-1] == 0) ? m_horzWeightsBinary[n-1] : -m_horzWeightsBinary[n-1];
|
|
|
+ if (y > 0) Di += (m_answer[n-m_width] == 0) ? m_vertWeightsBinary[n-m_width] : -m_vertWeightsBinary[n-m_width];
|
|
|
+
|
|
|
+ if (x < m_width - 1) Di += M_ptr[0]; // message (x+1,y)->(x,y)
|
|
|
+ if (y < m_height - 1) Di += M_ptr[1]; // message (x,y+1)->(x,y)
|
|
|
+
|
|
|
+ // compute min
|
|
|
+ m_answer[n] = (Di >= 0) ? 0 : 1;
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+void TRWS::optimize_GRID_L1(int nIterations)
|
|
|
+{
|
|
|
+ int x, y, n, K = m_nLabels;
|
|
|
+ CostVal* D_ptr;
|
|
|
+ REAL* M_ptr;
|
|
|
+ REAL* Di;
|
|
|
+
|
|
|
+ Di = new REAL[K];
|
|
|
+
|
|
|
+ for ( ; nIterations > 0; nIterations --)
|
|
|
+ {
|
|
|
+ // forward pass
|
|
|
+ n = 0;
|
|
|
+ D_ptr = m_D;
|
|
|
+ M_ptr = m_messages;
|
|
|
+
|
|
|
+ for (y = 0; y < m_height; y++)
|
|
|
+ for (x = 0; x < m_width; x++, D_ptr += K, M_ptr += 2 * K, n++)
|
|
|
+ {
|
|
|
+ CopyVector(Di, D_ptr, K);
|
|
|
+ if (x > 0) AddVector(Di, M_ptr - 2*K, K); // message (x-1,y)->(x,y)
|
|
|
+ if (y > 0) AddVector(Di, M_ptr - (2*m_width - 1)*K, K); // message (x,y-1)->(x,y)
|
|
|
+ if (x < m_width - 1) AddVector(Di, M_ptr, K); // message (x+1,y)->(x,y)
|
|
|
+ if (y < m_height - 1) AddVector(Di, M_ptr + K, K); // message (x,y+1)->(x,y)
|
|
|
+
|
|
|
+ if (x < m_width - 1)
|
|
|
+ {
|
|
|
+ CostVal lambda = (m_varWeights) ? m_lambda * m_horzWeights[n] : m_lambda;
|
|
|
+ UpdateMessageL1(M_ptr, Di, K, 0.5, lambda, m_smoothMax);
|
|
|
+ }
|
|
|
+ if (y < m_height - 1)
|
|
|
+ {
|
|
|
+ CostVal lambda = (m_varWeights) ? m_lambda * m_vertWeights[n] : m_lambda;
|
|
|
+ UpdateMessageL1(M_ptr + K, Di, K, 0.5, lambda, m_smoothMax);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // backward pass
|
|
|
+ m_lowerBound = 0;
|
|
|
+
|
|
|
+ n --;
|
|
|
+ D_ptr -= K;
|
|
|
+ M_ptr -= 2 * K;
|
|
|
+
|
|
|
+ for (y = m_height - 1; y >= 0; y--)
|
|
|
+ for (x = m_width - 1; x >= 0; x--, D_ptr -= K, M_ptr -= 2 * K, n--)
|
|
|
+ {
|
|
|
+ CopyVector(Di, D_ptr, K);
|
|
|
+ if (x > 0) AddVector(Di, M_ptr - 2*K, K); // message (x-1,y)->(x,y)
|
|
|
+ if (y > 0) AddVector(Di, M_ptr - (2*m_width - 1)*K, K); // message (x,y-1)->(x,y)
|
|
|
+ if (x < m_width - 1) AddVector(Di, M_ptr, K); // message (x+1,y)->(x,y)
|
|
|
+ if (y < m_height - 1) AddVector(Di, M_ptr + K, K); // message (x,y+1)->(x,y)
|
|
|
+
|
|
|
+ m_lowerBound += SubtractMin(Di, K);
|
|
|
+
|
|
|
+ if (x > 0)
|
|
|
+ {
|
|
|
+ CostVal lambda = (m_varWeights) ? m_lambda * m_horzWeights[n-1] : m_lambda;
|
|
|
+ m_lowerBound += UpdateMessageL1(M_ptr - 2 * K, Di, K, 0.5, lambda, m_smoothMax);
|
|
|
+ }
|
|
|
+ if (y > 0)
|
|
|
+ {
|
|
|
+ CostVal lambda = (m_varWeights) ? m_lambda * m_vertWeights[n-m_width] : m_lambda;
|
|
|
+ m_lowerBound += UpdateMessageL1(M_ptr - (2 * m_width - 1) * K, Di, K, 0.5, lambda, m_smoothMax);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ delete [] Di;
|
|
|
+}
|
|
|
+
|
|
|
+void TRWS::optimize_GRID_L2(int nIterations)
|
|
|
+{
|
|
|
+ int x, y, n, K = m_nLabels;
|
|
|
+ CostVal* D_ptr;
|
|
|
+ REAL* M_ptr;
|
|
|
+ REAL* Di;
|
|
|
+ void* buf;
|
|
|
+
|
|
|
+ Di = new REAL[K];
|
|
|
+ buf = new char[(2*K+1)*sizeof(int) + K*sizeof(REAL)];
|
|
|
+
|
|
|
+ for ( ; nIterations > 0; nIterations --)
|
|
|
+ {
|
|
|
+ // forward pass
|
|
|
+ n = 0;
|
|
|
+ D_ptr = m_D;
|
|
|
+ M_ptr = m_messages;
|
|
|
+
|
|
|
+ for (y = 0; y < m_height; y++)
|
|
|
+ for (x = 0; x < m_width; x++, D_ptr += K, M_ptr += 2 * K, n++)
|
|
|
+ {
|
|
|
+ CopyVector(Di, D_ptr, K);
|
|
|
+ if (x > 0) AddVector(Di, M_ptr - 2*K, K); // message (x-1,y)->(x,y)
|
|
|
+ if (y > 0) AddVector(Di, M_ptr - (2*m_width - 1)*K, K); // message (x,y-1)->(x,y)
|
|
|
+ if (x < m_width - 1) AddVector(Di, M_ptr, K); // message (x+1,y)->(x,y)
|
|
|
+ if (y < m_height - 1) AddVector(Di, M_ptr + K, K); // message (x,y+1)->(x,y)
|
|
|
+
|
|
|
+ if (x < m_width - 1)
|
|
|
+ {
|
|
|
+ CostVal lambda = (m_varWeights) ? m_lambda * m_horzWeights[n] : m_lambda;
|
|
|
+ UpdateMessageL2(M_ptr, Di, K, 0.5, lambda, m_smoothMax, buf);
|
|
|
+ }
|
|
|
+ if (y < m_height - 1)
|
|
|
+ {
|
|
|
+ CostVal lambda = (m_varWeights) ? m_lambda * m_vertWeights[n] : m_lambda;
|
|
|
+ UpdateMessageL2(M_ptr + K, Di, K, 0.5, lambda, m_smoothMax, buf);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // backward pass
|
|
|
+ m_lowerBound = 0;
|
|
|
+
|
|
|
+ n --;
|
|
|
+ D_ptr -= K;
|
|
|
+ M_ptr -= 2 * K;
|
|
|
+
|
|
|
+ for (y = m_height - 1; y >= 0; y--)
|
|
|
+ for (x = m_width - 1; x >= 0; x--, D_ptr -= K, M_ptr -= 2 * K, n--)
|
|
|
+ {
|
|
|
+ CopyVector(Di, D_ptr, K);
|
|
|
+ if (x > 0) AddVector(Di, M_ptr - 2*K, K); // message (x-1,y)->(x,y)
|
|
|
+ if (y > 0) AddVector(Di, M_ptr - (2*m_width - 1)*K, K); // message (x,y-1)->(x,y)
|
|
|
+ if (x < m_width - 1) AddVector(Di, M_ptr, K); // message (x+1,y)->(x,y)
|
|
|
+ if (y < m_height - 1) AddVector(Di, M_ptr + K, K); // message (x,y+1)->(x,y)
|
|
|
+
|
|
|
+ m_lowerBound += SubtractMin(Di, K);
|
|
|
+
|
|
|
+ if (x > 0)
|
|
|
+ {
|
|
|
+ CostVal lambda = (m_varWeights) ? m_lambda * m_horzWeights[n-1] : m_lambda;
|
|
|
+ m_lowerBound += UpdateMessageL2(M_ptr - 2 * K, Di, K, 0.5, lambda, m_smoothMax, buf);
|
|
|
+ }
|
|
|
+ if (y > 0)
|
|
|
+ {
|
|
|
+ CostVal lambda = (m_varWeights) ? m_lambda * m_vertWeights[n-m_width] : m_lambda;
|
|
|
+ m_lowerBound += UpdateMessageL2(M_ptr - (2 * m_width - 1) * K, Di, K, 0.5, lambda, m_smoothMax, buf);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ delete [] Di;
|
|
|
+ delete [] (char *)buf;
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+void TRWS::optimize_GRID_BINARY(int nIterations)
|
|
|
+{
|
|
|
+ int x, y, n;
|
|
|
+ REAL* M_ptr;
|
|
|
+ REAL Di;
|
|
|
+
|
|
|
+ for ( ; nIterations > 0; nIterations --)
|
|
|
+ {
|
|
|
+ // forward pass
|
|
|
+ n = 0;
|
|
|
+ M_ptr = m_messages;
|
|
|
+
|
|
|
+ for (y = 0; y < m_height; y++)
|
|
|
+ for (x = 0; x < m_width; x++, M_ptr += 2, n++)
|
|
|
+ {
|
|
|
+ Di = m_DBinary[n];
|
|
|
+ if (x > 0) Di += M_ptr[-2]; // message (x-1,y)->(x,y)
|
|
|
+ if (y > 0) Di += M_ptr[-2*m_width+1]; // message (x,y-1)->(x,y)
|
|
|
+ if (x < m_width - 1) Di += M_ptr[0]; // message (x+1,y)->(x,y)
|
|
|
+ if (y < m_height - 1) Di += M_ptr[1]; // message (x,y+1)->(x,y)
|
|
|
+
|
|
|
+ REAL DiScaled = Di * 0.5;
|
|
|
+ if (x < m_width - 1)
|
|
|
+ {
|
|
|
+ Di = DiScaled - M_ptr[0];
|
|
|
+ CostVal lambda = m_horzWeightsBinary[n];
|
|
|
+ if (lambda < 0) {
|
|
|
+ Di = -Di;
|
|
|
+ lambda = -lambda;
|
|
|
+ }
|
|
|
+ if (Di > lambda) M_ptr[0] = lambda;
|
|
|
+ else M_ptr[0] = (Di < -lambda) ? -lambda : Di;
|
|
|
+ }
|
|
|
+ if (y < m_height - 1)
|
|
|
+ {
|
|
|
+ Di = DiScaled - M_ptr[1];
|
|
|
+ CostVal lambda = m_vertWeightsBinary[n];
|
|
|
+ if (lambda < 0) {
|
|
|
+ Di = -Di;
|
|
|
+ lambda = -lambda;
|
|
|
+ }
|
|
|
+ if (Di > lambda) M_ptr[1] = lambda;
|
|
|
+ else M_ptr[1] = (Di < -lambda) ? -lambda : Di;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // backward pass
|
|
|
+ n --;
|
|
|
+ M_ptr -= 2;
|
|
|
+
|
|
|
+ for (y = m_height - 1; y >= 0; y--)
|
|
|
+ for (x = m_width - 1; x >= 0; x--, M_ptr -= 2, n--)
|
|
|
+ {
|
|
|
+ Di = m_DBinary[n];
|
|
|
+ if (x > 0) Di += M_ptr[-2]; // message (x-1,y)->(x,y)
|
|
|
+ if (y > 0) Di += M_ptr[-2*m_width+1]; // message (x,y-1)->(x,y)
|
|
|
+ if (x < m_width - 1) Di += M_ptr[0]; // message (x+1,y)->(x,y)
|
|
|
+ if (y < m_height - 1) Di += M_ptr[1]; // message (x,y+1)->(x,y)
|
|
|
+
|
|
|
+ REAL DiScaled = Di * 0.5;
|
|
|
+ if (x > 0)
|
|
|
+ {
|
|
|
+ Di = DiScaled - M_ptr[-2];
|
|
|
+ CostVal lambda = m_horzWeightsBinary[n-1];
|
|
|
+ if (lambda < 0) {
|
|
|
+ Di = -Di;
|
|
|
+ lambda = -lambda;
|
|
|
+ }
|
|
|
+ if (Di > lambda) M_ptr[-2] = lambda;
|
|
|
+ else M_ptr[-2] = (Di < -lambda) ? -lambda : Di;
|
|
|
+ }
|
|
|
+ if (y > 0)
|
|
|
+ {
|
|
|
+ Di = DiScaled - M_ptr[-2*m_width+1];
|
|
|
+ CostVal lambda = m_vertWeightsBinary[n-m_width];
|
|
|
+ if (lambda < 0) {
|
|
|
+ Di = -Di;
|
|
|
+ lambda = -lambda;
|
|
|
+ }
|
|
|
+ if (Di > lambda) M_ptr[-2*m_width+1] = lambda;
|
|
|
+ else M_ptr[-2*m_width+1] = (Di < -lambda) ? -lambda : Di;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ m_lowerBound = 0;
|
|
|
+}
|
|
|
+
|
|
|
+void TRWS::optimize_GRID_FIXED_MATRIX(int nIterations)
|
|
|
+{
|
|
|
+ int x, y, n, K = m_nLabels;
|
|
|
+ CostVal* D_ptr;
|
|
|
+ REAL* M_ptr;
|
|
|
+ REAL* Di;
|
|
|
+ void* buf;
|
|
|
+
|
|
|
+ Di = new REAL[K];
|
|
|
+ buf = new REAL[K];
|
|
|
+
|
|
|
+ for ( ; nIterations > 0; nIterations --)
|
|
|
+ {
|
|
|
+ // forward pass
|
|
|
+ n = 0;
|
|
|
+ D_ptr = m_D;
|
|
|
+ M_ptr = m_messages;
|
|
|
+
|
|
|
+ for (y = 0; y < m_height; y++)
|
|
|
+ for (x = 0; x < m_width; x++, D_ptr += K, M_ptr += 2 * K, n++)
|
|
|
+ {
|
|
|
+ CopyVector(Di, D_ptr, K);
|
|
|
+ if (x > 0) AddVector(Di, M_ptr - 2*K, K); // message (x-1,y)->(x,y)
|
|
|
+ if (y > 0) AddVector(Di, M_ptr - (2*m_width - 1)*K, K); // message (x,y-1)->(x,y)
|
|
|
+ if (x < m_width - 1) AddVector(Di, M_ptr, K); // message (x+1,y)->(x,y)
|
|
|
+ if (y < m_height - 1) AddVector(Di, M_ptr + K, K); // message (x,y+1)->(x,y)
|
|
|
+
|
|
|
+ if (x < m_width - 1)
|
|
|
+ {
|
|
|
+ CostVal lambda = (m_varWeights) ? m_horzWeights[n] : 1;
|
|
|
+ UpdateMessageFIXED_MATRIX(M_ptr, Di, K, 0.5, lambda, m_V, buf);
|
|
|
+ }
|
|
|
+ if (y < m_height - 1)
|
|
|
+ {
|
|
|
+ CostVal lambda = (m_varWeights) ? m_vertWeights[n] : 1;
|
|
|
+ UpdateMessageFIXED_MATRIX(M_ptr + K, Di, K, 0.5, lambda, m_V, buf);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // backward pass
|
|
|
+ m_lowerBound = 0;
|
|
|
+
|
|
|
+ n --;
|
|
|
+ D_ptr -= K;
|
|
|
+ M_ptr -= 2 * K;
|
|
|
+
|
|
|
+ for (y = m_height - 1; y >= 0; y--)
|
|
|
+ for (x = m_width - 1; x >= 0; x--, D_ptr -= K, M_ptr -= 2 * K, n--)
|
|
|
+ {
|
|
|
+ CopyVector(Di, D_ptr, K);
|
|
|
+ if (x > 0) AddVector(Di, M_ptr - 2*K, K); // message (x-1,y)->(x,y)
|
|
|
+ if (y > 0) AddVector(Di, M_ptr - (2*m_width - 1)*K, K); // message (x,y-1)->(x,y)
|
|
|
+ if (x < m_width - 1) AddVector(Di, M_ptr, K); // message (x+1,y)->(x,y)
|
|
|
+ if (y < m_height - 1) AddVector(Di, M_ptr + K, K); // message (x,y+1)->(x,y)
|
|
|
+
|
|
|
+ m_lowerBound += SubtractMin(Di, K);
|
|
|
+
|
|
|
+ if (x > 0)
|
|
|
+ {
|
|
|
+ CostVal lambda = (m_varWeights) ? m_horzWeights[n-1] : 1;
|
|
|
+ m_lowerBound += UpdateMessageFIXED_MATRIX(M_ptr - 2 * K, Di, K, 0.5, lambda, m_V, buf);
|
|
|
+ }
|
|
|
+ if (y > 0)
|
|
|
+ {
|
|
|
+ CostVal lambda = (m_varWeights) ? m_vertWeights[n-m_width] : 1;
|
|
|
+ m_lowerBound += UpdateMessageFIXED_MATRIX(M_ptr - (2 * m_width - 1) * K, Di, K, 0.5, lambda, m_V, buf);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ delete [] Di;
|
|
|
+ delete [] (REAL *)buf;
|
|
|
+}
|
|
|
+
|
|
|
+void TRWS::optimize_GRID_GENERAL(int nIterations)
|
|
|
+{
|
|
|
+ int x, y, n, K = m_nLabels;
|
|
|
+ CostVal* D_ptr;
|
|
|
+ REAL* M_ptr;
|
|
|
+ REAL* Di;
|
|
|
+ void* buf;
|
|
|
+
|
|
|
+ Di = new REAL[K];
|
|
|
+ buf = new REAL[K];
|
|
|
+
|
|
|
+ for ( ; nIterations > 0; nIterations --)
|
|
|
+ {
|
|
|
+ // forward pass
|
|
|
+ n = 0;
|
|
|
+ D_ptr = m_D;
|
|
|
+ M_ptr = m_messages;
|
|
|
+ CostVal* V_ptr = m_V;
|
|
|
+
|
|
|
+ for (y = 0; y < m_height; y++)
|
|
|
+ for (x = 0; x < m_width; x++, D_ptr += K, M_ptr += 2 * K, V_ptr += 2 * K * K, n++)
|
|
|
+ {
|
|
|
+ CopyVector(Di, D_ptr, K);
|
|
|
+ if (x > 0) AddVector(Di, M_ptr - 2*K, K); // message (x-1,y)->(x,y)
|
|
|
+ if (y > 0) AddVector(Di, M_ptr - (2*m_width - 1)*K, K); // message (x,y-1)->(x,y)
|
|
|
+ if (x < m_width - 1) AddVector(Di, M_ptr, K); // message (x+1,y)->(x,y)
|
|
|
+ if (y < m_height - 1) AddVector(Di, M_ptr + K, K); // message (x,y+1)->(x,y)
|
|
|
+
|
|
|
+ if (x < m_width - 1)
|
|
|
+ {
|
|
|
+ if (m_V) UpdateMessageGENERAL(M_ptr, Di, K, 0.5, /* forward dir*/ 0, V_ptr, buf);
|
|
|
+ else UpdateMessageGENERAL(M_ptr, Di, K, 0.5, m_smoothFn, n, n + 1, buf);
|
|
|
+ }
|
|
|
+ if (y < m_height - 1)
|
|
|
+ {
|
|
|
+ if (m_V) UpdateMessageGENERAL(M_ptr + K, Di, K, 0.5, /* forward dir*/ 0, V_ptr + K*K, buf);
|
|
|
+ else UpdateMessageGENERAL(M_ptr + K, Di, K, 0.5, m_smoothFn, n, n + m_width, buf);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // backward pass
|
|
|
+ m_lowerBound = 0;
|
|
|
+
|
|
|
+ n --;
|
|
|
+ D_ptr -= K;
|
|
|
+ M_ptr -= 2 * K;
|
|
|
+ V_ptr -= 2 * K * K;
|
|
|
+
|
|
|
+ for (y = m_height - 1; y >= 0; y--)
|
|
|
+ for (x = m_width - 1; x >= 0; x--, D_ptr -= K, M_ptr -= 2 * K, V_ptr -= 2 * K * K, n--)
|
|
|
+ {
|
|
|
+ CopyVector(Di, D_ptr, K);
|
|
|
+ if (x > 0) AddVector(Di, M_ptr - 2*K, K); // message (x-1,y)->(x,y)
|
|
|
+ if (y > 0) AddVector(Di, M_ptr - (2*m_width - 1)*K, K); // message (x,y-1)->(x,y)
|
|
|
+ if (x < m_width - 1) AddVector(Di, M_ptr, K); // message (x+1,y)->(x,y)
|
|
|
+ if (y < m_height - 1) AddVector(Di, M_ptr + K, K); // message (x,y+1)->(x,y)
|
|
|
+
|
|
|
+ // normalize Di, update lower bound
|
|
|
+ m_lowerBound += SubtractMin(Di, K);
|
|
|
+
|
|
|
+ if (x > 0)
|
|
|
+ {
|
|
|
+ if (m_V) m_lowerBound += UpdateMessageGENERAL(M_ptr - 2 * K, Di, K, 0.5, /* backward dir */ 1, V_ptr - 2 * K * K, buf);
|
|
|
+ else m_lowerBound += UpdateMessageGENERAL(M_ptr - 2 * K, Di, K, 0.5, m_smoothFn, n, n - 1, buf);
|
|
|
+ }
|
|
|
+ if (y > 0)
|
|
|
+ {
|
|
|
+ if (m_V) m_lowerBound += UpdateMessageGENERAL(M_ptr - (2 * m_width - 1) * K, Di, K, 0.5, /* backward dir */ 1, V_ptr - (2 * m_width - 1) * K * K, buf);
|
|
|
+ else m_lowerBound += UpdateMessageGENERAL(M_ptr - (2 * m_width - 1) * K, Di, K, 0.5, m_smoothFn, n, n - m_width, buf);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ delete [] Di;
|
|
|
+ delete [] (REAL *)buf;
|
|
|
+}
|