MaxProdBP.h 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. #ifndef __MAXPRODBP_H__
  2. #define __MAXPRODBP_H__
  3. #include <stdio.h>
  4. #include <stdlib.h>
  5. #include <string.h>
  6. #include <assert.h>
  7. #include "mrf.h"
  8. #include "LinkedBlockList.h"
  9. #include "regions-new.h"
  10. namespace OBJREC {
  11. #define FloatType float
  12. #define FLOATTYPE float
  13. class MaxProdBP;
  14. class MaxProdBP : public MRF {
  15. public:
  16. MaxProdBP(int width, int height, int nLabels, EnergyFunction *eng);
  17. MaxProdBP(int nPixels, int nLabels, EnergyFunction *eng);
  18. ~MaxProdBP();
  19. void setNeighbors(int pix1, int pix2, CostVal weight);
  20. Label getLabel(int pixel) {
  21. return(m_answer[pixel]);
  22. };
  23. void setLabel(int pixel, Label label) {
  24. m_answer[pixel] = label;
  25. };
  26. Label* getAnswerPtr() {
  27. return(m_answer);
  28. };
  29. void clearAnswer();
  30. void setParameters(int , void *) {
  31. printf("No optional parameters to set");
  32. }
  33. EnergyVal smoothnessEnergy();
  34. EnergyVal dataEnergy();
  35. EnergyFunction *getEnergyFunction();
  36. int getWidth();
  37. int getHeight();
  38. FLOATTYPE *getScratchMatrix();
  39. int getNLabels();
  40. bool varWeights();
  41. void setExpScale(int expScale);
  42. friend void getPsiMat(OneNodeCluster &cluster, FLOATTYPE *&destMatrix,
  43. int r, int c, MaxProdBP *mrf, int direction, FLOATTYPE &var_weight);
  44. InputType getSmoothType();
  45. FLOATTYPE getExpV(int i);
  46. FLOATTYPE *getExpV();
  47. CostVal getHorizWeight(int r, int c);
  48. CostVal getVertWeight(int r, int c);
  49. CostVal m_lambda;
  50. CostVal m_smoothMax;
  51. int m_smoothExp;
  52. enum //Borrowed from BP-S.h by vnk
  53. {
  54. NONE,
  55. L1,
  56. L2,
  57. FIXED_MATRIX,
  58. GENERAL,
  59. BINARY,
  60. } m_type;
  61. protected:
  62. void setData(DataCostFn dcost);
  63. void setData(CostVal* data);
  64. void setSmoothness(SmoothCostGeneralFn cost);
  65. void setSmoothness(CostVal* V);
  66. void setSmoothness(int smoothExp, CostVal smoothMax, CostVal lambda);
  67. void setCues(CostVal* hCue, CostVal* vCue);
  68. void initializeAlg();
  69. void BPinitializeAlg();
  70. void optimizeAlg(int nIterations);
  71. private:
  72. Label *m_answer;
  73. CostVal *m_V;
  74. CostVal *m_D;
  75. CostVal *m_horizWeights;
  76. CostVal *m_vertWeights;
  77. FLOATTYPE m_exp_scale;
  78. DataCostFn m_dataFn;
  79. SmoothCostGeneralFn m_smoothFn;
  80. bool m_needToFreeV;
  81. FLOATTYPE *m_scratchMatrix;
  82. FLOATTYPE *m_ExpData;
  83. FLOATTYPE *m_message_chunk;
  84. OneNodeCluster *nodeArray;
  85. typedef struct NeighborStruct {
  86. int to_node;
  87. CostVal weight;
  88. } Neighbor;
  89. LinkedBlockList *m_neighbors;
  90. };
  91. }
  92. #endif /* __ICM_H__ */