123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107 |
- #ifndef __TRWS_H__
- #define __TRWS_H__
- #include <stdio.h>
- #include <stdlib.h>
- #include <string.h>
- #include <assert.h>
- #include "mrf.h"
- #undef REAL
- namespace OBJREC {
- class TRWS : public MRF {
- public:
- typedef double REAL;
- TRWS(int width, int height, int nLabels, EnergyFunction *eng);
- TRWS(int nPixels, int nLabels, EnergyFunction *eng);
- ~TRWS();
- void setNeighbors(int /*pix1*/, int /*pix2*/, CostVal /*weight*/) {
- printf("Not implemented");
- exit(1);
- }
- Label getLabel(int pixel) {
- return(m_answer[pixel]);
- };
- void setLabel(int pixel, Label label) {
- m_answer[pixel] = label;
- };
- Label* getAnswerPtr() {
- return(m_answer);
- };
- void clearAnswer();
- void setParameters(int /*numParam*/, void * /*param*/) {
- printf("No optional parameters to set");
- exit(1);
- }
- EnergyVal smoothnessEnergy();
- EnergyVal dataEnergy();
- double lowerBound() {
- return (double)m_lowerBound;
- }
- // For general smoothness functions, this code tries to cache all function values in an array
- // for efficiency. To prevent this, call the following function before calling initialize():
- void dontCacheSmoothnessCosts() {
- m_allocateArrayForSmoothnessCostFn = false;
- }
- protected:
- void setData(DataCostFn dcost);
- void setData(CostVal* data);
- void setSmoothness(SmoothCostGeneralFn cost);
- void setSmoothness(CostVal* V);
- void setSmoothness(int smoothExp, CostVal smoothMax, CostVal lambda);
- void setCues(CostVal* hCue, CostVal* vCue);
- void Allocate();
- void initializeAlg();
- void optimizeAlg(int nIterations);
- private:
- enum
- {
- NONE,
- L1,
- L2,
- FIXED_MATRIX,
- GENERAL,
- BINARY,
- } m_type;
- CostVal m_smoothMax; // used only if
- CostVal m_lambda; // m_type == L1 or m_type == L2
- Label *m_answer;
- CostVal *m_V; // points to array of size nLabels^2 (if type==FIXED_MATRIX) or of size nEdges*nLabels^2 (if type==GENERAL)
- CostVal *m_D;
- CostVal *m_DBinary; // valid if type == BINARY
- CostVal *m_horzWeights;
- CostVal *m_vertWeights;
- CostVal *m_horzWeightsBinary;
- CostVal *m_vertWeightsBinary;
- DataCostFn m_dataFn;
- SmoothCostGeneralFn m_smoothFn;
- bool m_needToFreeV;
- bool m_needToFreeD;
- REAL* m_messages; // size of one message: N = 1 if m_type == BINARY, N = K otherwise
- // message between edges (x,y)-(x+1,y): m_messages+(2*x+2*y*m_width)*N
- // message between edges (x,y)-(x,y+1): m_messages+(2*x+2*y*m_width+1)*N
- int m_messageArraySizeInBytes;
- REAL m_lowerBound;
- void optimize_GRID_L1(int nIterations);
- void optimize_GRID_L2(int nIterations);
- void optimize_GRID_FIXED_MATRIX(int nIterations);
- void optimize_GRID_GENERAL(int nIterations);
- void optimize_GRID_BINARY(int nIterations);
- };
- }
- #endif /* __TRWS_H__ */
|