testPLSA.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445
  1. /**
  2. * @file testPLSA.cpp
  3. * @brief __DESC__
  4. * @author Erik Rodner
  5. * @date 05/21/2008
  6. */
  7. #include "core/vector/VectorT.h"
  8. #include "core/vector/MatrixT.h"
  9. #include "core/image/ImageT.h"
  10. #include "core/imagedisplay/ImageDisplay.h"
  11. #ifdef NICE_USELIB_ICE
  12. #include <core/vector/Distance.h>
  13. #include <image_nonvis.h>
  14. #include <core/iceconversion/convertice.h>
  15. #include <distancefunctions.h>
  16. #include <assert.h>
  17. #include <core/basics/Config.h>
  18. #include <vislearning/baselib/cmdline.h>
  19. #include <vislearning/baselib/Gnuplot.h>
  20. #include <vislearning/baselib/ICETools.h>
  21. #include <vislearning/baselib/Conversions.h>
  22. #include <vislearning/math/pdf/PDFDirichlet.h>
  23. #include <vislearning/math/pdf/PDFMultinomial.h>
  24. #include <vislearning/math/distances/ChiSqDistance.h>
  25. #include <vislearning/math/distances/KLDistance.h>
  26. #include <vislearning/math/distances/HistIntersectDistance.h>
  27. #include <vislearning/math/topics/PLSA.h>
  28. using namespace OBJREC;
  29. using namespace NICE;
  30. using namespace std;
  31. NICE::Vector randomDiscreteDistribution ( int dimension )
  32. {
  33. NICE::Vector theta (dimension);
  34. double s = 0.0;
  35. for ( int i = 0 ; i < dimension ; i++ )
  36. {
  37. theta[i] = drand48();
  38. s += theta[i];
  39. }
  40. for ( int i = 0 ; i < dimension ; i++ )
  41. theta[i] /= s;
  42. return theta;
  43. }
  44. void simulation_sivic ( int d, // number of documents
  45. int m, // number of topics
  46. int n, // number of words in the vocabulary
  47. NICE::Matrix & counts,
  48. NICE::Matrix & pw_z,
  49. NICE::Vector & pd,
  50. NICE::Matrix & pz_d,
  51. int samples_count
  52. )
  53. {
  54. assert ( m == 3 );
  55. assert ( n == 12 );
  56. std::istringstream pwz_string ( string("<\n<0.25,0.25,0.25,0.25,0,0,0,0,0,0,0,0>,\n") +
  57. string("<0,0,0,0,0.25,0.25,0.25,0.25,0,0,0,0>,\n") +
  58. string("<0,0,0,0,0,0,0,0,0.25,0.25,0.25,0.25>\n>") );
  59. pwz_string >> pw_z;
  60. vector<PDFMultinomial> beta;
  61. for ( int k = 0 ; k < m ; k++ )
  62. beta.push_back ( PDFMultinomial(pw_z.getRow(k),1) );
  63. PDFDirichlet dirichlet ( 0.2, m );
  64. VVector pzd;
  65. dirichlet.sample ( pzd, d );
  66. for ( int i = 0 ; i < d ; i++ )
  67. for ( int k = 0 ; k < m ; k++ )
  68. pz_d(k, i) = pz_d(i, k);
  69. for ( int i = 0 ; i < d ; i++ )
  70. pd[i] = 1.0 / d;
  71. counts.set(0);
  72. fprintf (stderr, "Generation ...\n");
  73. for ( int i = 0 ; i < d ; i++ )
  74. {
  75. PDFMultinomial theta ( pzd[i], 1 );
  76. for ( int w = 0 ; w < samples_count; w++ )
  77. {
  78. int topic = theta.sample();
  79. assert ( topic < m );
  80. int word = beta[topic].sample();
  81. counts(i, word) ++;
  82. }
  83. }
  84. pz_d.normalizeColumnsL1();
  85. pd.normalizeL1();
  86. }
  87. void simulation ( int d, // number of documents
  88. int m, // number of topics
  89. int n, // number of words in the vocabulary
  90. NICE::Matrix & counts,
  91. NICE::Matrix & pw_z,
  92. NICE::Vector & pd,
  93. NICE::Matrix & pz_d,
  94. int samples_count
  95. )
  96. {
  97. fprintf (stderr, "Generating model...\n");
  98. //pd = randomDiscreteDistribution(d);
  99. for ( int i = 0 ; i < d ; i++ )
  100. pd[i] = 1.0 / d;
  101. PDFMultinomial pds ( pd, 1);
  102. vector<PDFMultinomial> pz_ds;
  103. vector<PDFMultinomial> pw_zs;
  104. PDFDirichlet dirichlet ( 0.2, m );
  105. VVector pzd;
  106. dirichlet.sample ( pzd, d );
  107. for ( int i = 0 ; i < d ; i++ )
  108. for ( int k = 0 ; k < m ; k++ )
  109. pz_d(k, i) =pz_d(i, k);
  110. for ( int di = 0 ; di < d ; di++ )
  111. pz_ds.push_back ( PDFMultinomial (pzd[di],1) );
  112. // funny distributed dirichlet parameter
  113. pw_z.set(0);
  114. for ( int zk = 0 ; zk < m ; zk++ )
  115. {
  116. for ( int i = zk*n/m ; i < (zk+1)*n/m ; i++ )
  117. pw_z(zk, i) = 1.0;
  118. }
  119. pw_z.normalizeRowsL1();
  120. for ( int zk = 0 ; zk < m ; zk++ )
  121. pw_zs.push_back ( PDFMultinomial ( pw_z.getRow(zk), 1 ) );
  122. fprintf (stderr, "Normalization...\n");
  123. fprintf (stderr, "Generating samples...\n");
  124. counts.set(0);
  125. for ( int j = 0 ; j < samples_count*d ; j++ )
  126. {
  127. // sample document using p(d)
  128. int document = pds.sample();
  129. // sample topic using p(z|d)
  130. int topic = pz_ds[document].sample();
  131. // sample word of the vocabulary using p(w|z)
  132. int word = pw_zs[topic].sample();
  133. // refactor-nice.pl: check this substitution
  134. // old: counts[document][word]++;
  135. counts(document, word)++;
  136. }
  137. }
  138. void sample_document ( const NICE::Vector & pz_singled,
  139. const NICE::Matrix & pw_z,
  140. NICE::Vector & histogramm,
  141. int samples_count )
  142. {
  143. int m = pz_singled.size();
  144. vector<PDFMultinomial> pw_zs;
  145. for ( int zk = 0 ; zk < m ; zk++ )
  146. pw_zs.push_back ( PDFMultinomial ( pw_z.getRow(zk), 1 ) );
  147. PDFMultinomial topicgen ( pz_singled, 1 );
  148. for ( int j = 0 ; j < samples_count ; j++ )
  149. {
  150. int topic = topicgen.sample();
  151. int word = pw_zs[topic].sample();
  152. histogramm[word]++;
  153. }
  154. }
  155. double comparePermMatrices ( const NICE::Matrix & A, const NICE::Matrix & B, MatrixT<int> & reference_pairs )
  156. {
  157. double min_cost;
  158. ice::Matrix cost ( A.rows(), B.rows() );
  159. NICE::VectorDistance<double> *df = new NICE::EuclidianDistance<double> ();
  160. assert ( A.rows() == B.rows() );
  161. for ( int i = 0 ; i < A.rows() ; i++ )
  162. for ( int j = 0 ; j < B.rows() ; j++ )
  163. cost[i][j] = df->calculate ( A.getRow(i), B.getRow(j) );
  164. delete df;
  165. ice::IMatrix reference_pairs_ice;
  166. ice::Hungarian ( cost, reference_pairs_ice, min_cost );
  167. if ( reference_pairs_ice.rows() == A.rows()-1 ) {
  168. set<int> a;
  169. set<int> b;
  170. for ( int i = 0 ; i < A.rows() ; i++ )
  171. {
  172. a.insert ( i );
  173. b.insert ( i );
  174. }
  175. for ( int i = 0 ; i < reference_pairs_ice.rows() ; i++ )
  176. {
  177. a.erase ( a.find(reference_pairs_ice[i][0] ) );
  178. b.erase ( b.find(reference_pairs_ice[i][1] ) );
  179. }
  180. assert ( (a.size() == 1) && (b.size() == 1) );
  181. reference_pairs_ice.Append ( ice::IVector ( *(a.begin()), *(a.end()) ) );
  182. reference_pairs = NICE::makeIntegerMatrix<int> ( reference_pairs_ice );
  183. return min_cost / ( A.rows() );
  184. } else if ( reference_pairs_ice.rows() != A.rows() ) {
  185. reference_pairs = NICE::makeIntegerMatrix<int> ( reference_pairs_ice );
  186. return -1.0;
  187. }
  188. }
  189. int main (int argc, char **argv)
  190. {
  191. std::set_terminate(__gnu_cxx::__verbose_terminate_handler);
  192. srand48(time(NULL));
  193. int tempered = 0;
  194. int optimization_verbose = 0;
  195. int samples_count = 10000;
  196. int d = 1000;
  197. int m = 3;
  198. int n = 12;
  199. double delta_eps = 1e-3;
  200. int maxiterations = 500;
  201. int runs = 1;
  202. double holdoutportion = 0.2;
  203. int use_simulation_sivic = 0;
  204. struct CmdLineOption options[] = {
  205. {"verbose", "print details about the optimization", NULL, NULL, &optimization_verbose},
  206. {"documents", "number of documents", "1000", "%d", &d},
  207. {"topics", "number of topics", "3", "%d", &m},
  208. {"words", "number of words in the vocabulary", "12", "%d", &n},
  209. {"samples", "number of samples (in total)", "10000", "%d", &samples_count},
  210. {"maxiterations", "maximum number of EM-iterations", "500", "%d", &maxiterations},
  211. {"runs", "number of runs performed", "1", "%d", &runs},
  212. {"deltaeps", "terminate optimization if delta of likelihood is below this threshold", NULL, "%lf", &delta_eps},
  213. {"sivic", "use the simulation routine and parameters of Josef Sivic's matlab implementation", NULL, NULL, &use_simulation_sivic},
  214. {"holdout", "fraction of data for validation", "0.8", "%lf", &holdoutportion},
  215. {"tempered", "use tempered version of EM algorithm", NULL, NULL, &tempered},
  216. {NULL, NULL, NULL, NULL, NULL}
  217. };
  218. int ret;
  219. char *more_options[argc];
  220. ret = parse_arguments( argc, (const char**)argv, options, more_options);
  221. if ( ! tempered )
  222. holdoutportion = 0.0;
  223. srand48(time(NULL));
  224. PLSA pLSA ( maxiterations, delta_eps, 0.9, holdoutportion );
  225. double error_pw_z_avg = 0.0;
  226. double error_foldin_avg = 0.0;
  227. double error_pw_z_stddev = 0.0;
  228. int runs_successfull = 0;
  229. fprintf (stderr, "documents: %d\n", d );
  230. fprintf (stderr, "topics: %d\n", m );
  231. fprintf (stderr, "runs: %d\n", runs );
  232. for ( int run = 0 ; run < runs ; run++ )
  233. {
  234. fprintf (stderr, "[run %d/%d]\n", run, runs );
  235. // refactor-nice.pl: check this substitution
  236. // old: Matrix counts ( d, n );
  237. NICE::Matrix counts ( d, n );
  238. // refactor-nice.pl: check this substitution
  239. // old: Matrix pw_z ( m, n );
  240. NICE::Matrix pw_z ( m, n );
  241. // refactor-nice.pl: check this substitution
  242. // old: Vector pd ( d );
  243. NICE::Vector pd ( d );
  244. // refactor-nice.pl: check this substitution
  245. // old: Matrix pz_d ( m, d );
  246. NICE::Matrix pz_d ( m, d );
  247. if ( use_simulation_sivic )
  248. simulation_sivic ( d, m, n,
  249. counts,
  250. pw_z,
  251. pd,
  252. pz_d,
  253. samples_count );
  254. else
  255. simulation ( d, m, n,
  256. counts,
  257. pw_z,
  258. pd,
  259. pz_d,
  260. samples_count );
  261. /*
  262. // refactor-nice.pl: check this substitution
  263. // old: ImageRGB img_groundtruth ( n, m );
  264. NICE::ColorImage img_groundtruth ( n, m );
  265. ICETools::convertToRGB ( pw_z, img_groundtruth );
  266. Show(ON, img_groundtruth, "P(w|z) groundtruth");
  267. */
  268. double *counts_raw = ICETools::convertICE2M ( counts );
  269. double *pw_z_estimate_raw = new double [n*m];
  270. double *pz_d_estimate_raw = new double [m*d];
  271. // refactor-nice.pl: check this substitution
  272. // old: Vector pd_estimate ( d );
  273. NICE::Vector pd_estimate ( d );
  274. double likelihood_estimate =
  275. pLSA.pLSA ( counts_raw,
  276. pw_z_estimate_raw,
  277. pd_estimate.getDataPointer(),
  278. pz_d_estimate_raw,
  279. n, m, d,
  280. true /* do not perform folding */,
  281. tempered ? true : false /* use tempered version */,
  282. optimization_verbose ? true : false );
  283. double *pw_z_raw = ICETools::convertICE2M ( pw_z );
  284. double *pz_d_raw = ICETools::convertICE2M ( pz_d );
  285. int dtrained = d - (int) ( d * holdoutportion );
  286. double likelihood_groundtruth = pLSA.computeLikelihood (
  287. counts_raw, pw_z_raw, pd.getDataPointer(), pz_d_raw, n, m , d, dtrained );
  288. delete [] pw_z_raw;
  289. delete [] pz_d_raw;
  290. // refactor-nice.pl: check this substitution
  291. // old: Matrix pw_z_estimate ( m, n );
  292. NICE::Matrix pw_z_estimate ( m, n );
  293. ICETools::convertM2ICE ( pw_z_estimate, pw_z_estimate_raw );
  294. delete [] counts_raw;
  295. MatrixT<int> refPWZ, refPDZ;
  296. double error_pw_z = comparePermMatrices ( pw_z, pw_z_estimate, refPWZ );
  297. fprintf (stderr, "---------- final error %d/%d ---------\n", run+1, runs );
  298. fprintf (stderr, "error p(w|z) pwz: %f\n", error_pw_z );
  299. error_pw_z_avg += error_pw_z;
  300. error_pw_z_stddev += error_pw_z * error_pw_z;
  301. /***************** FOLDIN Tests *******************/
  302. KLDistance<double> dist;
  303. double error_foldin = 0.0;
  304. for ( int i = 0 ; i < dtrained ; i++ )
  305. {
  306. // refactor-nice.pl: check this substitution
  307. // old: Vector pz_singled ( m );
  308. NICE::Vector pz_singled ( m );
  309. for ( int k = 0 ; k < m ; k++ )
  310. pz_singled[k] = pz_d_estimate_raw[k*dtrained+i];
  311. // refactor-nice.pl: check this substitution
  312. // old: Vector histogramm ( n );
  313. NICE::Vector histogramm ( n );
  314. sample_document ( pz_singled, pw_z_estimate, histogramm, samples_count );
  315. // refactor-nice.pl: check this substitution
  316. // old: Vector pd_foldin (1);
  317. NICE::Vector pd_foldin (1);
  318. // refactor-nice.pl: check this substitution
  319. // old: Vector pz_d_foldin (m);
  320. NICE::Vector pz_d_foldin (m);
  321. pLSA.pLSA ( histogramm.getDataPointer(),
  322. pw_z_estimate_raw,
  323. pd_foldin.getDataPointer(),
  324. pz_d_foldin.getDataPointer(),
  325. n, m, 1,
  326. false /* do not use tempered version */,
  327. false );
  328. /*
  329. cerr << "p(1,w) = " << histogramm << endl;
  330. cerr << "p(d) =fold " << pd_foldin << endl;
  331. cerr << "p(z|d) =fold " << pz_d_foldin << endl;
  332. cerr << "p(z|d) =gt " << pz_singled << endl;
  333. */
  334. error_foldin += dist ( pz_d_foldin, pz_singled );
  335. }
  336. error_foldin /= dtrained;
  337. fprintf (stderr, "foldin error: %f\n", error_foldin );
  338. error_foldin_avg += error_foldin;
  339. runs_successfull++;
  340. double error_likelihood = likelihood_estimate - likelihood_groundtruth;
  341. fprintf (stderr, "likelihood error: %f\n", error_likelihood );
  342. if ( error_likelihood > 1e3 )
  343. {
  344. fprintf (stderr, "This seems to be a severe local minimum !\n");
  345. }
  346. delete [] pw_z_estimate_raw;
  347. delete [] pz_d_estimate_raw;
  348. }
  349. error_pw_z_avg /= runs_successfull;
  350. error_foldin_avg /= runs_successfull;
  351. error_pw_z_stddev /= runs_successfull;
  352. error_pw_z_stddev -= error_pw_z_avg*error_pw_z_avg;
  353. error_pw_z_stddev = (error_pw_z_stddev < 0.0) ? 0.0 : sqrt(error_pw_z_stddev);
  354. fprintf (stderr, "pwz: %f %f\n", error_pw_z_avg, error_pw_z_stddev );
  355. fprintf (stderr, "foldin: %f\n", error_foldin_avg );
  356. return 0;
  357. }
  358. #else
  359. int main (int argc, char **argv)
  360. {
  361. throw("no ice installed\n");
  362. return 0;
  363. }
  364. #endif