/**
 * @file DTBOblique.cpp
 * @brief random oblique decision tree
 * @author Sven Sickert
 * @date 10/15/2014

*/
#include <iostream>
#include <time.h>

#include "DTBOblique.h"
#include "vislearning/features/fpfeatures/ConvolutionFeature.h"

#include "core/vector/Algorithms.h"

using namespace OBJREC;

#define DEBUGTREE


using namespace std;
using namespace NICE;

DTBOblique::DTBOblique ( const Config *conf, string section )
{
    saveIndices = conf->gB( section, "save_indices", false);
    useShannonEntropy = conf->gB( section, "use_shannon_entropy", false );
    useOneVsOne = conf->gB( section, "use_one_vs_one", false );

    splitSteps = conf->gI( section, "split_steps", 20 );
    maxDepth = conf->gI( section, "max_depth", 10 );
    minExamples = conf->gI( section, "min_examples", 50);
    regularizationType = conf->gI( section, "regularization_type", 1 );

    minimumEntropy = conf->gD( section, "minimum_entropy", 10e-5 );
    minimumInformationGain = conf->gD( section, "minimum_information_gain", 10e-7 );
    lambdaInit = conf->gD( section, "lambda_init", 0.5 );

}

DTBOblique::~DTBOblique()
{

}

bool DTBOblique::entropyLeftRight (
        const FeatureValuesUnsorted & values,
        double threshold,
        double* stat_left,
        double* stat_right,
        double & entropy_left,
        double & entropy_right,
        double & count_left,
        double & count_right,
        int maxClassNo )
{
    count_left = 0;
    count_right = 0;
    for ( FeatureValuesUnsorted::const_iterator i = values.begin();
          i != values.end();
          i++ )
    {
        int classno = i->second;
        double value = i->first;
        if ( value < threshold ) {
            stat_left[classno] += i->fourth;
            count_left+=i->fourth;
        }
        else
        {
            stat_right[classno] += i->fourth;
            count_right+=i->fourth;
        }
    }

    if ( (count_left == 0) || (count_right == 0) )
        return false;

    entropy_left = 0.0;
    for ( int j = 0 ; j <= maxClassNo ; j++ )
        if ( stat_left[j] != 0 )
            entropy_left -= stat_left[j] * log(stat_left[j]);
    entropy_left /= count_left;
    entropy_left += log(count_left);

    entropy_right = 0.0;
    for ( int j = 0 ; j <= maxClassNo ; j++ )
        if ( stat_right[j] != 0 )
            entropy_right -= stat_right[j] * log(stat_right[j]);
    entropy_right /= count_right;
    entropy_right += log (count_right);

    return true;
}

bool DTBOblique::adaptDataAndLabelForMultiClass (
        const int posClass,
        const int negClass,
        NICE::Matrix & X,
        NICE::Vector & y )
{
    bool posHasExamples = false;
    bool negHasExamples = false;

    // One-vs-one: Transforming into {-1,0,+1} problem
    if ( useOneVsOne )
        for ( int i = 0; i < y.size(); i++ )
        {
            if ( y[i] == posClass )
            {
                y[i] = 1.0;
                posHasExamples = true;
            }
            else if ( y[i] == negClass )
            {
                y[i] = -1.0;
                negHasExamples = true;
            }
            else
            {
                y[i] = 0.0;
                X.setRow( i, NICE::Vector( X.cols(), 0.0 ) );
            }
        }
    // One-vs-all: Transforming into {-1,+1} problem
    else
        for ( int i = 0; i < y.size(); i++ )
        {
            if ( y[i] == posClass )
            {
                y[i] = 1.0;
                posHasExamples = true;
            }
            else
            {
                y[i] = -1.0;
                negHasExamples = true;
            }
        }

    if ( posHasExamples && negHasExamples )
        return true;
    else
        return false;
}

/** refresh data matrix X and label vector y */
void DTBOblique::getDataAndLabel(
        const FeaturePool &fp,
        const Examples &examples,
        const std::vector<int> &examples_selection,
        NICE::Matrix & X,
        NICE::Vector & y,
        NICE::Vector & w )
{
    ConvolutionFeature *f = (ConvolutionFeature*)fp.begin()->second;
    int amountParams = f->getParameterLength();
    int amountExamples = examples_selection.size();

    X = NICE::Matrix(amountExamples, amountParams, 0.0 );
    y = NICE::Vector(amountExamples, 0.0);
    w = NICE::Vector(amountExamples, 1.0);

    int matIndex = 0;
    for ( vector<int>::const_iterator si = examples_selection.begin();
          si != examples_selection.end();
          si++ )
    {
        const pair<int, Example> & p = examples[*si];
        const Example & ex = p.second;

        NICE::Vector pixelRepr = f->getFeatureVector( &ex );

        double label = p.first;
        pixelRepr *= ex.weight;

        w.set    ( matIndex, ex.weight );
        y.set    ( matIndex, label );
        X.setRow ( matIndex, pixelRepr );

        matIndex++;
    }

}

void DTBOblique::regularizeDataMatrix(
        const NICE::Matrix &X,
        NICE::Matrix &XTXreg,
        const int regOption,
        const double lambda )
{
    XTXreg = X.transpose()*X;
    NICE::Matrix R;
    const int dim = X.cols();

    switch (regOption)
    {
        // identity matrix
        case 0:
            R.resize(dim,dim);
            R.setIdentity();
            R *= lambda;
            XTXreg += R;
            break;

        // differences operator, k=1
        case 1:
            R.resize(dim-1,dim);
            R.set( 0.0 );
            for ( int r = 0; r < dim-1; r++ )
            {
                R(r,r)   =  1.0;
                R(r,r+1) = -1.0;
            }
            R = R.transpose()*R;
            R *= lambda;
            XTXreg += R;
            break;

        // difference operator, k=2
        case 2:
            R.resize(dim-2,dim);
            R.set( 0.0 );
            for ( int r = 0; r < dim-2; r++ )
            {
                R(r,r)   =  1.0;
                R(r,r+1) = -2.0;
                R(r,r+2) =  1.0;
            }
            R = R.transpose()*R;
            R *= lambda;
            XTXreg += R;
            break;

        // as in [Chen et al., 2012]
        case 3:
        {
            NICE::Vector q ( dim, (1.0-lambda) );
            q[0] = 1.0;
            NICE::Matrix Q;
            Q.tensorProduct(q,q);
            R.multiply(XTXreg,Q);
            for ( int r = 0; r < dim; r++ )
                R(r,r) = q[r] * XTXreg(r,r);
            XTXreg = R;
            break;
        }

        // no regularization
        default:
            std::cerr << "DTBOblique::regularizeDataMatrix: No regularization applied!"
                      << std::endl;
            break;
    }
}

void DTBOblique::findBestSplitThreshold (
        FeatureValuesUnsorted &values,
        SplitInfo &bestSplitInfo,
        const NICE::Vector &beta,
        const double &e,
        const int &maxClassNo )
{
    double *distribution_left = new double [maxClassNo+1];
    double *distribution_right = new double [maxClassNo+1];
    double minValue = (min_element ( values.begin(), values.end() ))->first;
    double maxValue = (max_element ( values.begin(), values.end() ))->first;

    if ( maxValue - minValue < 1e-7 )
        std::cerr << "DTBOblique: Difference between min and max of features values to small!"
                  << " [" << minValue << "," << maxValue << "]" << std::endl;

    // get best thresholds using complete search
    for ( int i = 0; i < splitSteps; i++ )
    {
        double threshold = (i * (maxValue - minValue ) / (double)splitSteps)
                            + minValue;
        // preparations
        double el, er;
        for ( int k = 0 ; k <= maxClassNo ; k++ )
        {
            distribution_left[k] = 0.0;
            distribution_right[k] = 0.0;
        }

        /** Test the current split */
        // Does another split make sense?
        double count_left;
        double count_right;
        if ( ! entropyLeftRight ( values, threshold,
                                  distribution_left, distribution_right,
                                  el, er, count_left, count_right, maxClassNo ) )
            continue;

        // information gain and entropy
        double pl = (count_left) / (count_left + count_right);
        double ig = e - pl*el - (1-pl)*er;

        if ( useShannonEntropy )
        {
            double esplit = - ( pl*log(pl) + (1-pl)*log(1-pl) );
            ig = 2*ig / ( e + esplit );
        }

        if ( ig > bestSplitInfo.informationGain )
        {
            bestSplitInfo.informationGain = ig;
            bestSplitInfo.threshold = threshold;
            bestSplitInfo.params = beta;

            for ( int k = 0 ; k <= maxClassNo ; k++ )
            {
                bestSplitInfo.distLeft[k] = distribution_left[k];
                bestSplitInfo.distRight[k] = distribution_right[k];
            }
            bestSplitInfo.entropyLeft = el;
            bestSplitInfo.entropyRight = er;
        }
    }

    //cleaning up
    delete [] distribution_left;
    delete [] distribution_right;
}

/** recursive building method */
DecisionNode *DTBOblique::buildRecursive(
        const FeaturePool & fp,
        const Examples & examples,
        std::vector<int> & examples_selection,
        FullVector & distribution,
        double e,
        int maxClassNo,
        int depth,
        double lambdaCurrent )
{

#ifdef DEBUGTREE
    std::cerr << "DTBOblique: Examples: " << (int)examples_selection.size()
              << ", Depth: " << (int)depth << ", Entropy: " << e << std::endl;
#endif

    // initialize new node
    DecisionNode *node = new DecisionNode ();
    node->distribution = distribution;

    // stop criteria: maxDepth, minExamples, min_entropy
    if (    ( e <= minimumEntropy )
         || ( (int)examples_selection.size() < minExamples )
         || ( depth > maxDepth ) )

    {
#ifdef DEBUGTREE
        std::cerr << "DTBOblique: Stopping criteria applied!" << std::endl;
#endif
        node->trainExamplesIndices = examples_selection;
        return node;
    }

    // variables
    FeatureValuesUnsorted values;
    SplitInfo bestSplitInfo;
    bestSplitInfo.threshold = 0.0;
    bestSplitInfo.informationGain = -1.0;
    bestSplitInfo.distLeft = new double [maxClassNo+1];
    bestSplitInfo.distRight = new double [maxClassNo+1];
    bestSplitInfo.entropyLeft = 0.0;
    bestSplitInfo.entropyRight = 0.0;

    ConvolutionFeature *f = (ConvolutionFeature*)fp.begin()->second;
    bestSplitInfo.params = f->getParameterVector();

    // Creating data matrix X and label vector y
    NICE::Matrix X;
    NICE::Vector y, beta, weights;
    getDataAndLabel( fp, examples, examples_selection, X, y, weights );

    // Transforming into multi-class problem
    for ( int posClass = 0; posClass <= maxClassNo; posClass++ )
    {
        bool gotInnerIteration = false;
        for ( int negClass = 0; negClass <= maxClassNo; negClass++ )
        {
            if ( posClass == negClass ) continue;

            NICE::Vector yCur = y;
            NICE::Matrix XCur = X;

            bool hasExamples = adaptDataAndLabelForMultiClass(
                posClass, negClass, XCur, yCur );

            yCur *= weights;

            // are there examples for positive and negative class?
            if ( !hasExamples ) continue;

            // one-vs-all setting: only one iteration for inner loop
            if ( !useOneVsOne && gotInnerIteration ) continue;

            // Preparing system of linear equations
            NICE::Matrix XTXr, G, temp;
            regularizeDataMatrix( XCur, XTXr, regularizationType, lambdaCurrent );
            choleskyDecomp(XTXr, G);
            choleskyInvert(G, XTXr);
            temp = XTXr * XCur.transpose();

            // Solve system of linear equations in a least squares manner
            beta.multiply(temp,yCur,false);

            // Updating parameter vector in convolutional feature
            f->setParameterVector( beta );

            // Feature Values
            values.clear();
            f->calcFeatureValues( examples, examples_selection, values);

            // complete search for threshold
            findBestSplitThreshold ( values, bestSplitInfo, beta, e, maxClassNo );

            gotInnerIteration = true;
        }
    }

    // supress strange behaviour for values near zero (8.88178e-16)
    if (bestSplitInfo.entropyLeft < 1.0e-10 ) bestSplitInfo.entropyLeft = 0.0;
    if (bestSplitInfo.entropyRight < 1.0e-10 ) bestSplitInfo.entropyRight = 0.0;

    // stop criteria: minimum information gain
    if ( bestSplitInfo.informationGain < minimumInformationGain )
    {
#ifdef DEBUGTREE
        std::cerr << "DTBOblique: Minimum information gain reached!" << std::endl;
#endif
        delete [] bestSplitInfo.distLeft;
        delete [] bestSplitInfo.distRight;
        node->trainExamplesIndices = examples_selection;
        return node;
    }

    /** Save the best split to current node */
    f->setParameterVector( bestSplitInfo.params );
    values.clear();
    f->calcFeatureValues( examples, examples_selection, values);
    node->f = f->clone();
    node->threshold = bestSplitInfo.threshold;

    /** Split examples according to best split function */
    vector<int> examples_left;
    vector<int> examples_right;

    examples_left.reserve ( values.size() / 2 );
    examples_right.reserve ( values.size() / 2 );
    for ( FeatureValuesUnsorted::const_iterator i = values.begin();
          i != values.end(); i++ )
    {
        if ( i->first < bestSplitInfo.threshold )
            examples_left.push_back ( i->third );
        else
            examples_right.push_back ( i->third );
    }

#ifdef DEBUGTREE
//    node->f->store( std::cerr );
//    std::cerr << std::endl;
    std::cerr << "DTBOblique: Information Gain: " << bestSplitInfo.informationGain
              << ", Left Entropy: " <<  bestSplitInfo.entropyLeft << ", Right Entropy: "
              << bestSplitInfo.entropyRight << std::endl;
#endif

    FullVector distribution_left_sparse ( distribution.size() );
    FullVector distribution_right_sparse ( distribution.size() );
    for ( int k = 0 ; k <= maxClassNo ; k++ )
    {
        double l = bestSplitInfo.distLeft[k];
        double r = bestSplitInfo.distRight[k];
        if ( l != 0 )
            distribution_left_sparse[k] = l;
        if ( r != 0 )
            distribution_right_sparse[k] = r;
//#ifdef DEBUGTREE
//        std::cerr << "DTBOblique: Split of Class " << k << " ("
//                  << l << " <-> " << r << ") " << std::endl;
//#endif
    }

    delete [] bestSplitInfo.distLeft;
    delete [] bestSplitInfo.distRight;

    // update lambda by heuristic [Laptev/Buhmann, 2014]
    double lambdaLeft = lambdaCurrent *
            pow(((double)examples_selection.size()/(double)examples_left.size()),(2./f->getParameterLength()));
    double lambdaRight = lambdaCurrent *
            pow(((double)examples_selection.size()/(double)examples_right.size()),(2./f->getParameterLength()));

//#ifdef DEBUGTREE
//    std::cerr << "regularization parameter lambda left " << lambdaLeft
//              << " right " << lambdaRight << std::endl;

//#endif

    /** Recursion */
    // left child
    node->left  = buildRecursive ( fp, examples, examples_left,
                                   distribution_left_sparse, bestSplitInfo.entropyLeft,
                                   maxClassNo, depth+1, lambdaLeft );
    // right child
    node->right = buildRecursive ( fp, examples, examples_right,
                                   distribution_right_sparse, bestSplitInfo.entropyRight,
                                   maxClassNo, depth+1, lambdaRight );

    return node;
}

/** initial building method */
DecisionNode *DTBOblique::build ( const FeaturePool & fp,
                                        const Examples & examples,
                                        int maxClassNo )
{
    int index = 0;

    FullVector distribution ( maxClassNo+1 );
    vector<int> all;

    all.reserve ( examples.size() );
    for ( Examples::const_iterator j = examples.begin();
          j != examples.end(); j++ )
    {
        int classno = j->first;
        distribution[classno] += j->second.weight;

        all.push_back ( index );
        index++;
    }

    double entropy = 0.0;
    double sum = 0.0;
    for ( int i = 0 ; i < distribution.size(); i++ )
    {
        double val = distribution[i];
        if ( val <= 0.0 ) continue;
        entropy -= val*log(val);
        sum += val;
    }
    entropy /= sum;
    entropy += log(sum);

    return buildRecursive ( fp, examples, all, distribution,
                            entropy, maxClassNo, 0, lambdaInit );
}