/**
 * @file DTBOblique.cpp
 * @brief random oblique decision tree
 * @author Sven Sickert
 * @date 10/15/2014

*/
#include <iostream>
#include <time.h>

#include "DTBOblique.h"
#include "vislearning/features/fpfeatures/ConvolutionFeature.h"

#include "core/vector/Algorithms.h"

using namespace OBJREC;

#define DEBUGTREE


using namespace std;
using namespace NICE;

DTBOblique::DTBOblique ( const Config *conf, string section )
{
    random_split_tests = conf->gI(section, "random_split_tests", 10 );
    max_depth = conf->gI(section, "max_depth", 10 );
    minimum_information_gain = conf->gD(section, "minimum_information_gain", 10e-7 );
    minimum_entropy = conf->gD(section, "minimum_entropy", 10e-5 );
    use_shannon_entropy = conf->gB(section, "use_shannon_entropy", false );
    min_examples = conf->gI(section, "min_examples", 50);
    save_indices = conf->gB(section, "save_indices", false);
    lambdaInit = conf->gD(section, "lambdaInit", 0.5 );
}

DTBOblique::~DTBOblique()
{

}

bool DTBOblique::entropyLeftRight (
        const FeatureValuesUnsorted & values,
        double threshold,
        double* stat_left,
        double* stat_right,
        double & entropy_left,
        double & entropy_right,
        double & count_left,
        double & count_right,
        int maxClassNo )
{
    count_left = 0;
    count_right = 0;
    for ( FeatureValuesUnsorted::const_iterator i = values.begin(); i != values.end(); i++ )
    {
        int classno = i->second;
        double value = i->first;
        if ( value < threshold ) {
            stat_left[classno] += i->fourth;
            count_left+=i->fourth;
        }
        else
        {
            stat_right[classno] += i->fourth;
            count_right+=i->fourth;
        }
    }

    if ( (count_left == 0) || (count_right == 0) )
        return false;

    entropy_left = 0.0;
    for ( int j = 0 ; j <= maxClassNo ; j++ )
        if ( stat_left[j] != 0 )
            entropy_left -= stat_left[j] * log(stat_left[j]);
    entropy_left /= count_left;
    entropy_left += log(count_left);

    entropy_right = 0.0;
    for ( int j = 0 ; j <= maxClassNo ; j++ )
        if ( stat_right[j] != 0 )
            entropy_right -= stat_right[j] * log(stat_right[j]);
    entropy_right /= count_right;
    entropy_right += log (count_right);

    return true;
}

/** refresh data matrix X and label vector y */
void DTBOblique::getDataAndLabel(
        const FeaturePool &fp,
        const Examples &examples,
        const std::vector<int> &examples_selection,
        NICE::Matrix & matX,
        NICE::Vector & vecY )
{
    ConvolutionFeature *f = (ConvolutionFeature*)fp.begin()->second;
    int amountParams = f->getParameterLength();
    int amountExamples = examples_selection.size();

    NICE::Matrix X(amountExamples, amountParams, 0.0 );
    NICE::Vector y(amountExamples, 0.0);

    int matIndex = 0;
    for ( vector<int>::const_iterator si = examples_selection.begin();
          si != examples_selection.end();
          si++ )
    {
        const pair<int, Example> & p = examples[*si];
        int classno = p.first;
        const Example & ce = p.second;

        NICE::Vector pixelRepr = f->getFeatureVector( &ce );
        pixelRepr /= pixelRepr.Max();

        // TODO for multiclass scenarios we need ONEvsALL!

        // {0,1} -> {-1,+1}
        double label = 2*classno-1;

        label *= ce.weight;
        pixelRepr *= ce.weight;

        y.set( matIndex, label );
        X.setRow(matIndex,pixelRepr);

        matIndex++;
    }

    matX = X;
    vecY = y;
}

/** recursive building method */
DecisionNode *DTBOblique::buildRecursive(
        const FeaturePool & fp,
        const Examples & examples,
        std::vector<int> & examples_selection,
        FullVector & distribution,
        double e,
        int maxClassNo,
        int depth,
        double lambdaCurrent )
{

#ifdef DEBUGTREE
    std::cerr << "Examples: " << (int)examples_selection.size()
              << " (depth " << (int)depth << ")" << std::endl;
#endif

    // initialize new node
    DecisionNode *node = new DecisionNode ();
    node->distribution = distribution;

    // stop criteria: max_depth, min_examples, min_entropy
    if ( depth > max_depth
         || (int)examples_selection.size() < min_examples
         || ( (e <= minimum_entropy) && (e != 0.0) ) )  // FIXME

    {
#ifdef DEBUGTREE
        std::cerr << "DTBOblique: Stopping criteria applied!" << std::endl;
#endif
        node->trainExamplesIndices = examples_selection;
        return node;
    }

    // refresh/set X and y
    NICE::Matrix X, G;
    NICE::Vector y, beta;
    getDataAndLabel( fp, examples, examples_selection, X, y );

    // least squares solution
    NICE::Matrix XTX = X.transpose()*X;
    XTX.addDiagonal ( NICE::Vector( XTX.rows(), lambdaCurrent) );
    choleskyDecomp(XTX, G);
    choleskyInvert(G, XTX);
    NICE::Matrix temp = XTX * X.transpose();
    beta.multiply(temp,y,false);

    // variables
    double best_threshold = 0.0;
    double best_ig = -1.0;
    FeatureValuesUnsorted values;
    double *best_distribution_left = new double [maxClassNo+1];
    double *best_distribution_right = new double [maxClassNo+1];
    double *distribution_left = new double [maxClassNo+1];
    double *distribution_right = new double [maxClassNo+1];
    double best_entropy_left = 0.0;
    double best_entropy_right = 0.0;

    // Setting Convolutional Feature
    ConvolutionFeature *f = (ConvolutionFeature*)fp.begin()->second;
    f->setParameterVector( beta );

    // Feature Values
    values.clear();
    f->calcFeatureValues( examples, examples_selection, values);

    double minValue = (min_element ( values.begin(), values.end() ))->first;
    double maxValue = (max_element ( values.begin(), values.end() ))->first;

    if ( maxValue - minValue < 1e-7 )
        std::cerr << "DTBOblique: Difference between min and max of features values to small!" << std::endl;

    // get best thresholds by complete search
    for ( int i = 0; i < random_split_tests; i++ )
    {
        double threshold = (i * (maxValue - minValue ) / (double)random_split_tests)
                            + minValue;
        // preparations
        double el, er;
        for ( int k = 0 ; k <= maxClassNo ; k++ )
        {
            distribution_left[k] = 0.0;
            distribution_right[k] = 0.0;
        }

        /** Test the current split */
        // Does another split make sense?
        double count_left;
        double count_right;
        if ( ! entropyLeftRight ( values, threshold,
                                  distribution_left, distribution_right,
                                  el, er, count_left, count_right, maxClassNo ) )
            continue;

        // information gain and entropy
        double pl = (count_left) / (count_left + count_right);
        double ig = e - pl*el - (1-pl)*er;

        if ( use_shannon_entropy )
        {
            double esplit = - ( pl*log(pl) + (1-pl)*log(1-pl) );
            ig = 2*ig / ( e + esplit );
        }

        if ( ig > best_ig )
        {
            best_ig = ig;
            best_threshold = threshold;

            for ( int k = 0 ; k <= maxClassNo ; k++ )
            {
                best_distribution_left[k] = distribution_left[k];
                best_distribution_right[k] = distribution_right[k];
            }
            best_entropy_left = el;
            best_entropy_right = er;
        }
    }

    //cleaning up
    delete [] distribution_left;
    delete [] distribution_right;

    // stop criteria: minimum information gain
    if ( best_ig < minimum_information_gain )
    {
#ifdef DEBUGTREE
        std::cerr << "DTBOblique: Minimum information gain reached!" << std::endl;
#endif
        delete [] best_distribution_left;
        delete [] best_distribution_right;
        node->trainExamplesIndices = examples_selection;
        return node;
    }

    /** Save the best split to current node */
    node->f = f->clone();
    node->threshold = best_threshold;

    /** Split examples according to split function */
    vector<int> examples_left;
    vector<int> examples_right;

    examples_left.reserve ( values.size() / 2 );
    examples_right.reserve ( values.size() / 2 );
    for ( FeatureValuesUnsorted::const_iterator i = values.begin();
          i != values.end(); i++ )
    {
        double value = i->first;
        if ( value < best_threshold )
            examples_left.push_back ( i->third );
        else
            examples_right.push_back ( i->third );
    }

#ifdef DEBUGTREE
    node->f->store( std::cerr );
    std::cerr << std::endl;
    std::cerr << "mutual information / shannon entropy " << best_ig << " entropy "
              << e << " left entropy " <<  best_entropy_left << " right entropy "
              << best_entropy_right << std::endl;
#endif

    FullVector distribution_left_sparse ( distribution.size() );
    FullVector distribution_right_sparse ( distribution.size() );
    for ( int k = 0 ; k <= maxClassNo ; k++ )
    {
        double l = best_distribution_left[k];
        double r = best_distribution_right[k];
        if ( l != 0 )
            distribution_left_sparse[k] = l;
        if ( r != 0 )
            distribution_right_sparse[k] = r;
#ifdef DEBUGTREE
        if ( (l>0)||(r>0) )
        {
            std::cerr << "DTBOblique: split of class " << k << " ("
                      << l << " <-> " << r << ") " << std::endl;
        }
#endif
    }

    delete [] best_distribution_left;
    delete [] best_distribution_right;

    // update lambda by heuristic [Laptev/Buhmann, 2014]
    double lambdaLeft = lambdaCurrent *
            pow(((double)examples_selection.size()/(double)examples_left.size()),(2./f->getParameterLength()));
    double lambdaRight = lambdaCurrent *
            pow(((double)examples_selection.size()/(double)examples_right.size()),(2./f->getParameterLength()));

#ifdef DEBUGTREE
    std::cerr << "regularization parameter lambda: left " << lambdaLeft
              << " right " << lambdaRight << std::endl;

#endif

    /** Recursion */
    // left child
    node->left  = buildRecursive ( fp, examples, examples_left,
                                   distribution_left_sparse, best_entropy_left,
                                   maxClassNo, depth+1, lambdaLeft );
    // right child
    node->right = buildRecursive ( fp, examples, examples_right,
                                   distribution_right_sparse, best_entropy_right,
                                   maxClassNo, depth+1, lambdaRight );

    return node;
}

/** initial building method */
DecisionNode *DTBOblique::build ( const FeaturePool & fp,
                                        const Examples & examples,
                                        int maxClassNo )
{
    int index = 0;

    FullVector distribution ( maxClassNo+1 );
    vector<int> all;

    all.reserve ( examples.size() );
    for ( Examples::const_iterator j = examples.begin();
          j != examples.end(); j++ )
    {
        int classno = j->first;
        distribution[classno] += j->second.weight;

        all.push_back ( index );
        index++;
    }

    double entropy = 0.0;
    double sum = 0.0;
    for ( int i = 0 ; i < distribution.size(); i++ )
    {
        double val = distribution[i];
        if ( val <= 0.0 ) continue;
        entropy -= val*log(val);
        sum += val;
    }
    entropy /= sum;
    entropy += log(sum);

    return buildRecursive ( fp, examples, all, distribution,
                            entropy, maxClassNo, 0, lambdaInit );
}