testDirichlet.cpp 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. /**
  2. * @file testDirichlet.cpp
  3. * @brief functions testing the MAP estimation routines
  4. * @author Erik Rodner
  5. * @date 05/21/2008
  6. */
  7. #ifdef NOVISUAL
  8. #warning "testDirichlet needs ICE with visualization !!"
  9. int main (int argc, char **argv) {};
  10. #else
  11. #include <vislearning/nice.h>
  12. #include <core/vector/Distance.h>
  13. #include <core/image/CrossT.h>
  14. #include <core/basics/Config.h>
  15. #include <vislearning/baselib/cmdline.h>
  16. #include <vislearning/baselib/Gnuplot.h>
  17. #include <vislearning/baselib/ICETools.h>
  18. #include <vislearning/baselib/Conversions.h>
  19. #include <vislearning/math/pdf/PDFDirichlet.h>
  20. #include <vislearning/optimization/mapestimation/MAPMultinomialGaussianPrior.h>
  21. using namespace OBJREC;
  22. // refactor-nice.pl: check this substitution
  23. // old: using namespace ice;
  24. using namespace NICE;
  25. using namespace std;
  26. // refactor-nice.pl: check this substitution
  27. // old: void plotDistributions ( Image & img, const VVector & samples, int color = 1 )
  28. void plotDistributions ( NICE::Image & img, const VVector & samples, int color = 1 )
  29. {
  30. for ( VVector::const_iterator i = samples.begin();
  31. i != samples.end();
  32. i++ )
  33. {
  34. // refactor-nice.pl: check this substitution
  35. // old: const Vector & x = *i;
  36. const NICE::Vector & x = *i;
  37. // refactor-nice.pl: check this substitution
  38. // old: PutVal ( img, (int)(x[0]*(img->xsize-1)), (int)(x[1]*(img->ysize-1)), color );
  39. img.setPixel((int)(x[0]*(img.width()-1)),(int)(x[1]*(img.height()-1)),color);
  40. }
  41. }
  42. // refactor-nice.pl: check this substitution
  43. // old: void estimateDirichlet_Newton ( Vector & alpha,
  44. void estimateDirichlet_Newton ( NICE::Vector & alpha,
  45. const VVector & samples )
  46. {
  47. }
  48. void simulation ( int samplesCount )
  49. {
  50. const int dimension = 100;
  51. // refactor-nice.pl: check this substitution
  52. // old: Vector alpha (dimension);
  53. NICE::Vector alpha (dimension);
  54. srand48(time(NULL));
  55. double s = 0.0;
  56. for ( int i = 0 ; i < dimension ; i++ )
  57. {
  58. alpha[i] = drand48() * 3.0;
  59. s += alpha[i];
  60. }
  61. cerr << "alpha: " << alpha << endl;
  62. PDFDirichlet pdfdirichlet (alpha);
  63. VVector samples;
  64. pdfdirichlet.sample ( samples, samplesCount );
  65. // refactor-nice.pl: check this substitution
  66. // old: Vector alphaEstimated (alpha.size());
  67. NICE::Vector alphaEstimated (alpha.size());
  68. estimateDirichlet_Newton ( alphaEstimated, samples );
  69. // refactor-nice.pl: check this substitution
  70. // old: Image img ( 400, 400, 255 );
  71. NICE::Image img (400, 400);
  72. // refactor-nice.pl: check this substitution
  73. // old: ClearImg(img);
  74. img.set(0);
  75. if ( dimension == 3 )
  76. {
  77. plotDistributions ( img, samples, 1 );
  78. }
  79. // calculate mean
  80. // refactor-nice.pl: check this substitution
  81. // old: Vector muCGD (dimension);
  82. NICE::Vector muCGD (dimension);
  83. muCGD.set(0);
  84. for ( VVector::const_iterator i = samples.begin();
  85. i != samples.end();
  86. i++ )
  87. muCGD = muCGD + (*i);
  88. muCGD = (1.0/samples.size())*muCGD;
  89. EuclidianDistance<double> euclid;
  90. // refactor-nice.pl: check this substitution
  91. // old: Vector meanDirichlet ( alpha*(1.0/s) );
  92. NICE::Vector meanDirichlet ( alpha*(1.0/s) );
  93. cerr << "mu estimated (cgd): " << muCGD << endl;
  94. cerr << "mean (dirichlet): " << alpha*(1.0/s) << endl;
  95. cerr << "distance: " << euclid(muCGD, meanDirichlet) << endl;
  96. if ( dimension == 3 )
  97. {
  98. Cross cross1 ( Coord( muCGD[0]*(img.width()-1), muCGD[1]*(img.height()-1) ), 10 );
  99. Cross cross2 ( Coord( alpha[0]/s*(img.width()-1), alpha[1]/s*(img.height()-1) ), 10 );
  100. img.draw ( cross1, 2 );
  101. img.draw ( cross2, 3 );
  102. showImageOverlay (img, img );
  103. } else {
  104. getchar();
  105. }
  106. }
  107. int main (int argc, char **argv)
  108. {
  109. std::set_terminate(__gnu_cxx::__verbose_terminate_handler);
  110. char configfile [300];
  111. struct CmdLineOption options[] = {
  112. {"config", "use config file", NULL, "%s", configfile},
  113. {NULL, NULL, NULL, NULL, NULL}
  114. };
  115. int ret;
  116. char *more_options[argc];
  117. ret = parse_arguments( argc, (const char**)argv, options, more_options);
  118. if ( ret != 0 )
  119. {
  120. if ( ret != 1 ) fprintf (stderr, "Error parsing command line !\n");
  121. exit (-1);
  122. }
  123. Config conf ( configfile );
  124. const int samples_count_sim = 100;
  125. simulation ( samples_count_sim );
  126. const int samples_count = 100;
  127. const int dimension = 1000;
  128. // refactor-nice.pl: check this substitution
  129. // old: Vector mu (dimension);
  130. NICE::Vector mu (dimension);
  131. // refactor-nice.pl: check this substitution
  132. // old: Vector mu_noninformative (dimension);
  133. NICE::Vector mu_noninformative (dimension);
  134. // refactor-nice.pl: check this substitution
  135. // old: Vector counts (dimension);
  136. NICE::Vector counts (dimension);
  137. mu.set(0);
  138. counts.set(0);
  139. MAPMultinomialGaussianPrior map;
  140. // simulate gaussian distribution
  141. for ( int i = 0 ; i < samples_count ; i++ )
  142. {
  143. double r = randGaussDouble ( dimension / 6 );
  144. int bin = (int) (r+dimension/3);
  145. if ( (bin >= 0) && (bin < dimension) )
  146. mu[bin]++;
  147. r = randGaussDouble ( dimension / 6 ) * randGaussDouble (dimension/6);
  148. bin = (int) (r+2*dimension/3);
  149. if ( (bin >= 0) && (bin < dimension) )
  150. counts[bin]++;
  151. }
  152. MAPMultinomialGaussianPrior::normalizeProb ( mu );
  153. mu_noninformative = Vector(counts);
  154. MAPMultinomialGaussianPrior::normalizeProb ( mu_noninformative );
  155. const double eps = 10e-11;
  156. for ( int i = 1 ; i < 12 ; i++ )
  157. {
  158. double sigmaq = pow(10,-i);
  159. // refactor-nice.pl: check this substitution
  160. // old: Vector theta ( dimension );
  161. NICE::Vector theta ( dimension );
  162. map.estimate ( theta, counts, mu, sigmaq );
  163. vector<double> tp, mup, munonp;
  164. tp = Conversions::stl_vector ( theta );
  165. mup = Conversions::stl_vector ( mu );
  166. munonp = Conversions::stl_vector ( mu_noninformative );
  167. Gnuplot gp;
  168. gp.set_style ( "boxes" );
  169. gp.plot_x ( mup, "mu" );
  170. gp.plot_x ( munonp, "mu (noninformative)" );
  171. gp.plot_x ( tp, "theta" );
  172. // refactor-nice.pl: check this substitution
  173. // old: GetChar();
  174. getchar();
  175. }
  176. return 0;
  177. }
  178. #endif