/** 
* @file CascadeOptimization.cpp
* @brief optimization of a previously built cascade
* @author Erik Rodner
* @date 11/13/2008

*/
#include <iostream>
#include <algorithm>
#include <assert.h>

#include "CascadeOptimization.h"

using namespace OBJREC;

using namespace std;
// refactor-nice.pl: check this substitution
// old: using namespace ice;
using namespace NICE;



CascadeOptimization::CascadeOptimization()
{
}

CascadeOptimization::~CascadeOptimization()
{
}

bool CascadeOptimization::calcOptimalCascade ( const vector<vector< triplet<double, double, double> > > & matrix,
					   list<int> & path,
					   double & besttprate,
					   double tprate,
					   double fprate,
					   double minimumFPRate,
					   uint round )
{
    if ( round < matrix.size() )
    {
	int index = 0;
	bool solutionFoundLoop = false;
	const vector<triplet<double, double, double> > & statistics = matrix[round];
	for ( vector<triplet<double, double, double> >::const_iterator i = statistics.begin();
								       i != statistics.end();
								       i++, index++ )
	{
	    double tp = i->first;
	    double fp = i->second;

	    if ( tp*tprate > besttprate ) {
		bool solutionFound = false;
		if ( fp*fprate < minimumFPRate) {
		    path.clear();
		    solutionFound = true;
		} else if ( calcOptimalCascade ( matrix, path, besttprate, 
				     tprate*tp, fprate*fp, minimumFPRate, round+1 ) )
		{
		    solutionFound = true;
		}

		if ( solutionFound )
		{
		    besttprate = tprate*tp;
		    path.push_front ( index );
		    solutionFoundLoop = true;
		}
	    }

	}
	return solutionFoundLoop;

    } else {
	assert ( tprate > besttprate );
	if ( fprate < minimumFPRate )
	{
	    path.clear();
	    return true;
	} else {
	    return false;
	}
    }
}

void CascadeOptimization::calcOptimalCascade ( const vector<vector< triplet<double, double, double> > > & matrix,
					 double minimumFPRate,
					 vector<double> & thresholds )
{
    list<int> path;
    double besttprate = 0.0;
    calcOptimalCascade ( matrix, path, besttprate, 1.0, 1.0, minimumFPRate, 0 );

    int index = 0;
    for ( list<int>::const_iterator i = path.begin(); i != path.end();
	  i++, index++ )
    {
	int entry = *i;
	const triplet<double, double, double> & vals = matrix[index][entry];
	
	double tprate = vals.first;
	double fprate = vals.second;
	double threshold = vals.third;
	fprintf (stderr, "cascade (%d): tp %f fp %f threshold %f\n",
	    index+1, tprate, fprate, threshold);
	thresholds.push_back(threshold);
    }
}

bool CascadeOptimization::evaluateCascade ( vector<pair<double, int> > & results,
				      long N, long P,
				      int negativeClassDST,
				      double requiredDetectionRate,
				      double & bestthreshold,
				      double & besttprate,
				      double & bestfprate,
				      vector< triplet<double, double, double> > & statistics )
{
    sort ( results.begin(), results.end() );

    long positives_count = 0;
    long count = 1;

    bool solutionFound = false;
    int bestEntry = 0;
    int secondBestEntry = 0;

    for ( vector<pair<double, int> >::const_iterator j = results.begin();
						     j+1 != results.end();
						     j++, count++ )
    {
	int classno = j->second;
	double threshold = j->first;
	//fprintf (stderr, "CascadeOptimization: classno %d, threshold %f\n", classno, threshold );

	if ( classno != negativeClassDST ) 
	    positives_count++;

	double tprate = positives_count / (double)P;
	double fprate = ( count - positives_count ) / (double) N;

	if ( (classno != negativeClassDST) && ((j+1)->second == negativeClassDST) 
	    && ((j+1)->first != threshold) ) 
	{
	    statistics.push_back ( triplet<double, double, double> ( tprate, fprate, threshold ) );
	    fprintf (stderr, "CascadeOptimization: tprate %f fprate %f threshold %f (required tprate %f)\n", tprate, fprate, j->first, requiredDetectionRate );
	    if ( (!solutionFound) && (tprate >= requiredDetectionRate) )
	    {
		bestEntry = statistics.size() - 1;
		fprintf (stderr, "CascadeOptimization: suitable entry found !\n");

		solutionFound = true;
	    } else {
		secondBestEntry = statistics.size() - 1;
	    }
	}
    }

    if ( ! solutionFound ) {
	fprintf (stderr, "CascadeOptimization: Using second best solution !!\n");

	besttprate = statistics[secondBestEntry].first;
	bestfprate = statistics[secondBestEntry].second;
	bestthreshold = statistics[secondBestEntry].third;

	fprintf (stderr, "CascadeOptimization: threshold %f detection rate %f fp rate %f\n", 
	    bestthreshold, besttprate, bestfprate );
    } else {
	besttprate = statistics[bestEntry].first;
	bestfprate = statistics[bestEntry].second;
	bestthreshold = statistics[bestEntry].third;
    }

    if ( besttprate == 0.0 ) {
        fprintf (stderr, "!!!! WORST CLASSIFIER I'VE EVER SEEN !!!!\n");
    }

    return solutionFound;
}