/**
* @file VCPreRandomForest.cpp
* @brief Combination of a classifier with a pre-clustering using a random forest
* @author Erik Rodner
* @date 06/17/2010

*/

#include "VCPreRandomForest.h"

#include <iostream>

#include <vislearning/cbaselib/VectorFeature.h>

#include "core/image/ImageT.h"
//#include "core/imagedisplay/ImageDisplay.h"

using namespace OBJREC;
using namespace std;
using namespace NICE;


VCPreRandomForest::VCPreRandomForest( const Config *conf, const std::string & section, VecClassifier *_leafClassifierPrototype )
    : leafClassifierPrototype(_leafClassifierPrototype), fp(conf)
{
  string cluster_section = conf->gS(section, "cluster_section", "RandomForest");
  mEx = conf->gI("DTBRandom", "min_examples", numeric_limits<int>::max());
  mEx = 500;
  randomforest = new FPCRandomForests ( conf, cluster_section );
}

VCPreRandomForest::~VCPreRandomForest()
{
  // delete the random forest
  if ( randomforest != NULL )
    delete randomforest;

  // delete all classifiers in the leafs
  for ( map<DecisionNode *, VecClassifier *>::const_iterator i = leafClassifiers.begin();
        i != leafClassifiers.end(); i++ )
  {
    VecClassifier *lc = i->second;
    delete lc;
  }
}

ClassificationResult VCPreRandomForest::classify ( const NICE::Vector & x ) const
{
  NICE::Vector *v = new NICE::Vector(x);
  Example example(v);

  vector<DecisionNode *> leafNodes;

  // traverse the forest and obtain all involved leaf nodes
  randomforest->getLeafNodes(example, leafNodes);

  ClassificationResult r ( ClassificationResult::REJECTION_NONE, maxClassNo );
  r.scores.set(0.0);

  for ( vector<DecisionNode *>::const_iterator i = leafNodes.begin();
        i != leafNodes.end(); i++ )
  {
    DecisionNode *node = *i;
    map<DecisionNode *, VecClassifier *>::const_iterator leafClassifierIt =
      leafClassifiers.find ( node );

    if ( leafClassifierIt == leafClassifiers.end() ) {
      // this leaf has no associated classifier
      // -> we will use the random forest "score" :)
      //
      double sum = node->distribution.sum();
      for (uint k = 0; k < (uint)std::min(node->distribution.size(), r.scores.size());k++)
      {
        r.scores[k] += node->distribution[k] / sum;
      }

      //fthrow(Exception, "Unable to find this leaf node !! (implementation bug)");
      continue;
    }

    VecClassifier *leafClassifier = leafClassifierIt->second;
    ClassificationResult rSingle = leafClassifier->classify ( x );

    rSingle.scores.normalize();
    for (uint k = 0; k < (uint)std::min(rSingle.scores.size(), r.scores.size());k++)
    {
      r.scores[k] += rSingle.scores[k];
    }
  }

  r.scores.multiply ( 1.0 / (leafNodes.size()) );
  r.classno = r.scores.maxElement();

  if ( fabs(r.scores.sum() - 1.0) > 1e-2 )
  {
    //fthrow(Exception, "Ups !\n");
    r.scores[0] = 1.0;
  }

  example.clean();

  return r;
}

void VCPreRandomForest::teach ( const LabeledSetVector & teachSet )
{
  Examples examples;
  maxClassNo = teachSet.getMaxClassno();

  LOOP_ALL(teachSet)
  {
    EACH(classno, x);
    NICE::Vector *v = new Vector(x);
    examples.push_back( pair<int, Example> (classno, Example(v)));
  }

  uint dimension = teachSet.dimension();
  fp.clear();
  Feature *f = new VectorFeature(dimension);
  f->explode(fp);

  // train the forest
  randomforest->setMaxClassNo( teachSet.getMaxClassno()  );
  randomforest->train ( fp, examples );
  // free some useless memory, we do not need this
  // data structure any more
  examples.clean();

  vector<DecisionNode *> leafNodes;
  randomforest->getAllLeafNodes ( leafNodes );

  int lsize = leafNodes.size();
  cout << "leafnodes: " << lsize << endl;
  int leafNo = 0;
#pragma omp parallel for
  for ( int l = 0; l < lsize; l++)
  {
    cerr << "Training classifier for leaf " << leafNo << endl;
    leafNo++;

    DecisionNode *node = leafNodes[l];

    if ( node->distribution.entropy() <= 0.0) continue;
    if ( ! node->isLeaf() ) continue;

    vector<int> examplesSet = node->trainExamplesIndices;
    assert(examplesSet.size() > 0);

    sort (examplesSet.begin(), examplesSet.end());

    LabeledSetVector trainSubSet;

    vector<double> counter(maxClassNo, 0.0);
    uint exampleIndex = 0;
    uint c = 0;

    LOOP_ALL(teachSet)
    {
      EACH(classno, x);
      if ( examplesSet[c] == exampleIndex )
      {
        c++;
        trainSubSet.add ( classno, x );
      }
      exampleIndex++;
    }

    VecClassifier *lc = leafClassifierPrototype->clone();

    lc->teach ( trainSubSet );

    leafClassifiers.insert ( pair<DecisionNode *, VecClassifier *> ( node, lc ) );
  }
}

void VCPreRandomForest::clear()
{
  map<DecisionNode *, VecClassifier *>::iterator iter;
  for ( iter = leafClassifiers.begin(); iter != leafClassifiers.end(); ++iter )
  {
    iter->second->clear();
  }
  randomforest->clear();
}