/** 
* @file testDirichlet.cpp
* @brief functions testing the MAP estimation routines
* @author Erik Rodner
* @date 05/21/2008

*/
#ifdef NOVISUAL
    #warning "testDirichlet needs ICE with visualization !!"
    int main (int argc, char **argv) {};
#else

#include "core/vector/VectorT.h"
#include "core/vector/MatrixT.h"
#include "core/image/ImageT.h"
#include "core/imagedisplay/ImageDisplay.h"

#include <core/vector/Distance.h>
#include <core/image/CrossT.h>

#include <core/basics/Config.h>
#include <vislearning/baselib/cmdline.h>
#include <vislearning/baselib/Gnuplot.h>
#include <vislearning/baselib/ICETools.h>
#include <vislearning/baselib/Conversions.h>
#include <vislearning/math/pdf/PDFDirichlet.h>
#include <vislearning/optimization/mapestimation/MAPMultinomialGaussianPrior.h>

using namespace OBJREC;

// refactor-nice.pl: check this substitution
// old: using namespace ice;
using namespace NICE;
using namespace std;

// refactor-nice.pl: check this substitution
// old: void plotDistributions ( Image & img, const VVector & samples, int color = 1 )
void plotDistributions ( NICE::Image & img, const VVector & samples, int color = 1 )
{
    for ( VVector::const_iterator i = samples.begin();
				  i != samples.end();
				  i++ )
    {
	// refactor-nice.pl: check this substitution
	// old: const Vector & x = *i;
	const NICE::Vector & x = *i;
	// refactor-nice.pl: check this substitution
	// old: PutVal ( img, (int)(x[0]*(img->xsize-1)), (int)(x[1]*(img->ysize-1)), color );
	img.setPixel((int)(x[0]*(img.width()-1)),(int)(x[1]*(img.height()-1)),color);
    }
}

// refactor-nice.pl: check this substitution
// old: void estimateDirichlet_Newton ( Vector & alpha, 
void estimateDirichlet_Newton ( NICE::Vector & alpha, 
				const VVector & samples )
{

}

void simulation ( int samplesCount )
{
    const int dimension = 100;

    // refactor-nice.pl: check this substitution
    // old: Vector alpha (dimension);
    NICE::Vector alpha (dimension);

    srand48(time(NULL));

    double s = 0.0;
    for ( int i = 0 ; i < dimension ; i++ )
    {
	alpha[i] = drand48() * 3.0;
	s += alpha[i];
    }

    cerr << "alpha: " << alpha << endl;

    PDFDirichlet pdfdirichlet (alpha);
    VVector samples;
    pdfdirichlet.sample ( samples, samplesCount );

    // refactor-nice.pl: check this substitution
    // old: Vector alphaEstimated (alpha.size());
    NICE::Vector alphaEstimated (alpha.size());
    estimateDirichlet_Newton ( alphaEstimated, samples );

    // refactor-nice.pl: check this substitution
    // old: Image img ( 400, 400, 255 );
    NICE::Image img (400, 400);
    // refactor-nice.pl: check this substitution
    // old: ClearImg(img);
    img.set(0);
    if ( dimension == 3 )
    {
	plotDistributions ( img, samples, 1 );
    }

    // calculate mean
    // refactor-nice.pl: check this substitution
    // old: Vector muCGD (dimension);
    NICE::Vector muCGD (dimension);
    muCGD.set(0);
    for ( VVector::const_iterator i = samples.begin();
				  i != samples.end();
				  i++ )
	muCGD = muCGD + (*i);

    muCGD = (1.0/samples.size())*muCGD;

    EuclidianDistance<double> euclid;
    // refactor-nice.pl: check this substitution
    // old: Vector meanDirichlet ( alpha*(1.0/s) );
    NICE::Vector meanDirichlet ( alpha*(1.0/s) );
    cerr << "mu estimated (cgd): " << muCGD << endl;
    cerr << "mean (dirichlet): " << alpha*(1.0/s) << endl;
    cerr << "distance: " << euclid(muCGD, meanDirichlet) << endl;

    if ( dimension == 3 )
    {
	Cross cross1 ( Coord( muCGD[0]*(img.width()-1), muCGD[1]*(img.height()-1) ), 10 );
	Cross cross2 ( Coord( alpha[0]/s*(img.width()-1), alpha[1]/s*(img.height()-1) ), 10 );
	img.draw ( cross1, 2 );
	img.draw ( cross2, 3 );

	showImageOverlay (img, img );
    } else {
	getchar();
    }
}

int main (int argc, char **argv)
{   
    std::set_terminate(__gnu_cxx::__verbose_terminate_handler);

    char configfile [300];

    struct CmdLineOption options[] = {
	{"config", "use config file", NULL, "%s", configfile},
	{NULL, NULL, NULL, NULL, NULL} 
    };
    int ret;
    char *more_options[argc];
    ret = parse_arguments( argc, (const char**)argv, options, more_options);

    if ( ret != 0 )
    {
	if ( ret != 1 ) fprintf (stderr, "Error parsing command line !\n");
	exit (-1);
    }

    Config conf ( configfile );

    const int samples_count_sim = 100;
    simulation ( samples_count_sim );
    
    const int samples_count = 100;

    const int dimension = 1000;
    // refactor-nice.pl: check this substitution
    // old: Vector mu (dimension);
    NICE::Vector mu (dimension);
    // refactor-nice.pl: check this substitution
    // old: Vector mu_noninformative (dimension);
    NICE::Vector mu_noninformative (dimension);
    // refactor-nice.pl: check this substitution
    // old: Vector counts (dimension);
    NICE::Vector counts (dimension);
    mu.set(0);
    counts.set(0);

    MAPMultinomialGaussianPrior map;

    // simulate gaussian distribution
    for ( int i = 0 ; i < samples_count ; i++ )
    {
	double r = randGaussDouble ( dimension / 6 );
	int bin = (int) (r+dimension/3);
	if ( (bin >= 0) && (bin < dimension) )
	    mu[bin]++;
	
	r = randGaussDouble ( dimension / 6 ) * randGaussDouble (dimension/6);
	bin = (int) (r+2*dimension/3);
	if ( (bin >= 0) && (bin < dimension) )
	    counts[bin]++;
    }

    MAPMultinomialGaussianPrior::normalizeProb ( mu );
    mu_noninformative = Vector(counts);
    MAPMultinomialGaussianPrior::normalizeProb ( mu_noninformative );

    const double eps = 10e-11;

    for ( int i = 1 ; i < 12 ; i++ )
    {
	double sigmaq = pow(10,-i);
	// refactor-nice.pl: check this substitution
	// old: Vector theta ( dimension );
	NICE::Vector theta ( dimension );
	map.estimate ( theta, counts, mu, sigmaq );

	vector<double> tp, mup, munonp;

	tp = Conversions::stl_vector ( theta );
	mup = Conversions::stl_vector ( mu );
	munonp = Conversions::stl_vector ( mu_noninformative );
	
	Gnuplot gp;

	gp.set_style ( "boxes" );
	gp.plot_x ( mup, "mu" );
	gp.plot_x ( munonp, "mu (noninformative)"  );
	gp.plot_x ( tp, "theta"  );

        // refactor-nice.pl: check this substitution
        // old: GetChar();
        getchar();
    }

    return 0;
}

#endif