/** 
* @file FPCFullSearch.cpp
* @brief optimal feature search like performed by boosting
* @author Erik Rodner
* @date 04/21/2008

*/
#include <iostream>

#include "FPCFullSearch.h"
#include "vislearning/features/fpfeatures/createFeatures.h"

using namespace OBJREC;

using namespace std;
// refactor-nice.pl: check this substitution
// old: using namespace ice;
using namespace NICE;

#undef DEBUG_BOOSTING


FPCFullSearch::FPCFullSearch( const Config *_conf ) : conf(_conf)
{
    f = NULL;
    srand(time(NULL));
    use_regression = _conf->gB("FPCFullSearch", "use_regression", false);
}

FPCFullSearch::~FPCFullSearch()
{
    if ( f != NULL )
	delete f;
}

ClassificationResult FPCFullSearch::classify ( Example & pe )
{
    assert ( f != NULL );
    double value = f->val ( &pe );
    int indicator = 0;
    if ( parity*value > parity*threshold )
	indicator =  1;
    else
	indicator = -1;
	
    FullVector scores(2);
    scores[0] = 0.0;

    if ( use_regression )
	scores[1] = alpha*indicator + beta;
    else
	scores[1] = indicator;
	
    int classno = (scores[1] >= 0);
    return ClassificationResult ( classno, scores );
}

void FPCFullSearch::train ( FeaturePool & fp, Examples & examples )
{
    f = NULL;    
    fprintf (stderr, "FPCFullSearch: Feature optimization ...\n");

    // ------------ search for a optimal hypothesis 
    // hypothesis defined by feature index and threshold (i,t)
    // [v_i > t] 
    const Feature *optimal_feature = NULL;
    double optimal_alpha = 1.0;
    double optimal_beta  = 0.0;
    double minimum_error = numeric_limits<double>::max();
    double optimal_threshold = 0.0;
    int optimal_parity = 1;


    for ( FeaturePool::const_iterator feature_i = fp.begin();
	    feature_i != fp.end() ; feature_i++ )
    {
	const Feature *fcurrent = feature_i->second;
#ifdef DEBUG_BOOSTING
	fcurrent->store(cerr); cerr << endl; // DEBUG
#endif
	FeatureValues values;

	int example_index = 0;
	double overallWeight1 = 0.0;
	double overallWeight0 = 0.0;

	for ( vector< pair<int, Example> >::const_iterator example_i = examples.begin();
		    example_i != examples.end() ; example_i++, example_index++ )
	{
	    int classno = example_i->first;
	    assert ( (classno == 0) || (classno == 1) );
	    const Example & ce = example_i->second;
	    double value;

	    value = fcurrent->val ( &ce );
    
	    values.insert ( 
		quadruplet<double,int,int,double> ( 
		    value, classno, example_index, ce.weight 
		) 
	    );

	    overallWeight0 += ( 1 - classno ) * ce.weight;
	    overallWeight1 += classno * ce.weight;
	}

	assert ( overallWeight0 + overallWeight1 > 10e-8 );

    	int previous_class = 0;
	double previous_value = 0.0;
	//unused:
    //double first_value = values.begin()->first;
	FeatureValues::const_iterator lastElement = values.end();
	lastElement--;
	
	//unused:
	//double last_value  = lastElement->first;

	double minimum_feature_error = minimum_error;
	double w0 = 0.0;
	double w1 = 0.0;
	for ( FeatureValues::const_iterator i = values.begin();
			i != values.end(); i++ )
	{
	    double w = i->fourth;
	    int current_class = i->second;
	    double current_value = i->first;

	    #ifdef DEBUG_BOOSTING
	    fprintf (stderr, "w %f w0 %f w1 %f cc %f\n", w, w0, w1, current_value );
	    #endif

	    assert (! isnan(current_value));

	    if ( (i != values.begin()) // do not check at the begin
		 #ifdef CHECK_AT_CLASS_BOUNDARIES_ONLY
		 && (current_class != previous_class)   // check only at class splits
		 #endif
		 && ( (current_value != previous_value) // check only at possible value splits
		 #ifdef CHECK_AT_EQUAL_SPLIT
		      || (  (current_value != first_value) 
		         && (current_value != last_value) // or if it is an equal split
							  // but can split the values somehow
			 ) 
		 #endif
		    ) 
	       )
	    {
		// possible split found -----

		#ifdef DEBUG_BOOSTING
		fprintf (stderr, "Current Best Setting: minimum_feature_error %f; optimal_threshold %f\n", minimum_feature_error, optimal_threshold );
		fprintf (stderr, "Current Split: 0(%f | %f) 1(%f | %f) w %f %d\n",
			w0, overallWeight0 - w0, w1, overallWeight1 - w1, w, current_class );

		#endif
		double current_threshold = (current_value + previous_value) / 2.0;
		
		// case: parity 1     -> all above threshold belongs to class 1
		// -> error: w1 + overallWeight0 - w0
		double error1 = 1.0;

		// some variables which are important for regression
		double a = 0.0;
		double c = 0.0;
		double d = 0.0;
		double W = 0.0;
		double regerror = 0.0;
		double det = 0.0;
		if ( use_regression )
		{
		    /* We solve the following linear equation system:
		       (MATLAB syntax)
		       
		       [ [ w, a ]; [ a, w ] ] * [ alpha; beta ] = [ c ; d ]

		       with the following definitions (x_i is the response, indicator)
		    */

		    // W = \sum_i w_i = \sum_i w_i x_i^2
		    W = overallWeight0 + overallWeight1;
		    // a = \sum_i w_i x_i
		    a = (overallWeight0 - w0) + (overallWeight1 - w1) - (w0 + w1);
		    // \sum_i x_i y_i w_i
		    c = w0 - w1 + (overallWeight1 - w1) - (overallWeight0 - w0);
		    // \sum_i y_i
		    d = overallWeight1 - overallWeight0;
		    // the determinant of the coefficient matrix
		    det = W*W - a*a;

		    /* The following is somewhat tricky.
		       We do not want to recalculate the regression
		       error by looping over all examples. Therefore
		       one can derive the following formula for the regression
		       error. 
		       To derive this formula, one has to 
		       solve the linear equations above and substitute
		       the optimal alpha and beta in the formulation of
		       the error: \sum_i ( \alpha x_i + \beta - y_i )^2.
		       
		       FIXME: write a tex-note about this stuff.
		    */
		    regerror = - (W*c*c - 2*a*c*d + W*d*d) / det + W;

		    error1 = regerror;
		} else {
		    error1 = w1 + overallWeight0 - w0;
		}

		// prefer splits which really seperate classes !
		if ( current_class == previous_class ) error1 += 1e-12;

		#ifdef DEBUG_BOOSTING
		fprintf (stderr, "a=%f; c=%f, d=%f, W=%f, regerror=%f, alpha=%f, beta=%f\n", a, c, d, W, regerror,
		    c/a - (d-c)/(W-a), (d-c) / (W-a));
		fprintf (stderr, "> %f (>=%f) belongs to class 1: err=%f\n", current_threshold, current_value, error1);
		#endif
		if ( error1 < minimum_feature_error ) 
		{
		    optimal_threshold = current_threshold;
		    optimal_parity = 1;  
		    
		    if ( use_regression )
		    {
			// solution of the linear equation system
			optimal_beta = (W*d - a*c)/det;
			optimal_alpha = (W*c - a*d)/det;
		    } else {
			optimal_beta = 0.0;
			optimal_alpha = 1.0;
		    }
		    minimum_feature_error = error1;
		    #ifdef DEBUG_BOOSTING
		    fprintf (stderr, "optimal feature: 0(%f | %f) 1(%f | %f) %f %d\n",
			w0, overallWeight0 - w0, w1, overallWeight1 - w1, w, current_class );
		    #endif
		}
	
		if ( ! use_regression )
		{
		    // case: parity -1     -> all below threshold belongs to class 1
		    // -> error: w0 + overallWeight1 - w1
		    double error0 = w0 + overallWeight1 - w1;
	    
		    if ( current_class == previous_class ) error0 += 1e-12;
		    #ifdef DEBUG_BOOSTING
		    fprintf (stderr, "< %f (<=%f) belongs to class 1: err=%f\n", current_threshold, previous_value, error0);
		    #endif
		    if ( error0 < minimum_feature_error ) 
		    {
			optimal_threshold = current_threshold;
			optimal_parity = -1;   
			
			optimal_beta = 0.0;
			optimal_alpha = 1.0;

			minimum_feature_error = error0;
			#ifdef DEBUG_BOOSTING
			fprintf (stderr, "optimal feature: 0(%f | %f) 1(%f | %f) w %f %d\n",
			    w0, overallWeight0 - w0, w1, overallWeight1 - w1, w, current_class );
			#endif
		    }
		}
	    }	
	    w1 += current_class * w;
	    w0 += (1-current_class) * w;

	    previous_class = current_class;
	    previous_value = current_value;
	}
	
	// update optimal feature
	if ( minimum_feature_error < minimum_error )
	{
	    optimal_feature = fcurrent;
	    minimum_error = minimum_feature_error;
	}
    }
    assert ( optimal_feature != NULL );

    fprintf (stderr, "FPCFullSearch: Feature optimization ...done\n");
    optimal_feature->store(cerr);
    cerr << endl;

    f = optimal_feature->clone();
    threshold = optimal_threshold;
    parity    = optimal_parity;
    alpha     = optimal_alpha;
    beta      = optimal_beta;

    last_error = minimum_error;
}

FPCFullSearch *FPCFullSearch::clone () const
{
    FPCFullSearch *myclone = new FPCFullSearch (conf);
    if ( f == NULL )
		myclone->f = NULL;
    else
		myclone->f = f->clone();

    myclone->threshold = threshold;
    myclone->parity = parity;
    myclone->beta = beta;
    myclone->alpha = alpha;
    return myclone;
}

void FPCFullSearch::restore (istream & is, int format)
{
    // refactor-nice.pl: check this substitution
    // old: string feature_tag;
    std::string feature_tag;
    is >> feature_tag;
    fprintf (stderr, "feature tag: %s\n", feature_tag.c_str() );
    f = createFeatureFromTag ( conf, feature_tag );
    if ( f == NULL ) {
	fprintf (stderr, "Unknown feature description: %s\n",
	    feature_tag.c_str() );
	exit(-1);
    }

    f->restore ( is, format );
    is >> threshold;
    is >> parity;
    is >> alpha;
    is >> beta;

    cerr << "T " << threshold << " P " << parity << endl;

    // refactor-nice.pl: check this substitution
    // old: string tmp;
    std::string tmp;
    is >> tmp;
}

void FPCFullSearch::store (ostream & os, int format) const
{
    f->store ( os, format );
    os << endl;
    os << threshold << " " << parity << " ";
    os << alpha << " " << beta << endl;
}

void FPCFullSearch::clear() 
{
    if ( f != NULL )
	delete f;
}
	
const Feature *FPCFullSearch::getStump ( double & _threshold, double & _parity ) const
{
    _threshold = threshold;
    _parity = parity;
    return f;
}