BP-S.h 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. #ifndef __BPS_H__
  2. #define __BPS_H__
  3. #include <stdio.h>
  4. #include <stdlib.h>
  5. #include <string.h>
  6. #include <assert.h>
  7. #include "mrf.h"
  8. namespace OBJREC {
  9. class BPS : public MRF {
  10. public:
  11. typedef CostVal REAL;
  12. BPS(int width, int height, int nLabels, EnergyFunction *eng);
  13. BPS(int nPixels, int nLabels, EnergyFunction *eng);
  14. ~BPS();
  15. void setNeighbors(int /*pix1*/, int /*pix2*/, CostVal /*weight*/) {
  16. printf("Not implemented");
  17. exit(1);
  18. }
  19. Label getLabel(int pixel) {
  20. return(m_answer[pixel]);
  21. };
  22. void setLabel(int pixel, Label label) {
  23. m_answer[pixel] = label;
  24. };
  25. Label* getAnswerPtr() {
  26. return(m_answer);
  27. };
  28. void clearAnswer();
  29. void setParameters(int /*numParam*/, void * /*param*/) {
  30. printf("No optional parameters to set");
  31. exit(1);
  32. }
  33. EnergyVal smoothnessEnergy();
  34. EnergyVal dataEnergy();
  35. protected:
  36. void setData(DataCostFn dcost);
  37. void setData(CostVal* data);
  38. void setSmoothness(SmoothCostGeneralFn cost);
  39. void setSmoothness(CostVal* V);
  40. void setSmoothness(int smoothExp, CostVal smoothMax, CostVal lambda);
  41. void setCues(CostVal* hCue, CostVal* vCue);
  42. void Allocate();
  43. void initializeAlg();
  44. void optimizeAlg(int nIterations);
  45. private:
  46. enum
  47. {
  48. NONE,
  49. L1,
  50. L2,
  51. FIXED_MATRIX,
  52. GENERAL,
  53. BINARY,
  54. } m_type;
  55. CostVal m_smoothMax; // used only if
  56. CostVal m_lambda; // m_type == L1 or m_type == L2
  57. Label *m_answer;
  58. CostVal *m_V; // points to array of size nLabels^2 (if type==FIXED_MATRIX) or of size nEdges*nLabels^2 (if type==GENERAL)
  59. CostVal *m_D;
  60. CostVal *m_DBinary; // valid if type == BINARY
  61. CostVal *m_horzWeights;
  62. CostVal *m_vertWeights;
  63. CostVal *m_horzWeightsBinary;
  64. CostVal *m_vertWeightsBinary;
  65. DataCostFn m_dataFn;
  66. SmoothCostGeneralFn m_smoothFn;
  67. bool m_needToFreeV;
  68. bool m_needToFreeD;
  69. REAL* m_messages; // size of one message: N = 1 if m_type == BINARY, N = K otherwise
  70. // message between edges (x,y)-(x+1,y): m_messages+(2*x+2*y*m_width)*N
  71. // message between edges (x,y)-(x,y+1): m_messages+(2*x+2*y*m_width+1)*N
  72. int m_messageArraySizeInBytes;
  73. void optimize_GRID_L1(int nIterations);
  74. void optimize_GRID_L2(int nIterations);
  75. void optimize_GRID_FIXED_MATRIX(int nIterations);
  76. void optimize_GRID_GENERAL(int nIterations);
  77. void optimize_GRID_BINARY(int nIterations);
  78. };
  79. }
  80. #endif /* __BPS_H__ */