PLSA.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491
  1. /**
  2. * @file PLSA.cpp
  3. * @brief implementation of the pLSA model
  4. * @author Erik Rodner
  5. * @date 02/05/2009
  6. */
  7. #include <iostream>
  8. #include <assert.h>
  9. #include <time.h>
  10. #include <core/vector/Algorithms.h>
  11. #include "PLSA.h"
  12. using namespace OBJREC;
  13. using namespace std;
  14. using namespace NICE;
  15. PLSA::PLSA( int maxiterations,
  16. double delta_eps,
  17. double betadecrease,
  18. double holdoutportion )
  19. {
  20. #ifdef WIN32
  21. srand ( time ( NULL ) );
  22. #else
  23. srand48 ( time ( NULL ) );
  24. #endif
  25. this->maxiterations = maxiterations;
  26. this->delta_eps = delta_eps;
  27. this->betadecrease = betadecrease;
  28. this->holdoutportion = holdoutportion;
  29. }
  30. PLSA::~PLSA()
  31. {
  32. }
  33. double PLSA::computeSparsity ( const double *A, long int size )
  34. {
  35. long count_zeros = 0;
  36. for ( long i = 0 ; i < size ; i++ )
  37. if ( fabs(A[i]) < 1e-20 )
  38. count_zeros++;
  39. return count_zeros / (double) size;
  40. }
  41. double PLSA::computeLikelihood ( const double *counts,
  42. const double *pw_z,
  43. const double *pd,
  44. const double *pz_d,
  45. int n, int m, int d,
  46. int dtrained ) const
  47. {
  48. const double eps = 1e-20;
  49. double likelihood = 0.0;
  50. if ( dtrained == -1 )
  51. dtrained = d;
  52. for ( int i = 0 ; i < dtrained ; i++ ) // documents
  53. for ( int j = 0 ; j < n ; j++ ) // words
  54. {
  55. double pdw = 0.0;
  56. assert ( ! NICE::isNaN(counts[i*n+j]) );
  57. assert ( ! NICE::isNaN(pd[i]) );
  58. for ( int k = 0 ; k < m ; k++ )
  59. {
  60. assert ( ! NICE::isNaN(pz_d[k*d + i]) );
  61. assert ( ! NICE::isNaN(pw_z[k*n + j]) );
  62. pdw += pz_d[k*d+i] * pw_z[k*n+j];
  63. }
  64. likelihood += counts[i*n+j] * log(pdw*pd[i] + eps);
  65. }
  66. return - likelihood;
  67. }
  68. double PLSA::computePerplexity ( const double *counts,
  69. const double *pw_z,
  70. const double *pz_d,
  71. int n, int m, int d) const
  72. {
  73. const double eps = 1e-7;
  74. double perplexity = 0.0;
  75. double normalization = 0.0;
  76. for ( int i = 0 ; i < d ; i++ ) // documents
  77. for ( int j = 0 ; j < n ; j++ ) // words
  78. {
  79. double pdw = 0.0;
  80. for ( int k = 0 ; k < m ; k++ )
  81. {
  82. assert ( ! NICE::isNaN(pz_d[k*d + i]) );
  83. assert ( ! NICE::isNaN(pw_z[k*n + j]) );
  84. pdw += pz_d[k*d+i] * pw_z[k*n+j];
  85. }
  86. perplexity += counts[i*n+j] * log(pdw + eps);
  87. normalization += counts[i*n+j];
  88. }
  89. return exp ( - perplexity / normalization );
  90. }
  91. void PLSA::uniformDistribution ( double *x, int size )
  92. {
  93. for ( int i = 0 ; i < size; i++ )
  94. x[i] = 1.0 / size;
  95. }
  96. void PLSA::normalizeRows ( double *A, int r, int c )
  97. {
  98. long index = 0;
  99. for ( int i = 0 ; i < r ; i++ )
  100. {
  101. double sum = 0.0;
  102. long index_row = index;
  103. for ( int j = 0 ; j < c ; j++, index++ )
  104. sum += A[index];
  105. if ( sum > 1e-20 )
  106. for ( int j = 0 ; j < c ; j++, index_row++ )
  107. A[index_row] /= sum;
  108. else
  109. for ( int j = 0 ; j < c ; j++, index_row++ )
  110. A[index_row] = 1.0/c;
  111. }
  112. }
  113. void PLSA::normalizeCols ( double *A, int r, int c )
  114. {
  115. for ( int j = 0 ; j < c ; j++ )
  116. {
  117. double sum = 0.0;
  118. for ( int i = 0 ; i < r ; i++ )
  119. sum += A[i*c+j];
  120. if ( sum > 1e-20 )
  121. for ( int i = 0 ; i < r ; i++ )
  122. A[i*c+j] /= sum;
  123. else
  124. for ( int i = 0 ; i < r ; i++ )
  125. A[i*c+j] = 1.0/r;
  126. }
  127. }
  128. void PLSA::randomizeBuffer ( double *A, long size )
  129. {
  130. for ( int index = 0 ; index < size ; index++ )
  131. {
  132. #ifdef WIN32
  133. A[index] = (double( rand() ) / RAND_MAX );
  134. #else
  135. A[index] = drand48();
  136. #endif
  137. }
  138. }
  139. void PLSA::pLSA_EMstep ( const double *counts,
  140. double *pw_z,
  141. double *pd,
  142. double *pz_d,
  143. double *pw_z_out,
  144. double *pd_out,
  145. double *pz_d_out,
  146. double *p_dw,
  147. int n, int m, int d,
  148. double beta,
  149. bool update_pw_z )
  150. {
  151. /************************ E-step ****************************/
  152. if ( update_pw_z )
  153. memset ( pw_z_out, 0x0, n*m*sizeof(double) );
  154. memset ( pz_d_out, 0x0, d*m*sizeof(double) );
  155. memset ( pd_out, 0x0, d*sizeof(double) );
  156. bool tempered = ( beta = 1.0 ) ? true : false;
  157. long indexij = 0;
  158. for ( int i = 0 ; i < d ; i++ )
  159. for ( int j = 0 ; j < n ; j++, indexij++ )
  160. {
  161. double c = counts[indexij];
  162. if ( c < 1e-20 ) continue;
  163. double sumk = 0.0;
  164. for ( int k = 0 ; k < m ; k++ )
  165. {
  166. double p_dw_value = pz_d[k*d+i] * pw_z[k*n+j] * pd[i];
  167. if ( tempered ) {
  168. p_dw_value = pow ( p_dw_value, beta );
  169. }
  170. // normalization of E-step
  171. sumk += p_dw_value;
  172. // M-step
  173. double val = c * p_dw_value;
  174. p_dw[k] = val;
  175. /*
  176. if ( i == 0 )
  177. fprintf (stderr, "k=%d, j=%d, val=%e p_dw=%e pw_z=%e pz_d=%e beta=%e pd=%e recomp=%e\n", k, j, val, p_dw_value,
  178. pw_z[k*n+j], pz_d[k*d+i], beta, pd[i], pz_d[k*d+i] * pw_z[k*n+j] * pd[i]);
  179. */
  180. }
  181. for ( int k = 0 ; k < m ; k++ )
  182. {
  183. if ( sumk > 1e-20 )
  184. p_dw[k] /= sumk;
  185. else
  186. p_dw[k] = 1.0/m;
  187. if ( update_pw_z )
  188. pw_z_out[k*n+j] += p_dw[k];
  189. pz_d_out[k*d+i] += p_dw[k];
  190. }
  191. pd_out[i] += counts[indexij];
  192. }
  193. if ( update_pw_z )
  194. {
  195. normalizeRows ( pw_z_out, m, n );
  196. memcpy ( pw_z, pw_z_out, sizeof(double)*n*m );
  197. }
  198. normalizeCols ( pz_d_out, m, d );
  199. memcpy ( pz_d, pz_d_out, sizeof(double)*d*m );
  200. // compute P(d)
  201. double sum_pd = 0.0;
  202. for ( int i = 0 ; i < d ; i++ )
  203. sum_pd += pd_out[i];
  204. if ( sum_pd > 1e-20 )
  205. for ( int i = 0 ; i < d ; i++ )
  206. pd[i] = pd_out[i] / sum_pd;
  207. /******** end of M step */
  208. }
  209. double PLSA::pLSA ( const double *counts,
  210. double *pw_z,
  211. double *pd,
  212. double *pz_d,
  213. int n, int m, int total_documents,
  214. bool update_pw_z,
  215. bool tempered,
  216. bool optimization_verbose )
  217. {
  218. if ( optimization_verbose ) {
  219. fprintf (stderr, "pLSA: EM algorithm ...\n");
  220. fprintf (stderr, "pLSA: EM algorithm: sparsity %f\n", computeSparsity(counts, n*total_documents) );
  221. }
  222. int d; // documents used for training
  223. int d_holdout; // documents used for validation
  224. const double *counts_holdout = NULL;
  225. double *pz_d_holdout = NULL;
  226. double *pd_holdout = NULL;
  227. if ( tempered ) {
  228. if ( optimization_verbose )
  229. fprintf (stderr, "pLSA: Tempered EM algorithm ...\n");
  230. d_holdout = (int)(holdoutportion * total_documents);
  231. d = total_documents - d_holdout;
  232. counts_holdout = counts + d*n;
  233. pz_d_holdout = new double[ d_holdout*m ];
  234. pd_holdout = new double[ d_holdout ];
  235. } else {
  236. d = total_documents;
  237. d_holdout = 0;
  238. }
  239. // EM algorithm
  240. if ( update_pw_z ) {
  241. randomizeBuffer ( pw_z, n*m );
  242. normalizeRows ( pw_z, m, n );
  243. }
  244. uniformDistribution ( pd, d );
  245. randomizeBuffer ( pz_d, d*m );
  246. normalizeCols ( pz_d, m, d );
  247. double *pz_d_out = new double [ d*m ];
  248. double *pw_z_out = NULL;
  249. if ( update_pw_z )
  250. pw_z_out = new double [ n*m ];
  251. int iteration = 0;
  252. vector<double> likelihoods;
  253. likelihoods.push_back ( computeLikelihood ( counts, pw_z, pd, pz_d, n, m, d ) );
  254. double delta_likelihood = 0.0;
  255. bool early_stop = false;
  256. double *p_dw = new double [m];
  257. double *pd_out = new double[d];
  258. double beta = 1.0;
  259. double oldperplexity = numeric_limits<double>::max();
  260. vector<double> delta_perplexities;
  261. do {
  262. pLSA_EMstep ( counts,
  263. pw_z, pd, pz_d,
  264. pw_z_out, pd_out, pz_d_out, p_dw,
  265. n, m, d,
  266. beta,
  267. update_pw_z );
  268. double newlikelihood = computeLikelihood(counts, pw_z, pd, pz_d, n, m, d);
  269. delta_likelihood = fabs(likelihoods.back() - newlikelihood) / (1.0 + fabs(newlikelihood));
  270. if ( optimization_verbose ) {
  271. fprintf (stderr, "pLSA %6d %f %e\n", iteration, newlikelihood, delta_likelihood );
  272. }
  273. likelihoods.push_back ( newlikelihood );
  274. if ( counts_holdout != NULL )
  275. {
  276. pLSA ( counts_holdout, pw_z, pd_holdout, pz_d_holdout,
  277. n, m, d_holdout, false, false );
  278. double perplexity = computePerplexity ( counts_holdout, pw_z,
  279. pz_d_holdout, n, m, d_holdout );
  280. double delta_perplexity = (oldperplexity - perplexity) / (1.0 + perplexity);
  281. if ( delta_perplexities.size() > 0 ) {
  282. if ( optimization_verbose )
  283. fprintf (stderr, "PLSA: early stopping: perplexity: %d %f %e (%e)\n", iteration, perplexity,
  284. delta_perplexity, oldperplexity);
  285. double last_delta_perplexity = delta_perplexities.back ();
  286. // if perplexity does not decrease in the last two iterations -> early stop
  287. if ( (delta_perplexity <= 0.0) && (last_delta_perplexity <= 0.0) )
  288. {
  289. early_stop = true;
  290. if ( optimization_verbose )
  291. fprintf (stderr, "PLSA: stopped due to early stopping !\n");
  292. }
  293. }
  294. delta_perplexities.push_back ( delta_perplexity );
  295. oldperplexity = perplexity;
  296. }
  297. iteration++;
  298. } while ( (iteration < maxiterations) && (delta_likelihood > delta_eps) && (! early_stop) );
  299. if ( tempered )
  300. {
  301. early_stop = false;
  302. delta_perplexities.clear();
  303. beta *= betadecrease;
  304. do {
  305. pLSA_EMstep ( counts,
  306. pw_z, pd, pz_d,
  307. pw_z_out, pd_out, pz_d_out, p_dw,
  308. n, m, d,
  309. beta,
  310. update_pw_z );
  311. double newlikelihood = computeLikelihood(counts, pw_z, pd, pz_d, n, m, d);
  312. delta_likelihood = fabs(likelihoods.back() - newlikelihood) / ( 1.0 + newlikelihood );
  313. if ( optimization_verbose )
  314. fprintf (stderr, "pLSA_tempered %6d %f %e\n", iteration, newlikelihood, delta_likelihood );
  315. likelihoods.push_back ( newlikelihood );
  316. pLSA ( counts_holdout, pw_z, pd_holdout, pz_d_holdout,
  317. n, m, d_holdout, false, false );
  318. double perplexity = computePerplexity ( counts_holdout, pw_z,
  319. pz_d_holdout, n, m, d_holdout );
  320. double delta_perplexity = (oldperplexity - perplexity) / (1.0 + perplexity);
  321. if ( delta_perplexities.size() > 0 ) {
  322. double last_delta_perplexity = delta_perplexities.back ();
  323. if ( optimization_verbose )
  324. fprintf (stderr, "PLSA: early stopping: perplexity: %d %f %f\n", iteration, perplexity,
  325. delta_perplexity);
  326. // if perplexity does not decrease in the last two iterations -> early stop
  327. if ( (delta_perplexity <= 0.0) && (last_delta_perplexity <= 0.0) )
  328. {
  329. if ( delta_perplexities.size() <= 1 ) {
  330. if ( optimization_verbose )
  331. fprintf (stderr, "PLSA: early stop !\n");
  332. } else {
  333. if ( optimization_verbose )
  334. fprintf (stderr, "PLSA: decreasing beta !\n");
  335. delta_perplexities.clear();
  336. beta *= betadecrease;
  337. }
  338. }
  339. }
  340. delta_perplexities.push_back ( delta_perplexity );
  341. oldperplexity = perplexity;
  342. iteration++;
  343. } while ( (iteration < maxiterations) && (delta_likelihood > delta_eps) && (! early_stop) );
  344. }
  345. if ( optimization_verbose )
  346. fprintf (stderr, "pLSA: total number of iterations %d\n", iteration );
  347. delete [] pz_d_out;
  348. delete [] pd_out;
  349. if ( update_pw_z )
  350. delete [] pw_z_out;
  351. delete [] p_dw;
  352. if ( counts_holdout != NULL )
  353. {
  354. delete [] pz_d_holdout;
  355. delete [] pd_holdout;
  356. }
  357. /*
  358. Gnuplot gp ("lines");
  359. gp.plot_x ( likelihoods, "pLSA optimization" );
  360. // refactor-nice.pl: check this substitution
  361. // old: GetChar();
  362. getchar();
  363. */
  364. return likelihoods.back();
  365. }
  366. double PLSA::algebraicFoldIn ( const double *counts,
  367. double *pw_z,
  368. double *pd,
  369. double *pz_d,
  370. int n, int m )
  371. {
  372. // refactor-nice.pl: check this substitution
  373. // old: Matrix W ( n, m );
  374. NICE::Matrix W ( n, m );
  375. // refactor-nice.pl: check this substitution
  376. // old: Vector c ( n );
  377. NICE::Vector c ( n );
  378. for ( int i = 0 ; i < n ; i++ )
  379. c[i] = counts[i];
  380. for ( int i = 0 ; i < n ; i++ )
  381. for ( int k = 0 ; k < m ; k++ )
  382. // refactor-nice.pl: check this substitution
  383. // old: W[i][k] = pw_z[k*n+i];
  384. W(i, k) = pw_z[k*n+i];
  385. // refactor-nice.pl: check this substitution
  386. // old: Vector sol ( m );
  387. NICE::Vector sol ( m );
  388. NICE::solveLinearEquationQR ( W, c, sol );
  389. (*pd) = 1.0;
  390. sol.normalizeL1();
  391. memcpy ( pz_d, sol.getDataPointer(), m*sizeof(double));
  392. return 0.0;
  393. }