PLSA.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487
  1. /**
  2. * @file PLSA.cpp
  3. * @brief implementation of the pLSA model
  4. * @author Erik Rodner
  5. * @date 02/05/2009
  6. */
  7. #ifdef NOVISUAL
  8. #include <vislearning/nice_nonvis.h>
  9. #else
  10. #include <vislearning/nice.h>
  11. #endif
  12. #include <iostream>
  13. #include <assert.h>
  14. #include <core/vector/Algorithms.h>
  15. #include "PLSA.h"
  16. using namespace OBJREC;
  17. using namespace std;
  18. using namespace NICE;
  19. PLSA::PLSA( int maxiterations,
  20. double delta_eps,
  21. double betadecrease,
  22. double holdoutportion )
  23. {
  24. srand48(time(NULL));
  25. this->maxiterations = maxiterations;
  26. this->delta_eps = delta_eps;
  27. this->betadecrease = betadecrease;
  28. this->holdoutportion = holdoutportion;
  29. }
  30. PLSA::~PLSA()
  31. {
  32. }
  33. double PLSA::computeSparsity ( const double *A, long int size )
  34. {
  35. long count_zeros = 0;
  36. for ( long i = 0 ; i < size ; i++ )
  37. if ( fabs(A[i]) < 1e-20 )
  38. count_zeros++;
  39. return count_zeros / (double) size;
  40. }
  41. double PLSA::computeLikelihood ( const double *counts,
  42. const double *pw_z,
  43. const double *pd,
  44. const double *pz_d,
  45. int n, int m, int d,
  46. int dtrained ) const
  47. {
  48. const double eps = 1e-20;
  49. double likelihood = 0.0;
  50. if ( dtrained == -1 )
  51. dtrained = d;
  52. for ( int i = 0 ; i < dtrained ; i++ ) // documents
  53. for ( int j = 0 ; j < n ; j++ ) // words
  54. {
  55. double pdw = 0.0;
  56. assert ( ! isnan(counts[i*n+j]) );
  57. assert ( ! isnan(pd[i]) );
  58. for ( int k = 0 ; k < m ; k++ )
  59. {
  60. assert ( ! isnan(pz_d[k*d + i]) );
  61. assert ( ! isnan(pw_z[k*n + j]) );
  62. pdw += pz_d[k*d+i] * pw_z[k*n+j];
  63. }
  64. likelihood += counts[i*n+j] * log(pdw*pd[i] + eps);
  65. }
  66. return - likelihood;
  67. }
  68. double PLSA::computePerplexity ( const double *counts,
  69. const double *pw_z,
  70. const double *pz_d,
  71. int n, int m, int d) const
  72. {
  73. const double eps = 1e-7;
  74. double perplexity = 0.0;
  75. double normalization = 0.0;
  76. for ( int i = 0 ; i < d ; i++ ) // documents
  77. for ( int j = 0 ; j < n ; j++ ) // words
  78. {
  79. double pdw = 0.0;
  80. for ( int k = 0 ; k < m ; k++ )
  81. {
  82. assert ( ! isnan(pz_d[k*d + i]) );
  83. assert ( ! isnan(pw_z[k*n + j]) );
  84. pdw += pz_d[k*d+i] * pw_z[k*n+j];
  85. }
  86. perplexity += counts[i*n+j] * log(pdw + eps);
  87. normalization += counts[i*n+j];
  88. }
  89. return exp ( - perplexity / normalization );
  90. }
  91. void PLSA::uniformDistribution ( double *x, int size )
  92. {
  93. for ( int i = 0 ; i < size; i++ )
  94. x[i] = 1.0 / size;
  95. }
  96. void PLSA::normalizeRows ( double *A, int r, int c )
  97. {
  98. long index = 0;
  99. for ( int i = 0 ; i < r ; i++ )
  100. {
  101. double sum = 0.0;
  102. long index_row = index;
  103. for ( int j = 0 ; j < c ; j++, index++ )
  104. sum += A[index];
  105. if ( sum > 1e-20 )
  106. for ( int j = 0 ; j < c ; j++, index_row++ )
  107. A[index_row] /= sum;
  108. else
  109. for ( int j = 0 ; j < c ; j++, index_row++ )
  110. A[index_row] = 1.0/c;
  111. }
  112. }
  113. void PLSA::normalizeCols ( double *A, int r, int c )
  114. {
  115. for ( int j = 0 ; j < c ; j++ )
  116. {
  117. double sum = 0.0;
  118. for ( int i = 0 ; i < r ; i++ )
  119. sum += A[i*c+j];
  120. if ( sum > 1e-20 )
  121. for ( int i = 0 ; i < r ; i++ )
  122. A[i*c+j] /= sum;
  123. else
  124. for ( int i = 0 ; i < r ; i++ )
  125. A[i*c+j] = 1.0/r;
  126. }
  127. }
  128. void PLSA::randomizeBuffer ( double *A, long size )
  129. {
  130. for ( int index = 0 ; index < size ; index++ )
  131. {
  132. A[index] = drand48();
  133. }
  134. }
  135. void PLSA::pLSA_EMstep ( const double *counts,
  136. double *pw_z,
  137. double *pd,
  138. double *pz_d,
  139. double *pw_z_out,
  140. double *pd_out,
  141. double *pz_d_out,
  142. double *p_dw,
  143. int n, int m, int d,
  144. double beta,
  145. bool update_pw_z )
  146. {
  147. /************************ E-step ****************************/
  148. if ( update_pw_z )
  149. memset ( pw_z_out, 0x0, n*m*sizeof(double) );
  150. memset ( pz_d_out, 0x0, d*m*sizeof(double) );
  151. memset ( pd_out, 0x0, d*sizeof(double) );
  152. bool tempered = ( beta = 1.0 ) ? true : false;
  153. long indexij = 0;
  154. for ( int i = 0 ; i < d ; i++ )
  155. for ( int j = 0 ; j < n ; j++, indexij++ )
  156. {
  157. double c = counts[indexij];
  158. if ( c < 1e-20 ) continue;
  159. double sumk = 0.0;
  160. for ( int k = 0 ; k < m ; k++ )
  161. {
  162. double p_dw_value = pz_d[k*d+i] * pw_z[k*n+j] * pd[i];
  163. if ( tempered ) {
  164. p_dw_value = pow ( p_dw_value, beta );
  165. }
  166. // normalization of E-step
  167. sumk += p_dw_value;
  168. // M-step
  169. double val = c * p_dw_value;
  170. p_dw[k] = val;
  171. /*
  172. if ( i == 0 )
  173. fprintf (stderr, "k=%d, j=%d, val=%e p_dw=%e pw_z=%e pz_d=%e beta=%e pd=%e recomp=%e\n", k, j, val, p_dw_value,
  174. pw_z[k*n+j], pz_d[k*d+i], beta, pd[i], pz_d[k*d+i] * pw_z[k*n+j] * pd[i]);
  175. */
  176. }
  177. for ( int k = 0 ; k < m ; k++ )
  178. {
  179. if ( sumk > 1e-20 )
  180. p_dw[k] /= sumk;
  181. else
  182. p_dw[k] = 1.0/m;
  183. if ( update_pw_z )
  184. pw_z_out[k*n+j] += p_dw[k];
  185. pz_d_out[k*d+i] += p_dw[k];
  186. }
  187. pd_out[i] += counts[indexij];
  188. }
  189. if ( update_pw_z )
  190. {
  191. normalizeRows ( pw_z_out, m, n );
  192. memcpy ( pw_z, pw_z_out, sizeof(double)*n*m );
  193. }
  194. normalizeCols ( pz_d_out, m, d );
  195. memcpy ( pz_d, pz_d_out, sizeof(double)*d*m );
  196. // compute P(d)
  197. double sum_pd = 0.0;
  198. for ( int i = 0 ; i < d ; i++ )
  199. sum_pd += pd_out[i];
  200. if ( sum_pd > 1e-20 )
  201. for ( int i = 0 ; i < d ; i++ )
  202. pd[i] = pd_out[i] / sum_pd;
  203. /******** end of M step */
  204. }
  205. double PLSA::pLSA ( const double *counts,
  206. double *pw_z,
  207. double *pd,
  208. double *pz_d,
  209. int n, int m, int total_documents,
  210. bool update_pw_z,
  211. bool tempered,
  212. bool optimization_verbose )
  213. {
  214. if ( optimization_verbose ) {
  215. fprintf (stderr, "pLSA: EM algorithm ...\n");
  216. fprintf (stderr, "pLSA: EM algorithm: sparsity %f\n", computeSparsity(counts, n*total_documents) );
  217. }
  218. int d; // documents used for training
  219. int d_holdout; // documents used for validation
  220. const double *counts_holdout = NULL;
  221. double *pz_d_holdout = NULL;
  222. double *pd_holdout = NULL;
  223. if ( tempered ) {
  224. if ( optimization_verbose )
  225. fprintf (stderr, "pLSA: Tempered EM algorithm ...\n");
  226. d_holdout = (int)(holdoutportion * total_documents);
  227. d = total_documents - d_holdout;
  228. counts_holdout = counts + d*n;
  229. pz_d_holdout = new double[ d_holdout*m ];
  230. pd_holdout = new double[ d_holdout ];
  231. } else {
  232. d = total_documents;
  233. d_holdout = 0;
  234. }
  235. // EM algorithm
  236. if ( update_pw_z ) {
  237. randomizeBuffer ( pw_z, n*m );
  238. normalizeRows ( pw_z, m, n );
  239. }
  240. uniformDistribution ( pd, d );
  241. randomizeBuffer ( pz_d, d*m );
  242. normalizeCols ( pz_d, m, d );
  243. double *pz_d_out = new double [ d*m ];
  244. double *pw_z_out = NULL;
  245. if ( update_pw_z )
  246. pw_z_out = new double [ n*m ];
  247. int iteration = 0;
  248. vector<double> likelihoods;
  249. likelihoods.push_back ( computeLikelihood ( counts, pw_z, pd, pz_d, n, m, d ) );
  250. double delta_likelihood = 0.0;
  251. bool early_stop = false;
  252. double *p_dw = new double [m];
  253. double *pd_out = new double[d];
  254. double beta = 1.0;
  255. double oldperplexity = numeric_limits<double>::max();
  256. vector<double> delta_perplexities;
  257. do {
  258. pLSA_EMstep ( counts,
  259. pw_z, pd, pz_d,
  260. pw_z_out, pd_out, pz_d_out, p_dw,
  261. n, m, d,
  262. beta,
  263. update_pw_z );
  264. double newlikelihood = computeLikelihood(counts, pw_z, pd, pz_d, n, m, d);
  265. delta_likelihood = fabs(likelihoods.back() - newlikelihood) / (1.0 + fabs(newlikelihood));
  266. if ( optimization_verbose ) {
  267. fprintf (stderr, "pLSA %6d %f %e\n", iteration, newlikelihood, delta_likelihood );
  268. }
  269. likelihoods.push_back ( newlikelihood );
  270. if ( counts_holdout != NULL )
  271. {
  272. pLSA ( counts_holdout, pw_z, pd_holdout, pz_d_holdout,
  273. n, m, d_holdout, false, false );
  274. double perplexity = computePerplexity ( counts_holdout, pw_z,
  275. pz_d_holdout, n, m, d_holdout );
  276. double delta_perplexity = (oldperplexity - perplexity) / (1.0 + perplexity);
  277. if ( delta_perplexities.size() > 0 ) {
  278. if ( optimization_verbose )
  279. fprintf (stderr, "PLSA: early stopping: perplexity: %d %f %e (%e)\n", iteration, perplexity,
  280. delta_perplexity, oldperplexity);
  281. double last_delta_perplexity = delta_perplexities.back ();
  282. // if perplexity does not decrease in the last two iterations -> early stop
  283. if ( (delta_perplexity <= 0.0) && (last_delta_perplexity <= 0.0) )
  284. {
  285. early_stop = true;
  286. if ( optimization_verbose )
  287. fprintf (stderr, "PLSA: stopped due to early stopping !\n");
  288. }
  289. }
  290. delta_perplexities.push_back ( delta_perplexity );
  291. oldperplexity = perplexity;
  292. }
  293. iteration++;
  294. } while ( (iteration < maxiterations) && (delta_likelihood > delta_eps) && (! early_stop) );
  295. if ( tempered )
  296. {
  297. early_stop = false;
  298. delta_perplexities.clear();
  299. beta *= betadecrease;
  300. do {
  301. pLSA_EMstep ( counts,
  302. pw_z, pd, pz_d,
  303. pw_z_out, pd_out, pz_d_out, p_dw,
  304. n, m, d,
  305. beta,
  306. update_pw_z );
  307. double newlikelihood = computeLikelihood(counts, pw_z, pd, pz_d, n, m, d);
  308. delta_likelihood = fabs(likelihoods.back() - newlikelihood) / ( 1.0 + newlikelihood );
  309. if ( optimization_verbose )
  310. fprintf (stderr, "pLSA_tempered %6d %f %e\n", iteration, newlikelihood, delta_likelihood );
  311. likelihoods.push_back ( newlikelihood );
  312. pLSA ( counts_holdout, pw_z, pd_holdout, pz_d_holdout,
  313. n, m, d_holdout, false, false );
  314. double perplexity = computePerplexity ( counts_holdout, pw_z,
  315. pz_d_holdout, n, m, d_holdout );
  316. double delta_perplexity = (oldperplexity - perplexity) / (1.0 + perplexity);
  317. if ( delta_perplexities.size() > 0 ) {
  318. double last_delta_perplexity = delta_perplexities.back ();
  319. if ( optimization_verbose )
  320. fprintf (stderr, "PLSA: early stopping: perplexity: %d %f %f\n", iteration, perplexity,
  321. delta_perplexity);
  322. // if perplexity does not decrease in the last two iterations -> early stop
  323. if ( (delta_perplexity <= 0.0) && (last_delta_perplexity <= 0.0) )
  324. {
  325. if ( delta_perplexities.size() <= 1 ) {
  326. if ( optimization_verbose )
  327. fprintf (stderr, "PLSA: early stop !\n");
  328. } else {
  329. if ( optimization_verbose )
  330. fprintf (stderr, "PLSA: decreasing beta !\n");
  331. delta_perplexities.clear();
  332. beta *= betadecrease;
  333. }
  334. }
  335. }
  336. delta_perplexities.push_back ( delta_perplexity );
  337. oldperplexity = perplexity;
  338. iteration++;
  339. } while ( (iteration < maxiterations) && (delta_likelihood > delta_eps) && (! early_stop) );
  340. }
  341. if ( optimization_verbose )
  342. fprintf (stderr, "pLSA: total number of iterations %d\n", iteration );
  343. delete [] pz_d_out;
  344. delete [] pd_out;
  345. if ( update_pw_z )
  346. delete [] pw_z_out;
  347. delete [] p_dw;
  348. if ( counts_holdout != NULL )
  349. {
  350. delete [] pz_d_holdout;
  351. delete [] pd_holdout;
  352. }
  353. /*
  354. Gnuplot gp ("lines");
  355. gp.plot_x ( likelihoods, "pLSA optimization" );
  356. // refactor-nice.pl: check this substitution
  357. // old: GetChar();
  358. getchar();
  359. */
  360. return likelihoods.back();
  361. }
  362. double PLSA::algebraicFoldIn ( const double *counts,
  363. double *pw_z,
  364. double *pd,
  365. double *pz_d,
  366. int n, int m )
  367. {
  368. // refactor-nice.pl: check this substitution
  369. // old: Matrix W ( n, m );
  370. NICE::Matrix W ( n, m );
  371. // refactor-nice.pl: check this substitution
  372. // old: Vector c ( n );
  373. NICE::Vector c ( n );
  374. for ( int i = 0 ; i < n ; i++ )
  375. c[i] = counts[i];
  376. for ( int i = 0 ; i < n ; i++ )
  377. for ( int k = 0 ; k < m ; k++ )
  378. // refactor-nice.pl: check this substitution
  379. // old: W[i][k] = pw_z[k*n+i];
  380. W(i, k) = pw_z[k*n+i];
  381. // refactor-nice.pl: check this substitution
  382. // old: Vector sol ( m );
  383. NICE::Vector sol ( m );
  384. NICE::solveLinearEquationQR ( W, c, sol );
  385. (*pd) = 1.0;
  386. sol.normalizeL1();
  387. memcpy ( pz_d, sol.getDataPointer(), m*sizeof(double));
  388. return 0.0;
  389. }