testPLSA.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442
  1. /**
  2. * @file testPLSA.cpp
  3. * @brief __DESC__
  4. * @author Erik Rodner
  5. * @date 05/21/2008
  6. */
  7. #include <vislearning/nice.h>
  8. #ifdef NICE_USELIB_ICE
  9. #include <core/vector/Distance.h>
  10. #include <image_nonvis.h>
  11. #include <core/iceconversion/convertice.h>
  12. #include <distancefunctions.h>
  13. #include <assert.h>
  14. #include <core/basics/Config.h>
  15. #include <vislearning/baselib/cmdline.h>
  16. #include <vislearning/baselib/Gnuplot.h>
  17. #include <vislearning/baselib/ICETools.h>
  18. #include <vislearning/baselib/Conversions.h>
  19. #include <vislearning/math/pdf/PDFDirichlet.h>
  20. #include <vislearning/math/pdf/PDFMultinomial.h>
  21. #include <vislearning/math/distances/ChiSqDistance.h>
  22. #include <vislearning/math/distances/KLDistance.h>
  23. #include <vislearning/math/distances/HistIntersectDistance.h>
  24. #include <vislearning/math/topics/PLSA.h>
  25. using namespace OBJREC;
  26. using namespace NICE;
  27. using namespace std;
  28. NICE::Vector randomDiscreteDistribution ( int dimension )
  29. {
  30. NICE::Vector theta (dimension);
  31. double s = 0.0;
  32. for ( int i = 0 ; i < dimension ; i++ )
  33. {
  34. theta[i] = drand48();
  35. s += theta[i];
  36. }
  37. for ( int i = 0 ; i < dimension ; i++ )
  38. theta[i] /= s;
  39. return theta;
  40. }
  41. void simulation_sivic ( int d, // number of documents
  42. int m, // number of topics
  43. int n, // number of words in the vocabulary
  44. NICE::Matrix & counts,
  45. NICE::Matrix & pw_z,
  46. NICE::Vector & pd,
  47. NICE::Matrix & pz_d,
  48. int samples_count
  49. )
  50. {
  51. assert ( m == 3 );
  52. assert ( n == 12 );
  53. std::istringstream pwz_string ( string("<\n<0.25,0.25,0.25,0.25,0,0,0,0,0,0,0,0>,\n") +
  54. string("<0,0,0,0,0.25,0.25,0.25,0.25,0,0,0,0>,\n") +
  55. string("<0,0,0,0,0,0,0,0,0.25,0.25,0.25,0.25>\n>") );
  56. pwz_string >> pw_z;
  57. vector<PDFMultinomial> beta;
  58. for ( int k = 0 ; k < m ; k++ )
  59. beta.push_back ( PDFMultinomial(pw_z.getRow(k),1) );
  60. PDFDirichlet dirichlet ( 0.2, m );
  61. VVector pzd;
  62. dirichlet.sample ( pzd, d );
  63. for ( int i = 0 ; i < d ; i++ )
  64. for ( int k = 0 ; k < m ; k++ )
  65. pz_d(k, i) = pz_d(i, k);
  66. for ( int i = 0 ; i < d ; i++ )
  67. pd[i] = 1.0 / d;
  68. counts.set(0);
  69. fprintf (stderr, "Generation ...\n");
  70. for ( int i = 0 ; i < d ; i++ )
  71. {
  72. PDFMultinomial theta ( pzd[i], 1 );
  73. for ( int w = 0 ; w < samples_count; w++ )
  74. {
  75. int topic = theta.sample();
  76. assert ( topic < m );
  77. int word = beta[topic].sample();
  78. counts(i, word) ++;
  79. }
  80. }
  81. pz_d.normalizeColumnsL1();
  82. pd.normalizeL1();
  83. }
  84. void simulation ( int d, // number of documents
  85. int m, // number of topics
  86. int n, // number of words in the vocabulary
  87. NICE::Matrix & counts,
  88. NICE::Matrix & pw_z,
  89. NICE::Vector & pd,
  90. NICE::Matrix & pz_d,
  91. int samples_count
  92. )
  93. {
  94. fprintf (stderr, "Generating model...\n");
  95. //pd = randomDiscreteDistribution(d);
  96. for ( int i = 0 ; i < d ; i++ )
  97. pd[i] = 1.0 / d;
  98. PDFMultinomial pds ( pd, 1);
  99. vector<PDFMultinomial> pz_ds;
  100. vector<PDFMultinomial> pw_zs;
  101. PDFDirichlet dirichlet ( 0.2, m );
  102. VVector pzd;
  103. dirichlet.sample ( pzd, d );
  104. for ( int i = 0 ; i < d ; i++ )
  105. for ( int k = 0 ; k < m ; k++ )
  106. pz_d(k, i) =pz_d(i, k);
  107. for ( int di = 0 ; di < d ; di++ )
  108. pz_ds.push_back ( PDFMultinomial (pzd[di],1) );
  109. // funny distributed dirichlet parameter
  110. pw_z.set(0);
  111. for ( int zk = 0 ; zk < m ; zk++ )
  112. {
  113. for ( int i = zk*n/m ; i < (zk+1)*n/m ; i++ )
  114. pw_z(zk, i) = 1.0;
  115. }
  116. pw_z.normalizeRowsL1();
  117. for ( int zk = 0 ; zk < m ; zk++ )
  118. pw_zs.push_back ( PDFMultinomial ( pw_z.getRow(zk), 1 ) );
  119. fprintf (stderr, "Normalization...\n");
  120. fprintf (stderr, "Generating samples...\n");
  121. counts.set(0);
  122. for ( int j = 0 ; j < samples_count*d ; j++ )
  123. {
  124. // sample document using p(d)
  125. int document = pds.sample();
  126. // sample topic using p(z|d)
  127. int topic = pz_ds[document].sample();
  128. // sample word of the vocabulary using p(w|z)
  129. int word = pw_zs[topic].sample();
  130. // refactor-nice.pl: check this substitution
  131. // old: counts[document][word]++;
  132. counts(document, word)++;
  133. }
  134. }
  135. void sample_document ( const NICE::Vector & pz_singled,
  136. const NICE::Matrix & pw_z,
  137. NICE::Vector & histogramm,
  138. int samples_count )
  139. {
  140. int m = pz_singled.size();
  141. vector<PDFMultinomial> pw_zs;
  142. for ( int zk = 0 ; zk < m ; zk++ )
  143. pw_zs.push_back ( PDFMultinomial ( pw_z.getRow(zk), 1 ) );
  144. PDFMultinomial topicgen ( pz_singled, 1 );
  145. for ( int j = 0 ; j < samples_count ; j++ )
  146. {
  147. int topic = topicgen.sample();
  148. int word = pw_zs[topic].sample();
  149. histogramm[word]++;
  150. }
  151. }
  152. double comparePermMatrices ( const NICE::Matrix & A, const NICE::Matrix & B, MatrixT<int> & reference_pairs )
  153. {
  154. double min_cost;
  155. ice::Matrix cost ( A.rows(), B.rows() );
  156. NICE::VectorDistance<double> *df = new NICE::EuclidianDistance<double> ();
  157. assert ( A.rows() == B.rows() );
  158. for ( int i = 0 ; i < A.rows() ; i++ )
  159. for ( int j = 0 ; j < B.rows() ; j++ )
  160. cost[i][j] = df->calculate ( A.getRow(i), B.getRow(j) );
  161. delete df;
  162. ice::IMatrix reference_pairs_ice;
  163. ice::Hungarian ( cost, reference_pairs_ice, min_cost );
  164. if ( reference_pairs_ice.rows() == A.rows()-1 ) {
  165. set<int> a;
  166. set<int> b;
  167. for ( int i = 0 ; i < A.rows() ; i++ )
  168. {
  169. a.insert ( i );
  170. b.insert ( i );
  171. }
  172. for ( int i = 0 ; i < reference_pairs_ice.rows() ; i++ )
  173. {
  174. a.erase ( a.find(reference_pairs_ice[i][0] ) );
  175. b.erase ( b.find(reference_pairs_ice[i][1] ) );
  176. }
  177. assert ( (a.size() == 1) && (b.size() == 1) );
  178. reference_pairs_ice.Append ( ice::IVector ( *(a.begin()), *(a.end()) ) );
  179. reference_pairs = NICE::makeIntegerMatrix<int> ( reference_pairs_ice );
  180. return min_cost / ( A.rows() );
  181. } else if ( reference_pairs_ice.rows() != A.rows() ) {
  182. reference_pairs = NICE::makeIntegerMatrix<int> ( reference_pairs_ice );
  183. return -1.0;
  184. }
  185. }
  186. int main (int argc, char **argv)
  187. {
  188. std::set_terminate(__gnu_cxx::__verbose_terminate_handler);
  189. srand48(time(NULL));
  190. int tempered = 0;
  191. int optimization_verbose = 0;
  192. int samples_count = 10000;
  193. int d = 1000;
  194. int m = 3;
  195. int n = 12;
  196. double delta_eps = 1e-3;
  197. int maxiterations = 500;
  198. int runs = 1;
  199. double holdoutportion = 0.2;
  200. int use_simulation_sivic = 0;
  201. struct CmdLineOption options[] = {
  202. {"verbose", "print details about the optimization", NULL, NULL, &optimization_verbose},
  203. {"documents", "number of documents", "1000", "%d", &d},
  204. {"topics", "number of topics", "3", "%d", &m},
  205. {"words", "number of words in the vocabulary", "12", "%d", &n},
  206. {"samples", "number of samples (in total)", "10000", "%d", &samples_count},
  207. {"maxiterations", "maximum number of EM-iterations", "500", "%d", &maxiterations},
  208. {"runs", "number of runs performed", "1", "%d", &runs},
  209. {"deltaeps", "terminate optimization if delta of likelihood is below this threshold", NULL, "%lf", &delta_eps},
  210. {"sivic", "use the simulation routine and parameters of Josef Sivic's matlab implementation", NULL, NULL, &use_simulation_sivic},
  211. {"holdout", "fraction of data for validation", "0.8", "%lf", &holdoutportion},
  212. {"tempered", "use tempered version of EM algorithm", NULL, NULL, &tempered},
  213. {NULL, NULL, NULL, NULL, NULL}
  214. };
  215. int ret;
  216. char *more_options[argc];
  217. ret = parse_arguments( argc, (const char**)argv, options, more_options);
  218. if ( ! tempered )
  219. holdoutportion = 0.0;
  220. srand48(time(NULL));
  221. PLSA pLSA ( maxiterations, delta_eps, 0.9, holdoutportion );
  222. double error_pw_z_avg = 0.0;
  223. double error_foldin_avg = 0.0;
  224. double error_pw_z_stddev = 0.0;
  225. int runs_successfull = 0;
  226. fprintf (stderr, "documents: %d\n", d );
  227. fprintf (stderr, "topics: %d\n", m );
  228. fprintf (stderr, "runs: %d\n", runs );
  229. for ( int run = 0 ; run < runs ; run++ )
  230. {
  231. fprintf (stderr, "[run %d/%d]\n", run, runs );
  232. // refactor-nice.pl: check this substitution
  233. // old: Matrix counts ( d, n );
  234. NICE::Matrix counts ( d, n );
  235. // refactor-nice.pl: check this substitution
  236. // old: Matrix pw_z ( m, n );
  237. NICE::Matrix pw_z ( m, n );
  238. // refactor-nice.pl: check this substitution
  239. // old: Vector pd ( d );
  240. NICE::Vector pd ( d );
  241. // refactor-nice.pl: check this substitution
  242. // old: Matrix pz_d ( m, d );
  243. NICE::Matrix pz_d ( m, d );
  244. if ( use_simulation_sivic )
  245. simulation_sivic ( d, m, n,
  246. counts,
  247. pw_z,
  248. pd,
  249. pz_d,
  250. samples_count );
  251. else
  252. simulation ( d, m, n,
  253. counts,
  254. pw_z,
  255. pd,
  256. pz_d,
  257. samples_count );
  258. /*
  259. // refactor-nice.pl: check this substitution
  260. // old: ImageRGB img_groundtruth ( n, m );
  261. NICE::ColorImage img_groundtruth ( n, m );
  262. ICETools::convertToRGB ( pw_z, img_groundtruth );
  263. Show(ON, img_groundtruth, "P(w|z) groundtruth");
  264. */
  265. double *counts_raw = ICETools::convertICE2M ( counts );
  266. double *pw_z_estimate_raw = new double [n*m];
  267. double *pz_d_estimate_raw = new double [m*d];
  268. // refactor-nice.pl: check this substitution
  269. // old: Vector pd_estimate ( d );
  270. NICE::Vector pd_estimate ( d );
  271. double likelihood_estimate =
  272. pLSA.pLSA ( counts_raw,
  273. pw_z_estimate_raw,
  274. pd_estimate.getDataPointer(),
  275. pz_d_estimate_raw,
  276. n, m, d,
  277. true /* do not perform folding */,
  278. tempered ? true : false /* use tempered version */,
  279. optimization_verbose ? true : false );
  280. double *pw_z_raw = ICETools::convertICE2M ( pw_z );
  281. double *pz_d_raw = ICETools::convertICE2M ( pz_d );
  282. int dtrained = d - (int) ( d * holdoutportion );
  283. double likelihood_groundtruth = pLSA.computeLikelihood (
  284. counts_raw, pw_z_raw, pd.getDataPointer(), pz_d_raw, n, m , d, dtrained );
  285. delete [] pw_z_raw;
  286. delete [] pz_d_raw;
  287. // refactor-nice.pl: check this substitution
  288. // old: Matrix pw_z_estimate ( m, n );
  289. NICE::Matrix pw_z_estimate ( m, n );
  290. ICETools::convertM2ICE ( pw_z_estimate, pw_z_estimate_raw );
  291. delete [] counts_raw;
  292. MatrixT<int> refPWZ, refPDZ;
  293. double error_pw_z = comparePermMatrices ( pw_z, pw_z_estimate, refPWZ );
  294. fprintf (stderr, "---------- final error %d/%d ---------\n", run+1, runs );
  295. fprintf (stderr, "error p(w|z) pwz: %f\n", error_pw_z );
  296. error_pw_z_avg += error_pw_z;
  297. error_pw_z_stddev += error_pw_z * error_pw_z;
  298. /***************** FOLDIN Tests *******************/
  299. KLDistance<double> dist;
  300. double error_foldin = 0.0;
  301. for ( int i = 0 ; i < dtrained ; i++ )
  302. {
  303. // refactor-nice.pl: check this substitution
  304. // old: Vector pz_singled ( m );
  305. NICE::Vector pz_singled ( m );
  306. for ( int k = 0 ; k < m ; k++ )
  307. pz_singled[k] = pz_d_estimate_raw[k*dtrained+i];
  308. // refactor-nice.pl: check this substitution
  309. // old: Vector histogramm ( n );
  310. NICE::Vector histogramm ( n );
  311. sample_document ( pz_singled, pw_z_estimate, histogramm, samples_count );
  312. // refactor-nice.pl: check this substitution
  313. // old: Vector pd_foldin (1);
  314. NICE::Vector pd_foldin (1);
  315. // refactor-nice.pl: check this substitution
  316. // old: Vector pz_d_foldin (m);
  317. NICE::Vector pz_d_foldin (m);
  318. pLSA.pLSA ( histogramm.getDataPointer(),
  319. pw_z_estimate_raw,
  320. pd_foldin.getDataPointer(),
  321. pz_d_foldin.getDataPointer(),
  322. n, m, 1,
  323. false /* do not use tempered version */,
  324. false );
  325. /*
  326. cerr << "p(1,w) = " << histogramm << endl;
  327. cerr << "p(d) =fold " << pd_foldin << endl;
  328. cerr << "p(z|d) =fold " << pz_d_foldin << endl;
  329. cerr << "p(z|d) =gt " << pz_singled << endl;
  330. */
  331. error_foldin += dist ( pz_d_foldin, pz_singled );
  332. }
  333. error_foldin /= dtrained;
  334. fprintf (stderr, "foldin error: %f\n", error_foldin );
  335. error_foldin_avg += error_foldin;
  336. runs_successfull++;
  337. double error_likelihood = likelihood_estimate - likelihood_groundtruth;
  338. fprintf (stderr, "likelihood error: %f\n", error_likelihood );
  339. if ( error_likelihood > 1e3 )
  340. {
  341. fprintf (stderr, "This seems to be a severe local minimum !\n");
  342. }
  343. delete [] pw_z_estimate_raw;
  344. delete [] pz_d_estimate_raw;
  345. }
  346. error_pw_z_avg /= runs_successfull;
  347. error_foldin_avg /= runs_successfull;
  348. error_pw_z_stddev /= runs_successfull;
  349. error_pw_z_stddev -= error_pw_z_avg*error_pw_z_avg;
  350. error_pw_z_stddev = (error_pw_z_stddev < 0.0) ? 0.0 : sqrt(error_pw_z_stddev);
  351. fprintf (stderr, "pwz: %f %f\n", error_pw_z_avg, error_pw_z_stddev );
  352. fprintf (stderr, "foldin: %f\n", error_foldin_avg );
  353. return 0;
  354. }
  355. #else
  356. int main (int argc, char **argv)
  357. {
  358. throw("no ice installed\n");
  359. return 0;
  360. }
  361. #endif