/** 
* @file FPCCascade.cpp
* @brief abstract interface for a classifier using feature selection
* @author Erik Rodner
* @date 04/21/2008

*/
#include <iostream>

#include "FPCCascade.h"

using namespace OBJREC;

using namespace std;
// refactor-nice.pl: check this substitution
// old: using namespace ice;
using namespace NICE;

const int maxCascades = 9999;

FPCCascade::FPCCascade( FeaturePoolClassifier *_classifier,
			int _backgroundClass ) 
    : classifier(_classifier), 
      backgroundClass (_backgroundClass),
      nextComplexity (0)
{
}

FPCCascade::~FPCCascade()
{
    for ( Cascade::iterator i = cascade.begin(); i != cascade.end(); i++ )
    {
	delete i->second;
    }

    delete classifier;
}

double shiftedSigmoid ( double x )
{
    const double alpha = 100.0;
    return 1/(1+exp(-(x-0.5)*alpha));
}

#if 0
void calibrateProb ( SparseVector & scores, double T, int backgroundClass )
{
    for ( SparseVector::iterator i = scores.begin();
				i != scores.end();
				i++ )
    {
	if ( i->first == backgroundClass ) {
	    if ( i->second < T ) 
		i->second *= 0.5 / T;
	    else
		i->second = 0.5 * (i->second - T) / (1.0 - T) + 0.5;
	} else {
	    if ( i->second < 1.0-T ) 
		i->second *= 0.5 / (1.0-T);
	    else
		i->second = 0.5 * (i->second - 1.0 + T) / T + 0.5;

	}
	i->second = shiftedSigmoid ( i->second );
    }

    scores.normalize();
}
#endif

ClassificationResult FPCCascade::classify ( Example & pe )
{
    FullVector overall_scores;
    bool rejected = false;
    int cindex = 0;

    for ( Cascade::const_iterator i = cascade.begin(); 
	  (i != cascade.end()) && (cindex < maxCascades); 
	  i++,cindex++ )
    {
	ClassificationResult r = i->second->classify ( pe );
	double threshold = i->first;

	if ( overall_scores.empty() )
	    overall_scores = r.scores;
	else
	    overall_scores.add ( r.scores );
	
	double score = overall_scores.get(backgroundClass) / overall_scores.sum();

	if ( score > threshold ) {
	    rejected = true;
	    break;
	}
    }
    overall_scores.normalize();

    int overall_classno = 0;
    overall_classno = rejected ? backgroundClass : overall_scores.maxElementExclusive( backgroundClass );
    assert ( !(!rejected && (overall_classno == backgroundClass) ) );

    return ClassificationResult ( overall_classno, overall_scores );
}

void FPCCascade::train ( FeaturePool & fp, Examples & examples )
{
    maxClassNo = examples.getMaxClassNo();

    fprintf (stderr, "FPCCascade: adding new cascade %d\n", (int)cascade.size() + 1 );

    FeaturePoolClassifier *c = classifier->clone();
    c->maxClassNo = maxClassNo;

    if ( nextComplexity >= 0 )
	c->setComplexity ( nextComplexity );
    c->train ( fp, examples );

    cascade.push_back ( pair<double, FeaturePoolClassifier *> ( 0.5, c ) );
}

void FPCCascade::restore (istream & is, int format)
{
    // refactor-nice.pl: check this substitution
    // old: string tag;
    std::string tag;

    classifier->maxClassNo = maxClassNo;
    fprintf (stderr, "FPCCascade: max classno %d\n", maxClassNo );

    while ( (is >> tag) && (tag == "WEAKCLASSIFIER") )
    {
	FeaturePoolClassifier *c = classifier->clone();

	double threshold;

	is >> threshold;
	c->restore ( is );

	cascade.push_back ( pair<double, FeaturePoolClassifier *>
	    ( threshold, c ) );
    }

}

void FPCCascade::store (ostream & os, int format) const
{
    for ( Cascade::const_iterator i = cascade.begin();
				       i != cascade.end();
				       i++ )
    {
	const FeaturePoolClassifier *classifier = i->second;
	double threshold = i->first;
    
	os << "WEAKCLASSIFIER" << endl;

	os << threshold << endl;

	classifier->store (os);

	os << "ENDWEAKCLASSIFIER" << endl;

    }
}

void FPCCascade::deleteLastClassifier ()
{
    FeaturePoolClassifier *c = cascade.back().second;

    delete c;
    cascade.pop_back();
}

FeaturePoolClassifier *FPCCascade::getLastClassifier ()
{
    return cascade.back().second;
}

void FPCCascade::clear ()
{
    cascade.clear();
}

FPCCascade *FPCCascade::clone () const
{
    FeaturePoolClassifier *classifierClone = classifier->clone();
    classifierClone->maxClassNo = maxClassNo;

    FPCCascade *cascadeClone = new FPCCascade ( classifierClone, backgroundClass );
    cascadeClone->maxClassNo = maxClassNo;
    cascadeClone->nextComplexity = nextComplexity;

    return cascadeClone;
}

void FPCCascade::setLastThreshold ( double threshold )
{
    assert ( cascade.size() > 0 );
    cascade.back().first = threshold;
}
	
void FPCCascade::setComplexity ( int size )
{
    nextComplexity = size;
}