compareGMM.cpp 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. #ifdef NOVISUAL
  2. #warning "testKMeans needs ICE with visualization !!"
  3. int main (int argc, char **argv) {};
  4. #else
  5. #ifndef M_PI
  6. #define M_PI 3.14159265358979323846f
  7. #endif
  8. #include <fstream>
  9. #include <vislearning/nice.h>
  10. #include <iostream>
  11. #include <core/image/CrossT.h>
  12. #include "vislearning/math/cluster/GMM.h"
  13. using namespace std;
  14. using namespace NICE;
  15. using namespace OBJREC;
  16. int main (int argc, char **argv)
  17. {
  18. Config conf ( argc, argv );
  19. std::string params = conf.gS("main", "params");
  20. std::string featfile = conf.gS("main", "feats");
  21. ifstream fin( params.c_str() );
  22. if(fin.bad())
  23. {
  24. cout << "there were errors while opening the file " << params << endl;
  25. }
  26. int dim, gaussians, feats;
  27. double tmp;
  28. fin >> dim;
  29. for(int i = 1; i < dim; i++)
  30. {
  31. fin >> tmp;
  32. }
  33. fin >> gaussians;
  34. for(int i = 1; i < dim; i++)
  35. {
  36. fin >> tmp;
  37. }
  38. fin >> feats;
  39. for(int i = 1; i < dim; i++)
  40. {
  41. fin >> tmp;
  42. }
  43. vector<double> prior1;
  44. VVector mean1(gaussians, dim);
  45. VVector sigma1(gaussians, dim);
  46. for(int i = 0; i < gaussians; i++)
  47. {
  48. fin >> tmp;
  49. prior1.push_back(tmp);
  50. for(int d = 1; d < dim; d++)
  51. {
  52. fin >> tmp;
  53. }
  54. for(int d = 0; d < dim; d++)
  55. {
  56. fin >> mean1[i][d];
  57. }
  58. for(int d1 = 0; d1 < dim; d1++)
  59. {
  60. for(int d2 = 0; d2 < dim; d2++)
  61. {
  62. if(d1 == d2)
  63. fin >> sigma1[i][d1];
  64. else
  65. fin >> tmp;
  66. }
  67. }
  68. }
  69. int featsize;
  70. fin >> featsize;
  71. fin.close();
  72. ifstream fin2( featfile.c_str() );
  73. if(fin2.bad())
  74. {
  75. cout << "there were errors while opening the file " << featfile << endl;
  76. }
  77. VVector x;
  78. VVector prototypes;
  79. std::vector<double> weights;
  80. std::vector<int> assignments;
  81. for(int i = 0; i < featsize; i++)
  82. {
  83. NICE::Vector f(dim);
  84. for(int d = 0; d < dim; d++)
  85. {
  86. fin2 >> f[d];
  87. }
  88. x.push_back(f);
  89. }
  90. GMM clusteralg( &conf, gaussians);
  91. #if 1
  92. //clusteralg->computeMixture ( x );
  93. clusteralg.setCompareGM(mean1, sigma1, prior1);
  94. clusteralg.comparing(true);
  95. clusteralg.cluster ( x, prototypes, weights, assignments );
  96. double dist = clusteralg.compare();
  97. cout << "dist1: " << endl << dist << endl;
  98. int width = 500;
  99. int height = 500;
  100. NICE::Image panel (width, height);
  101. NICE::Image overlay (width, height);
  102. panel.set(255);
  103. overlay.set(0);
  104. width--;
  105. height--;
  106. for ( size_t i = 0 ; i < prototypes.size() ; i++ )
  107. {
  108. // refactor-nice.pl: check this substitution
  109. fprintf (stderr, "prototype %d (%f, %f)\n", (int)i, prototypes[i][0], prototypes[i][1] );
  110. Cross cross ( Coord( (int)(prototypes[i][0]*(double)width), (int)(prototypes[i][1]*(double)height) ), 3 );
  111. overlay.draw ( cross, i+1 );
  112. }
  113. fprintf (stderr, "assignments: %ld\n", assignments.size() );
  114. for ( size_t i = 0 ; i < assignments.size() ; i++ )
  115. {
  116. Cross cross ( Coord( (int)(x[i][0]*(double)width), (int)(x[i][1]*(double)height)) , 3 );
  117. overlay.draw ( cross, assignments[i]+1 );
  118. }
  119. NICE::showImageOverlay ( panel, overlay, "Clustering results" );
  120. #else
  121. double scale = 2.0;
  122. vector<double> prior2 = prior1;
  123. VVector sigma2 = sigma1;
  124. VVector mean2 = mean1;
  125. for(int i = 0; i < sigma2.size(); i++)
  126. {
  127. for(int j = 0; j < sigma2[i].size(); j++)
  128. {
  129. sigma2[i][j] /= scale;
  130. }
  131. }
  132. for(int i = 0; i < mean2.size(); i++)
  133. {
  134. for(int j = 0; j < mean2[i].size(); j++)
  135. {
  136. mean2[i][j] /= scale;
  137. }
  138. }
  139. for(int i = 0; i < prior2.size(); i++)
  140. {
  141. prior2[i] /= scale;
  142. }
  143. double distkii = 0.0;
  144. double distkjj = 0.0;
  145. double distkij = 0.0;
  146. double dist = 0.0;
  147. for(int i = 0; i < gaussians; i++)
  148. {
  149. for(int j = 0; j < gaussians; j++)
  150. {
  151. double kij = clusteralg.kPPK(sigma1[i],sigma2[j], mean1[i], mean2[j], 0.5);
  152. double kii = clusteralg.kPPK(sigma1[i],sigma1[j], mean1[i], mean1[j], 0.5);
  153. double kjj = clusteralg.kPPK(sigma2[i],sigma2[j], mean2[i], mean2[j], 0.5);
  154. kij = prior1[i]*prior2[j]*kij;
  155. kii = prior1[i]*prior1[j]*kii;
  156. kjj = prior2[i]*prior2[j]*kjj;
  157. distkii += kii;
  158. distkjj += kjj;
  159. distkij += kij;
  160. }
  161. }
  162. double dist2 = distkij / (sqrt(distkii)*sqrt(distkjj));
  163. cout << "dist: " << dist << endl;
  164. cout << "dist2: " << dist2 << endl;
  165. for(int k = 0; k < 10; k++)
  166. {
  167. dist = 0.0;
  168. for(int i = 0; i < gaussians; i++)
  169. {
  170. int j = i+k;
  171. if(j >= gaussians)
  172. j -= gaussians;
  173. double kij = clusteralg.kPPK(sigma1[i],sigma2[j], mean1[i], mean2[j], 0.5);
  174. kij = prior1[i]*prior2[j]*kij;
  175. double kii = clusteralg.kPPK(sigma1[i],sigma1[i], mean1[i], mean1[i], 0.5);
  176. kii = prior1[i]*prior1[j]*kii;
  177. double kjj = clusteralg.kPPK(sigma2[j],sigma2[j], mean2[j], mean2[j], 0.5);
  178. kjj = prior2[i]*prior2[j]*kjj;
  179. double val = kij / (sqrt(kii)*sqrt(kjj));
  180. if(kii == 0.0 || kjj == 0.0)
  181. continue;
  182. dist+=val;
  183. //printf ("%4.6f ", val*10.0);
  184. //cout << endl;
  185. }
  186. cout << "dist" << 3+k << " " << dist << endl;
  187. }
  188. #endif
  189. return 0;
  190. }
  191. #endif