/**
* @file FPCRandomForests.cpp
* @brief implementation of random set forests
* @author Erik Rodner
* @date 04/24/2008

*/

#ifdef NICE_USELIB_OPENMP
#include <omp.h>
#endif

#include <iostream>

#include "core/image/ImageT.h"
#include "core/imagedisplay/ImageDisplay.h"

#include "vislearning/classifier/fpclassifier/randomforest/FPCRandomForests.h"
#include "vislearning/classifier/fpclassifier/randomforest/DTBStandard.h"
#include "vislearning/classifier/fpclassifier/randomforest/DTBRandom.h"
#include "vislearning/classifier/fpclassifier/randomforest/DTBClusterRandom.h"
#include "vislearning/cbaselib/FeaturePool.h"

using namespace OBJREC;

using namespace std;

using namespace NICE;


FPCRandomForests::FPCRandomForests()
{
	builder = NULL;
	minimum_entropy = 0.0;
	enableOutOfBagEstimates = false;
	maxClassNo = -1;
}

FPCRandomForests::FPCRandomForests(const Config *_conf, std::string section) : conf(_conf)
{
	std::string builder_method = conf->gS(section, "builder", "random");
	minimum_entropy = conf->gD(section, "minimum_entropy", 0.0);
	enableOutOfBagEstimates = conf->gB(section, "enable_out_of_bag_estimates", false);
	maxClassNo = -1;

	confsection = section;

    if ( builder_method == "none" ) {
		// do not initialize 
		builder = NULL;
    } else {
		number_of_trees = conf->gI(section, "number_of_trees", 20 );
		features_per_tree = conf->gD(section, "features_per_tree", 1.0 );
		samples_per_tree  = conf->gD(section, "samples_per_tree", 0.2 );
		use_simple_balancing = conf->gB(section, "use_simple_balancing", false);
		weight_examples = conf->gB(section, "weight_examples", false);
		memory_efficient = conf->gB(section, "memory_efficient", false);

		std::string builder_section = conf->gS(section, "builder_section", "DTBRandom");

		if ( builder_method == "standard" )
			builder = new DTBStandard ( conf, builder_section );
		else if (builder_method == "random" )
			builder = new DTBRandom ( conf, builder_section );
		else if (builder_method == "cluster_random" )
			builder = new DTBClusterRandom ( conf, builder_section );
		else {
			fprintf (stderr, "DecisionTreeBuilder %s not yet implemented !\n", builder_method.c_str() );
			exit(-1);
		}
    }
}

FPCRandomForests::~FPCRandomForests()
{
    for ( vector<DecisionTree *>::iterator i = forest.begin();
				       i != forest.end();
				       i++ )
		delete (*i);

    if ( builder != NULL )
		delete builder;
}

void FPCRandomForests::calcOutOfBagEstimates ( 
    vector< vector<int> > & outofbagtrees, 
    Examples & examples )
{
    oobResults.clear ();

	// calculate out of bag classification results
	// as suggested by Breiman
	// out of bag = training data not used to build
	// a single tree is used as testing data for this tree
    long index = 0;
    for ( Examples::iterator k = examples.begin();
		k != examples.end(); k++, index++ )
    {
		int classno_groundtruth = k->first;
		const vector<int> & trees = outofbagtrees[index];

		if ( trees.size() <= 0 ) continue;

		ClassificationResult r = classify ( k->second, trees );

		// FIXME: assumption negative class dst is 0
		double score = r.scores.get( 0 /*negativeClassDST*/);
		oobResults.push_back ( pair<double, int> ( score, classno_groundtruth ) );
    }
}

void FPCRandomForests::getAllLeafNodes ( vector<DecisionNode *> & leafNodes)
{
	//leafNodes.reserve ( forest.size() );
	int z = 0;
	for ( vector<DecisionTree *>::const_iterator i = forest.begin();
		     i != forest.end();
		     i++,z++ )
	{
		DecisionTree & dt = *(*i);
		vector<DecisionNode *> leaves = dt.getAllLeafNodes();
		for(int j = 0; j < (int)leaves.size(); j++)
		{
			for(int k = 0; k < leaves[j]->trainExamplesIndices.size(); k++)
			{
				leaves[j]->trainExamplesIndices[k] = exselection[z][leaves[j]->trainExamplesIndices[k]];
			}
			leafNodes.push_back(leaves[j]);
		}
	}
}

void FPCRandomForests::getLeafNodes ( Example & pce,
		    vector<DecisionNode *> & leafNodes,
		    int depth )
{
    leafNodes.reserve ( forest.size() );
    for ( vector<DecisionTree *>::const_iterator i = forest.begin();
					    i != forest.end();
					    i++ )
    {
		DecisionTree & dt = *(*i);
		DecisionNode *leaf = dt.getLeafNode ( pce, depth );
		leafNodes.push_back ( leaf );
    }
}

ClassificationResult FPCRandomForests::classify ( Example & pce,
						  const vector<int> & outofbagtrees )
{
	// classify using only a selection of all trees
	// contained in outofbagtrees
    
	FullVector overall_distribution;
    for ( vector<int>::const_iterator  i = outofbagtrees.begin();
				       i != outofbagtrees.end();
				       i++ )
    {
		assert ( *i < (int)forest.size() );
		DecisionTree & dt = *(forest[(*i)]);
		FullVector distribution;
		dt.traverse ( pce, distribution );
		distribution.normalize();

		if ( overall_distribution.empty() )
			overall_distribution = distribution;
		else
			overall_distribution.add ( distribution );
    }

	overall_distribution.normalize();

	int classno = overall_distribution.maxElement();

	return ClassificationResult(classno, overall_distribution);
}

ClassificationResult FPCRandomForests::classify(Example & pce)
{
	FullVector overall_distribution;

	for (vector<DecisionTree *>::const_iterator i = forest.begin();
						i != forest.end();
						i++)
	{
		DecisionTree & dt = *(*i);
		FullVector distribution;
		dt.traverse(pce, distribution);
		distribution.normalize();
		
		if (overall_distribution.empty())
			overall_distribution = distribution;
		else
			overall_distribution.add(distribution);
	}

	overall_distribution.normalize();

	int classno = overall_distribution.maxElement();

	return ClassificationResult(classno, overall_distribution);

}

int FPCRandomForests::classify_optimize(Example & pce)
{
	FullVector overall_distribution;

	for (vector<DecisionTree *>::const_iterator i = forest.begin();
						i != forest.end();
						i++)
	{
		DecisionTree & dt = *(*i);
		FullVector distribution;
		dt.traverse(pce, distribution);

		if (overall_distribution.empty())
			overall_distribution = distribution;
		else
			overall_distribution.add(distribution);

	}

	return overall_distribution.maxElement();
}

void FPCRandomForests::train(FeaturePool & fp, Examples & examples)
{
	cerr << "FPCRandomForests::train()" << endl;
	assert(builder != NULL);

	if (maxClassNo < 0)
		maxClassNo = examples.getMaxClassNo();

	FullVector example_distribution(maxClassNo + 1);

	map<int, vector<int> > class_examples;

	long index = 0;

	for (Examples::const_iterator i = examples.begin(); i != examples.end(); i++, index++)
	{
		int classno = i->first;
		example_distribution[classno]++;
		class_examples[classno].push_back(index);
	}

	if (weight_examples)
	{
		for (Examples::iterator i = examples.begin();
							i != examples.end();
							i++, index++)
			i->second.weight = examples.size() / example_distribution[i->first];
	}


	double minExamples = (double)examples.size();

	int minExamplesClassNo = 0;

	for (int i = 0 ; i < example_distribution.size() ; i++)
	{
		double val = example_distribution[i];

		if (minExamples > val && val != 0.0)
		{
			minExamples = val;
			minExamplesClassNo = i;
		}
	}

	fprintf(stderr, "FPCRandomForests: minimum number of examples: %f (classno: %d)\n", minExamples, minExamplesClassNo);

	int featuresCount = (int)(fp.size() * features_per_tree);
	fprintf(stderr, "FPCRandomForests: number of features %d\n", (int)fp.size());

	vector< vector<int> > outofbagtrees;
	outofbagtrees.resize(examples.size());
	
	for (int k = 0 ; k < number_of_trees ; k++)
	{
		vector<int> tmp;
		exselection.push_back(tmp);
	}
	
#pragma omp parallel for
	for (int k = 0 ; k < number_of_trees ; k++)
	{
		fprintf(stderr, "[ -- building tree %d/%d -- ]\n", k + 1, number_of_trees);

		FeaturePool fp_subset;
		Examples examples_subset;

		for (map<int, vector<int> >::const_iterator j  = class_examples.begin();
							j != class_examples.end(); j++)
		{
			vector<int> examples_index ( j->second );
			int trainingExamples;

			if (use_simple_balancing)
				trainingExamples = (int)(minExamples * samples_per_tree);
			else
				trainingExamples = (int)(examples_index.size() * samples_per_tree);
		
			fprintf (stderr, "FPCRandomForests: selection of %d examples for each tree\n", trainingExamples );
		
			if ( (trainingExamples < 3) && ((int)examples_index.size() > trainingExamples) )
			{
				fprintf(stderr, "FPCRandomForests: number of examples < 3 !! minExamples=%f, trainingExamples=%d\n",
												minExamples, trainingExamples);
				trainingExamples = examples_index.size();
				fprintf(stderr, "FPCRandomForests: I will use all %d examples of this class !!\n", trainingExamples);
			}

			// TODO: optional include examples weights
			if(samples_per_tree < 1.0)
				random_shuffle(examples_index.begin(), examples_index.end());

			examples_subset.reserve(examples_subset.size() + trainingExamples);
						
			for (int e = 0 ; e < trainingExamples ; e++)
			{
				examples_subset.push_back(examples[examples_index[e]]);
				exselection[k].push_back(examples_index[e]);
			}
			

			// set out of bag trees
			for (uint e = trainingExamples; e < examples_index.size() ; e++)
			{
				int index = examples_index[e];
				#pragma omp critical
				outofbagtrees[index].push_back(k);
			}
		}

		/******* select a random feature set *******/
		FeaturePool fpTree ( fp );
		

		int featuresCountT = featuresCount;

		if (featuresCountT >= (int)fpTree.size()) featuresCountT = fpTree.size();

		random_shuffle(fpTree.begin(), fpTree.end());
		if (featuresCountT < (int)fpTree.size())
		{
			fp_subset.insert(fp_subset.begin(), fpTree.begin(), fpTree.begin() + featuresCountT);
		}
		else
		{
			fp_subset = fpTree;
		}
		fp_subset.initRandomFeatureSelection();

		/******* training of an individual tree ****/
		DecisionTree *tree = new DecisionTree(conf, maxClassNo);

		builder->build(*tree, fp_subset, examples_subset, maxClassNo);

		/******* prune tree using a simple minimum entropy criterion *****/
		if (minimum_entropy != 0.0)
			tree->pruneTreeEntropy(minimum_entropy);

		/******* drop some precalculated data if memory should be saved **/
		#pragma omp critical
		if (memory_efficient)
		{
			set<CachedExample *> alreadyDropped;

			for (Examples::iterator i = examples_subset.begin();
								i != examples_subset.end();
								i++)
			{
				CachedExample *ce = i->second.ce;

				if (alreadyDropped.find(ce) == alreadyDropped.end())
				{
					ce->dropPreCached();
					alreadyDropped.insert(ce);
				}
			}
		}

		/******* add individual tree to ensemble *****/
		#pragma omp critical
		forest.push_back(tree);
	}

	if (enableOutOfBagEstimates)
		calcOutOfBagEstimates(outofbagtrees, examples);
}


void FPCRandomForests::restore(istream & is, int format)
{
    std::string tag;
    int index;

    while ( (is >> tag) && (tag == "TREE") )
    {
		is >> index;
		DecisionTree *dt = new DecisionTree ( conf, maxClassNo );
		dt->restore ( is );
		if ( minimum_entropy != 0.0 )
			dt->pruneTreeEntropy ( minimum_entropy );

		forest.push_back(dt);
	}
}

void FPCRandomForests::store(ostream & os, int format) const
{
    int index = 0;
    for ( vector<DecisionTree *>::const_iterator i = forest.begin();
					    i != forest.end();
					    i++, index++ )
    {
		const DecisionTree & dt = *(*i);
		os << "TREE " << index << endl;
		dt.store ( os, format );
		os << "ENDTREE ";
    }
}

void FPCRandomForests::clear()
{
    for ( vector<DecisionTree *>::iterator i = forest.begin();
					   i != forest.end();
					   i++ )
		delete (*i);

	forest.clear();
}

void FPCRandomForests::indexDescendants(map<DecisionNode *, pair<long, int> > & index) const
{
    long maxindex = 0;
    for ( vector<DecisionTree *>::const_iterator i = forest.begin();
					    i != forest.end();
					    i++ )
		(*i)->indexDescendants ( index, maxindex );
}

void FPCRandomForests::resetCounters()
{
    for ( vector<DecisionTree *>::const_iterator i = forest.begin();
					    i != forest.end();
					    i++ )
		(*i)->resetCounters ();
}

FeaturePoolClassifier *FPCRandomForests::clone() const
{
	FPCRandomForests *o = new FPCRandomForests(conf, confsection);

	o->maxClassNo = maxClassNo;

	return o;
}

void FPCRandomForests::setComplexity(int size)
{
    fprintf (stderr, "FPCRandomForests: set complexity to %d, overwriting current value %d\n", 
		size, number_of_trees );
    number_of_trees = size;
}