/**
 * @file DTBObliqueLS.cpp
 * @brief random oblique decision tree
 * @author Sven Sickert
 * @date 10/15/2014

*/
#include <iostream>
#include <time.h>

#include "DTBObliqueLS.h"
#include "SCInformationGain.h"
#include "SCGiniIndex.h"

#include "vislearning/features/fpfeatures/ConvolutionFeature.h"

#include "core/vector/Algorithms.h"

using namespace OBJREC;

//#define DEBUGTREE

DTBObliqueLS::DTBObliqueLS ( const NICE::Config *conf, std::string section )
{
    saveIndices = conf->gB( section, "save_indices", false);
    useDynamicRegularization = conf->gB( section, "use_dynamic_regularization", true );
    multiClassMode = conf->gB( section, "multi_class_mode", 0 );

    splitSteps = conf->gI( section, "split_steps", 20 );
    maxDepth = conf->gI( section, "max_depth", 10 );
    regularizationType = conf->gI( section, "regularization_type", 1 );

    lambdaInit = conf->gD( section, "lambda_init", 0.5 );

    std::string splitCrit = conf->gS( section, "split_criterion", "information_gain" );
    if (splitCrit == "information_gain")
        splitCriterion = new SCInformationGain( conf );
    else if (splitCrit == "gini_index")
        splitCriterion = new SCGiniIndex( conf );
    else
    {
        std::cerr << "DTBObliqueLS::DTBObliqueLS: No valid splitting criterion defined!" << std::endl;
        splitCriterion = NULL;
    }

    if ( conf->gB(section, "start_random_generator", true ) )
        srand(time(NULL));
}

DTBObliqueLS::~DTBObliqueLS()
{
    if (splitCriterion != NULL)
        delete splitCriterion;
}

bool DTBObliqueLS::adaptDataAndLabelForMultiClass (
        const int posClass,
        const int negClass,
        NICE::Matrix & X,
        NICE::Vector & y )
{
    int posCount = 0;
    int negCount = 0;

    // One-vs-one: Transforming into {-1,0,+1} problem
    if ( multiClassMode == 0 )
        for ( int i = 0; i < y.size(); i++ )
        {
            if ( y[i] == posClass )
            {
                y[i] = 1.0;
                posCount++;
            }
            else if ( y[i] == negClass )
            {
                y[i] = -1.0;
                negCount++;
            }
            else
            {
                y[i] = 0.0;
                X.setRow( i, NICE::Vector( X.cols(), 0.0 ) );
            }
        }
    // One-vs-all: Transforming into {-1,+1} problem
    else if ( multiClassMode == 1 )
        for ( int i = 0; i < y.size(); i++ )
        {
            if ( y[i] == posClass )
            {
                y[i] = 1.0;
                posCount++;
            }
            else
            {
                y[i] = -1.0;
                negCount++;
            }
        }
    // Many-vs-many: Transforming into {-1,+1}
    else
    {
        // get existing classes
        std::vector<double> unClass = y.std_vector();
        std::sort( unClass.begin(), unClass.end() );
        unClass.erase( std::unique( unClass.begin(), unClass.end() ), unClass.end() );

        // randomly split set of classes into two buckets
        std::random_shuffle ( unClass.begin(), unClass.end() );
        int firstHalf = std::ceil(unClass.size()/2.0);
        for ( int i = 0; i < y.size(); i++ )
        {
           bool wasFound = false;
           int c = 0;
           //assign new labels
           while ( (!wasFound) && (c<firstHalf) )
           {
               if ( y[i] == unClass[c] )
               {
                   wasFound = true;
               }
               c++;
           }
           if (wasFound)
           {
               y[i] = 1.0;
               posCount++;
           }
           else
           {
               y[i] = -1.0;
               negCount++;
           }
        }
    }

    return ( (posCount>0) && (negCount>0));
}

/** refresh data matrix X and label vector y */
void DTBObliqueLS::getDataAndLabel(
        const FeaturePool &fp,
        const Examples &examples,
        const std::vector<int> &examples_selection,
        NICE::Matrix & X,
        NICE::Vector & y,
        NICE::Vector & w )
{
    ConvolutionFeature *f = (ConvolutionFeature*)fp.begin()->second;
    int amountParams = f->getParameterLength();
    int amountExamples = examples_selection.size();

    X = NICE::Matrix(amountExamples, amountParams, 0.0 );
    y = NICE::Vector(amountExamples, 0.0);
    w = NICE::Vector(amountExamples, 1.0);

    int matIndex = 0;
    for ( std::vector<int>::const_iterator si = examples_selection.begin();
          si != examples_selection.end();
          si++ )
    {
        const std::pair<int, Example> & p = examples[*si];
        const Example & ex = p.second;

        NICE::Vector pixelRepr (amountParams, 1.0);
        f->getFeatureVector( &ex, pixelRepr );

        double label = p.first;
        pixelRepr *= ex.weight;

        w.set    ( matIndex, ex.weight );
        y.set    ( matIndex, label );
        X.setRow ( matIndex, pixelRepr );

        matIndex++;
    }

}

void DTBObliqueLS::regularizeDataMatrix(
        const NICE::Matrix &X,
        NICE::Matrix &XTXreg,
        const int regOption,
        const double lambda )
{
    XTXreg = X.transpose()*X;
    NICE::Matrix R;
    const int dim = X.cols();

    switch (regOption)
    {
        // identity matrix
        case 0:
            R.resize(dim,dim);
            R.setIdentity();
            R *= lambda;
            XTXreg += R;
            break;

        // differences operator, k=1
        case 1:
            R.resize(dim-1,dim);
            R.set( 0.0 );
            for ( int r = 0; r < dim-1; r++ )
            {
                R(r,r)   =  1.0;
                R(r,r+1) = -1.0;
            }
            R = R.transpose()*R;
            R *= lambda;
            XTXreg += R;
            break;

        // difference operator, k=2
        case 2:
            R.resize(dim-2,dim);
            R.set( 0.0 );
            for ( int r = 0; r < dim-2; r++ )
            {
                R(r,r)   =  1.0;
                R(r,r+1) = -2.0;
                R(r,r+2) =  1.0;
            }
            R = R.transpose()*R;
            R *= lambda;
            XTXreg += R;
            break;

        // as in [Chen et al., 2012]
        case 3:
        {
            NICE::Vector q ( dim, (1.0-lambda) );
            q[0] = 1.0;
            NICE::Matrix Q;
            Q.tensorProduct(q,q);
            R.resize(dim,dim);
            for ( int r = 0; r < dim; r++ )
            {
                for ( int c = 0; c < dim; c++ )
                    R(r,c) = XTXreg(r,c) * Q(r,c);

                R(r,r) = q[r] * XTXreg(r,r);
            }

            XTXreg = R;
            break;
        }

        // no regularization
        default:
            std::cerr << "DTBObliqueLS::regularizeDataMatrix: No regularization applied!"
                      << std::endl;
            break;
    }
}

void DTBObliqueLS::findBestSplitThreshold (
        FeatureValuesUnsorted &values,
        SplitInfo &bestSplitInfo,
        const NICE::Vector &params,
        const int &maxClassNo )
{
    double *distribution_left = new double [maxClassNo+1];
    double *distribution_right = new double [maxClassNo+1];
    double minValue = (min_element ( values.begin(), values.end() ))->first;
    double maxValue = (max_element ( values.begin(), values.end() ))->first;

    if ( maxValue - minValue < 1e-7 )
        std::cerr << "DTBObliqueLS: Difference between min and max of features values to small!"
                  << " [" << minValue << "," << maxValue << "]" << std::endl;

    // get best thresholds using complete search
    for ( int i = 0; i < splitSteps; i++ )
    {
        double threshold = (i * (maxValue - minValue ) / (double)splitSteps)
                            + minValue;
        // preparations
        for ( int k = 0 ; k <= maxClassNo ; k++ )
        {
            distribution_left[k] = 0.0;
            distribution_right[k] = 0.0;
        }

        /** Test the current split */
        SplittingCriterion *curSplit = splitCriterion->clone();

        if ( ! curSplit->evaluateSplit ( values, threshold,
                 distribution_left, distribution_right, maxClassNo ) )
            continue;

        // get value for impurity
        double purity = curSplit->computePurity();
        double entropy = curSplit->getEntropy();

        if ( purity > bestSplitInfo.purity )
        {
            bestSplitInfo.purity = purity;
            bestSplitInfo.entropy = entropy;
            bestSplitInfo.threshold = threshold;
            bestSplitInfo.params = params;

            for ( int k = 0 ; k <= maxClassNo ; k++ )
            {
                bestSplitInfo.distLeft[k] = distribution_left[k];
                bestSplitInfo.distRight[k] = distribution_right[k];
            }
        }

        delete curSplit;
    }

    //cleaning up
    delete [] distribution_left;
    delete [] distribution_right;
}

/** recursive building method */
DecisionNode *DTBObliqueLS::buildRecursive(
        const FeaturePool & fp,
        const Examples & examples,
        std::vector<int> & examples_selection,
        FullVector & distribution,
        double entropy,
        int maxClassNo,
        int depth,
        double lambdaCurrent )
{

    std::cerr << "DTBObliqueLS: Examples: " << (int)examples_selection.size()
              << ", Depth: " << (int)depth << ", Entropy: " << entropy << std::endl;

    // initialize new node
    DecisionNode *node = new DecisionNode ();
    node->distribution = distribution;

    // stopping criteria
    if (    ( entropy <= splitCriterion->getMinimumEntropy() )
         || ( (int)examples_selection.size() < splitCriterion->getMinimumExamples() )
         || ( depth > maxDepth ) )

    {
#ifdef DEBUGTREE
        std::cerr << "DTBObliqueLS: Stopping criteria applied!" << std::endl;
#endif
        node->trainExamplesIndices = examples_selection;
        return node;
    }

    // variables
    FeatureValuesUnsorted values;
    SplitInfo bestSplitInfo;
    bestSplitInfo.threshold = 0.0;
    bestSplitInfo.purity = -1.0;
    bestSplitInfo.entropy = 0.0;
    bestSplitInfo.distLeft = new double [maxClassNo+1];
    bestSplitInfo.distRight = new double [maxClassNo+1];

    ConvolutionFeature *f = (ConvolutionFeature*)fp.begin()->second;
    bestSplitInfo.params = f->getParameterVector();

    // Creating data matrix X and label vector y
    NICE::Matrix X;
    NICE::Vector y, params, weights;
    getDataAndLabel( fp, examples, examples_selection, X, y, weights );

    // Transforming into multi-class problem
    bool hasExamples = false;
    NICE::Vector yCur;
    NICE::Matrix XCur;

    while ( !hasExamples )
    {
        int posClass, negClass;

        posClass = rand() % (maxClassNo+1);
        negClass = (posClass + (rand() % maxClassNo)) % (maxClassNo+1);

        yCur = y;
        XCur = X;

        hasExamples = adaptDataAndLabelForMultiClass(
            posClass, negClass, XCur, yCur );
    }

    yCur *= weights;

    // Preparing system of linear equations
    NICE::Matrix XTXr, G, temp;
    regularizeDataMatrix( XCur, XTXr, regularizationType, lambdaCurrent );
    choleskyDecomp(XTXr, G);
    choleskyInvert(G, XTXr);
    temp = XTXr * XCur.transpose();

    // Solve system of linear equations in a least squares manner
    params.multiply(temp,yCur,false);

    // Updating parameter vector in convolutional feature
    f->setParameterVector( params );

    // Feature Values
    values.clear();
    f->calcFeatureValues( examples, examples_selection, values);

    // complete search for threshold
    findBestSplitThreshold ( values, bestSplitInfo, params, maxClassNo );

    // stop criteria: minimum purity reached?
    if ( bestSplitInfo.purity < splitCriterion->getMinimumPurity() )
    {
#ifdef DEBUGTREE
        std::cerr << "DTBObliqueLS: Minimum purity reached!" << std::endl;
#endif
        delete [] bestSplitInfo.distLeft;
        delete [] bestSplitInfo.distRight;
        node->trainExamplesIndices = examples_selection;
        return node;
    }

    /** Save the best split to current node */
    f->setParameterVector( bestSplitInfo.params );
    values.clear();
    f->calcFeatureValues( examples, examples_selection, values);
    node->f = f->clone();
    node->threshold = bestSplitInfo.threshold;

    /** Split examples according to best split function */
    std::vector<int> examples_left;
    std::vector<int> examples_right;

    examples_left.reserve ( values.size() / 2 );
    examples_right.reserve ( values.size() / 2 );
    for ( FeatureValuesUnsorted::const_iterator i = values.begin();
          i != values.end(); i++ )
    {
        if ( i->first < bestSplitInfo.threshold )
            examples_left.push_back ( i->third );
        else
            examples_right.push_back ( i->third );
    }

#ifdef DEBUGTREE
//    node->f->store( std::cerr );
//    std::cerr << std::endl;
#endif

    FullVector distribution_left_sparse ( distribution.size() );
    FullVector distribution_right_sparse ( distribution.size() );
    for ( int k = 0 ; k <= maxClassNo ; k++ )
    {
        double l = bestSplitInfo.distLeft[k];
        double r = bestSplitInfo.distRight[k];
        if ( l != 0 )
            distribution_left_sparse[k] = l;
        if ( r != 0 )
            distribution_right_sparse[k] = r;
#ifdef DEBUGTREE
        std::cerr << "DTBObliqueLS: Split of Class " << k << " ("
                  << l << " <-> " << r << ") " << std::endl;
#endif
    }

    delete [] bestSplitInfo.distLeft;
    delete [] bestSplitInfo.distRight;

    // update lambda by heuristic [Laptev/Buhmann, 2014]
    double lambdaLeft, lambdaRight;

    if (useDynamicRegularization)
    {
        lambdaLeft = lambdaCurrent *
            pow(((double)examples_selection.size()/(double)examples_left.size()),(2./f->getParameterLength()));
        lambdaRight = lambdaCurrent *
            pow(((double)examples_selection.size()/(double)examples_right.size()),(2./f->getParameterLength()));
    }
    else
    {
        lambdaLeft = lambdaCurrent;
        lambdaRight = lambdaCurrent;
    }


    /** Recursion */
    // left child
    node->left  = buildRecursive ( fp, examples, examples_left,
                                   distribution_left_sparse, bestSplitInfo.entropy,
                                   maxClassNo, depth+1, lambdaLeft );
    // right child
    node->right = buildRecursive ( fp, examples, examples_right,
                                   distribution_right_sparse, bestSplitInfo.entropy,
                                   maxClassNo, depth+1, lambdaRight );

    return node;
}

/** initial building method */
DecisionNode *DTBObliqueLS::build ( const FeaturePool & fp,
                                        const Examples & examples,
                                        int maxClassNo )
{
    int index = 0;

    FullVector distribution ( maxClassNo+1 );
    std::vector<int> all;

    all.reserve ( examples.size() );
    for ( Examples::const_iterator j = examples.begin();
          j != examples.end(); j++ )
    {
        int classno = j->first;
        distribution[classno] += j->second.weight;

        all.push_back ( index );
        index++;
    }

    double entropy = 0.0;
    double sum = 0.0;
    for ( int i = 0 ; i < distribution.size(); i++ )
    {
        double val = distribution[i];
        if ( val <= 0.0 ) continue;
        entropy -= val*log(val);
        sum += val;
    }
    entropy /= sum;
    entropy += log(sum);

    return buildRecursive ( fp, examples, all, distribution,
                            entropy, maxClassNo, 0, lambdaInit );
}