PLSA.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481
  1. /**
  2. * @file PLSA.cpp
  3. * @brief implementation of the pLSA model
  4. * @author Erik Rodner
  5. * @date 02/05/2009
  6. */
  7. #include <iostream>
  8. #include <assert.h>
  9. #include <core/vector/Algorithms.h>
  10. #include "PLSA.h"
  11. using namespace OBJREC;
  12. using namespace std;
  13. using namespace NICE;
  14. PLSA::PLSA( int maxiterations,
  15. double delta_eps,
  16. double betadecrease,
  17. double holdoutportion )
  18. {
  19. srand48(time(NULL));
  20. this->maxiterations = maxiterations;
  21. this->delta_eps = delta_eps;
  22. this->betadecrease = betadecrease;
  23. this->holdoutportion = holdoutportion;
  24. }
  25. PLSA::~PLSA()
  26. {
  27. }
  28. double PLSA::computeSparsity ( const double *A, long int size )
  29. {
  30. long count_zeros = 0;
  31. for ( long i = 0 ; i < size ; i++ )
  32. if ( fabs(A[i]) < 1e-20 )
  33. count_zeros++;
  34. return count_zeros / (double) size;
  35. }
  36. double PLSA::computeLikelihood ( const double *counts,
  37. const double *pw_z,
  38. const double *pd,
  39. const double *pz_d,
  40. int n, int m, int d,
  41. int dtrained ) const
  42. {
  43. const double eps = 1e-20;
  44. double likelihood = 0.0;
  45. if ( dtrained == -1 )
  46. dtrained = d;
  47. for ( int i = 0 ; i < dtrained ; i++ ) // documents
  48. for ( int j = 0 ; j < n ; j++ ) // words
  49. {
  50. double pdw = 0.0;
  51. assert ( ! isnan(counts[i*n+j]) );
  52. assert ( ! isnan(pd[i]) );
  53. for ( int k = 0 ; k < m ; k++ )
  54. {
  55. assert ( ! isnan(pz_d[k*d + i]) );
  56. assert ( ! isnan(pw_z[k*n + j]) );
  57. pdw += pz_d[k*d+i] * pw_z[k*n+j];
  58. }
  59. likelihood += counts[i*n+j] * log(pdw*pd[i] + eps);
  60. }
  61. return - likelihood;
  62. }
  63. double PLSA::computePerplexity ( const double *counts,
  64. const double *pw_z,
  65. const double *pz_d,
  66. int n, int m, int d) const
  67. {
  68. const double eps = 1e-7;
  69. double perplexity = 0.0;
  70. double normalization = 0.0;
  71. for ( int i = 0 ; i < d ; i++ ) // documents
  72. for ( int j = 0 ; j < n ; j++ ) // words
  73. {
  74. double pdw = 0.0;
  75. for ( int k = 0 ; k < m ; k++ )
  76. {
  77. assert ( ! isnan(pz_d[k*d + i]) );
  78. assert ( ! isnan(pw_z[k*n + j]) );
  79. pdw += pz_d[k*d+i] * pw_z[k*n+j];
  80. }
  81. perplexity += counts[i*n+j] * log(pdw + eps);
  82. normalization += counts[i*n+j];
  83. }
  84. return exp ( - perplexity / normalization );
  85. }
  86. void PLSA::uniformDistribution ( double *x, int size )
  87. {
  88. for ( int i = 0 ; i < size; i++ )
  89. x[i] = 1.0 / size;
  90. }
  91. void PLSA::normalizeRows ( double *A, int r, int c )
  92. {
  93. long index = 0;
  94. for ( int i = 0 ; i < r ; i++ )
  95. {
  96. double sum = 0.0;
  97. long index_row = index;
  98. for ( int j = 0 ; j < c ; j++, index++ )
  99. sum += A[index];
  100. if ( sum > 1e-20 )
  101. for ( int j = 0 ; j < c ; j++, index_row++ )
  102. A[index_row] /= sum;
  103. else
  104. for ( int j = 0 ; j < c ; j++, index_row++ )
  105. A[index_row] = 1.0/c;
  106. }
  107. }
  108. void PLSA::normalizeCols ( double *A, int r, int c )
  109. {
  110. for ( int j = 0 ; j < c ; j++ )
  111. {
  112. double sum = 0.0;
  113. for ( int i = 0 ; i < r ; i++ )
  114. sum += A[i*c+j];
  115. if ( sum > 1e-20 )
  116. for ( int i = 0 ; i < r ; i++ )
  117. A[i*c+j] /= sum;
  118. else
  119. for ( int i = 0 ; i < r ; i++ )
  120. A[i*c+j] = 1.0/r;
  121. }
  122. }
  123. void PLSA::randomizeBuffer ( double *A, long size )
  124. {
  125. for ( int index = 0 ; index < size ; index++ )
  126. {
  127. A[index] = drand48();
  128. }
  129. }
  130. void PLSA::pLSA_EMstep ( const double *counts,
  131. double *pw_z,
  132. double *pd,
  133. double *pz_d,
  134. double *pw_z_out,
  135. double *pd_out,
  136. double *pz_d_out,
  137. double *p_dw,
  138. int n, int m, int d,
  139. double beta,
  140. bool update_pw_z )
  141. {
  142. /************************ E-step ****************************/
  143. if ( update_pw_z )
  144. memset ( pw_z_out, 0x0, n*m*sizeof(double) );
  145. memset ( pz_d_out, 0x0, d*m*sizeof(double) );
  146. memset ( pd_out, 0x0, d*sizeof(double) );
  147. bool tempered = ( beta = 1.0 ) ? true : false;
  148. long indexij = 0;
  149. for ( int i = 0 ; i < d ; i++ )
  150. for ( int j = 0 ; j < n ; j++, indexij++ )
  151. {
  152. double c = counts[indexij];
  153. if ( c < 1e-20 ) continue;
  154. double sumk = 0.0;
  155. for ( int k = 0 ; k < m ; k++ )
  156. {
  157. double p_dw_value = pz_d[k*d+i] * pw_z[k*n+j] * pd[i];
  158. if ( tempered ) {
  159. p_dw_value = pow ( p_dw_value, beta );
  160. }
  161. // normalization of E-step
  162. sumk += p_dw_value;
  163. // M-step
  164. double val = c * p_dw_value;
  165. p_dw[k] = val;
  166. /*
  167. if ( i == 0 )
  168. fprintf (stderr, "k=%d, j=%d, val=%e p_dw=%e pw_z=%e pz_d=%e beta=%e pd=%e recomp=%e\n", k, j, val, p_dw_value,
  169. pw_z[k*n+j], pz_d[k*d+i], beta, pd[i], pz_d[k*d+i] * pw_z[k*n+j] * pd[i]);
  170. */
  171. }
  172. for ( int k = 0 ; k < m ; k++ )
  173. {
  174. if ( sumk > 1e-20 )
  175. p_dw[k] /= sumk;
  176. else
  177. p_dw[k] = 1.0/m;
  178. if ( update_pw_z )
  179. pw_z_out[k*n+j] += p_dw[k];
  180. pz_d_out[k*d+i] += p_dw[k];
  181. }
  182. pd_out[i] += counts[indexij];
  183. }
  184. if ( update_pw_z )
  185. {
  186. normalizeRows ( pw_z_out, m, n );
  187. memcpy ( pw_z, pw_z_out, sizeof(double)*n*m );
  188. }
  189. normalizeCols ( pz_d_out, m, d );
  190. memcpy ( pz_d, pz_d_out, sizeof(double)*d*m );
  191. // compute P(d)
  192. double sum_pd = 0.0;
  193. for ( int i = 0 ; i < d ; i++ )
  194. sum_pd += pd_out[i];
  195. if ( sum_pd > 1e-20 )
  196. for ( int i = 0 ; i < d ; i++ )
  197. pd[i] = pd_out[i] / sum_pd;
  198. /******** end of M step */
  199. }
  200. double PLSA::pLSA ( const double *counts,
  201. double *pw_z,
  202. double *pd,
  203. double *pz_d,
  204. int n, int m, int total_documents,
  205. bool update_pw_z,
  206. bool tempered,
  207. bool optimization_verbose )
  208. {
  209. if ( optimization_verbose ) {
  210. fprintf (stderr, "pLSA: EM algorithm ...\n");
  211. fprintf (stderr, "pLSA: EM algorithm: sparsity %f\n", computeSparsity(counts, n*total_documents) );
  212. }
  213. int d; // documents used for training
  214. int d_holdout; // documents used for validation
  215. const double *counts_holdout = NULL;
  216. double *pz_d_holdout = NULL;
  217. double *pd_holdout = NULL;
  218. if ( tempered ) {
  219. if ( optimization_verbose )
  220. fprintf (stderr, "pLSA: Tempered EM algorithm ...\n");
  221. d_holdout = (int)(holdoutportion * total_documents);
  222. d = total_documents - d_holdout;
  223. counts_holdout = counts + d*n;
  224. pz_d_holdout = new double[ d_holdout*m ];
  225. pd_holdout = new double[ d_holdout ];
  226. } else {
  227. d = total_documents;
  228. d_holdout = 0;
  229. }
  230. // EM algorithm
  231. if ( update_pw_z ) {
  232. randomizeBuffer ( pw_z, n*m );
  233. normalizeRows ( pw_z, m, n );
  234. }
  235. uniformDistribution ( pd, d );
  236. randomizeBuffer ( pz_d, d*m );
  237. normalizeCols ( pz_d, m, d );
  238. double *pz_d_out = new double [ d*m ];
  239. double *pw_z_out = NULL;
  240. if ( update_pw_z )
  241. pw_z_out = new double [ n*m ];
  242. int iteration = 0;
  243. vector<double> likelihoods;
  244. likelihoods.push_back ( computeLikelihood ( counts, pw_z, pd, pz_d, n, m, d ) );
  245. double delta_likelihood = 0.0;
  246. bool early_stop = false;
  247. double *p_dw = new double [m];
  248. double *pd_out = new double[d];
  249. double beta = 1.0;
  250. double oldperplexity = numeric_limits<double>::max();
  251. vector<double> delta_perplexities;
  252. do {
  253. pLSA_EMstep ( counts,
  254. pw_z, pd, pz_d,
  255. pw_z_out, pd_out, pz_d_out, p_dw,
  256. n, m, d,
  257. beta,
  258. update_pw_z );
  259. double newlikelihood = computeLikelihood(counts, pw_z, pd, pz_d, n, m, d);
  260. delta_likelihood = fabs(likelihoods.back() - newlikelihood) / (1.0 + fabs(newlikelihood));
  261. if ( optimization_verbose ) {
  262. fprintf (stderr, "pLSA %6d %f %e\n", iteration, newlikelihood, delta_likelihood );
  263. }
  264. likelihoods.push_back ( newlikelihood );
  265. if ( counts_holdout != NULL )
  266. {
  267. pLSA ( counts_holdout, pw_z, pd_holdout, pz_d_holdout,
  268. n, m, d_holdout, false, false );
  269. double perplexity = computePerplexity ( counts_holdout, pw_z,
  270. pz_d_holdout, n, m, d_holdout );
  271. double delta_perplexity = (oldperplexity - perplexity) / (1.0 + perplexity);
  272. if ( delta_perplexities.size() > 0 ) {
  273. if ( optimization_verbose )
  274. fprintf (stderr, "PLSA: early stopping: perplexity: %d %f %e (%e)\n", iteration, perplexity,
  275. delta_perplexity, oldperplexity);
  276. double last_delta_perplexity = delta_perplexities.back ();
  277. // if perplexity does not decrease in the last two iterations -> early stop
  278. if ( (delta_perplexity <= 0.0) && (last_delta_perplexity <= 0.0) )
  279. {
  280. early_stop = true;
  281. if ( optimization_verbose )
  282. fprintf (stderr, "PLSA: stopped due to early stopping !\n");
  283. }
  284. }
  285. delta_perplexities.push_back ( delta_perplexity );
  286. oldperplexity = perplexity;
  287. }
  288. iteration++;
  289. } while ( (iteration < maxiterations) && (delta_likelihood > delta_eps) && (! early_stop) );
  290. if ( tempered )
  291. {
  292. early_stop = false;
  293. delta_perplexities.clear();
  294. beta *= betadecrease;
  295. do {
  296. pLSA_EMstep ( counts,
  297. pw_z, pd, pz_d,
  298. pw_z_out, pd_out, pz_d_out, p_dw,
  299. n, m, d,
  300. beta,
  301. update_pw_z );
  302. double newlikelihood = computeLikelihood(counts, pw_z, pd, pz_d, n, m, d);
  303. delta_likelihood = fabs(likelihoods.back() - newlikelihood) / ( 1.0 + newlikelihood );
  304. if ( optimization_verbose )
  305. fprintf (stderr, "pLSA_tempered %6d %f %e\n", iteration, newlikelihood, delta_likelihood );
  306. likelihoods.push_back ( newlikelihood );
  307. pLSA ( counts_holdout, pw_z, pd_holdout, pz_d_holdout,
  308. n, m, d_holdout, false, false );
  309. double perplexity = computePerplexity ( counts_holdout, pw_z,
  310. pz_d_holdout, n, m, d_holdout );
  311. double delta_perplexity = (oldperplexity - perplexity) / (1.0 + perplexity);
  312. if ( delta_perplexities.size() > 0 ) {
  313. double last_delta_perplexity = delta_perplexities.back ();
  314. if ( optimization_verbose )
  315. fprintf (stderr, "PLSA: early stopping: perplexity: %d %f %f\n", iteration, perplexity,
  316. delta_perplexity);
  317. // if perplexity does not decrease in the last two iterations -> early stop
  318. if ( (delta_perplexity <= 0.0) && (last_delta_perplexity <= 0.0) )
  319. {
  320. if ( delta_perplexities.size() <= 1 ) {
  321. if ( optimization_verbose )
  322. fprintf (stderr, "PLSA: early stop !\n");
  323. } else {
  324. if ( optimization_verbose )
  325. fprintf (stderr, "PLSA: decreasing beta !\n");
  326. delta_perplexities.clear();
  327. beta *= betadecrease;
  328. }
  329. }
  330. }
  331. delta_perplexities.push_back ( delta_perplexity );
  332. oldperplexity = perplexity;
  333. iteration++;
  334. } while ( (iteration < maxiterations) && (delta_likelihood > delta_eps) && (! early_stop) );
  335. }
  336. if ( optimization_verbose )
  337. fprintf (stderr, "pLSA: total number of iterations %d\n", iteration );
  338. delete [] pz_d_out;
  339. delete [] pd_out;
  340. if ( update_pw_z )
  341. delete [] pw_z_out;
  342. delete [] p_dw;
  343. if ( counts_holdout != NULL )
  344. {
  345. delete [] pz_d_holdout;
  346. delete [] pd_holdout;
  347. }
  348. /*
  349. Gnuplot gp ("lines");
  350. gp.plot_x ( likelihoods, "pLSA optimization" );
  351. // refactor-nice.pl: check this substitution
  352. // old: GetChar();
  353. getchar();
  354. */
  355. return likelihoods.back();
  356. }
  357. double PLSA::algebraicFoldIn ( const double *counts,
  358. double *pw_z,
  359. double *pd,
  360. double *pz_d,
  361. int n, int m )
  362. {
  363. // refactor-nice.pl: check this substitution
  364. // old: Matrix W ( n, m );
  365. NICE::Matrix W ( n, m );
  366. // refactor-nice.pl: check this substitution
  367. // old: Vector c ( n );
  368. NICE::Vector c ( n );
  369. for ( int i = 0 ; i < n ; i++ )
  370. c[i] = counts[i];
  371. for ( int i = 0 ; i < n ; i++ )
  372. for ( int k = 0 ; k < m ; k++ )
  373. // refactor-nice.pl: check this substitution
  374. // old: W[i][k] = pw_z[k*n+i];
  375. W(i, k) = pw_z[k*n+i];
  376. // refactor-nice.pl: check this substitution
  377. // old: Vector sol ( m );
  378. NICE::Vector sol ( m );
  379. NICE::solveLinearEquationQR ( W, c, sol );
  380. (*pd) = 1.0;
  381. sol.normalizeL1();
  382. memcpy ( pz_d, sol.getDataPointer(), m*sizeof(double));
  383. return 0.0;
  384. }