TRW-S.h 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. #ifndef __TRWS_H__
  2. #define __TRWS_H__
  3. #include <stdio.h>
  4. #include <stdlib.h>
  5. #include <string.h>
  6. #include <assert.h>
  7. #include "mrf.h"
  8. #undef REAL
  9. namespace OBJREC {
  10. class TRWS : public MRF {
  11. public:
  12. typedef double REAL;
  13. TRWS(int width, int height, int nLabels, EnergyFunction *eng);
  14. TRWS(int nPixels, int nLabels, EnergyFunction *eng);
  15. ~TRWS();
  16. void setNeighbors(int /*pix1*/, int /*pix2*/, CostVal /*weight*/) {
  17. printf("Not implemented");
  18. exit(1);
  19. }
  20. Label getLabel(int pixel) {
  21. return(m_answer[pixel]);
  22. };
  23. void setLabel(int pixel, Label label) {
  24. m_answer[pixel] = label;
  25. };
  26. Label* getAnswerPtr() {
  27. return(m_answer);
  28. };
  29. void clearAnswer();
  30. void setParameters(int /*numParam*/, void * /*param*/) {
  31. printf("No optional parameters to set");
  32. exit(1);
  33. }
  34. EnergyVal smoothnessEnergy();
  35. EnergyVal dataEnergy();
  36. double lowerBound() {
  37. return (double)m_lowerBound;
  38. }
  39. // For general smoothness functions, this code tries to cache all function values in an array
  40. // for efficiency. To prevent this, call the following function before calling initialize():
  41. void dontCacheSmoothnessCosts() {
  42. m_allocateArrayForSmoothnessCostFn = false;
  43. }
  44. protected:
  45. void setData(DataCostFn dcost);
  46. void setData(CostVal* data);
  47. void setSmoothness(SmoothCostGeneralFn cost);
  48. void setSmoothness(CostVal* V);
  49. void setSmoothness(int smoothExp, CostVal smoothMax, CostVal lambda);
  50. void setCues(CostVal* hCue, CostVal* vCue);
  51. void Allocate();
  52. void initializeAlg();
  53. void optimizeAlg(int nIterations);
  54. private:
  55. enum
  56. {
  57. NONE,
  58. L1,
  59. L2,
  60. FIXED_MATRIX,
  61. GENERAL,
  62. BINARY,
  63. } m_type;
  64. CostVal m_smoothMax; // used only if
  65. CostVal m_lambda; // m_type == L1 or m_type == L2
  66. Label *m_answer;
  67. CostVal *m_V; // points to array of size nLabels^2 (if type==FIXED_MATRIX) or of size nEdges*nLabels^2 (if type==GENERAL)
  68. CostVal *m_D;
  69. CostVal *m_DBinary; // valid if type == BINARY
  70. CostVal *m_horzWeights;
  71. CostVal *m_vertWeights;
  72. CostVal *m_horzWeightsBinary;
  73. CostVal *m_vertWeightsBinary;
  74. DataCostFn m_dataFn;
  75. SmoothCostGeneralFn m_smoothFn;
  76. bool m_needToFreeV;
  77. bool m_needToFreeD;
  78. REAL* m_messages; // size of one message: N = 1 if m_type == BINARY, N = K otherwise
  79. // message between edges (x,y)-(x+1,y): m_messages+(2*x+2*y*m_width)*N
  80. // message between edges (x,y)-(x,y+1): m_messages+(2*x+2*y*m_width+1)*N
  81. int m_messageArraySizeInBytes;
  82. REAL m_lowerBound;
  83. void optimize_GRID_L1(int nIterations);
  84. void optimize_GRID_L2(int nIterations);
  85. void optimize_GRID_FIXED_MATRIX(int nIterations);
  86. void optimize_GRID_GENERAL(int nIterations);
  87. void optimize_GRID_BINARY(int nIterations);
  88. };
  89. }
  90. #endif /* __TRWS_H__ */