/** 
* @file RegressionTree.cpp
* @brief regression tree implementation
* @author Sven Sickert
* @date 06/19/2013

*/
#include <iostream>
#include <assert.h>

#include "vislearning/regression/randomforest/RegressionTree.h"

using namespace OBJREC;

using namespace std;
using namespace NICE;

RegressionTree::RegressionTree( const Config *_conf ) : conf(_conf)
{
  root = NULL;
}

RegressionTree::~RegressionTree()
{
  deleteNodes ( root );
}

void RegressionTree::statistics ( int & depth, int & count ) const
{
  if ( root == NULL )
  {
    depth = 0;
    count = 0;
  } else {
    root->statistics ( depth, count );
  }
}

void RegressionTree::traverse (
          const Vector & x,
          double & predVal )
{
  assert( root != NULL );
  root->traverse ( x, predVal );
}

void RegressionTree::deleteNodes ( RegressionNode *tree )
{
  if ( tree != NULL )
  {
    deleteNodes ( tree->left );
    deleteNodes ( tree->right );
    delete tree;
  }
}

void RegressionTree::clear ()
{
  deleteNodes ( root );
}

void RegressionTree::resetCounters ()
{
  if ( root != NULL )
    root->resetCounters ();
}

void RegressionTree::indexDescendants ( 
          map<RegressionNode *, pair<long, int> > & index,
          long & maxindex ) const
{
    if ( root != NULL )
      root->indexDescendants ( index, maxindex, 0 );
}

RegressionNode *RegressionTree::getLeafNode ( 
          Vector & x,
          int maxdepth )
{
    return root->getLeafNode ( x, maxdepth );
}

void RegressionTree::getLeaves(
          RegressionNode *node,
          vector<RegressionNode*> &leaves)
{
  if(node->left == NULL && node->right == NULL)
  {
    leaves.push_back(node);
    return;
  }
  getLeaves(node->right, leaves);
  getLeaves(node->left, leaves);
}

vector<RegressionNode *> RegressionTree::getAllLeafNodes()
{
  vector<RegressionNode*> leaves;
  getLeaves(root, leaves);
  return leaves;
}

void RegressionTree::setRoot ( RegressionNode *newroot )
{
    root = newroot;
}

RegressionNode *RegressionTree::pruneTreeLeastSquares (
          RegressionNode *node,
          double minErrorReduction,
          double & lsError )
{
  if ( node == NULL )  return NULL;
  
  lsError = node->lsError;
  double leftError, rightError;
  node->left = pruneTreeLeastSquares ( node->left, minErrorReduction, leftError );
  node->right = pruneTreeLeastSquares ( node->right, minErrorReduction, rightError );

  if (node->left != NULL && node->right != NULL)
  {
    if (lsError-leftError-rightError < minErrorReduction)
    {
      deleteNodes( node->left );
      deleteNodes( node->right );
    }
  }
  
  return node;
}

void RegressionTree::pruneTreeLeastSquares ( double minErrorReduction )
{
  int depth, count;
  statistics ( depth, count );
  fprintf (stderr, "RegressionTree::pruneTreeLeastSquares: depth %d count %d\n", depth, count );
  double tmp;
  root = pruneTreeLeastSquares ( root, minErrorReduction, tmp );
  statistics ( depth, count );
  fprintf (stderr, "RegressionTree::pruneTreeLeastSquares: depth %d count %d (modified)\n", depth, count );
}

void RegressionTree::store (ostream & os, int format) const
{
  if ( root == NULL ) return;
  
  // indexing
  map<RegressionNode *, pair<long, int> > index;

  index.insert ( pair<RegressionNode *, pair<long, int> > ( NULL, pair<long, int> ( 0, 0 ) ) );
  index.insert ( pair<RegressionNode *, pair<long, int> > ( root, pair<long, int> ( 1, 0 ) ) );
  long maxindex = 1;
  root->indexDescendants ( index, maxindex, 0 );

  for ( map<RegressionNode *, pair<long, int> >::iterator i  = index.begin();
        i != index.end();
        i++ )
  {
    RegressionNode *node = i->first;

    if ( node == NULL ) continue;

    long ind = i->second.first;
    long ind_l = index[ node->left ].first;
    long ind_r = index[ node->right ].first;

    os << "NODE " << ind << " " << ind_l << " " << ind_r << endl;

    if ( !node->isLeaf() ) {
      os << node->f;
      os << endl;
      os << node->threshold;
      os << endl;
    } else {
      os << "LEAF";
      os << endl;
    }

    os << node->lsError << " " << -1 << endl;
  }
  
}

void RegressionTree::restore (istream & is, int format)
{
  // indexing
  map<long, RegressionNode *> index;
  map<long, pair<long, long> > descendants;

  index.insert ( pair<long, RegressionNode *> ( 0, NULL ) );

  // refactor-nice.pl: check this substitution
  // old: string tag;
  std::string tag;

  while ( (! is.eof()) && ( (is >> tag) && (tag == "NODE") ) )
  {
    long ind;
    long ind_l;
    long ind_r;
    if (! (is >> ind)) break;
    if (! (is >> ind_l)) break;
    if (! (is >> ind_r)) break;
  
    descendants.insert ( pair<long, pair<long, long> > ( ind, pair<long, long> ( ind_l, ind_r ) ) );
    RegressionNode *node = new RegressionNode();
    index.insert ( pair<long, RegressionNode *> ( ind, node ) );
  
    std::string feature_tag;
  
    is >> feature_tag;
    if ( feature_tag != "LEAF" )
    {
      is >> node->f;
      is >> node->threshold;
    }
  
    is >> node->lsError;
  }

  // connecting the tree
  for ( map<long, RegressionNode *>::const_iterator it = index.begin();
       it != index.end(); it++ )
  {
    RegressionNode *node = it->second;

    if ( node == NULL ) continue;

    long ind_l = descendants[it->first].first;
    long ind_r = descendants[it->first].second;

    map<long, RegressionNode *>::const_iterator il = index.find ( ind_l );
    map<long, RegressionNode *>::const_iterator ir = index.find ( ind_r );

    if ( ( il == index.end() ) || ( ir == index.end() ) )
    {
      fprintf (stderr, "File inconsistent: unable to build tree\n");
      exit(-1);
    }

    RegressionNode *left = il->second;
    RegressionNode *right = ir->second;

    node->left = left;
    node->right = right;
  }
  
  map<long, RegressionNode *>::const_iterator iroot = index.find ( 1 );

  if ( iroot == index.end() ) 
  {
    fprintf (stderr, "File inconsistent: unable to build tree (root node not found)\n");
    exit(-1);
  }

  root = iroot->second;
}