/** 
* @file testPLSA.cpp
* @brief __DESC__
* @author Erik Rodner
* @date 05/21/2008

*/
#include "core/vector/VectorT.h"
#include "core/vector/MatrixT.h"
#include "core/image/ImageT.h"
#include "core/imagedisplay/ImageDisplay.h"

#ifdef NICE_USELIB_ICE

#include <core/vector/Distance.h>

#include <image_nonvis.h>
#include <core/iceconversion/convertice.h>

#include <distancefunctions.h>
#include <assert.h>

#include <core/basics/Config.h>
#include <vislearning/baselib/cmdline.h>
#include <vislearning/baselib/Gnuplot.h>
#include <vislearning/baselib/ICETools.h>
#include <vislearning/baselib/Conversions.h>
#include <vislearning/math/pdf/PDFDirichlet.h>
#include <vislearning/math/pdf/PDFMultinomial.h>
#include <vislearning/math/distances/ChiSqDistance.h>
#include <vislearning/math/distances/KLDistance.h>
#include <vislearning/math/distances/HistIntersectDistance.h>
#include <vislearning/math/topics/PLSA.h>

using namespace OBJREC;
using namespace NICE;
using namespace std;

NICE::Vector randomDiscreteDistribution ( int dimension )
{
    NICE::Vector theta (dimension);
    double s = 0.0;
    for ( int i = 0 ; i < dimension ; i++ )
    {
	theta[i] = drand48();
	s += theta[i];
    }
    
    for ( int i = 0 ; i < dimension ; i++ )
	theta[i] /= s;

    return theta;
}

void simulation_sivic ( int d, // number of documents
		  int m, // number of topics
		  int n,  // number of words in the vocabulary
		  NICE::Matrix & counts,
		  NICE::Matrix & pw_z,
		  NICE::Vector & pd,
	          NICE::Matrix & pz_d,
		  int samples_count
		)
{
    assert ( m == 3 );
    assert ( n == 12 );

    std::istringstream pwz_string ( string("<\n<0.25,0.25,0.25,0.25,0,0,0,0,0,0,0,0>,\n") +
       string("<0,0,0,0,0.25,0.25,0.25,0.25,0,0,0,0>,\n") + 
       string("<0,0,0,0,0,0,0,0,0.25,0.25,0.25,0.25>\n>") );

    pwz_string >> pw_z;

    vector<PDFMultinomial> beta;
    for ( int k = 0 ; k < m ; k++ )
	beta.push_back ( PDFMultinomial(pw_z.getRow(k),1) );

    PDFDirichlet dirichlet ( 0.2, m );
    VVector pzd;
    dirichlet.sample ( pzd, d );
    for ( int i = 0 ; i < d ; i++ )
	for ( int k = 0 ; k < m ; k++ )
	    pz_d(k, i) = pz_d(i, k);

    for ( int i = 0 ; i < d ; i++ )
        pd[i] = 1.0 / d;

    counts.set(0);

    fprintf (stderr, "Generation ...\n");
    for ( int i = 0 ; i < d ; i++ )
    {
	PDFMultinomial theta ( pzd[i], 1 );
	for ( int w = 0 ; w < samples_count; w++ )
	{
	    int topic = theta.sample();
	    assert ( topic < m );
	    
	    int word  = beta[topic].sample();

	    counts(i, word) ++;
	}
    }

    pz_d.normalizeColumnsL1();
    pd.normalizeL1();
}

void simulation ( int d, // number of documents
		  int m, // number of topics
		  int n,  // number of words in the vocabulary
		  NICE::Matrix & counts,
		  NICE::Matrix & pw_z,
		  NICE::Vector & pd,
	          NICE::Matrix & pz_d,
		  int samples_count
		)
{
    fprintf (stderr, "Generating model...\n");
    //pd = randomDiscreteDistribution(d);
    for ( int i = 0 ; i < d ; i++ )
	pd[i] = 1.0 / d;

    PDFMultinomial pds ( pd, 1);
    vector<PDFMultinomial> pz_ds;
    vector<PDFMultinomial> pw_zs;
 
    PDFDirichlet dirichlet ( 0.2, m );
    VVector pzd;
    dirichlet.sample ( pzd, d );
    for ( int i = 0 ; i < d ; i++ )
	for ( int k = 0 ; k < m ; k++ )
	    pz_d(k, i) =pz_d(i, k);
  
    for ( int di = 0 ; di < d ; di++ )
        pz_ds.push_back ( PDFMultinomial (pzd[di],1) );
    
    // funny distributed dirichlet parameter
    pw_z.set(0);
    for ( int zk = 0 ; zk < m ; zk++ )
    {
	for ( int i = zk*n/m ; i < (zk+1)*n/m ; i++ )
	    pw_z(zk, i) = 1.0;
    }

    pw_z.normalizeRowsL1();
	
    for ( int zk = 0 ; zk < m ; zk++ )
        pw_zs.push_back ( PDFMultinomial ( pw_z.getRow(zk), 1 ) );

    fprintf (stderr, "Normalization...\n");

    fprintf (stderr, "Generating samples...\n");

    counts.set(0);

    for ( int j = 0 ; j < samples_count*d ; j++ )
    {
        // sample document using p(d)
	int document = pds.sample();

	// sample topic using p(z|d)
	int topic = pz_ds[document].sample();

	// sample word of the vocabulary using p(w|z)
	int word = pw_zs[topic].sample();

	// refactor-nice.pl: check this substitution
	// old: counts[document][word]++;
	counts(document, word)++;
    }
}

void sample_document ( const NICE::Vector & pz_singled,
		       const NICE::Matrix & pw_z,
		       NICE::Vector & histogramm,
		       int samples_count )
{
    int m = pz_singled.size();

    vector<PDFMultinomial> pw_zs;
    for ( int zk = 0 ; zk < m ; zk++ )
        pw_zs.push_back ( PDFMultinomial ( pw_z.getRow(zk), 1 ) );	

    PDFMultinomial topicgen ( pz_singled, 1 );
    for ( int j = 0 ; j < samples_count ; j++ )
    {
	int topic = topicgen.sample();

	int word = pw_zs[topic].sample();

	histogramm[word]++;
    }
}

double comparePermMatrices ( const NICE::Matrix & A, const NICE::Matrix & B, MatrixT<int> & reference_pairs )
{
    double min_cost;
    ice::Matrix cost ( A.rows(), B.rows() );
    NICE::VectorDistance<double> *df = new NICE::EuclidianDistance<double> ();

    assert ( A.rows() == B.rows() );

    for ( int i = 0 ; i < A.rows() ; i++ )
	for ( int j = 0 ; j < B.rows() ; j++ )
	    cost[i][j] = df->calculate ( A.getRow(i), B.getRow(j) );
    delete df;

    ice::IMatrix reference_pairs_ice;
    ice::Hungarian ( cost, reference_pairs_ice, min_cost ); 

    if ( reference_pairs_ice.rows() == A.rows()-1 ) {
	set<int> a;
	set<int> b;
	for ( int i = 0 ; i < A.rows() ; i++ )
	{
	    a.insert ( i ); 
	    b.insert ( i );
	}
	for ( int i = 0 ; i < reference_pairs_ice.rows() ; i++ )
	{
	    a.erase ( a.find(reference_pairs_ice[i][0] ) );
	    b.erase ( b.find(reference_pairs_ice[i][1] ) );
	}
	assert ( (a.size() == 1) && (b.size() == 1) );
	reference_pairs_ice.Append ( ice::IVector ( *(a.begin()), *(a.end()) ) );
    
	reference_pairs = NICE::makeIntegerMatrix<int> ( reference_pairs_ice );

	return min_cost / ( A.rows() );
    } else if ( reference_pairs_ice.rows() != A.rows() ) {
	
	reference_pairs = NICE::makeIntegerMatrix<int> ( reference_pairs_ice );

	return -1.0;
    }
}

int main (int argc, char **argv)
{   
    std::set_terminate(__gnu_cxx::__verbose_terminate_handler);
    srand48(time(NULL));

    int tempered = 0;
    int optimization_verbose = 0;
    int samples_count = 10000;
    int d = 1000;
    int m = 3;
    int n = 12;

    double delta_eps = 1e-3;
    int maxiterations = 500;
    int runs = 1;
    double holdoutportion = 0.2;

    int use_simulation_sivic = 0;

    struct CmdLineOption options[] = {
	{"verbose", "print details about the optimization", NULL, NULL, &optimization_verbose},
	{"documents", "number of documents", "1000", "%d", &d},
	{"topics", "number of topics", "3", "%d", &m},
	{"words", "number of words in the vocabulary", "12", "%d", &n},
	{"samples", "number of samples (in total)", "10000", "%d", &samples_count},
	{"maxiterations", "maximum number of EM-iterations", "500", "%d", &maxiterations},
	{"runs", "number of runs performed", "1", "%d", &runs},
	{"deltaeps", "terminate optimization if delta of likelihood is below this threshold", NULL, "%lf", &delta_eps},
	{"sivic", "use the simulation routine and parameters of Josef Sivic's matlab implementation", NULL, NULL, &use_simulation_sivic},
	{"holdout", "fraction of data for validation", "0.8", "%lf", &holdoutportion},
	{"tempered", "use tempered version of EM algorithm", NULL, NULL, &tempered},
	{NULL, NULL, NULL, NULL, NULL} 
    };
    int ret;
    char *more_options[argc];
    ret = parse_arguments( argc, (const char**)argv, options, more_options);

    if ( ! tempered )
	holdoutportion = 0.0;

    srand48(time(NULL));
   
    PLSA pLSA ( maxiterations, delta_eps, 0.9, holdoutportion );
  
    double error_pw_z_avg = 0.0; 
    double error_foldin_avg = 0.0;
    double error_pw_z_stddev = 0.0;
    int runs_successfull = 0;

    fprintf (stderr, "documents: %d\n", d );
    fprintf (stderr, "topics: %d\n", m );
    fprintf (stderr, "runs: %d\n", runs );
    for ( int run = 0 ; run < runs ; run++ )
    {
	fprintf (stderr, "[run %d/%d]\n", run, runs );
	// refactor-nice.pl: check this substitution
	// old: Matrix counts ( d, n );
	NICE::Matrix counts ( d, n );

	// refactor-nice.pl: check this substitution
	// old: Matrix pw_z ( m, n );
	NICE::Matrix pw_z ( m, n );
	// refactor-nice.pl: check this substitution
	// old: Vector pd   ( d );
	NICE::Vector pd   ( d );
	// refactor-nice.pl: check this substitution
	// old: Matrix pz_d ( m, d );
	NICE::Matrix pz_d ( m, d );

	if ( use_simulation_sivic )
	    simulation_sivic ( d, m, n, 
			 counts, 
			 pw_z, 
			 pd,
			 pz_d,
			 samples_count );
	else
	    simulation ( d, m, n, 
			 counts, 
			 pw_z, 
			 pd,
			 pz_d,
			 samples_count );


    /*
	// refactor-nice.pl: check this substitution
	// old: ImageRGB img_groundtruth ( n, m );
	NICE::ColorImage img_groundtruth ( n, m );
	ICETools::convertToRGB ( pw_z, img_groundtruth );
	Show(ON, img_groundtruth, "P(w|z) groundtruth");
    */

	double *counts_raw = ICETools::convertICE2M ( counts );
	double *pw_z_estimate_raw = new double [n*m];
	double *pz_d_estimate_raw = new double [m*d];
	// refactor-nice.pl: check this substitution
	// old: Vector pd_estimate ( d );
	NICE::Vector pd_estimate ( d );
	double likelihood_estimate = 
	    pLSA.pLSA ( counts_raw, 
		    pw_z_estimate_raw, 
		    pd_estimate.getDataPointer(), 
		    pz_d_estimate_raw, 
		    n, m, d,
		    true /* do not perform folding */,
		    tempered ? true : false /* use tempered version */,
		    optimization_verbose ? true : false );

	double *pw_z_raw = ICETools::convertICE2M ( pw_z );
	double *pz_d_raw = ICETools::convertICE2M ( pz_d );
	
	int dtrained = d - (int) ( d * holdoutportion );
	double likelihood_groundtruth = pLSA.computeLikelihood (
	    counts_raw, pw_z_raw, pd.getDataPointer(), pz_d_raw, n, m , d, dtrained );
	delete [] pw_z_raw;
	delete [] pz_d_raw;

	// refactor-nice.pl: check this substitution
	// old: Matrix pw_z_estimate ( m, n );
	NICE::Matrix pw_z_estimate ( m, n );
	ICETools::convertM2ICE ( pw_z_estimate, pw_z_estimate_raw );
	delete [] counts_raw;

	MatrixT<int> refPWZ, refPDZ;
	double error_pw_z = comparePermMatrices ( pw_z, pw_z_estimate, refPWZ );

	fprintf (stderr, "---------- final error %d/%d ---------\n", run+1, runs );
	fprintf (stderr, "error p(w|z) pwz: %f\n", error_pw_z );
	error_pw_z_avg += error_pw_z;
	error_pw_z_stddev += error_pw_z * error_pw_z;

	/***************** FOLDIN Tests *******************/
	KLDistance<double> dist;
	double error_foldin = 0.0;
	for ( int i = 0 ; i < dtrained ; i++ )
	{
	    // refactor-nice.pl: check this substitution
	    // old: Vector pz_singled ( m );
	    NICE::Vector pz_singled ( m );
	    for ( int k = 0 ; k < m ; k++ )
		pz_singled[k] = pz_d_estimate_raw[k*dtrained+i];
		
	    // refactor-nice.pl: check this substitution
	    // old: Vector histogramm ( n );
	    NICE::Vector histogramm ( n );
	    sample_document ( pz_singled, pw_z_estimate, histogramm, samples_count ); 

	    // refactor-nice.pl: check this substitution
	    // old: Vector pd_foldin (1);
	    NICE::Vector pd_foldin (1);
	    // refactor-nice.pl: check this substitution
	    // old: Vector pz_d_foldin (m);
	    NICE::Vector pz_d_foldin (m);
	    pLSA.pLSA ( histogramm.getDataPointer(), 
			pw_z_estimate_raw, 
			pd_foldin.getDataPointer(), 
			pz_d_foldin.getDataPointer(), 
			n, m, 1,
			false /* do not use tempered version */,
			false );

	    /*
	    cerr << "p(1,w) = " << histogramm << endl;
	    cerr << "p(d) =fold " << pd_foldin << endl;
	    cerr << "p(z|d) =fold " << pz_d_foldin << endl;
	    cerr << "p(z|d) =gt " << pz_singled << endl;
	    */

	    error_foldin += dist ( pz_d_foldin, pz_singled );
	}
	error_foldin /= dtrained;
	fprintf (stderr, "foldin error: %f\n", error_foldin );
	error_foldin_avg += error_foldin;

	runs_successfull++;

	double error_likelihood = likelihood_estimate - likelihood_groundtruth;
	fprintf (stderr, "likelihood error: %f\n", error_likelihood );
	if ( error_likelihood > 1e3 )
	{
	    fprintf (stderr, "This seems to be a severe local minimum !\n");
	}
	
	delete [] pw_z_estimate_raw;
	delete [] pz_d_estimate_raw;
    }

    error_pw_z_avg /= runs_successfull;
    error_foldin_avg /= runs_successfull;
    error_pw_z_stddev /= runs_successfull;

    error_pw_z_stddev -= error_pw_z_avg*error_pw_z_avg;
    error_pw_z_stddev = (error_pw_z_stddev < 0.0) ? 0.0 : sqrt(error_pw_z_stddev);
    fprintf (stderr, "pwz: %f %f\n", error_pw_z_avg, error_pw_z_stddev );
    fprintf (stderr, "foldin: %f\n", error_foldin_avg );

    return 0;
}

#else
int main (int argc, char **argv)
{   
    throw("no ice installed\n");
    return 0;
}
#endif