MaxProdBP.cpp 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. #include <stdio.h>
  2. #include <stdlib.h>
  3. #include <string.h>
  4. #include <assert.h>
  5. #include <math.h>
  6. #include "MaxProdBP.h"
  7. #include "regions-new.h"
  8. using namespace OBJREC;
  9. #define m_D(pix,l) m_D[(pix)*m_nLabels+(l)]
  10. #define m_V(l1,l2) m_V[(l1)*m_nLabels+(l2)]
  11. MaxProdBP::MaxProdBP(int width, int height, int nLabels, EnergyFunction *eng): MRF(width, height, nLabels, eng)
  12. {
  13. m_needToFreeV = 0;
  14. BPinitializeAlg();
  15. }
  16. MaxProdBP::MaxProdBP(int nPixels, int nLabels, EnergyFunction *eng): MRF(nPixels, nLabels, eng)
  17. {
  18. m_needToFreeV = 0;
  19. BPinitializeAlg();
  20. }
  21. MaxProdBP::~MaxProdBP()
  22. {
  23. delete[] m_answer;
  24. if (m_message_chunk) delete[] m_message_chunk;
  25. if (!m_grid_graph) delete[] m_neighbors;
  26. if ( m_needToFreeV ) delete[] m_V;
  27. }
  28. void MaxProdBP::initializeAlg()
  29. {
  30. }
  31. void MaxProdBP::BPinitializeAlg()
  32. {
  33. m_answer = (Label *) new Label[m_nPixels];
  34. if ( !m_answer ) {
  35. printf("\nNot enough memory, exiting");
  36. exit(0);
  37. }
  38. m_scratchMatrix = new FLOATTYPE[m_nLabels * m_nLabels];
  39. // MEMORY LEAK? where does this ever get deleted??
  40. nodeArray = new OneNodeCluster[m_nPixels];
  41. // MEMORY LEAK? where does this ever get deleted??
  42. OneNodeCluster::numStates = m_nLabels;
  43. if (!m_grid_graph)
  44. {
  45. assert(0);
  46. // Only Grid Graphs are supported
  47. m_neighbors = (LinkedBlockList *) new LinkedBlockList[m_nPixels];
  48. if (!m_neighbors) {
  49. printf("Not enough memory,exiting");
  50. exit(0);
  51. };
  52. }
  53. else
  54. {
  55. // const int clen = 4*m_nPixels * m_nLabels + 4*m_nPixels * m_nLabels * m_nLabels;
  56. const int clen = 4 * m_nPixels * m_nLabels ;
  57. //printf("clen:%d\n",clen/1024/1024);
  58. m_message_chunk = (FloatType *) new FloatType[clen];
  59. if ( !m_message_chunk ) {
  60. printf("\nNot enough memory for messages, exiting");
  61. exit(0);
  62. }
  63. for (int i = 0; i < clen; i++)
  64. m_message_chunk[i] = 0;
  65. initOneNodeMsgMem(nodeArray, m_message_chunk, m_nPixels, m_nLabels);
  66. }
  67. }
  68. MRF::InputType MaxProdBP::getSmoothType()
  69. {
  70. return m_smoothType;
  71. }
  72. EnergyFunction *MaxProdBP::getEnergyFunction()
  73. {
  74. return m_e;
  75. }
  76. void MaxProdBP::setExpScale(int expScale)
  77. {
  78. m_exp_scale = (float)expScale;
  79. }
  80. int MaxProdBP::getNLabels()
  81. {
  82. return m_nLabels;
  83. }
  84. int MaxProdBP::getWidth()
  85. {
  86. return m_width;
  87. }
  88. int MaxProdBP::getHeight()
  89. {
  90. return m_height;
  91. }
  92. FLOATTYPE MaxProdBP::getExpV(int i)
  93. {
  94. return m_ExpData[i];
  95. }
  96. FLOATTYPE *MaxProdBP::getExpV()
  97. {
  98. return m_ExpData;
  99. }
  100. MRF::CostVal MaxProdBP::getHorizWeight(int r, int c)
  101. {
  102. int x = c;
  103. int y = r;
  104. int pix = x + y * m_width;
  105. return m_varWeights ? m_horizWeights[pix] : 1;
  106. }
  107. MRF::CostVal MaxProdBP::getVertWeight(int r, int c)
  108. {
  109. int x = c;
  110. int y = r;
  111. int pix = x + y * m_width;
  112. return m_varWeights ? m_vertWeights[pix] : 1;
  113. }
  114. bool MaxProdBP::varWeights()
  115. {
  116. return m_varWeights;
  117. }
  118. FLOATTYPE *MaxProdBP::getScratchMatrix()
  119. {
  120. return m_scratchMatrix;
  121. }
  122. void MaxProdBP::clearAnswer()
  123. {
  124. memset(m_answer, 0, m_nPixels*sizeof(Label));
  125. }
  126. void MaxProdBP::setNeighbors(int pixel1, int pixel2, CostVal weight)
  127. {
  128. assert(0);
  129. //Only Grid Graphs are supported
  130. assert(!m_grid_graph);
  131. assert(pixel1 < m_nPixels && pixel1 >= 0 && pixel2 < m_nPixels && pixel2 >= 0);
  132. Neighbor *temp1 = (Neighbor *) new Neighbor;
  133. Neighbor *temp2 = (Neighbor *) new Neighbor;
  134. if ( !temp1 || ! temp2 ) {
  135. printf("\nNot enough memory, exiting");
  136. exit(0);
  137. }
  138. temp1->weight = weight;
  139. temp1->to_node = pixel2;
  140. temp2->weight = weight;
  141. temp2->to_node = pixel1;
  142. m_neighbors[pixel1].addFront(temp1);
  143. m_neighbors[pixel2].addFront(temp2);
  144. }
  145. MRF::EnergyVal MaxProdBP::smoothnessEnergy()
  146. {
  147. EnergyVal eng = (EnergyVal) 0;
  148. EnergyVal weight;
  149. int x, y, pix;
  150. if ( m_smoothType != FUNCTION )
  151. {
  152. for ( y = 0; y < m_height; y++ )
  153. for ( x = 1; x < m_width; x++ )
  154. {
  155. pix = x + y * m_width;
  156. weight = m_varWeights ? m_horizWeights[pix-1] : 1;
  157. eng = eng + m_V(m_answer[pix], m_answer[pix-1]) * weight;
  158. }
  159. for ( y = 1; y < m_height; y++ )
  160. for ( x = 0; x < m_width; x++ )
  161. {
  162. pix = x + y * m_width;
  163. weight = m_varWeights ? m_vertWeights[pix-m_width] : 1;
  164. eng = eng + m_V(m_answer[pix], m_answer[pix-m_width]) * weight;
  165. }
  166. }
  167. else
  168. {
  169. for ( y = 0; y < m_height; y++ )
  170. for ( x = 1; x < m_width; x++ )
  171. {
  172. pix = x + y * m_width;
  173. eng = eng + m_smoothFn(pix, pix - 1, m_answer[pix], m_answer[pix-1]);
  174. }
  175. for ( y = 1; y < m_height; y++ )
  176. for ( x = 0; x < m_width; x++ )
  177. {
  178. pix = x + y * m_width;
  179. eng = eng + m_smoothFn(pix, pix - m_width, m_answer[pix], m_answer[pix-m_width]);
  180. }
  181. }
  182. return(eng);
  183. }
  184. MRF::EnergyVal MaxProdBP::dataEnergy()
  185. {
  186. EnergyVal eng = (EnergyVal) 0;
  187. if ( m_dataType == ARRAY)
  188. {
  189. for ( int i = 0; i < m_nPixels; i++ )
  190. eng = eng + m_D(i, m_answer[i]);
  191. }
  192. else
  193. {
  194. for ( int i = 0; i < m_nPixels; i++ )
  195. eng = eng + m_dataFn(i, m_answer[i]);
  196. }
  197. return(eng);
  198. }
  199. void MaxProdBP::setData(DataCostFn dcost)
  200. {
  201. m_dataFn = dcost;
  202. int i;
  203. int j;
  204. m_ExpData = new FloatType[m_nPixels * m_nLabels];
  205. // MEMORY LEAK? where does this ever get deleted??
  206. if (!m_ExpData)
  207. {
  208. exit(0);
  209. }
  210. m_exp_scale = 1;//FLOATTYPE(cmax)*4.0;
  211. FloatType *cData = m_ExpData;
  212. for ( i = 0; i < m_nPixels; i++)
  213. {
  214. nodeArray[i].localEv = cData;
  215. for ( j = 0; j < m_nLabels; j++)
  216. {
  217. *cData = (float)m_dataFn(i, j);
  218. cData++;
  219. }
  220. }
  221. }
  222. void MaxProdBP::setData(CostVal* data)
  223. {
  224. int i;
  225. int j;
  226. m_D = data;
  227. m_ExpData = new FloatType[m_nPixels * m_nLabels];
  228. // MEMORY LEAK? where does this ever get deleted??
  229. if (!m_ExpData)
  230. {
  231. exit(0);
  232. }
  233. m_exp_scale = 1;//FLOATTYPE(cmax)*4.0;
  234. FloatType *cData = m_ExpData;
  235. for ( i = 0; i < m_nPixels; i++)
  236. {
  237. nodeArray[i].localEv = cData;
  238. for ( j = 0; j < m_nLabels; j++)
  239. {
  240. *cData = (float)m_D(i, j);
  241. cData++;
  242. }
  243. }
  244. }
  245. void MaxProdBP::setSmoothness(SmoothCostGeneralFn cost)
  246. {
  247. m_smoothFn = cost;
  248. }
  249. void MaxProdBP::setSmoothness(CostVal* V)
  250. {
  251. m_type = FIXED_MATRIX;
  252. m_V = V;
  253. }
  254. void MaxProdBP::setSmoothness(int smoothExp, CostVal smoothMax, CostVal lambda)
  255. {
  256. int i, j;
  257. CostVal cost;
  258. m_type = (smoothExp == 1) ? L1 : L2; //borrowed from BP-S.cpp from vnk
  259. m_lambda = lambda;
  260. m_smoothMax = smoothMax;
  261. m_smoothExp = smoothExp;
  262. m_needToFreeV = 1;
  263. m_V = (CostVal *) new CostVal[m_nLabels*m_nLabels*sizeof(CostVal)];
  264. if (!m_V) {
  265. fprintf(stderr, "Not enough memory!\n");
  266. exit(1);
  267. }
  268. for (i = 0; i < m_nLabels; i++)
  269. for (j = i; j < m_nLabels; j++)
  270. {
  271. cost = (MRF::CostVal)((smoothExp == 1) ? j - i : (j - i) * (j - i));
  272. if (cost > smoothMax) cost = smoothMax;
  273. m_V[i*m_nLabels + j] = m_V[j*m_nLabels + i] = cost * lambda;
  274. }
  275. }
  276. void MaxProdBP::setCues(CostVal* hCue, CostVal* vCue)
  277. {
  278. m_horizWeights = hCue;
  279. m_vertWeights = vCue;
  280. }
  281. void MaxProdBP::optimizeAlg(int nIterations)
  282. {
  283. //int x, y, i, j, n;
  284. //Label* l;
  285. //CostVal* dataPtr;
  286. if ( !m_grid_graph) {
  287. printf("\nMaxProdBP is not implemented for nongrids yet!");
  288. exit(1);
  289. }
  290. int numRows = getHeight();
  291. int numCols = getWidth();
  292. const FLOATTYPE alpha = 0.8f;
  293. for (int niter = 0; niter < nIterations; niter++)
  294. {
  295. for (int r = 0; r < numRows; r++)
  296. {
  297. computeMessagesLeftRight(nodeArray, numCols, numRows, r, alpha, this);
  298. }
  299. for (int c = 0; c < numCols; c++)
  300. {
  301. computeMessagesUpDown(nodeArray, numCols, numRows, c, alpha, this);
  302. }
  303. }
  304. Label *currAssign = m_answer;
  305. for (int m = 0; m < numRows; m++)
  306. {
  307. for (int n = 0; n < numCols; n++)
  308. {
  309. int maxInd = nodeArray[m*numCols+n].getBeliefMaxInd();
  310. currAssign[m * numCols +n] = maxInd;
  311. }
  312. }
  313. }