/** 
* @file DTEstimateAPriori.cpp
* @brief estimate decision structure using a priori density
* @author Erik Rodner
* @date 05/27/2008

*/
#include <iostream>

#include "vislearning/classifier/fpclassifier/randomforest/DTEstimateAPriori.h"

#include "vislearning/optimization/mapestimation/MAPMultinomialGaussianPrior.h"
#include "vislearning/optimization/mapestimation/MAPMultinomialDirichlet.h"

using namespace OBJREC;

using namespace std;
using namespace NICE;

#define DEBUG_DTESTIMATE

DTEstimateAPriori::DTEstimateAPriori( const Config *conf, const std::string & section )
{
    std::string mapEstimatorType_s = conf->gS(section, "map_multinomial_estimator", 
					 "gaussianprior" );

    if ( mapEstimatorType_s == "gaussianprior" )
		map_estimator = new MAPMultinomialGaussianPrior();
    else if ( mapEstimatorType_s == "dirichletprior" )
		map_estimator = new MAPMultinomialDirichlet();
    else {
		fprintf (stderr, "DTEstimateAPriori: estimator type %s unknown\n", mapEstimatorType_s.c_str() );
		exit(-1);
    }
}

DTEstimateAPriori::~DTEstimateAPriori()
{
    delete map_estimator;
}

void DTEstimateAPriori::reestimate ( DecisionTree & dt, 
			  Examples & examples,
			  double sigmaq,
			  int newClassNo,
			  set<int> muClasses,
			  set<int> substituteClasses,
			  int maxClassNo )
{
    mapEstimateClass (  dt, 
			examples,
			newClassNo,
			muClasses,
			substituteClasses,
			sigmaq,
			maxClassNo );
}

/** calculating node probabilities recursive 
    using the following formula:
	p(n | i) = p(p | i) ( c(i | n) c( i | p)^{-1} )
    @remark do not use normalized a posteriori values !
*/
void DTEstimateAPriori::calculateNodeProbabilitiesRec ( 
	    map<DecisionNode *, FullVector> & p, 
	    DecisionNode *node )
{
    if ( node == NULL ) return;
	else if ( (node->left == NULL) && (node->right == NULL ) ) return;		
    else {
		assert ( left != NULL );
		assert ( right != NULL );

		// estimate probabilies for children
		const FullVector & parent = p[node];

		// calculate left prob 
		const FullVector & posteriori = node->distribution;
		const FullVector & posteriori_left = node->left->distribution;
		const FullVector & posteriori_right = node->right->distribution;

		FullVector result_left (parent);
		FullVector result_right (parent);

		FullVector transition_left ( posteriori_left );
		transition_left.divide ( posteriori );

		assert ( transition_left.max() <= 1.0 );
		assert ( transition_left.min() >= 0.0 );

		FullVector transition_right ( posteriori_right );
		transition_right.divide ( posteriori );

		result_left.multiply ( transition_left );
		result_right.multiply ( transition_right );

		p.insert ( pair<DecisionNode *, FullVector> ( node->left, result_left ) );
		p.insert ( pair<DecisionNode *, FullVector> ( node->right, result_right ) );

		calculateNodeProbabilitiesRec ( p, node->left );
		calculateNodeProbabilitiesRec ( p, node->right );
    }
}

void DTEstimateAPriori::calculateNodeProbabilities ( map<DecisionNode *, FullVector> & p, DecisionTree & tree )
{
    DecisionNode *root = tree.getRoot();
    
    FullVector rootNP ( root->distribution.size() );
    // root node probability is 1 for each class 
    rootNP.set ( 1.0 );

    p.insert ( pair<DecisionNode *, FullVector> ( root, rootNP )  );
    calculateNodeProbabilitiesRec ( p, root );
}


void DTEstimateAPriori::calculateNodeProbVec 
	( map<DecisionNode *, FullVector> & nodeProbs, 
	  int classno,
	  // refactor-nice.pl: check this substitution
	  // old: Vector & p )
	  NICE::Vector & p )
{
    double sum = 0.0;

    assert ( p.size() == 0 );

    for ( map<DecisionNode *, FullVector>::const_iterator k = nodeProbs.begin();
			k != nodeProbs.end(); k++ )
    {
		const FullVector & v = k->second;
		DecisionNode *node = k->first;

		if ( (node->left != NULL) || (node->right != NULL) )
			continue;

		double val = v[classno];
		NICE::Vector single (1);
		single[0] = val;
		// inefficient !!
		p.append ( single );

		sum += val;
    }

    for ( size_t i = 0 ; i < p.size() ; i++ )
	p[i] /= sum;
}

double DTEstimateAPriori::calcInnerNodeProbs ( 
	DecisionNode *node, 
	map<DecisionNode *, double> & p )
{
    map<DecisionNode *, double>::const_iterator i = p.find( node );
    if ( i == p.end() )
    {
	double prob = 0.0;
	if ( node->left != NULL )
	    prob += calcInnerNodeProbs ( node->left, p );
	
	if ( node->right != NULL )
	    prob += calcInnerNodeProbs ( node->right, p );

	p.insert ( pair<DecisionNode *, double> ( node, prob ) );
	return prob;
    } else {
	return i->second;
    }
}

/** calculates a-posteriori probabilities using the formula:
    p(i | n) = p(n | i) p(i) ( \sum_j p(n | j) p(j) )^{-1}
*/
void DTEstimateAPriori::calcPosteriori ( 
	DecisionNode *node, 
	const FullVector & apriori,
	const map<DecisionNode *, FullVector> & nodeprob,
	map<DecisionNode *, FullVector> & posterioriResult )
{
    if ( node == NULL ) return;

    map<DecisionNode *, FullVector>::const_iterator i;
    i = nodeprob.find ( node );
    assert ( i != nodeprob.end() );
    const FullVector & np = i->second;

    assert ( np.sum() > 10e-7 );

    FullVector joint ( np );

    joint.multiply ( apriori );
    joint.normalize();
    
    posterioriResult.insert ( pair<DecisionNode *, FullVector> ( node, joint ) );

    calcPosteriori (node->left, apriori, nodeprob, posterioriResult);
    calcPosteriori (node->right, apriori, nodeprob, posterioriResult);
}

/** calculates a-posteriori probabilities by substituting support class
    values with new ones
*/
void DTEstimateAPriori::calcPosteriori ( DecisionTree & tree, 
		const FullVector & aprioriOld,
		const map<DecisionNode *, FullVector> & nodeprobOld,
		map<DecisionNode *, double> & nodeprobNew,
		const set<int> & substituteClasses,
		int newClassNo,
		map<DecisionNode *, FullVector> & posterioriResult )
{
    // calculating node probabilities of inner nodes
    calcInnerNodeProbs ( tree.getRoot(), nodeprobNew );

    // building new apriori probabilities
    FullVector apriori ( aprioriOld );
    for ( int i = 0 ; i < apriori.size() ; i++ )
		if ( substituteClasses.find( i ) != substituteClasses.end() )
			apriori[i] = 0.0;

    if ( substituteClasses.size() > 0 )
		apriori[newClassNo] =  1.0 - apriori.sum();
    else {
		// mean a priori
		double avg = apriori.sum() / apriori.size();
		apriori[newClassNo] = avg;
		apriori.normalize();
    }
	
    if ( substituteClasses.size() > 0 )
    {
		fprintf (stderr, "WARNING: do you really want to do class substitution ?\n");
    }

    // building new node probabilities
    map<DecisionNode *, FullVector> nodeprob;
    for ( map<DecisionNode *, FullVector>::const_iterator j = nodeprobOld.begin();
							    j != nodeprobOld.end();
							    j++ )
    {
		const FullVector & d = j->second;
		DecisionNode *node = j->first;
		map<DecisionNode *, double>::const_iterator k = nodeprobNew.find ( node );
		assert ( k != nodeprobNew.end() );
		double newNP = k->second;

		assert ( d.sum() > 10e-7 );

		FullVector np ( d );
		for ( int i = 0 ; i < d.size() ; i++ )
			if ( substituteClasses.find( i ) != substituteClasses.end() )
				np[i] = 0.0;

		if ( (np[ newClassNo ] > 10e-7) && (newNP < 10e-7) ) {
			fprintf (stderr, "DTEstimateAPriori: handling special case!\n");
		} else {
			np[ newClassNo ] = newNP;
		}
		
		if ( np.sum() < 10e-7 )
		{
			fprintf (stderr, "DTEstimateAPriori: handling special case (2), mostly for binary tasks!\n");
			assert ( substituteClasses.size() == 1 );
			int oldClassNo = *(substituteClasses.begin());
			np[ newClassNo ] = d[ oldClassNo ];
		}
		nodeprob.insert ( pair<DecisionNode *, FullVector> ( node, np ) );
    }
    
    calcPosteriori ( tree.getRoot(), apriori, nodeprob, posterioriResult );
}

void DTEstimateAPriori::mapEstimateClass ( DecisionTree & tree,
					   Examples & new_examples,
					   int newClassNo,
					   set<int> muClasses,
					   set<int> substituteClasses,
					   double sigmaq,
					   int maxClassNo )
{
    // ----------- (0) class a priori information ---------------------------------------------
    FullVector & root_distribution = tree.getRoot()->distribution;
    FullVector apriori ( root_distribution );
    apriori.normalize();

    // ----------- (1) collect leaf probabilities of oldClassNo -> mu -------------------------
    fprintf (stderr, "DTEstimateAPriori: calculating mu vector\n");
    map<DecisionNode *, FullVector> nodeProbs;
    calculateNodeProbabilities ( nodeProbs, tree );

    VVector priorDistributionSamples;
    for ( set<int>::const_iterator i = muClasses.begin();
				   i != muClasses.end();
				   i++ )
    {
		NICE::Vector p;
		calculateNodeProbVec ( nodeProbs, *i, p );
		priorDistributionSamples.push_back(p);
    }

    // ----------- (2) infer examples_new into tree -> leaf prob counts ----------------------- 
    FullVector distribution ( maxClassNo+1 );

    fprintf (stderr, "DTEstimateAPriori: Infering %d new examples into the tree\n", (int)new_examples.size() );
    assert ( new_examples.size() > 0 );

    tree.resetCounters ();
    for ( Examples::iterator j = new_examples.begin() ;
				       j != new_examples.end() ;
				       j++ )
	tree.traverse ( j->second, distribution );
    
    // refactor-nice.pl: check this substitution
    // old: Vector scores;
    vector<double> scores_stl;
    for ( map<DecisionNode *, FullVector>::const_iterator k = nodeProbs.begin();
		    k != nodeProbs.end(); k++ )
    {
		DecisionNode *node = k->first;
		if ( (node->left != NULL) || (node->right != NULL) )
			continue;
		scores_stl.push_back ( node->counter );
    }
    NICE::Vector scores (scores_stl);

    VVector likelihoodDistributionSamples;
    likelihoodDistributionSamples.push_back ( scores );

    // ------------------------------- (3) map estimation ------------------------------------------
    fprintf (stderr, "DTEstimateAPriori: MAP estimation ...sigmaq = %e\n", sigmaq);
    NICE::Vector theta;
	// scores = ML solution = counts in each leaf node
	// theta = solution of the MAP estimation
    map_estimator->estimate ( theta, likelihoodDistributionSamples, priorDistributionSamples, sigmaq ); 

    assert ( theta.size() == scores.size() );

	// compute normalized scores
    NICE::Vector scores_n ( scores );
    double sum = 0.0;
    for ( int k = 0 ; k < (int)scores_n.size() ; k++ )
		sum += scores_n[k];
    
    assert ( fabs(sum) > 10e-8 );

    for ( int k = 0 ; k < (int)scores_n.size() ; k++ )
		scores_n[k] /= sum;
	
    // ---------- (4) calculate posteriori probs in each leaf according to leaf probs ---------------------
    map<DecisionNode *, double> npMAP;

    long index = 0;
    for ( map<DecisionNode *, FullVector>::const_iterator k = nodeProbs.begin();
		    k != nodeProbs.end(); k++ )
    {
		DecisionNode *node = k->first;
		if ( (node->left != NULL) || (node->right != NULL) )
			continue;
		npMAP[node] = theta[index];
		index++;
    }

    map<DecisionNode *, FullVector> posteriori;
    calcPosteriori ( tree, apriori, nodeProbs, npMAP, substituteClasses,
		newClassNo, posteriori );

    // (5) substitute class scores
    for ( map<DecisionNode *, FullVector>::iterator i = posteriori.begin();
						i != posteriori.end();
						i++ )
    {
		DecisionNode *node = i->first;

		if ( (node->left != NULL) || (node->right != NULL) )
		{
			//fprintf (stderr, "MAPMultinomialGaussianPrior: reestimating prob of a inner node !\n");
			continue;
		}
#ifdef DEBUG_DTESTIMATE
		FullVector old_distribution ( node->distribution );
		old_distribution.normalize();
		old_distribution.store (cerr);
#endif

		for ( int k = 0 ; k < node->distribution.size() ; k++ )
			if ( substituteClasses.find( k ) != substituteClasses.end() )
				node->distribution[k] = 0.0;


		// recalculate probabilities in weights
		double oldvalue = node->distribution.get(newClassNo);
		double supportsum = node->distribution.sum() - oldvalue;
		double pgamma = i->second[newClassNo];

		if ( (fabs(supportsum) > 10e-11) && (fabs(1.0-pgamma) < 10e-11 ) )
		{
			fprintf (stderr, "DTEstimateAPriori: corrupted probabilities\n");
			fprintf (stderr, "sum of all other class: %f\n", supportsum );
			fprintf (stderr, "prob of new class: %f\n", pgamma );
			exit(-1);
		}

		double newvalue = 0.0;
		if ( fabs(supportsum) < 10e-11 )
			newvalue = 0.0;
		else
			newvalue = supportsum * pgamma / (1.0 - pgamma);



		if ( (muClasses.size() == 1) && (substituteClasses.size() == 0) )
		{
			double muvalue = node->distribution.get( *(muClasses.begin()) );
#ifdef DEBUG_DTESTIMATE
			fprintf (stderr, "#REESTIMATE old=%f new=%f mu=%f pgamma=%f likelihood_prob=%f estimated_prob=%f\n", oldvalue, 
			newvalue, muvalue, pgamma, nodeProbs[node][newClassNo], npMAP[node] );
			fprintf (stderr, "#REESTIMATE mu_prob=%f\n", nodeProbs[node][ *(muClasses.begin()) ] );
#endif
		} else {   
#ifdef DEBUG_DTESTIMATE
			fprintf (stderr, "#REESTIMATE old=%f new=%f pgamma=%f supportsum=%f\n", oldvalue, newvalue, pgamma, supportsum );
#endif
		}

		//if ( newvalue > oldvalue ) 
			node->distribution[newClassNo] = newvalue;

#ifdef DEBUG_DTESTIMATE
		FullVector new_distribution ( node->distribution );
	//	new_distribution.normalize();
	//	new_distribution.store (cerr);

	/*
		for ( int i = 0 ; i < new_distribution.size() ; i++ )
		{
			if ( (muClasses.find(i) != muClasses.end()) || ( i == newClassNo ) )
			continue;

			if ( new_distribution[i] != old_distribution[i] )
			{
			fprintf (stderr, "class %d %f <-> %f\n", i, new_distribution[i], old_distribution[i] );
			new_distribution.store ( cerr );
			old_distribution.store ( cerr );
			node->distribution.store ( cerr );
			exit(-1);
			}
		}
	*/
#endif
    }


    int count, depth;
    tree.statistics ( depth, count );
    
    assert ( count == (int)posteriori.size() );

    tree.pruneTreeScore ( 10e-10 );
}