#ifndef __TRWS_H__ #define __TRWS_H__ #include #include #include #include #include "mrf.h" #undef REAL namespace OBJREC { class TRWS : public MRF { public: typedef double REAL; TRWS(int width, int height, int nLabels, EnergyFunction *eng); TRWS(int nPixels, int nLabels, EnergyFunction *eng); ~TRWS(); void setNeighbors(int /*pix1*/, int /*pix2*/, CostVal /*weight*/) { printf("Not implemented"); exit(1); } Label getLabel(int pixel) { return(m_answer[pixel]); }; void setLabel(int pixel, Label label) { m_answer[pixel] = label; }; Label* getAnswerPtr() { return(m_answer); }; void clearAnswer(); void setParameters(int /*numParam*/, void * /*param*/) { printf("No optional parameters to set"); exit(1); } EnergyVal smoothnessEnergy(); EnergyVal dataEnergy(); double lowerBound() { return (double)m_lowerBound; } // For general smoothness functions, this code tries to cache all function values in an array // for efficiency. To prevent this, call the following function before calling initialize(): void dontCacheSmoothnessCosts() { m_allocateArrayForSmoothnessCostFn = false; } protected: void setData(DataCostFn dcost); void setData(CostVal* data); void setSmoothness(SmoothCostGeneralFn cost); void setSmoothness(CostVal* V); void setSmoothness(int smoothExp, CostVal smoothMax, CostVal lambda); void setCues(CostVal* hCue, CostVal* vCue); void Allocate(); void initializeAlg(); void optimizeAlg(int nIterations); private: enum { NONE, L1, L2, FIXED_MATRIX, GENERAL, BINARY, } m_type; CostVal m_smoothMax; // used only if CostVal m_lambda; // m_type == L1 or m_type == L2 Label *m_answer; CostVal *m_V; // points to array of size nLabels^2 (if type==FIXED_MATRIX) or of size nEdges*nLabels^2 (if type==GENERAL) CostVal *m_D; CostVal *m_DBinary; // valid if type == BINARY CostVal *m_horzWeights; CostVal *m_vertWeights; CostVal *m_horzWeightsBinary; CostVal *m_vertWeightsBinary; DataCostFn m_dataFn; SmoothCostGeneralFn m_smoothFn; bool m_needToFreeV; bool m_needToFreeD; REAL* m_messages; // size of one message: N = 1 if m_type == BINARY, N = K otherwise // message between edges (x,y)-(x+1,y): m_messages+(2*x+2*y*m_width)*N // message between edges (x,y)-(x,y+1): m_messages+(2*x+2*y*m_width+1)*N int m_messageArraySizeInBytes; REAL m_lowerBound; void optimize_GRID_L1(int nIterations); void optimize_GRID_L2(int nIterations); void optimize_GRID_FIXED_MATRIX(int nIterations); void optimize_GRID_GENERAL(int nIterations); void optimize_GRID_BINARY(int nIterations); }; } #endif /* __TRWS_H__ */