/** 
* @file DTBRandom.cpp
* @brief random decision tree
* @author Erik Rodner
* @date 05/06/2008

*/
#include <iostream>

#include "vislearning/classifier/fpclassifier/randomforest/DTBRandom.h"

using namespace OBJREC;

#undef DEBUGTREE
#undef DETAILTREE


using namespace std;
using namespace NICE;

DTBRandom::DTBRandom( const Config *conf, std::string section )
{
    random_split_tests = conf->gI(section, "random_split_tests", 10 );
    random_features = conf->gI(section, "random_features", 500 );
    max_depth = conf->gI(section, "max_depth", 10 );
    minimum_information_gain = conf->gD(section, "minimum_information_gain", 10e-7 );
    minimum_entropy = conf->gD(section, "minimum_entropy", 10e-5 );
    use_shannon_entropy = conf->gB(section, "use_shannon_entropy", false );
    min_examples = conf->gI(section, "min_examples", 50);
    save_indices = conf->gB(section, "save_indices", false);

	if ( conf->gB(section, "start_random_generator", false ) )
	    srand(time(NULL));
}

DTBRandom::~DTBRandom()
{
}

bool DTBRandom::entropyLeftRight ( const FeatureValuesUnsorted & values,
		     double threshold,
		     double* stat_left,
		     double* stat_right,
		     double & entropy_left,
		     double & entropy_right,
		     double & count_left,
		     double & count_right,
		     int maxClassNo ) 
{
	count_left = 0;
	count_right = 0;
	for ( FeatureValuesUnsorted::const_iterator i = values.begin(); i != values.end(); i++ )
	{
		int classno = i->second;
		double value = i->first;
		if ( value < threshold ) {
			stat_left[classno] += i->fourth; 
			count_left+=i->fourth; 
		}
		else
		{
			stat_right[classno] += i->fourth; 
			count_right+=i->fourth;
		}
	}

	if ( (count_left == 0) || (count_right == 0) )
	   return false;

	entropy_left = 0.0;
	for ( int j = 0 ; j <= maxClassNo ; j++ )
	   if ( stat_left[j] != 0 )
		   entropy_left -= stat_left[j] * log(stat_left[j]);
	entropy_left /= count_left;
	entropy_left += log(count_left);

	entropy_right = 0.0;
	for ( int j = 0 ; j <= maxClassNo ; j++ )
	   if ( stat_right[j] != 0 )
			entropy_right -= stat_right[j] * log(stat_right[j]);
	entropy_right /= count_right;
	entropy_right += log (count_right);

	return true;
}

DecisionNode *DTBRandom::buildRecursive ( const FeaturePool & fp, 
		     const Examples & examples,
		     vector<int> & examples_selection,
		     FullVector & distribution,
		     double e,
		     int maxClassNo,
		     int depth )
{
#ifdef DEBUGTREE
    fprintf (stderr, "Examples: %d (depth %d)\n", (int)examples_selection.size(),
		(int)depth);
#endif

    DecisionNode *node = new DecisionNode ();
    node->distribution = distribution;

    if ( depth > max_depth ) {
#ifdef DEBUGTREE
		fprintf (stderr, "DTBRandom: maxmimum depth reached !\n");
#endif
		node->trainExamplesIndices = examples_selection;
		return node;
    }

    if ( (int)examples_selection.size() < min_examples ) {
#ifdef DEBUGTREE
		fprintf (stderr, "DTBRandom: minimum examples reached %d < %d !\n",
			(int)examples_selection.size(), min_examples );
#endif
		node->trainExamplesIndices = examples_selection;
		return node;
    }

	// REALLY BAD FIXME
    if ( (e <= minimum_entropy) && (e != 0.0) ) {
    //if ( e <= minimum_entropy ) {
#ifdef DEBUGTREE
		fprintf (stderr, "DTBRandom: minimum entropy reached !\n");
#endif
		node->trainExamplesIndices = examples_selection;
		return node;
    }

    Feature *best_feature = NULL;
    double best_threshold = 0.0;
    double best_ig = -1.0;
    FeatureValuesUnsorted best_values;
    FeatureValuesUnsorted values;
    double *best_distribution_left = new double [maxClassNo+1];
    double *best_distribution_right = new double [maxClassNo+1];
    double *distribution_left = new double [maxClassNo+1];
    double *distribution_right = new double [maxClassNo+1];
    double best_entropy_left = 0.0;
    double best_entropy_right = 0.0;

    for ( int k = 0 ; k < random_features ; k++ )
    {
#ifdef DETAILTREE
		fprintf (stderr, "calculating random feature %d\n", k );
#endif
		Feature *f = fp.getRandomFeature ();

		values.clear();
		f->calcFeatureValues ( examples, examples_selection, values );

		double minValue = (min_element ( values.begin(), values.end() ))->first;
		double maxValue = (max_element ( values.begin(), values.end() ))->first;

#ifdef DETAILTREE
		fprintf (stderr, "max %f min %f\n", maxValue, minValue );
#endif
		if ( maxValue - minValue < 1e-7 ) continue;

		for ( int i = 0 ; i < random_split_tests ; i++ )
		{
			double threshold;
			threshold = rand() * (maxValue - minValue ) / RAND_MAX + minValue;

#ifdef DETAILTREE
			fprintf (stderr, "calculating split f/s(f) %d/%d %f\n", k, i, threshold );
#endif
			double el, er;
			// clear distribution
			for ( int k = 0 ; k <= maxClassNo ; k++ )
			{
				distribution_left[k] = 0;
				distribution_right[k] = 0;
			}

			double count_left;
			double count_right;
			if ( ! entropyLeftRight ( values, threshold,
					   distribution_left, distribution_right,
					   el, er, count_left, count_right, maxClassNo ) )
				continue;
			
			double pl = (count_left) / (count_left + count_right);
			double ig = e - pl*el - (1-pl)*er;

			if ( use_shannon_entropy )
			{
				double esplit = - ( pl*log(pl) + (1-pl)*log(1-pl) );
				ig = 2*ig / ( e + esplit );
			}

#ifdef DETAILTREE
			fprintf (stderr, "ig %f el %f er %f e %f\n", ig, el, er, e );
			assert ( ig >= -1e-7 );
#endif
			
			if ( ig > best_ig )
			{
				best_ig = ig;
				best_threshold = threshold;
#ifdef DETAILTREE
				fprintf (stderr, "t %f\n", best_threshold );
#endif
				best_feature = f;		
				for ( int k = 0 ; k <= maxClassNo ; k++ )
				{
					best_distribution_left[k] = distribution_left[k];
					best_distribution_right[k] = distribution_right[k];
				}
				best_entropy_left = el;
				best_entropy_right = er;
			}
		}
    }

    delete [] distribution_left;
    delete [] distribution_right;

    if ( best_ig < minimum_information_gain )
    {
#ifdef DEBUGTREE
		fprintf (stderr, "DTBRandom: minimum information gain reached !\n");
#endif
		delete [] best_distribution_left;
		delete [] best_distribution_right;
		node->trainExamplesIndices = examples_selection;
		return node;
    }

    node->f = best_feature->clone();
    node->threshold = best_threshold;

    // re calculating examples_left and examples_right
    vector<int> best_examples_left;
    vector<int> best_examples_right;
    values.clear();
    best_feature->calcFeatureValues ( examples, examples_selection, values );

    best_examples_left.reserve ( values.size() / 2 );
    best_examples_right.reserve ( values.size() / 2 );
    for ( FeatureValuesUnsorted::const_iterator i = values.begin();
				       i != values.end();
				       i++ )
    {
       double value = i->first;
       if ( value < best_threshold ) {
			best_examples_left.push_back ( i->third );
       } else {
			best_examples_right.push_back ( i->third );
       }
    }

#ifdef DEBUGTREE
    node->f->store(cerr);
    cerr << endl;
    fprintf (stderr, "mutual information / shannon entropy %f entropy %f, left entropy %f right entropy %f\n", best_ig, e, best_entropy_left, 
	best_entropy_right );
#endif
    
    FullVector best_distribution_left_sparse ( distribution.size() );
    FullVector best_distribution_right_sparse ( distribution.size() );
    for ( int k = 0 ; k <= maxClassNo ; k++ )
    {
		double l = best_distribution_left[k];
		double r = best_distribution_right[k];
		if ( l != 0 )
			best_distribution_left_sparse[k] = l;
		if ( r != 0 )
			best_distribution_right_sparse[k] = r;
#ifdef DEBUGTREE
		if ( (l>0)||(r>0) )
			fprintf (stderr, "DTBRandom: split of class %d (%f <-> %f)\n", k, l, r );
#endif
    }

    delete [] best_distribution_left;
    delete [] best_distribution_right;

    node->left = buildRecursive ( fp, examples, best_examples_left,
		         best_distribution_left_sparse, best_entropy_left, maxClassNo, depth+1 );

    node->right = buildRecursive ( fp, examples, best_examples_right,
		         best_distribution_right_sparse, best_entropy_right, maxClassNo, depth+1 );
			 
    return node;
}


DecisionNode *DTBRandom::build ( const FeaturePool & fp, 
		     const Examples & examples,
		     int maxClassNo )
{
    int index = 0;

    fprintf (stderr, "Feature Statistics (Geurts et al.): N=%d sqrt(N)=%lf K=%d\n",
	(int)fp.size(), sqrt((double)fp.size()), random_split_tests*random_features );

    FullVector distribution ( maxClassNo+1 );
    vector<int> all;

    all.reserve ( examples.size() );
    for ( Examples::const_iterator j = examples.begin();
				    j != examples.end();
				    j++ )
    {
       int classno = j->first;
       distribution[classno] += j->second.weight;
       
       all.push_back ( index );

       index++;
    }
    
    double entropy = 0.0;
    double sum = 0.0;
    for ( int i = 0 ; i < distribution.size(); i++ )
    {
		double val = distribution[i];
		if ( val <= 0.0 ) continue;
			entropy -= val*log(val);
		sum += val;
    }
    entropy /= sum;
    entropy += log(sum);
    

    return buildRecursive ( fp, examples, all, distribution, entropy, maxClassNo, 0 );
}