PLSA.h 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. /**
  2. * @file PLSA.h
  3. * @brief implementation of the pLSA model
  4. * @author Erik Rodner
  5. * @date 02/05/2009
  6. */
  7. #ifndef PLSAINCLUDE
  8. #define PLSAINCLUDE
  9. #ifdef NOVISUAL
  10. #include <vislearning/nice_nonvis.h>
  11. #else
  12. #include <vislearning/nice.h>
  13. #endif
  14. namespace OBJREC {
  15. /** implementation of the pLSA model */
  16. class PLSA
  17. {
  18. protected:
  19. int maxiterations;
  20. double delta_eps;
  21. double betadecrease;
  22. double holdoutportion;
  23. void pLSA_EMstep ( const double *counts,
  24. double *pw_z, double *pd, double *pz_d,
  25. double *pw_z_out, double *pd_out, double *pz_d_out, double *p_dw,
  26. int n, int m, int d,
  27. double beta,
  28. bool update_pw_z );
  29. public:
  30. /** simple constructor */
  31. PLSA( int maxiterations = 5000,
  32. double delta_eps = 1e-3,
  33. double betadecrease = 0.9,
  34. double holdoutportion = 0.2 );
  35. /** simple destructor */
  36. virtual ~PLSA();
  37. double pLSA ( const double *counts,
  38. double *pw_z,
  39. double *pd,
  40. double *pz_d,
  41. int n,
  42. int m,
  43. int total_documents,
  44. bool update_pw_z = true,
  45. bool tempered = false,
  46. bool optimization_verbose = false );
  47. double computeLikelihood ( const double *counts,
  48. const double *pw_z,
  49. const double *pd,
  50. const double *pz_d,
  51. int n,
  52. int m,
  53. int d,
  54. int dtrained = -1) const;
  55. double computePerplexity ( const double *counts,
  56. const double *pw_z,
  57. const double *pz_d,
  58. int n,
  59. int m,
  60. int d) const;
  61. double algebraicFoldIn ( const double *counts,
  62. double *pw_z,
  63. double *pd,
  64. double *pz_d,
  65. int n, int m );
  66. public: /** static functions */
  67. static void randomizeBuffer ( double *A, long size );
  68. static void uniformDistribution ( double *x, int size );
  69. static void normalizeRows ( double *A, int r, int c );
  70. static void normalizeCols ( double *A, int r, int c );
  71. static double computeSparsity ( const double *A, long int size );
  72. };
  73. } // namespace
  74. #endif