mrf.h 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. /* Copyright Olga Veksler, Ramin Zabih, Vladimir Kolmogorov, and Daniel Scharstein
  2. * Send any questions to schar@middlebury.edu
  3. */
  4. #ifndef __MRF_H__
  5. #define __MRF_H__
  6. #include <stdio.h>
  7. namespace OBJREC {
  8. class EnergyFunction;
  9. class MRF
  10. {
  11. public:
  12. // *********** CONSTRUCTORS/DESTRUCTOR
  13. // Constructor. After you call this, you must call setData and setSmoothness
  14. // Use this constructor for 2D grid graphs of size width by height Standard 4-connected
  15. // neighborhood system is assumed. Labels are in the range 0,1,...nLabels - 1
  16. // Width is in the range 0,1,...width-1 and height is in the range 0,1,...height-1
  17. // Input parameter eng specifies the data and smoothness parts of the energy
  18. // For 2D grids, since 4 connected neighborhood structure is assumed, this
  19. // fully specifies the energy
  20. MRF(int width, int height, int nLabels, EnergyFunction *eng);
  21. // Use this constructor for a general neighborhood system. Pixels are in the range
  22. // 0,1,..nPixels-1, and labels are in the range 0,1,...,nLabels-1
  23. // Input parameter eng specifies the data and smoothness parts of the energy
  24. // after this constructor you need to call setNeighbors() to specify the neighborhood system
  25. MRF(int nPixels, int nLabels, EnergyFunction *eng);
  26. virtual ~MRF() { }
  27. // Returns true if energy function has been specified, returns false otherwise
  28. // By default, it always returns true. Can be modified by the supplier of
  29. // optimization algorithm
  30. virtual int isValid(){return true;};
  31. // *********** EVALUATING THE ENERGY
  32. typedef int Label;
  33. typedef double EnergyVal; /* The total energy of a labeling */
  34. typedef double CostVal; /* costs of individual terms of the energy */
  35. EnergyVal totalEnergy(); /* returns energy of current labeling */
  36. virtual EnergyVal dataEnergy() = 0; /* returns the data part of the energy */
  37. virtual EnergyVal smoothnessEnergy() = 0; /* returns the smoothness part of the energy */
  38. //Functional representation for data costs
  39. typedef CostVal (*DataCostFn)(int pix, Label l);
  40. // Functional representation for the general cost function type
  41. typedef CostVal (*SmoothCostGeneralFn)(int pix1, int pix2, Label l1, Label l2);
  42. // For general smoothness functions, some implementations try to cache all function values in an array
  43. // for efficiency. To prevent this, call the following function before calling initialize():
  44. void dontCacheSmoothnessCosts() {m_allocateArrayForSmoothnessCostFn = false;}
  45. // Use this function only for non-grid graphs. Sets pix1 and pix2 to be neighbors
  46. // with the specified weight. Can be called ONLY once for each pair of pixels
  47. // That is if pixel1 and pixel2 are neihbors, call either setNeighbors(pixel1,pixel2,weight)
  48. // or setNeighbors(pixel2,pixel1,weight), but NOT BOTH
  49. virtual void setNeighbors(int pix1, int pix2, CostVal weight)= 0;
  50. void initialize();
  51. // Runs optimization for nIterations. Input parameter time returns the time it took to
  52. // perform nIterations of optimization
  53. void optimize(int nIterations, float& time);
  54. virtual void optimizeAlg(int nIterations)=0;
  55. // *********** ACCESS TO SOLUTION
  56. // Returns pointer to array of size nPixels. Client may then read/write solution (but not deallocate array).
  57. virtual Label* getAnswerPtr()= 0;
  58. // returns the label of the input pixel
  59. virtual Label getLabel(int pixel)= 0;
  60. // sets label of a pixel
  61. virtual void setLabel(int pixel,Label label)= 0;
  62. // sets all the labels to zero
  63. virtual void clearAnswer() = 0;
  64. // use this function to pass any parameters to optimization algorithm.
  65. // The first argument is the number of passed, parameters and
  66. // the second argument is the pointer to the array of parameters
  67. virtual void setParameters(int numParam, void *param) = 0;
  68. // This function returns lower bound computed by the algorithm (if any)
  69. // By default, it returns 0.
  70. virtual double lowerBound(){return((double) 0);};
  71. // Returns 0 if the energy is not suitable for current optimization algorithm
  72. // Returns 1 if the energy is suitable for current optimization algorithm
  73. // Returns 2 if current optimizaiton algorithm does not check the energy
  74. virtual char checkEnergy();
  75. typedef enum
  76. {
  77. FUNCTION,
  78. ARRAY,
  79. THREE_PARAM,
  80. NONE
  81. } InputType;
  82. protected:
  83. int m_width, m_height; // width and height of a grid,if graph is a grid
  84. int m_nPixels; // number of pixels, for both grid and non-grid graphs
  85. int m_nLabels; // number of labels, for both grid and non-grid graphs
  86. bool m_grid_graph; // true if the graph is a 2D grid
  87. bool m_varWeights; // true if weights are spatially varying. To be used only with 2D grids
  88. bool m_initialized; // true if array m_V is allocated memory.
  89. EnergyFunction *m_e;
  90. InputType m_dataType;
  91. InputType m_smoothType;
  92. // *********** SET THE DATA COSTS
  93. // Following 2 functions set the data costs
  94. virtual void setData(DataCostFn dcost)=0;
  95. virtual void setData(CostVal* data)=0;
  96. // *********** SET THE SMOOTHNESS COSTS
  97. // following 3 functions set the smoothness costs
  98. // there are 2 ways to represent the smoothness costs, one with array, one with function
  99. // In addition, for 2D grid graphs spacially varying weights can be specified by 2 arrays
  100. // Smoothness cost depends on labels V(l1,l2) for all edges (except for a multiplier - see setCues() ).
  101. // V must be symmetric: V(l1,l2) = V(l2,l1)
  102. // V must be an array of size nLabels*nLabels. It is NOT copied into internal memory
  103. virtual void setSmoothness(CostVal* V)=0;
  104. // General smoothness cost can be specified by passing pointer to a function
  105. virtual void setSmoothness(SmoothCostGeneralFn cost)=0;
  106. // To prevent implementations from caching all general smoothness costs values, the flag below
  107. // can be set to false by calling dontCacheSmoothnessCosts() before calling initialize():
  108. bool m_allocateArrayForSmoothnessCostFn;
  109. // Use if the smoothness is V(l1,l2) = lambda * min ( |l1-l2|^m_smoothExp, m_smoothMax )
  110. // Can also add optional spatially varying weights for 2D grid graphs using setCues()
  111. virtual void setSmoothness(int smoothExp,CostVal smoothMax, CostVal lambda)=0;
  112. // You are not required to call setCues, in which case there is no multiplier.
  113. // Function below cannot be called for general cost function.
  114. // This function can be only used for a 2D grid graph
  115. // hCue and vCue must be arrays of size width*height in row major order.
  116. // They are NOT copied into internal memory.
  117. // hCue(x,y) holds the variable weight for edge between pixels (x+1,y) and (x,y)
  118. // vCue(x,y) holds the variable weight for edge between pixels (x,y+1) and (x,y)
  119. virtual void setCues(CostVal* hCue, CostVal* vCue)=0;
  120. virtual void initializeAlg()=0; // called by initialize()
  121. void commonInitialization(EnergyFunction *e);
  122. void checkArray(CostVal *V);
  123. };
  124. // *********** This class is for data costs
  125. // Data costs can be specified eithe by an array or by a pointer to a function
  126. // If specified by an array, use constructor DataCost(cost) where
  127. // cost is the array of type CostVal. The cost of pixel p and label l is
  128. // stored at cost[p*nLabels+l] where nLabels is the number of labels
  129. // If data costs are to be specified by a function, pass
  130. // a pointer to a function
  131. // CostVal costFn(int pix, Label lab)
  132. // which returns the
  133. // data cost of pixel pix to be assigned label lab
  134. class DataCost
  135. {
  136. friend class MRF;
  137. public:
  138. typedef MRF::CostVal CostVal;
  139. typedef MRF::DataCostFn DataCostFn;
  140. DataCost(CostVal *cost){m_costArray = cost;m_type = MRF::ARRAY; };
  141. DataCost(DataCostFn costFn){m_costFn = costFn;m_type = MRF::FUNCTION;};
  142. private:
  143. MRF::CostVal *m_costArray;
  144. MRF::DataCostFn m_costFn;
  145. MRF::InputType m_type;
  146. };
  147. // ***************** This class represents smoothness costs
  148. // If the smoothness is V(l1,l2) = lambda * min ( |l1-l2|^m_smoothExp, m_smoothMax )
  149. // use constructor SmoothnessCost(smoothExp,smoothMax,lambda)
  150. // If, in addition, there are spacially varying weights use constructor
  151. // SmoothnessCost(smoothExp,smoothMax,lambda,hWeights,vWeights)
  152. // hWeights and vWeights can be only used for a 2D grid graph
  153. // hWeights and vWeights must be arrays of size width*height in row major order.
  154. // They are NOT copied into internal memory.
  155. // hWeights(x,y) holds the variable weight for edge between pixels (x+1,y) and (x,y)
  156. // vWeights(x,y) holds the variable weight for edge between pixels (x,y+1) and (x,y)
  157. // If the smoothness costs are specified by input array V of type CostVal and
  158. // size nLabels*nLabels, use consructor SmoothnessCost(V).
  159. // If in addition, there are
  160. // are spacially varying weights use constructor SmoothnessCost(V,hWeights,vWeights)
  161. // Note that array V must be of size nLabels*nLabels, and be symmetric.
  162. // That is V[i*nLabels+j] = V[j*nLabels+i]
  163. // Finally, if the smoothness term is specified by a general function, use
  164. // constructor SmoothnessCost(costFn)
  165. class SmoothnessCost
  166. {
  167. friend class MRF;
  168. public:
  169. typedef MRF::CostVal CostVal;
  170. // Can be used for 2D grids and for general graphs
  171. // In case if used for 2D grids, the smoothness term WILL NOT be spacially varying
  172. SmoothnessCost(int smoothExp,CostVal smoothMax,CostVal lambda)
  173. {m_type=MRF::THREE_PARAM;m_smoothMax = smoothMax;m_smoothExp = smoothExp;m_lambda=lambda;m_varWeights=false;};
  174. // Can be used only for 2D grids
  175. // the smoothness term WILL BE be spacially varying
  176. SmoothnessCost(int smoothExp,CostVal smoothMax,CostVal lambda,CostVal *hWeights, CostVal *vWeights)
  177. {m_type=MRF::THREE_PARAM;m_smoothMax = smoothMax;m_smoothExp = smoothExp;m_lambda=lambda;
  178. m_varWeights = true;m_hWeights = hWeights; m_vWeights = vWeights;};
  179. // Can be used 2D grids and for general graphs
  180. // In case if used for 2D grids, the smoothness term WILL NOT be spacially varying
  181. SmoothnessCost(CostVal *V){m_V = V;m_type = MRF::ARRAY;m_varWeights=false;};
  182. // Can be used only for 2D grids
  183. // the smoothness term WILL BE be spacially varying
  184. SmoothnessCost(CostVal *V,CostVal *hWeights, CostVal *vWeights )
  185. {m_V = V;m_hWeights = hWeights; m_vWeights = vWeights; m_varWeights = true; m_type=MRF::ARRAY;};
  186. // Can be used 2D grids and for general graphs
  187. SmoothnessCost(MRF::SmoothCostGeneralFn costFn){m_costFn = costFn;m_type = MRF::FUNCTION;m_varWeights=false;};
  188. private:
  189. CostVal *m_V,*m_hWeights, *m_vWeights;
  190. MRF::SmoothCostGeneralFn m_costFn;
  191. MRF::InputType m_type;
  192. int m_smoothExp;
  193. CostVal m_smoothMax,m_lambda;
  194. bool m_varWeights;
  195. EnergyFunction *m_eng;
  196. };
  197. class EnergyFunction
  198. {
  199. public:
  200. EnergyFunction(DataCost *dataCost,SmoothnessCost *smoothCost)
  201. {m_dataCost = dataCost;m_smoothCost = smoothCost;};
  202. DataCost *m_dataCost;
  203. SmoothnessCost *m_smoothCost;
  204. };
  205. }
  206. #endif /* __MRF_H__ */
  207. /*
  208. virtual EnergyVal dataEnergy() = 0;
  209. virtual EnergyVal smoothnessEnergy() = 0;
  210. virtual void setNeighbors(int pix1, int pix2, CostVal weight)= 0;
  211. virtual void optimizeAlg(int nIterations)=0;
  212. virtual Label* getAnswerPtr()= 0;
  213. virtual Label getLabel(int pixel)= 0;
  214. virtual void setLabel(int pixel,Label label)= 0;
  215. virtual void clearAnswer() = 0;
  216. virtual void setParameters(int numParam, void *param) = 0;
  217. virtual void setData(DataCostFn dcost)=0;
  218. virtual void setData(CostVal* data)=0;
  219. virtual void setSmoothness(CostVal* V)=0;
  220. virtual void setSmoothness(SmoothCostGeneralFn cost)=0;
  221. virtual void setCues(CostVal* hCue, CostVal* vCue)=0;
  222. virtual void setSmoothness(int smoothExp,CostVal smoothMax, CostVal lambda);
  223. virtual EnergyVal lowerBound(){return((EnergyVal) 0);};
  224. */