PLSA.h 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. /**
  2. * @file PLSA.h
  3. * @brief implementation of the pLSA model
  4. * @author Erik Rodner
  5. * @date 02/05/2009
  6. */
  7. #ifndef PLSAINCLUDE
  8. #define PLSAINCLUDE
  9. #include "core/vector/VectorT.h"
  10. #include "core/vector/MatrixT.h"
  11. namespace OBJREC {
  12. /** implementation of the pLSA model */
  13. class PLSA
  14. {
  15. protected:
  16. int maxiterations;
  17. double delta_eps;
  18. double betadecrease;
  19. double holdoutportion;
  20. void pLSA_EMstep ( const double *counts,
  21. double *pw_z, double *pd, double *pz_d,
  22. double *pw_z_out, double *pd_out, double *pz_d_out, double *p_dw,
  23. int n, int m, int d,
  24. double beta,
  25. bool update_pw_z );
  26. public:
  27. /** simple constructor */
  28. PLSA( int maxiterations = 5000,
  29. double delta_eps = 1e-3,
  30. double betadecrease = 0.9,
  31. double holdoutportion = 0.2 );
  32. /** simple destructor */
  33. virtual ~PLSA();
  34. double pLSA ( const double *counts,
  35. double *pw_z,
  36. double *pd,
  37. double *pz_d,
  38. int n,
  39. int m,
  40. int total_documents,
  41. bool update_pw_z = true,
  42. bool tempered = false,
  43. bool optimization_verbose = false );
  44. double computeLikelihood ( const double *counts,
  45. const double *pw_z,
  46. const double *pd,
  47. const double *pz_d,
  48. int n,
  49. int m,
  50. int d,
  51. int dtrained = -1) const;
  52. double computePerplexity ( const double *counts,
  53. const double *pw_z,
  54. const double *pz_d,
  55. int n,
  56. int m,
  57. int d) const;
  58. double algebraicFoldIn ( const double *counts,
  59. double *pw_z,
  60. double *pd,
  61. double *pz_d,
  62. int n, int m );
  63. public: /** static functions */
  64. static void randomizeBuffer ( double *A, long size );
  65. static void uniformDistribution ( double *x, int size );
  66. static void normalizeRows ( double *A, int r, int c );
  67. static void normalizeCols ( double *A, int r, int c );
  68. static double computeSparsity ( const double *A, long int size );
  69. };
  70. } // namespace
  71. #endif