/** 
* @file PLSA.cpp
* @brief implementation of the pLSA model
* @author Erik Rodner
* @date 02/05/2009

*/
#include <iostream>
#include <assert.h>
#include <time.h>
#include <core/vector/Algorithms.h>

#include "PLSA.h"

using namespace OBJREC;

using namespace std;
using namespace NICE;



PLSA::PLSA( int maxiterations,
	    double delta_eps,
	    double betadecrease,
	    double holdoutportion )
{
#ifdef WIN32
	srand ( time ( NULL ) );
#else
    srand48 ( time ( NULL ) );
#endif

    
    this->maxiterations = maxiterations;
    this->delta_eps = delta_eps;
    this->betadecrease = betadecrease;
    this->holdoutportion = holdoutportion;
}

PLSA::~PLSA()
{
}

double PLSA::computeSparsity ( const double *A, long int size )
{
    long count_zeros = 0;
    for ( long i = 0 ; i < size ; i++ )
	if ( fabs(A[i]) < 1e-20 )
	    count_zeros++;
    return count_zeros / (double) size;
}

double PLSA::computeLikelihood ( const double *counts,
			         const double *pw_z,
			         const double *pd,
			         const double *pz_d,
				 int n, int m, int d,
				 int dtrained ) const
{
    const double eps = 1e-20;
    double likelihood = 0.0;

    if ( dtrained == -1 )
	dtrained = d;

    for ( int i = 0 ; i < dtrained ; i++ ) // documents
	for ( int j = 0 ; j < n ; j++ ) // words
	{
	    double pdw = 0.0;
	    assert ( ! NICE::isNaN(counts[i*n+j]) );
	    assert ( ! NICE::isNaN(pd[i]) );
	    for ( int k = 0 ; k < m ; k++ )
	    {
		assert ( ! NICE::isNaN(pz_d[k*d + i]) );
		assert ( ! NICE::isNaN(pw_z[k*n + j]) );
		pdw += pz_d[k*d+i] * pw_z[k*n+j];
	    }
	    
	    likelihood += counts[i*n+j] * log(pdw*pd[i] + eps);
	}

    return - likelihood;
}

double PLSA::computePerplexity ( const double *counts,
			         const double *pw_z,
			         const double *pz_d,
				 int n, int m, int d) const
{
    const double eps = 1e-7;
    double perplexity = 0.0;
    double normalization = 0.0;
    for ( int i = 0 ; i < d ; i++ ) // documents
	for ( int j = 0 ; j < n ; j++ ) // words
	{
	    double pdw = 0.0;
	    for ( int k = 0 ; k < m ; k++ )
	    {
		assert ( ! NICE::isNaN(pz_d[k*d + i]) );
		assert ( ! NICE::isNaN(pw_z[k*n + j]) );
		pdw += pz_d[k*d+i] * pw_z[k*n+j];
	    }
	    
	    perplexity += counts[i*n+j] * log(pdw + eps);
	    normalization += counts[i*n+j];
	}

    return exp ( - perplexity / normalization );
}

void PLSA::uniformDistribution ( double *x, int size )
{
    for ( int i = 0 ; i < size; i++ )
	x[i] = 1.0 / size;
}

void PLSA::normalizeRows ( double *A, int r, int c )
{
    long index = 0;
    for ( int i = 0 ; i < r ; i++ )
    {
	double sum = 0.0;
	long index_row = index;
	for ( int j = 0 ; j < c ; j++, index++ )
	    sum += A[index];
	
	if ( sum > 1e-20 )
	    for ( int j = 0 ; j < c ; j++, index_row++ )
		A[index_row] /= sum;
	else
	    for ( int j = 0 ; j < c ; j++, index_row++ )
		A[index_row] = 1.0/c;

    }
}

void PLSA::normalizeCols ( double *A, int r, int c )
{
    for ( int j = 0 ; j < c ; j++ )
    {
	double sum = 0.0;
	for ( int i = 0 ; i < r ; i++ )
	    sum += A[i*c+j];
	
	if ( sum > 1e-20 )
	    for ( int i = 0 ; i < r ; i++ )
		A[i*c+j] /= sum;
	else
	    for ( int i = 0 ; i < r ; i++ )
		A[i*c+j] = 1.0/r;

    }
}

void PLSA::randomizeBuffer ( double *A, long size )
{
    for ( int index = 0 ; index < size ; index++ )
    {
#ifdef WIN32
	A[index] = (double( rand() ) / RAND_MAX );
#else
	A[index] = drand48();
#endif
    }
}

void PLSA::pLSA_EMstep ( const double *counts, 
	    double *pw_z,
	    double *pd,
	    double *pz_d,

	    double *pw_z_out,
	    double *pd_out,
	    double *pz_d_out,
	    double *p_dw,

	    int n, int m, int d,
	    double beta, 
	    bool update_pw_z )
{
    /************************ E-step ****************************/
    if ( update_pw_z )
	memset ( pw_z_out, 0x0, n*m*sizeof(double) );
    memset ( pz_d_out, 0x0, d*m*sizeof(double) );
    memset ( pd_out, 0x0, d*sizeof(double) );

    bool tempered = ( beta = 1.0 ) ? true : false;

    long indexij = 0;
    for ( int i = 0 ; i < d ; i++ )
	for ( int j = 0 ; j < n ; j++, indexij++ )
	{
	    double c = counts[indexij];
	    if ( c < 1e-20 ) continue;

	    double sumk = 0.0;
	    for ( int k = 0 ; k < m ; k++ )
	    {
		double p_dw_value = pz_d[k*d+i] * pw_z[k*n+j] * pd[i];
		if ( tempered ) {
		    p_dw_value = pow ( p_dw_value, beta );
		}

		// normalization of E-step
		sumk += p_dw_value;

		// M-step
		double val = c * p_dw_value;

		p_dw[k] = val;
		/*
		if ( i == 0 ) 
		    fprintf (stderr, "k=%d, j=%d, val=%e p_dw=%e pw_z=%e pz_d=%e beta=%e pd=%e recomp=%e\n", k, j, val, p_dw_value,
			pw_z[k*n+j], pz_d[k*d+i], beta, pd[i], pz_d[k*d+i] * pw_z[k*n+j] * pd[i]);
		*/

	    }


	    for ( int k = 0 ; k < m ; k++ )
	    {
		if ( sumk > 1e-20 ) 
		    p_dw[k] /= sumk;
		else 
		    p_dw[k] = 1.0/m;

		if ( update_pw_z )
		    pw_z_out[k*n+j] += p_dw[k];
		pz_d_out[k*d+i] += p_dw[k];
	    }

	    pd_out[i] += counts[indexij];
	}

    if ( update_pw_z )
    {
	normalizeRows ( pw_z_out, m, n );
	memcpy ( pw_z, pw_z_out, sizeof(double)*n*m );
    }
    normalizeCols ( pz_d_out, m, d );
    memcpy ( pz_d, pz_d_out, sizeof(double)*d*m );

    // compute P(d)
    double sum_pd = 0.0;
    for ( int i = 0 ; i < d ; i++ )
	sum_pd += pd_out[i];

    if ( sum_pd > 1e-20 )
	for ( int i = 0 ; i < d ; i++ )
	    pd[i] = pd_out[i] / sum_pd;

    /******** end of M step */

}


double PLSA::pLSA ( const double *counts, 
	    double *pw_z,
	    double *pd,
	    double *pz_d,
	    int n, int m, int total_documents,
	    bool update_pw_z,
	    bool tempered,
	    bool optimization_verbose )
{
    if ( optimization_verbose ) {
        fprintf (stderr, "pLSA: EM algorithm ...\n");
	fprintf (stderr, "pLSA: EM algorithm: sparsity %f\n", computeSparsity(counts, n*total_documents) );
    }

    int d; // documents used for training
    
    int d_holdout; // documents used for validation
    const double *counts_holdout = NULL;
    double *pz_d_holdout = NULL;
    double *pd_holdout = NULL;

    if ( tempered ) {
	if ( optimization_verbose )
	    fprintf (stderr, "pLSA: Tempered EM algorithm ...\n");

	d_holdout = (int)(holdoutportion * total_documents);
	d = total_documents - d_holdout;
	counts_holdout = counts + d*n;
	pz_d_holdout = new double[ d_holdout*m ];
	pd_holdout = new double[ d_holdout ];
    } else {
	d = total_documents;
	d_holdout = 0;
    }

    // EM algorithm
    if ( update_pw_z ) {
	randomizeBuffer ( pw_z, n*m );
	normalizeRows ( pw_z, m, n );
    }

    uniformDistribution ( pd, d );
    randomizeBuffer ( pz_d, d*m );
    normalizeCols ( pz_d, m, d );

    double *pz_d_out = new double [ d*m ];
    double *pw_z_out = NULL;
    
    if ( update_pw_z )
	pw_z_out = new double [ n*m ];

    int iteration = 0;
    vector<double> likelihoods;
    likelihoods.push_back ( computeLikelihood ( counts, pw_z, pd, pz_d, n, m, d ) );
    double delta_likelihood = 0.0;
    bool early_stop = false;
	
    double *p_dw = new double [m];
    double *pd_out = new double[d];
    double beta = 1.0;
    double oldperplexity = numeric_limits<double>::max();
    vector<double> delta_perplexities;

    do {
	pLSA_EMstep ( counts, 
	    pw_z, pd, pz_d,
	    pw_z_out, pd_out, pz_d_out, p_dw,
	    n, m, d,
	    beta, 
	    update_pw_z );

	double newlikelihood = computeLikelihood(counts, pw_z, pd, pz_d, n, m, d);

	delta_likelihood = fabs(likelihoods.back() - newlikelihood) / (1.0 + fabs(newlikelihood)); 

	if ( optimization_verbose ) {
		fprintf (stderr, "pLSA %6d %f %e\n", iteration, newlikelihood, delta_likelihood );
	}

	likelihoods.push_back ( newlikelihood );

	if ( counts_holdout != NULL )
	{
	    pLSA ( counts_holdout, pw_z, pd_holdout, pz_d_holdout,
		   n, m, d_holdout, false, false );

	    double perplexity = computePerplexity ( counts_holdout, pw_z,
		pz_d_holdout, n, m, d_holdout );

	    double delta_perplexity = (oldperplexity - perplexity) / (1.0 + perplexity);

	    if ( delta_perplexities.size() > 0 ) {
		if ( optimization_verbose )
		    fprintf (stderr, "PLSA: early stopping: perplexity: %d %f %e (%e)\n", iteration, perplexity,
			delta_perplexity, oldperplexity);

		double last_delta_perplexity = delta_perplexities.back ();

		// if perplexity does not decrease in the last two iterations -> early stop
		if ( (delta_perplexity <= 0.0) && (last_delta_perplexity <= 0.0) )
		{
		    early_stop = true;
		    if ( optimization_verbose )
			    fprintf (stderr, "PLSA: stopped due to early stopping !\n");
		}
	    }

	    delta_perplexities.push_back ( delta_perplexity );
	    oldperplexity = perplexity;
	}

	iteration++;
    } while ( (iteration < maxiterations) && (delta_likelihood > delta_eps) && (! early_stop) );

    if ( tempered )
    {
	early_stop = false;
	delta_perplexities.clear();
	beta *= betadecrease;
	do {
	    pLSA_EMstep ( counts, 
		pw_z, pd, pz_d,
		pw_z_out, pd_out, pz_d_out, p_dw,
		n, m, d,
		beta, 
		update_pw_z );

	    double newlikelihood = computeLikelihood(counts, pw_z, pd, pz_d, n, m, d);

	    delta_likelihood = fabs(likelihoods.back() - newlikelihood) / ( 1.0 + newlikelihood ); 
	    if ( optimization_verbose )
		fprintf (stderr, "pLSA_tempered %6d %f %e\n", iteration, newlikelihood, delta_likelihood );

	    likelihoods.push_back ( newlikelihood );

	    pLSA ( counts_holdout, pw_z, pd_holdout, pz_d_holdout,
	                       n, m, d_holdout, false, false );

	    double perplexity = computePerplexity ( counts_holdout, pw_z,
	                    pz_d_holdout, n, m, d_holdout );
	    double delta_perplexity = (oldperplexity - perplexity) / (1.0 + perplexity);
	    
	    if ( delta_perplexities.size() > 0 ) {
		double last_delta_perplexity = delta_perplexities.back ();
	
		if ( optimization_verbose )
		    fprintf (stderr, "PLSA: early stopping: perplexity: %d %f %f\n", iteration, perplexity,
			delta_perplexity);
	
		// if perplexity does not decrease in the last two iterations -> early stop
		if ( (delta_perplexity <= 0.0) && (last_delta_perplexity <= 0.0) )
		{
		    if ( delta_perplexities.size() <= 1 ) {
			if ( optimization_verbose )
			    fprintf (stderr, "PLSA: early stop !\n");
		    } else {
			if ( optimization_verbose )
				fprintf (stderr, "PLSA: decreasing beta !\n");

			delta_perplexities.clear();
			beta *= betadecrease;
		    }
		}
	    }
	    delta_perplexities.push_back ( delta_perplexity );
	    oldperplexity = perplexity;

	    iteration++;
	} while ( (iteration < maxiterations) && (delta_likelihood > delta_eps) && (! early_stop) );
    }

    if ( optimization_verbose )
	fprintf (stderr, "pLSA: total number of iterations %d\n", iteration );

    delete [] pz_d_out;
    delete [] pd_out;

    if ( update_pw_z )
        delete [] pw_z_out;
    delete [] p_dw;

    if ( counts_holdout != NULL )
    {
	delete [] pz_d_holdout;
	delete [] pd_holdout;
    }

    /*
    Gnuplot gp ("lines");
    gp.plot_x ( likelihoods, "pLSA optimization" );
    // refactor-nice.pl: check this substitution
    // old: GetChar();
    getchar();
    */

    return likelihoods.back();
}


double PLSA::algebraicFoldIn ( const double *counts,
			double *pw_z,
			double *pd,
			double *pz_d, 
			int n, int m )
{
    // refactor-nice.pl: check this substitution
    // old: Matrix W ( n, m );
    NICE::Matrix W ( n, m );
    // refactor-nice.pl: check this substitution
    // old: Vector c ( n );
    NICE::Vector c ( n );

    for ( int i = 0 ; i < n ; i++ )
	c[i] = counts[i];

    for ( int i = 0 ; i < n ; i++ )
	for ( int k = 0 ; k < m ; k++ )
	    // refactor-nice.pl: check this substitution
	    // old: W[i][k] = pw_z[k*n+i];
	    W(i, k) = pw_z[k*n+i];

    // refactor-nice.pl: check this substitution
    // old: Vector sol ( m );
    NICE::Vector sol ( m );

    NICE::solveLinearEquationQR ( W, c, sol );

    (*pd) = 1.0;
    sol.normalizeL1();

    memcpy ( pz_d, sol.getDataPointer(), m*sizeof(double));

    return 0.0;
}