ICM.cpp 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. #include <stdio.h>
  2. #include <stdlib.h>
  3. #include <string.h>
  4. #include <assert.h>
  5. #include "ICM.h"
  6. using namespace OBJREC;
  7. #define m_D(pix,l) m_D[(pix)*m_nLabels+(l)]
  8. #define m_V(l1,l2) m_V[(l1)*m_nLabels+(l2)]
  9. ICM::ICM(int width, int height, int nLabels,EnergyFunction *eng):MRF(width,height,nLabels,eng)
  10. {
  11. m_needToFreeV = 0;
  12. initializeAlg();
  13. }
  14. ICM::ICM(int nPixels, int nLabels,EnergyFunction *eng):MRF(nPixels,nLabels,eng)
  15. {
  16. m_needToFreeV = 0;
  17. initializeAlg();
  18. }
  19. ICM::~ICM()
  20. {
  21. delete[] m_answer;
  22. if (!m_grid_graph) delete[] m_neighbors;
  23. if ( m_needToFreeV ) delete[] m_V;
  24. }
  25. void ICM::initializeAlg()
  26. {
  27. m_answer = (Label *) new Label[m_nPixels];
  28. if ( !m_answer ){printf("\nNot enough memory, exiting");exit(0);}
  29. if (!m_grid_graph)
  30. {
  31. m_neighbors = (LinkedBlockList *) new LinkedBlockList[m_nPixels];
  32. if (!m_neighbors) {printf("Not enough memory,exiting");exit(0);};
  33. }
  34. }
  35. void ICM::clearAnswer()
  36. {
  37. memset(m_answer, 0, m_nPixels*sizeof(Label));
  38. }
  39. void ICM::setNeighbors(int pixel1, int pixel2, CostVal weight)
  40. {
  41. assert(!m_grid_graph);
  42. assert(pixel1 < m_nPixels && pixel1 >= 0 && pixel2 < m_nPixels && pixel2 >= 0);
  43. Neighbor *temp1 = (Neighbor *) new Neighbor;
  44. Neighbor *temp2 = (Neighbor *) new Neighbor;
  45. if ( !temp1 || ! temp2 ) {printf("\nNot enough memory, exiting");exit(0);}
  46. temp1->weight = weight;
  47. temp1->to_node = pixel2;
  48. temp2->weight = weight;
  49. temp2->to_node = pixel1;
  50. m_neighbors[pixel1].addFront(temp1);
  51. m_neighbors[pixel2].addFront(temp2);
  52. }
  53. MRF::EnergyVal ICM::smoothnessEnergy()
  54. {
  55. EnergyVal eng = (EnergyVal) 0;
  56. EnergyVal weight;
  57. int x,y,pix,i;
  58. if ( m_grid_graph )
  59. {
  60. if ( m_smoothType != FUNCTION )
  61. {
  62. for ( y = 0; y < m_height; y++ )
  63. for ( x = 1; x < m_width; x++ )
  64. {
  65. pix = x+y*m_width;
  66. weight = m_varWeights ? m_horizWeights[pix-1] : 1;
  67. eng = eng + m_V(m_answer[pix],m_answer[pix-1])*weight;
  68. }
  69. for ( y = 1; y < m_height; y++ )
  70. for ( x = 0; x < m_width; x++ )
  71. {
  72. pix = x+y*m_width;
  73. weight = m_varWeights ? m_vertWeights[pix-m_width] : 1;
  74. eng = eng + m_V(m_answer[pix],m_answer[pix-m_width])*weight;
  75. }
  76. }
  77. else
  78. {
  79. for ( y = 0; y < m_height; y++ )
  80. for ( x = 1; x < m_width; x++ )
  81. {
  82. pix = x+y*m_width;
  83. eng = eng + m_smoothFn(pix,pix-1,m_answer[pix],m_answer[pix-1]);
  84. }
  85. for ( y = 1; y < m_height; y++ )
  86. for ( x = 0; x < m_width; x++ )
  87. {
  88. pix = x+y*m_width;
  89. eng = eng + m_smoothFn(pix,pix-m_width,m_answer[pix],m_answer[pix-m_width]);
  90. }
  91. }
  92. }
  93. else
  94. {
  95. Neighbor *temp;
  96. if ( m_smoothType != FUNCTION )
  97. {
  98. for ( i = 0; i < m_nPixels; i++ )
  99. if ( !m_neighbors[i].isEmpty() )
  100. {
  101. m_neighbors[i].setCursorFront();
  102. while ( m_neighbors[i].hasNext() )
  103. {
  104. temp = (Neighbor *) m_neighbors[i].next();
  105. if ( i < temp->to_node )
  106. eng = eng + m_V(m_answer[i],m_answer[temp->to_node])*(temp->weight);
  107. }
  108. }
  109. }
  110. else
  111. {
  112. for ( i = 0; i < m_nPixels; i++ )
  113. if ( !m_neighbors[i].isEmpty() )
  114. {
  115. m_neighbors[i].setCursorFront();
  116. while ( m_neighbors[i].hasNext() )
  117. {
  118. temp = (Neighbor *) m_neighbors[i].next();
  119. if ( i < temp->to_node )
  120. eng = eng + m_smoothFn(i,temp->to_node, m_answer[i],m_answer[temp->to_node]);
  121. }
  122. }
  123. }
  124. }
  125. return(eng);
  126. }
  127. MRF::EnergyVal ICM::dataEnergy()
  128. {
  129. EnergyVal eng = (EnergyVal) 0;
  130. if ( m_dataType == ARRAY)
  131. {
  132. for ( int i = 0; i < m_nPixels; i++ )
  133. eng = eng + m_D(i,m_answer[i]);
  134. }
  135. else
  136. {
  137. for ( int i = 0; i < m_nPixels; i++ )
  138. eng = eng + m_dataFn(i,m_answer[i]);
  139. }
  140. return(eng);
  141. }
  142. void ICM::setData(DataCostFn dcost)
  143. {
  144. m_dataFn = dcost;
  145. }
  146. void ICM::setData(CostVal* data)
  147. {
  148. m_D = data;
  149. }
  150. void ICM::setSmoothness(SmoothCostGeneralFn cost)
  151. {
  152. m_smoothFn = cost;
  153. }
  154. void ICM::setSmoothness(CostVal* V)
  155. {
  156. m_V = V;
  157. }
  158. void ICM::setSmoothness(int smoothExp,CostVal smoothMax, CostVal lambda)
  159. {
  160. int i, j;
  161. CostVal cost;
  162. m_needToFreeV = 1;
  163. m_V = (CostVal *) new CostVal[m_nLabels*m_nLabels*sizeof(CostVal)];
  164. if (!m_V) { fprintf(stderr, "Not enough memory!\n"); exit(1); }
  165. for (i=0; i<m_nLabels; i++)
  166. for (j=i; j<m_nLabels; j++)
  167. {
  168. cost = (CostVal) ((smoothExp == 1) ? j - i : (j - i)*(j - i));
  169. if (cost > smoothMax) cost = smoothMax;
  170. m_V[i*m_nLabels + j] = m_V[j*m_nLabels + i] = cost*lambda;
  171. }
  172. }
  173. void ICM::setCues(CostVal* hCue, CostVal* vCue)
  174. {
  175. m_horizWeights = hCue;
  176. m_vertWeights = vCue;
  177. }
  178. void ICM::optimizeAlg(int nIterations)
  179. {
  180. int x, y, i, j, n;
  181. Label* l;
  182. CostVal* dataPtr;
  183. CostVal *D = (CostVal *) new CostVal[m_nLabels];
  184. if ( !D ) {printf("\nNot enough memory, exiting");exit(0);}
  185. if ( !m_grid_graph) {printf("\nICM is not implemented for nongrids yet!");exit(1);}
  186. for ( ; nIterations > 0; nIterations --)
  187. {
  188. n = 0;
  189. l = m_answer;
  190. dataPtr = m_D;
  191. for (y=0; y<m_height; y++)
  192. for (x=0; x<m_width; x++, l++, dataPtr+=m_nLabels, n++)
  193. {
  194. // set array D
  195. if (m_dataType == FUNCTION)
  196. {
  197. for (i=0; i<m_nLabels; i++)
  198. {
  199. D[i] = m_dataFn(x+y*m_width, i);
  200. }
  201. }
  202. else memcpy(D, dataPtr, m_nLabels*sizeof(CostVal));
  203. // add smoothness costs
  204. if (m_smoothType == FUNCTION)
  205. {
  206. if (x > 0)
  207. {
  208. j = *(l-1);
  209. for (i=0; i<m_nLabels; i++) D[i] += m_smoothFn(x+y*m_width-1, x+y*m_width, j, i);
  210. }
  211. if (y > 0)
  212. {
  213. j = *(l-m_width);
  214. for (i=0; i<m_nLabels; i++) D[i] += m_smoothFn(x+y*m_width-m_width,x+y*m_width , j, i);
  215. }
  216. if (x < m_width-1)
  217. {
  218. j = *(l+1);
  219. for (i=0; i<m_nLabels; i++) D[i] += m_smoothFn(x+y*m_width+1, x+y*m_width, i, j);
  220. }
  221. if (y < m_height-1)
  222. {
  223. j = *(l+m_width);
  224. for (i=0; i<m_nLabels; i++) D[i] += m_smoothFn(x+y*m_width+m_width, x+y*m_width, i, j);
  225. }
  226. }
  227. else
  228. {
  229. if (x > 0)
  230. {
  231. j = *(l-1);
  232. CostVal lambda = (m_varWeights) ? m_horizWeights[n-1] : 1;
  233. for (i=0; i<m_nLabels; i++) D[i] += lambda * m_V[j*m_nLabels + i];
  234. }
  235. if (y > 0)
  236. {
  237. j = *(l-m_width);
  238. CostVal lambda = (m_varWeights) ? m_vertWeights[n-m_width] : 1;
  239. for (i=0; i<m_nLabels; i++) D[i] += lambda * m_V[j*m_nLabels + i];
  240. }
  241. if (x < m_width-1)
  242. {
  243. j = *(l+1);
  244. CostVal lambda = (m_varWeights) ? m_horizWeights[n] : 1;
  245. for (i=0; i<m_nLabels; i++) D[i] += lambda * m_V[j*m_nLabels + i];
  246. }
  247. if (y < m_height-1)
  248. {
  249. j = *(l+m_width);
  250. CostVal lambda = (m_varWeights) ? m_vertWeights[n] : 1;
  251. for (i=0; i<m_nLabels; i++) D[i] += lambda * m_V[j*m_nLabels + i];
  252. }
  253. }
  254. // compute minimum of D, set new label for (x,y)
  255. CostVal D_min = D[0];
  256. *l = 0;
  257. for (i=1; i<m_nLabels; i++)
  258. {
  259. if (D_min > D[i])
  260. {
  261. D_min = D[i];
  262. *l = i;
  263. }
  264. }
  265. }
  266. }
  267. delete[] D;
  268. }