compareGMM.cpp 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. #ifdef NOVISUAL
  2. #warning "testKMeans needs ICE with visualization !!"
  3. int main (int argc, char **argv) {};
  4. #else
  5. #ifndef M_PI
  6. #define M_PI 3.14159265358979323846f
  7. #endif
  8. #include <fstream>
  9. #include "core/vector/VectorT.h"
  10. #include "core/vector/MatrixT.h"
  11. #include "core/image/ImageT.h"
  12. #include "core/imagedisplay/ImageDisplay.h"
  13. #include <iostream>
  14. #include <core/image/CrossT.h>
  15. #include "vislearning/math/cluster/GMM.h"
  16. using namespace std;
  17. using namespace NICE;
  18. using namespace OBJREC;
  19. int main (int argc, char **argv)
  20. {
  21. Config conf ( argc, argv );
  22. std::string params = conf.gS("main", "params");
  23. std::string featfile = conf.gS("main", "feats");
  24. ifstream fin( params.c_str() );
  25. if(fin.bad())
  26. {
  27. cout << "there were errors while opening the file " << params << endl;
  28. }
  29. int dim, gaussians, feats;
  30. double tmp;
  31. fin >> dim;
  32. for(int i = 1; i < dim; i++)
  33. {
  34. fin >> tmp;
  35. }
  36. fin >> gaussians;
  37. for(int i = 1; i < dim; i++)
  38. {
  39. fin >> tmp;
  40. }
  41. fin >> feats;
  42. for(int i = 1; i < dim; i++)
  43. {
  44. fin >> tmp;
  45. }
  46. vector<double> prior1;
  47. VVector mean1(gaussians, dim);
  48. VVector sigma1(gaussians, dim);
  49. for(int i = 0; i < gaussians; i++)
  50. {
  51. fin >> tmp;
  52. prior1.push_back(tmp);
  53. for(int d = 1; d < dim; d++)
  54. {
  55. fin >> tmp;
  56. }
  57. for(int d = 0; d < dim; d++)
  58. {
  59. fin >> mean1[i][d];
  60. }
  61. for(int d1 = 0; d1 < dim; d1++)
  62. {
  63. for(int d2 = 0; d2 < dim; d2++)
  64. {
  65. if(d1 == d2)
  66. fin >> sigma1[i][d1];
  67. else
  68. fin >> tmp;
  69. }
  70. }
  71. }
  72. int featsize;
  73. fin >> featsize;
  74. fin.close();
  75. ifstream fin2( featfile.c_str() );
  76. if(fin2.bad())
  77. {
  78. cout << "there were errors while opening the file " << featfile << endl;
  79. }
  80. VVector x;
  81. VVector prototypes;
  82. std::vector<double> weights;
  83. std::vector<int> assignments;
  84. for(int i = 0; i < featsize; i++)
  85. {
  86. NICE::Vector f(dim);
  87. for(int d = 0; d < dim; d++)
  88. {
  89. fin2 >> f[d];
  90. }
  91. x.push_back(f);
  92. }
  93. GMM clusteralg( &conf, gaussians);
  94. #if 1
  95. //clusteralg->computeMixture ( x );
  96. clusteralg.setGMMtoCompareWith(mean1, sigma1, prior1);
  97. clusteralg.setCompareTo2ndGMM(true);
  98. clusteralg.cluster ( x, prototypes, weights, assignments );
  99. double dist = clusteralg.compareTo2ndGMM();
  100. cout << "dist1: " << endl << dist << endl;
  101. int width = 500;
  102. int height = 500;
  103. NICE::Image panel (width, height);
  104. NICE::Image overlay (width, height);
  105. panel.set(255);
  106. overlay.set(0);
  107. width--;
  108. height--;
  109. for ( size_t i = 0 ; i < prototypes.size() ; i++ )
  110. {
  111. // refactor-nice.pl: check this substitution
  112. fprintf (stderr, "prototype %d (%f, %f)\n", (int)i, prototypes[i][0], prototypes[i][1] );
  113. Cross cross ( Coord( (int)(prototypes[i][0]*(double)width), (int)(prototypes[i][1]*(double)height) ), 3 );
  114. overlay.draw ( cross, i+1 );
  115. }
  116. fprintf (stderr, "assignments: %ld\n", assignments.size() );
  117. for ( size_t i = 0 ; i < assignments.size() ; i++ )
  118. {
  119. Cross cross ( Coord( (int)(x[i][0]*(double)width), (int)(x[i][1]*(double)height)) , 3 );
  120. overlay.draw ( cross, assignments[i]+1 );
  121. }
  122. NICE::showImageOverlay ( panel, overlay, "Clustering results" );
  123. #else
  124. double scale = 2.0;
  125. vector<double> prior2 = prior1;
  126. VVector sigma2 = sigma1;
  127. VVector mean2 = mean1;
  128. for(int i = 0; i < sigma2.size(); i++)
  129. {
  130. for(int j = 0; j < sigma2[i].size(); j++)
  131. {
  132. sigma2[i][j] /= scale;
  133. }
  134. }
  135. for(int i = 0; i < mean2.size(); i++)
  136. {
  137. for(int j = 0; j < mean2[i].size(); j++)
  138. {
  139. mean2[i][j] /= scale;
  140. }
  141. }
  142. for(int i = 0; i < prior2.size(); i++)
  143. {
  144. prior2[i] /= scale;
  145. }
  146. double distkii = 0.0;
  147. double distkjj = 0.0;
  148. double distkij = 0.0;
  149. double dist = 0.0;
  150. for(int i = 0; i < gaussians; i++)
  151. {
  152. for(int j = 0; j < gaussians; j++)
  153. {
  154. double kij = clusteralg.kPPK(sigma1[i],sigma2[j], mean1[i], mean2[j], 0.5);
  155. double kii = clusteralg.kPPK(sigma1[i],sigma1[j], mean1[i], mean1[j], 0.5);
  156. double kjj = clusteralg.kPPK(sigma2[i],sigma2[j], mean2[i], mean2[j], 0.5);
  157. kij = prior1[i]*prior2[j]*kij;
  158. kii = prior1[i]*prior1[j]*kii;
  159. kjj = prior2[i]*prior2[j]*kjj;
  160. distkii += kii;
  161. distkjj += kjj;
  162. distkij += kij;
  163. }
  164. }
  165. double dist2 = distkij / (sqrt(distkii)*sqrt(distkjj));
  166. cout << "dist: " << dist << endl;
  167. cout << "dist2: " << dist2 << endl;
  168. for(int k = 0; k < 10; k++)
  169. {
  170. dist = 0.0;
  171. for(int i = 0; i < gaussians; i++)
  172. {
  173. int j = i+k;
  174. if(j >= gaussians)
  175. j -= gaussians;
  176. double kij = clusteralg.kPPK(sigma1[i],sigma2[j], mean1[i], mean2[j], 0.5);
  177. kij = prior1[i]*prior2[j]*kij;
  178. double kii = clusteralg.kPPK(sigma1[i],sigma1[i], mean1[i], mean1[i], 0.5);
  179. kii = prior1[i]*prior1[j]*kii;
  180. double kjj = clusteralg.kPPK(sigma2[j],sigma2[j], mean2[j], mean2[j], 0.5);
  181. kjj = prior2[i]*prior2[j]*kjj;
  182. double val = kij / (sqrt(kii)*sqrt(kjj));
  183. if(kii == 0.0 || kjj == 0.0)
  184. continue;
  185. dist+=val;
  186. //printf ("%4.6f ", val*10.0);
  187. //cout << endl;
  188. }
  189. cout << "dist" << 3+k << " " << dist << endl;
  190. }
  191. #endif
  192. return 0;
  193. }
  194. #endif