/** 
* @file FPCRandomForestTransfer.cpp
* @brief implementation of random set forests
* @author Erik Rodner
* @date 04/24/2008

*/
#include <iostream>
#include <list>

#include "vislearning/classifier/fpclassifier/randomforest/FPCRandomForestTransfer.h"
#include "vislearning/classifier/fpclassifier/randomforest/DTBStandard.h"
#include "vislearning/classifier/fpclassifier/randomforest/DTBRandom.h"

using namespace OBJREC;

using namespace std;
using namespace NICE;


FPCRandomForestTransfer::FPCRandomForestTransfer( const Config *_conf, 
    const ClassNames *classNames, std::string section ) :
    FPCRandomForests ( _conf, section ), dte ( _conf, section )
{
    reduce_training_set = _conf->gB(section, "reduce_training_set", false);
    entropy_rejection_threshold = _conf->gD(section, "entropy_rejection_threshold", 0.0 );
    extend_only_critical_leafs = _conf->gB(section, "extend_only_critical_leafs", true );

    if ( reduce_training_set ) {
		training_absolute = _conf->gI ( section, "training_absolute", -1 );
	if ( training_absolute < 0 )
	    training_ratio = _conf->gD ( section, "training_ratio" );
    }

    std::string substituteClasses_s = _conf->gS ( section, "substitute_classes" );
    classNames->getSelection ( substituteClasses_s, substituteClasses );
 
    std::string muClasses_s = _conf->gS ( section, "mu_classes" );
    classNames->getSelection ( muClasses_s, muClasses );
   
    std::string newClass_s = _conf->gS ( section, "new_classes" );
    classNames->getSelection ( newClass_s, newClass );

    sigmaq = _conf->gD ( section, "sigmaq" );
    cached_prior_structure = _conf->gS(section, "cached_prior_structure", "prior.tree" );
    read_cached_prior_structure = _conf->gB(section, "read_cached_prior_structure", false );

    if ( newClass.size() != 1 )
    {
		fprintf (stderr, "Multi-New-Class stuff not yet implemented\n");
		exit(-1);
    }

    partial_ml_estimation = _conf->gB(section, "partial_ml_estimation", false );
    partial_ml_estimation_depth = _conf->gI(section, "partial_ml_estimation_depth", 4 );

    extend_map_tree = _conf->gB(section, "extend_map_tree", false );

    if ( extend_map_tree )
    {
		std::string builder_e_method = _conf->gS(section, "builder_extend", "random" );
		std::string builder_e_section = _conf->gS(section, "builder_extend_section" );

		if ( builder_e_method == "standard" )
			builder_extend = new DTBStandard ( _conf, builder_e_section );
		else if (builder_e_method == "random" )
			builder_extend = new DTBRandom ( _conf, builder_e_section );
		else {
			fprintf (stderr, "DecisionTreeBuilder %s not yet implemented !\n",
			builder_e_method.c_str() );
			exit(-1);
		}
    }

    learn_ert_with_newclass = _conf->gB(section, "learn_ert_with_newclass", false);
}

FPCRandomForestTransfer::~FPCRandomForestTransfer()
{
}

void FPCRandomForestTransfer::mlEstimate ( DecisionNode *node,
					   Examples & examples_new,
					   int newClassNo )
{
    node->resetCounters();
    for ( Examples::iterator i = examples_new.begin() ; 
				       i != examples_new.end();
				       i++ )
    {
		FullVector distribution (maxClassNo+1);

		assert ( i->first == newClassNo );
		node->traverse ( i->second, distribution );
    }

    map<DecisionNode *, pair<long, int> > index;
    long maxindex = 0;
    node->indexDescendants ( index, maxindex, 0 );

    for ( map<DecisionNode *, pair<long, int> >::iterator i = index.begin();
							  i != index.end();
							  i++ )
    {
		DecisionNode *node = i->first;
		node->distribution[newClassNo] = node->counter;
    } 
}

void FPCRandomForestTransfer::partialMLEstimate ( DecisionTree & tree,
						  Examples & examples_new,
						  int newClassNo,
						  int mldepth )
{
    map<DecisionNode *, pair<long, int> > index;
    long maxindex = 0;

    tree.indexDescendants ( index, maxindex );

    for ( map<DecisionNode *, pair<long, int> >::iterator i = index.begin();
							  i != index.end();
							  i++ )
    {
	DecisionNode *node = i->first;
	pair<long, int> & data = i->second;
	int depth = data.second;
	if ( depth == mldepth ) {
	    // I do not care whether this is a leaf node or not

	    Examples examples_new_rw;

	    examples_new_rw.insert ( examples_new_rw.begin(),
				     examples_new.begin(),
				     examples_new.end() );

	    // reweight examples
	    double weight = ( node->distribution.get ( newClassNo ) );

	    if ( fabs(weight) < 10e-10 )
	    {	
		continue;
	    }

	    for ( Examples::iterator j = examples_new_rw.begin(); 
				   j != examples_new_rw.end() ; 
				   j++ )
	    {
		j->second.weight = weight / examples_new_rw.size();
	    }

	    mlEstimate ( node, examples_new_rw, newClassNo );
	}
    } 

}

void FPCRandomForestTransfer::extendMapTree ( FeaturePool & fp,
					      DecisionTree & tree, 
					      Examples & examples_transfer, 
					      Examples & examples_new, 
					      int newClassNo,
					      const set<int> & muClasses )
{
    map<DecisionNode *, set<int> > examplesTransferLeafs;

    fprintf (stderr, "FPCRandomForestTransfer: classify all %ld transfer examples\n", 
	examples_transfer.size());
    int index = 0;
    for ( Examples::iterator i = examples_transfer.begin() ; 
				       i != examples_transfer.end();
				       i++, index++ )
    {
	Example & pce = i->second;

	int example_classno = i->first;
	if ( (example_classno != newClassNo) && 
	     (muClasses.find(example_classno) == muClasses.end() ) )
	    continue;
	else
	    fprintf (stderr, "suitable example of class %d found !\n", example_classno);
	
	DecisionNode *leaf = tree.getLeafNode ( pce );	

	double weight = ( leaf->distribution.get ( newClassNo ) );
	if ( fabs(weight) < 10e-2 )
	    continue;
   
	if ( extend_only_critical_leafs )
	{
	    int maxClass = leaf->distribution.maxElement();
	    if ( muClasses.find(maxClass) == muClasses.end() )
		continue;
	}
	
	double avgentropy = leaf->distribution.entropy() / log(leaf->distribution.size());
	if ( examplesTransferLeafs.find(leaf) == examplesTransferLeafs.end() )
	{
	    /*fprintf (stderr, "FPCRandomForestTransfer: leaf owned by %d (normalized entropy %f)\n", maxClass, avgentropy );
	    leaf->distribution.store(cerr); */
	}

	if ( avgentropy < entropy_rejection_threshold ) 
	{
	    fprintf (stderr, "FPCRandomForestTransfer: leaf rejected due to entropy %f < %f!\n", avgentropy, entropy_rejection_threshold);
	    continue;
	}

	examplesTransferLeafs[leaf].insert ( index );
    }

    fprintf (stderr, "FPCRandomForestTransfer: %ld leaf nodes will be extended\n", 
	examplesTransferLeafs.size() );

    fprintf (stderr, "FPCRandomForestTransfer: Extending Leaf Nodes !\n"); 
    for ( map<DecisionNode *, set<int> >::iterator k = examplesTransferLeafs.begin();
						   k != examplesTransferLeafs.end();
						   k++ )
    {
	DecisionNode *node = k->first;
	FullVector examples_counts ( maxClassNo+1 );

	Examples examples_node;
	set<int> & examplesset = k->second;
	for ( set<int>::iterator i = examplesset.begin(); i != examplesset.end(); i++ )
	{
	    pair<int, Example> & example = examples_transfer[ *i ];
	    if ( node->distribution [ example.first ] < 10e-11 ) 
		continue;
	    examples_node.push_back ( example );
	    examples_counts[ example.first ]++;
	}

	fprintf (stderr, "FPCRandomForestTransfer: Examples from support classes %ld\n", examples_node.size() );
	fprintf (stderr, "FPCRandomForestTransfer: Examples from new class %ld (classno %d)\n", examples_new.size(), 
	    newClassNo);

	examples_node.insert ( examples_node.begin(), examples_new.begin(), examples_new.end() );
	examples_counts[newClassNo] = examples_new.size();

	fprintf (stderr, "FPCRandomForestTransfer: Extending leaf node with %ld examples\n", examples_node.size() );

	for ( Examples::iterator j = examples_node.begin(); 
					   j != examples_node.end() ; 
					   j++ )
	{
	    int classno = j->first;
	    double weight = ( node->distribution.get ( classno ) );
	    fprintf (stderr, "examples_counts[%d] = %f; weight %f\n", classno, examples_counts[classno], weight );
	    j->second.weight = weight / examples_counts[classno];
	}
	DecisionNode *newnode = builder_extend->build ( fp, examples_node, maxClassNo );

	FullVector orig_distribution ( node->distribution );
	node->copy ( newnode );
	node->distribution = orig_distribution;
	orig_distribution.normalize();
	orig_distribution.store(cerr);

	double support_node_sum = 0.0;
	for ( int classi = 0 ; classi < node->distribution.size() ; classi++ )
	    if ( (classi == newClassNo) || (muClasses.find(classi) != muClasses.end() ) )
		support_node_sum += node->distribution[classi];

	// set all probabilities for non support classes
	std::list<DecisionNode *> stack;
	stack.push_back ( node );
	while ( stack.size() > 0 )
	{
	    DecisionNode *cnode = stack.front();
	    stack.pop_front();

	    double cnode_sum = 0.0;
	    for ( int classi = 0 ; classi < cnode->distribution.size() ; classi++ )
		if ( (classi != newClassNo) && (muClasses.find(classi) == muClasses.end() ) )
		    cnode->distribution[classi] = node->distribution[classi];
		else 
		    cnode_sum += cnode->distribution[classi];
	    
	    if ( fabs(cnode_sum) > 10e-11 )
		for ( int classi = 0 ; classi < node->distribution.size() ; classi++ )
		    if ( (classi == newClassNo) || (muClasses.find(classi) != muClasses.end() ) )
			cnode->distribution[classi] *= support_node_sum / cnode_sum;

	    if ( (cnode->left == NULL) && (cnode->right == NULL ) )
	    {
			FullVector stuff ( cnode->distribution );
			stuff.normalize();
			stuff.store(cerr);
	    }

	    if ( cnode->left != NULL )
			stack.push_back ( cnode->left );

	    if ( cnode->right != NULL )
			stack.push_back ( cnode->right );
	}

    }
    
    fprintf (stderr, "FPCRandomForestTransfer: MAP tree extension done !\n"); 
}


void FPCRandomForestTransfer::train ( FeaturePool & fp,
      Examples & examples )
{
    maxClassNo = examples.getMaxClassNo();

    fprintf (stderr, "############### FPCRandomForestTransfer::train ####################\n");
    assert ( newClass.size() == 1 );
    int newClassNo = *(newClass.begin());

    // reduce training set
    Examples examples_new;
    Examples examples_transfer;
    
    for ( Examples::const_iterator i = examples.begin();
					     i != examples.end();
					     i++ )
    {
	int classno = i->first;
	if ( newClass.find(classno) != newClass.end() ) {
	    examples_new.push_back ( *i );
	} else {
	    examples_transfer.push_back ( *i );
	}
    }

    if ( examples_new.size() <= 0 ) 
    {
		if ( newClass.size() <= 0 ) {
			fprintf (stderr, "FPCRandomForestTransfer::train: no new classes given !\n");
		} else {
			fprintf (stderr, "FPCRandomForestTransfer::train: no examples found of class %d\n", newClassNo );
		}
		exit(-1);
    }

    if ( reduce_training_set ) 
    {
	// reduce training set
	random_shuffle ( examples_new.begin(), examples_new.end() );
	int oldsize = (int)examples_new.size();
	int newsize;
	
	if ( training_absolute < 0 ) 
	    newsize = (int)(training_ratio*examples_new.size());
	else
	    newsize = training_absolute;

	Examples::iterator j = examples_new.begin() + newsize;
	examples_new.erase ( j, examples_new.end() );

	fprintf (stderr, "Size of training set randomly reduced from %d to %d\n", oldsize, 
	    (int)examples_new.size() );
    }


    if ( read_cached_prior_structure )
    {
	FPCRandomForests::read ( cached_prior_structure );
    } else {
	if ( learn_ert_with_newclass ) 
	{
	    FPCRandomForests::train ( fp, examples ); 
	} else {
	    FPCRandomForests::train ( fp, examples_transfer ); 
	}
	FPCRandomForests::save ( cached_prior_structure );
    }

    fprintf (stderr, "MAP ESTIMATION sigmaq = %e\n", sigmaq);

    for ( vector<DecisionTree *>::iterator i  = forest.begin();
					   i != forest.end();
					   i++ )
    {
	DecisionTree & tree = *(*i);
	dte.reestimate ( tree, 
			 examples_new,
			 sigmaq, 
			 newClassNo, 
			 muClasses, 
			 substituteClasses,
			 maxClassNo);

	if ( partial_ml_estimation )
	{
	    partialMLEstimate ( tree,
				examples_new,
				newClassNo,
				partial_ml_estimation_depth );
	}

	if ( extend_map_tree )
	{
	    fp.initRandomFeatureSelection ();
	    extendMapTree ( fp, 
			    tree, 
			    examples_transfer, 
			    examples_new,
			    newClassNo,
			    muClasses);
	}
    }

    save ( "map.tree" );

}

FeaturePoolClassifier *FPCRandomForestTransfer::clone () const
{
    fprintf (stderr, "FPCRandomForestTransfer::clone() not yet implemented !\n");
    exit(-1);
}