/** 
* @file DecisionTree.cpp
* @brief decision tree implementation
* @author Erik Rodner
* @date 04/24/2008

*/
#include <iostream>

#include "vislearning/classifier/fpclassifier/randomforest/DecisionTree.h"

#include "vislearning/features/fpfeatures/createFeatures.h"

using namespace OBJREC;

using namespace std;
// refactor-nice.pl: check this substitution
// old: using namespace ice;
using namespace NICE;


DecisionTree::DecisionTree( const Config *_conf, int _maxClassNo ) : conf(_conf)
{
    root = NULL;
    maxClassNo = _maxClassNo;
}

DecisionTree::~DecisionTree()
{
    deleteNodes ( root );
}

void DecisionTree::statistics ( int & depth, int & count ) const
{
    if ( root == NULL )
    {
	depth = 0;
	count = 0; 
    } else {
	root->statistics ( depth, count );
    }
}

void DecisionTree::traverse ( 
		    const Example & ce, 
		    FullVector & distribution )
{
    assert(root != NULL);
    root->traverse ( ce, distribution );
}

void DecisionTree::deleteNodes ( DecisionNode *tree )
{
    if ( tree != NULL )
    {
	deleteNodes ( tree->left );
	deleteNodes ( tree->right );
	delete tree;
    }
}

void DecisionTree::restore (istream & is, int format)
{
    // indexing
    map<long, DecisionNode *> index;
    map<long, pair<long, long> > descendants;

    index.insert ( pair<long, DecisionNode *> ( 0, (DecisionNode*)NULL ) );

    // refactor-nice.pl: check this substitution
    // old: string tag;
    std::string tag;

    while ( (! is.eof()) && ( (is >> tag) && (tag == "NODE") ) )
    {
		long ind;
		long ind_l;
		long ind_r;
		if (! (is >> ind)) break;
		if (! (is >> ind_l)) break;
		if (! (is >> ind_r)) break;
	
		descendants.insert ( pair<long, pair<long, long> > ( ind, pair<long, long> ( ind_l, ind_r ) ) );
		DecisionNode *node = new DecisionNode();
		index.insert ( pair<long, DecisionNode *> ( ind, node ) );
	
		std::string feature_tag;
	
		is >> feature_tag;
		if ( feature_tag != "LEAF" )
		{
			node->f = createFeatureFromTag ( conf, feature_tag );
			if ( node->f == NULL ) 
			{
				fprintf (stderr, "Unknown feature description: %s\n",
					feature_tag.c_str() );
				exit(-1);
			}
			node->f->restore ( is, format );
			is >> node->threshold;
		}
	
		FullVector distribution ( maxClassNo+1 );
		int classno;
		double score;
	
		//distribution.restore ( is );
		is >> classno;
		while ( classno >= 0 )
		{
			is >> score;
			if ( classno > maxClassNo )
			{
			fprintf (stderr, "classno: %d; maxClassNo: %d\n", classno, maxClassNo);
			exit(-1);
			}
			distribution[classno] = score;
			is >> classno;
		}
		//distribution.store(cerr);
		node->distribution = distribution;
    }

    // connecting the tree
    for ( map<long, DecisionNode *>::const_iterator i = index.begin();
	     i != index.end(); i++ )
    {
	DecisionNode *node = i->second;

	if ( node == NULL ) continue;

	long ind_l = descendants[i->first].first;
	long ind_r = descendants[i->first].second;

	map<long, DecisionNode *>::const_iterator il = index.find ( ind_l );
	map<long, DecisionNode *>::const_iterator ir = index.find ( ind_r );

	if ( ( il == index.end() ) || ( ir == index.end() ) )
	{
	    fprintf (stderr, "File inconsistent: unable to build tree\n");
	    exit(-1);
	}

	DecisionNode *left = il->second;
	DecisionNode *right = ir->second;

	node->left = left;
	node->right = right;
    }
	
    map<long, DecisionNode *>::const_iterator iroot = index.find ( 1 );

    if ( iroot == index.end() ) 
    {
	fprintf (stderr, "File inconsistent: unable to build tree (root node not found)\n");
	exit(-1);
    }

    root = iroot->second;
}

void DecisionTree::store (ostream & os, int format) const
{
    if ( root == NULL ) return;

    // indexing
    map<DecisionNode *, pair<long, int> > index;

    index.insert ( pair<DecisionNode *, pair<long, int> > ( (DecisionNode*)NULL, pair<long, int> ( 0, 0 ) ) );
    index.insert ( pair<DecisionNode *, pair<long, int> > ( root, pair<long, int> ( 1, 0 ) ) );
    long maxindex = 1;
    root->indexDescendants ( index, maxindex, 0 );

    for ( map<DecisionNode *, pair<long, int> >::iterator i  = index.begin();
		i != index.end();
		i++ )
    {
	DecisionNode *node = i->first;

	if ( node == NULL ) continue;

	long ind = i->second.first;
	long ind_l = index[ node->left ].first;
	long ind_r = index[ node->right ].first;

	os << "NODE " << ind << " " << ind_l << " " << ind_r << endl;

	Feature *f = node->f;

	if ( f != NULL ) {
	    f->store ( os, format );
	    os << endl;
	    os << node->threshold;
	    os << endl;
	} else {
	    os << "LEAF";
	    os << endl;
	}

	const FullVector & distribution = node->distribution;

	for ( int i = 0 ; i < distribution.size() ; i++ )
	    os << i << " " << distribution[i] << " ";
	os << -1 << endl;
	//distribution.store ( os );
    }
}

void DecisionTree::clear ()
{
    deleteNodes ( root );
}

void DecisionTree::resetCounters ()
{
    if ( root != NULL )
		root->resetCounters ();
}
	
void DecisionTree::indexDescendants ( map<DecisionNode *, pair<long, int> > & index, long & maxindex ) const
{
    if ( root != NULL )
		root->indexDescendants ( index, maxindex, 0 );
}

DecisionNode *DecisionTree::getLeafNode ( Example & pce, int maxdepth )
{
    return root->getLeafNode ( pce, maxdepth );
}

void DecisionTree::getLeaves(DecisionNode *node, vector<DecisionNode*> &leaves)
{
	if(node->left == NULL && node->right == NULL)
	{
		leaves.push_back(node);
		return;
	}
	getLeaves(node->right, leaves);
	getLeaves(node->left, leaves);
}

vector<DecisionNode *> DecisionTree::getAllLeafNodes()
{
	vector<DecisionNode*> leaves;
	getLeaves(root, leaves);
	return leaves;
}

DecisionNode *DecisionTree::pruneTreeEntropy ( DecisionNode *node, double minEntropy )
{
    if ( node == NULL ) return NULL;
    double entropy = node->distribution.entropy();
    if ( entropy < minEntropy )
    {
		deleteNodes ( node );
		return NULL;
    } else {
		node->left = pruneTreeEntropy ( node->left, minEntropy );
		node->right = pruneTreeEntropy ( node->right, minEntropy );
		return node;
    }
}

DecisionNode *DecisionTree::pruneTreeScore ( DecisionNode *node, double minScore )
{
    if ( node == NULL ) return NULL;
    double score = node->distribution.max();
    if ( score < minScore )
    {
		deleteNodes ( node );
		return NULL;
    } else {
		node->left = pruneTreeScore ( node->left, minScore );
		node->right = pruneTreeScore ( node->right, minScore );
		return node;
    }
}

void DecisionTree::pruneTreeScore ( double minScore )
{
    int depth, count;
    statistics ( depth, count );
    fprintf (stderr, "DecisionTree::pruneTreeScore: depth %d count %d\n", depth, count );
    root = pruneTreeScore ( root, minScore );
    statistics ( depth, count );
    fprintf (stderr, "DecisionTree::pruneTreeScore: depth %d count %d (modified)\n", depth, count );
}

void DecisionTree::pruneTreeEntropy ( double minEntropy )
{
    int depth, count;
    statistics ( depth, count );
    fprintf (stderr, "DecisionTree::entropyTreeScore: depth %d count %d\n", depth, count );
    root = pruneTreeEntropy ( root, minEntropy );
    statistics ( depth, count );
    fprintf (stderr, "DecisionTree::entropyTreeScore: depth %d count %d (modified)\n", depth, count );
}

void DecisionTree::normalize (DecisionNode *node)
{
    if ( node != NULL )
    {
	node->distribution.normalize();
	normalize ( node->left );
	normalize ( node->right );
    }
}

void DecisionTree::normalize ()
{
    normalize ( root );
}

void DecisionTree::setRoot ( DecisionNode *newroot )
{
    root = newroot;
}