/**
* @file FPCRandomForests.h
* @brief implementation of random set forests
* @author Erik Rodner
* @date 04/24/2008

*/
#ifndef FPCRANDOMFORESTSINCLUDE
#define FPCRANDOMFORESTSINCLUDE

#include <vector>

#include "core/vector/VectorT.h"
#include "core/vector/MatrixT.h"

#include "vislearning/classifier/classifierbase/FeaturePoolClassifier.h"
#include "vislearning/cbaselib/FeaturePool.h"
#include "DecisionTree.h"

#include "DecisionTreeBuilder.h"


namespace OBJREC
{

/** implementation of random set forests */
class FPCRandomForests : public FeaturePoolClassifier
{
  protected:
    /** vector containing all decision trees */
    std::vector<DecisionTree *> forest;

    /** number of trees which will be generated in the
        during training */
    int number_of_trees;

    /** fraction of features used for each tree */
    double features_per_tree;

    /** fraction of training examples used for each tree */
    double samples_per_tree;

    /** use an equal number of training examples of each class
        to build a single tree */
    bool use_simple_balancing;

    /** weight examples according to a priori class probabilities
        as estimated using the distribution contained in the training data */
    bool weight_examples;

    /** if >0 then prune the trees using pruneTreeEntropy */
    double minimum_entropy;

    /** clear all examples after building a tree, this deletes
        all cached images contained in CachedExample etc. */
    bool memory_efficient;

    /** stored config to initialize a tree */
    const NICE::Config *conf;

    /** config section containing important config values */
    std::string confsection;

    /** pointer to the tree builder method */
    DecisionTreeBuilder *builder;

    /** out-of-bag statistics */
    bool enableOutOfBagEstimates;
    std::vector<std::pair<double, int> > oobResults;

    /** classify using only a subset of all trees */
    ClassificationResult classify ( Example & pce,
                                    const std::vector<int> & outofbagtrees );

    /** calculate out-of-bag statistics */
    void calcOutOfBagEstimates ( std::vector< std::vector<int> > & outofbagtrees,
                                 Examples & examples );

    /** save example selection per tree */
    std::vector<std::vector<int> > exselection;

  public:

    /** initialize the classifier */
    FPCRandomForests ( const NICE::Config *conf,
                       std::string section );

    /** do nothing */
    FPCRandomForests ();

    /** simple destructor */
    virtual ~FPCRandomForests();

    /** main classification function */
    ClassificationResult classify ( Example & pce );
    int classify_optimize ( Example & pce );

    /** get all leaf nodes for an given example (or inner nodes if depth is set to the level) */
    void getLeafNodes ( Example & pce,
                        std::vector<DecisionNode *> & leafNodes,
                        int depth = 100000 );
    /** get all leaf nodes (or inner nodes if depth is set to the level) */
    void getAllLeafNodes ( std::vector<DecisionNode *> & leafNodes );

    /** perform training using a given feature pool and some training data */
    virtual void train ( FeaturePool & fp,
                         Examples & examples );

    /** enumerate all nodes within the trees */
    void indexDescendants ( std::map<DecisionNode *, std::pair<long, int> > & index ) const;

    /** reset all counters in all nodes contained in the forest */
    void resetCounters ();

    /** direct access to all trees */
    const std::vector<DecisionTree *> & getForest () const
    {
      return forest;
    };

    /** direct write access to all trees */
    std::vector<DecisionTree *> & getForestNonConst ()
    {
      return forest;
    };

    /** clone this object */
    FeaturePoolClassifier *clone () const;

    /** get out of bag estimates */
    std::vector<std::pair<double, int> > & getOutOfBagResults ()
    {
      return oobResults;
    };

    /** set the number of trees */
    void setComplexity ( int size );

    /** IO functions */
    void restore ( std::istream & is, int format = 0 );
    void store ( std::ostream & os, int format = 0 ) const;
    void clear ();


};


} // namespace

#endif