testPLSA.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456
  1. /**
  2. * @file testPLSA.cpp
  3. * @brief __DESC__
  4. * @author Erik Rodner
  5. * @date 05/21/2008
  6. */
  7. #if 0
  8. #include "core/vector/VectorT.h"
  9. #include "core/vector/MatrixT.h"
  10. #include "core/image/ImageT.h"
  11. #include "core/imagedisplay/ImageDisplay.h"
  12. #ifdef NICE_USELIB_ICE
  13. #include <core/vector/Distance.h>
  14. #include <image_nonvis.h>
  15. #include <core/iceconversion/convertice.h>
  16. #include <distancefunctions.h>
  17. #include <assert.h>
  18. #include <core/basics/Config.h>
  19. #include <vislearning/baselib/cmdline.h>
  20. #include <vislearning/baselib/Gnuplot.h>
  21. #include <vislearning/baselib/ICETools.h>
  22. #include <vislearning/baselib/Conversions.h>
  23. #include <vislearning/math/pdf/PDFDirichlet.h>
  24. #include <vislearning/math/pdf/PDFMultinomial.h>
  25. #include <vislearning/math/distances/ChiSqDistance.h>
  26. #include <vislearning/math/distances/KLDistance.h>
  27. #include <vislearning/math/distances/HistIntersectDistance.h>
  28. #include <vislearning/math/topics/PLSA.h>
  29. using namespace OBJREC;
  30. using namespace NICE;
  31. using namespace std;
  32. NICE::Vector randomDiscreteDistribution ( int dimension )
  33. {
  34. NICE::Vector theta (dimension);
  35. double s = 0.0;
  36. for ( int i = 0 ; i < dimension ; i++ )
  37. {
  38. theta[i] = drand48();
  39. s += theta[i];
  40. }
  41. for ( int i = 0 ; i < dimension ; i++ )
  42. theta[i] /= s;
  43. return theta;
  44. }
  45. void simulation_sivic ( int d, // number of documents
  46. int m, // number of topics
  47. int n, // number of words in the vocabulary
  48. NICE::Matrix & counts,
  49. NICE::Matrix & pw_z,
  50. NICE::Vector & pd,
  51. NICE::Matrix & pz_d,
  52. int samples_count
  53. )
  54. {
  55. assert ( m == 3 );
  56. assert ( n == 12 );
  57. std::istringstream pwz_string ( string("<\n<0.25,0.25,0.25,0.25,0,0,0,0,0,0,0,0>,\n") +
  58. string("<0,0,0,0,0.25,0.25,0.25,0.25,0,0,0,0>,\n") +
  59. string("<0,0,0,0,0,0,0,0,0.25,0.25,0.25,0.25>\n>") );
  60. pwz_string >> pw_z;
  61. vector<PDFMultinomial> beta;
  62. for ( int k = 0 ; k < m ; k++ )
  63. beta.push_back ( PDFMultinomial(pw_z.getRow(k),1) );
  64. PDFDirichlet dirichlet ( 0.2, m );
  65. VVector pzd;
  66. dirichlet.sample ( pzd, d );
  67. for ( int i = 0 ; i < d ; i++ )
  68. for ( int k = 0 ; k < m ; k++ )
  69. pz_d(k, i) = pz_d(i, k);
  70. for ( int i = 0 ; i < d ; i++ )
  71. pd[i] = 1.0 / d;
  72. counts.set(0);
  73. fprintf (stderr, "Generation ...\n");
  74. for ( int i = 0 ; i < d ; i++ )
  75. {
  76. PDFMultinomial theta ( pzd[i], 1 );
  77. for ( int w = 0 ; w < samples_count; w++ )
  78. {
  79. int topic = theta.sample();
  80. assert ( topic < m );
  81. int word = beta[topic].sample();
  82. counts(i, word) ++;
  83. }
  84. }
  85. pz_d.normalizeColumnsL1();
  86. pd.normalizeL1();
  87. }
  88. void simulation ( int d, // number of documents
  89. int m, // number of topics
  90. int n, // number of words in the vocabulary
  91. NICE::Matrix & counts,
  92. NICE::Matrix & pw_z,
  93. NICE::Vector & pd,
  94. NICE::Matrix & pz_d,
  95. int samples_count
  96. )
  97. {
  98. fprintf (stderr, "Generating model...\n");
  99. //pd = randomDiscreteDistribution(d);
  100. for ( int i = 0 ; i < d ; i++ )
  101. pd[i] = 1.0 / d;
  102. PDFMultinomial pds ( pd, 1);
  103. vector<PDFMultinomial> pz_ds;
  104. vector<PDFMultinomial> pw_zs;
  105. PDFDirichlet dirichlet ( 0.2, m );
  106. VVector pzd;
  107. dirichlet.sample ( pzd, d );
  108. for ( int i = 0 ; i < d ; i++ )
  109. for ( int k = 0 ; k < m ; k++ )
  110. pz_d(k, i) =pz_d(i, k);
  111. for ( int di = 0 ; di < d ; di++ )
  112. pz_ds.push_back ( PDFMultinomial (pzd[di],1) );
  113. // funny distributed dirichlet parameter
  114. pw_z.set(0);
  115. for ( int zk = 0 ; zk < m ; zk++ )
  116. {
  117. for ( int i = zk*n/m ; i < (zk+1)*n/m ; i++ )
  118. pw_z(zk, i) = 1.0;
  119. }
  120. pw_z.normalizeRowsL1();
  121. for ( int zk = 0 ; zk < m ; zk++ )
  122. pw_zs.push_back ( PDFMultinomial ( pw_z.getRow(zk), 1 ) );
  123. fprintf (stderr, "Normalization...\n");
  124. fprintf (stderr, "Generating samples...\n");
  125. counts.set(0);
  126. for ( int j = 0 ; j < samples_count*d ; j++ )
  127. {
  128. // sample document using p(d)
  129. int document = pds.sample();
  130. // sample topic using p(z|d)
  131. int topic = pz_ds[document].sample();
  132. // sample word of the vocabulary using p(w|z)
  133. int word = pw_zs[topic].sample();
  134. // refactor-nice.pl: check this substitution
  135. // old: counts[document][word]++;
  136. counts(document, word)++;
  137. }
  138. }
  139. void sample_document ( const NICE::Vector & pz_singled,
  140. const NICE::Matrix & pw_z,
  141. NICE::Vector & histogramm,
  142. int samples_count )
  143. {
  144. int m = pz_singled.size();
  145. vector<PDFMultinomial> pw_zs;
  146. for ( int zk = 0 ; zk < m ; zk++ )
  147. pw_zs.push_back ( PDFMultinomial ( pw_z.getRow(zk), 1 ) );
  148. PDFMultinomial topicgen ( pz_singled, 1 );
  149. for ( int j = 0 ; j < samples_count ; j++ )
  150. {
  151. int topic = topicgen.sample();
  152. int word = pw_zs[topic].sample();
  153. histogramm[word]++;
  154. }
  155. }
  156. double comparePermMatrices ( const NICE::Matrix & A, const NICE::Matrix & B, MatrixT<int> & reference_pairs )
  157. {
  158. double min_cost;
  159. ice::Matrix cost ( A.rows(), B.rows() );
  160. NICE::VectorDistance<double> *df = new NICE::EuclidianDistance<double> ();
  161. assert ( A.rows() == B.rows() );
  162. for ( int i = 0 ; i < A.rows() ; i++ )
  163. for ( int j = 0 ; j < B.rows() ; j++ )
  164. cost[i][j] = df->calculate ( A.getRow(i), B.getRow(j) );
  165. delete df;
  166. ice::IMatrix reference_pairs_ice;
  167. ice::Hungarian ( cost, reference_pairs_ice, min_cost );
  168. if ( reference_pairs_ice.rows() == A.rows()-1 ) {
  169. set<int> a;
  170. set<int> b;
  171. for ( int i = 0 ; i < A.rows() ; i++ )
  172. {
  173. a.insert ( i );
  174. b.insert ( i );
  175. }
  176. for ( int i = 0 ; i < reference_pairs_ice.rows() ; i++ )
  177. {
  178. a.erase ( a.find(reference_pairs_ice[i][0] ) );
  179. b.erase ( b.find(reference_pairs_ice[i][1] ) );
  180. }
  181. assert ( (a.size() == 1) && (b.size() == 1) );
  182. reference_pairs_ice.Append ( ice::IVector ( *(a.begin()), *(a.end()) ) );
  183. reference_pairs = NICE::makeIntegerMatrix<int> ( reference_pairs_ice );
  184. return min_cost / ( A.rows() );
  185. } else if ( reference_pairs_ice.rows() != A.rows() ) {
  186. reference_pairs = NICE::makeIntegerMatrix<int> ( reference_pairs_ice );
  187. return -1.0;
  188. }
  189. }
  190. int main (int argc, char **argv)
  191. {
  192. std::set_terminate(__gnu_cxx::__verbose_terminate_handler);
  193. srand48(time(NULL));
  194. int tempered = 0;
  195. int optimization_verbose = 0;
  196. int samples_count = 10000;
  197. int d = 1000;
  198. int m = 3;
  199. int n = 12;
  200. double delta_eps = 1e-3;
  201. int maxiterations = 500;
  202. int runs = 1;
  203. double holdoutportion = 0.2;
  204. int use_simulation_sivic = 0;
  205. struct CmdLineOption options[] = {
  206. {"verbose", "print details about the optimization", NULL, NULL, &optimization_verbose},
  207. {"documents", "number of documents", "1000", "%d", &d},
  208. {"topics", "number of topics", "3", "%d", &m},
  209. {"words", "number of words in the vocabulary", "12", "%d", &n},
  210. {"samples", "number of samples (in total)", "10000", "%d", &samples_count},
  211. {"maxiterations", "maximum number of EM-iterations", "500", "%d", &maxiterations},
  212. {"runs", "number of runs performed", "1", "%d", &runs},
  213. {"deltaeps", "terminate optimization if delta of likelihood is below this threshold", NULL, "%lf", &delta_eps},
  214. {"sivic", "use the simulation routine and parameters of Josef Sivic's matlab implementation", NULL, NULL, &use_simulation_sivic},
  215. {"holdout", "fraction of data for validation", "0.8", "%lf", &holdoutportion},
  216. {"tempered", "use tempered version of EM algorithm", NULL, NULL, &tempered},
  217. {NULL, NULL, NULL, NULL, NULL}
  218. };
  219. int ret;
  220. char *more_options[argc];
  221. ret = parse_arguments( argc, (const char**)argv, options, more_options);
  222. if ( ! tempered )
  223. holdoutportion = 0.0;
  224. srand48(time(NULL));
  225. PLSA pLSA ( maxiterations, delta_eps, 0.9, holdoutportion );
  226. double error_pw_z_avg = 0.0;
  227. double error_foldin_avg = 0.0;
  228. double error_pw_z_stddev = 0.0;
  229. int runs_successfull = 0;
  230. fprintf (stderr, "documents: %d\n", d );
  231. fprintf (stderr, "topics: %d\n", m );
  232. fprintf (stderr, "runs: %d\n", runs );
  233. for ( int run = 0 ; run < runs ; run++ )
  234. {
  235. fprintf (stderr, "[run %d/%d]\n", run, runs );
  236. // refactor-nice.pl: check this substitution
  237. // old: Matrix counts ( d, n );
  238. NICE::Matrix counts ( d, n );
  239. // refactor-nice.pl: check this substitution
  240. // old: Matrix pw_z ( m, n );
  241. NICE::Matrix pw_z ( m, n );
  242. // refactor-nice.pl: check this substitution
  243. // old: Vector pd ( d );
  244. NICE::Vector pd ( d );
  245. // refactor-nice.pl: check this substitution
  246. // old: Matrix pz_d ( m, d );
  247. NICE::Matrix pz_d ( m, d );
  248. if ( use_simulation_sivic )
  249. simulation_sivic ( d, m, n,
  250. counts,
  251. pw_z,
  252. pd,
  253. pz_d,
  254. samples_count );
  255. else
  256. simulation ( d, m, n,
  257. counts,
  258. pw_z,
  259. pd,
  260. pz_d,
  261. samples_count );
  262. /*
  263. // refactor-nice.pl: check this substitution
  264. // old: ImageRGB img_groundtruth ( n, m );
  265. NICE::ColorImage img_groundtruth ( n, m );
  266. ICETools::convertToRGB ( pw_z, img_groundtruth );
  267. Show(ON, img_groundtruth, "P(w|z) groundtruth");
  268. */
  269. double *counts_raw = ICETools::convertICE2M ( counts );
  270. double *pw_z_estimate_raw = new double [n*m];
  271. double *pz_d_estimate_raw = new double [m*d];
  272. // refactor-nice.pl: check this substitution
  273. // old: Vector pd_estimate ( d );
  274. NICE::Vector pd_estimate ( d );
  275. double likelihood_estimate =
  276. pLSA.pLSA ( counts_raw,
  277. pw_z_estimate_raw,
  278. pd_estimate.getDataPointer(),
  279. pz_d_estimate_raw,
  280. n, m, d,
  281. true /* do not perform folding */,
  282. tempered ? true : false /* use tempered version */,
  283. optimization_verbose ? true : false );
  284. double *pw_z_raw = ICETools::convertICE2M ( pw_z );
  285. double *pz_d_raw = ICETools::convertICE2M ( pz_d );
  286. int dtrained = d - (int) ( d * holdoutportion );
  287. double likelihood_groundtruth = pLSA.computeLikelihood (
  288. counts_raw, pw_z_raw, pd.getDataPointer(), pz_d_raw, n, m , d, dtrained );
  289. delete [] pw_z_raw;
  290. delete [] pz_d_raw;
  291. // refactor-nice.pl: check this substitution
  292. // old: Matrix pw_z_estimate ( m, n );
  293. NICE::Matrix pw_z_estimate ( m, n );
  294. ICETools::convertM2ICE ( pw_z_estimate, pw_z_estimate_raw );
  295. delete [] counts_raw;
  296. MatrixT<int> refPWZ, refPDZ;
  297. double error_pw_z = comparePermMatrices ( pw_z, pw_z_estimate, refPWZ );
  298. fprintf (stderr, "---------- final error %d/%d ---------\n", run+1, runs );
  299. fprintf (stderr, "error p(w|z) pwz: %f\n", error_pw_z );
  300. error_pw_z_avg += error_pw_z;
  301. error_pw_z_stddev += error_pw_z * error_pw_z;
  302. /***************** FOLDIN Tests *******************/
  303. KLDistance<double> dist;
  304. double error_foldin = 0.0;
  305. for ( int i = 0 ; i < dtrained ; i++ )
  306. {
  307. // refactor-nice.pl: check this substitution
  308. // old: Vector pz_singled ( m );
  309. NICE::Vector pz_singled ( m );
  310. for ( int k = 0 ; k < m ; k++ )
  311. pz_singled[k] = pz_d_estimate_raw[k*dtrained+i];
  312. // refactor-nice.pl: check this substitution
  313. // old: Vector histogramm ( n );
  314. NICE::Vector histogramm ( n );
  315. sample_document ( pz_singled, pw_z_estimate, histogramm, samples_count );
  316. // refactor-nice.pl: check this substitution
  317. // old: Vector pd_foldin (1);
  318. NICE::Vector pd_foldin (1);
  319. // refactor-nice.pl: check this substitution
  320. // old: Vector pz_d_foldin (m);
  321. NICE::Vector pz_d_foldin (m);
  322. pLSA.pLSA ( histogramm.getDataPointer(),
  323. pw_z_estimate_raw,
  324. pd_foldin.getDataPointer(),
  325. pz_d_foldin.getDataPointer(),
  326. n, m, 1,
  327. false /* do not use tempered version */,
  328. false );
  329. /*
  330. cerr << "p(1,w) = " << histogramm << endl;
  331. cerr << "p(d) =fold " << pd_foldin << endl;
  332. cerr << "p(z|d) =fold " << pz_d_foldin << endl;
  333. cerr << "p(z|d) =gt " << pz_singled << endl;
  334. */
  335. error_foldin += dist ( pz_d_foldin, pz_singled );
  336. }
  337. error_foldin /= dtrained;
  338. fprintf (stderr, "foldin error: %f\n", error_foldin );
  339. error_foldin_avg += error_foldin;
  340. runs_successfull++;
  341. double error_likelihood = likelihood_estimate - likelihood_groundtruth;
  342. fprintf (stderr, "likelihood error: %f\n", error_likelihood );
  343. if ( error_likelihood > 1e3 )
  344. {
  345. fprintf (stderr, "This seems to be a severe local minimum !\n");
  346. }
  347. delete [] pw_z_estimate_raw;
  348. delete [] pz_d_estimate_raw;
  349. }
  350. error_pw_z_avg /= runs_successfull;
  351. error_foldin_avg /= runs_successfull;
  352. error_pw_z_stddev /= runs_successfull;
  353. error_pw_z_stddev -= error_pw_z_avg*error_pw_z_avg;
  354. error_pw_z_stddev = (error_pw_z_stddev < 0.0) ? 0.0 : sqrt(error_pw_z_stddev);
  355. fprintf (stderr, "pwz: %f %f\n", error_pw_z_avg, error_pw_z_stddev );
  356. fprintf (stderr, "foldin: %f\n", error_foldin_avg );
  357. return 0;
  358. }
  359. #else
  360. int main (int argc, char **argv)
  361. {
  362. throw("no ice installed\n");
  363. return 0;
  364. }
  365. #endif
  366. #else
  367. int main (int argc, char **argv)
  368. {
  369. throw("Not converted to new structure of our library\n");
  370. return 0;
  371. }
  372. #endif