/** 
* @file analyseLocalization.cpp
* @brief print recall/precision curves etc
* @author Erik Rodner
* @date 09/01/2008

*/
#include "core/vector/VectorT.h"
#include "core/vector/MatrixT.h"
#include "core/image/ImageT.h"

#include <core/basics/Config.h>
#include <vislearning/baselib/Gnuplot.h>
#include <core/basics/StringTools.h>
#include <vislearning/cbaselib/LocalizationAnalysis.h>

using namespace OBJREC;

using namespace NICE;
using namespace std;

void readResults ( const string & resultsfn, vector<pair<double, int> > & results )
{
	ifstream ifs ( resultsfn.c_str(), ios::in );

	if ( !ifs.good() ) 
		fthrow(IOException, "Unable to open " << resultsfn << "." );

	while ( !ifs.eof() ) 
	{
		char buf [1024];
		ifs.getline( buf, 1024 );
		if ( !ifs.good() ) break;
		
		Vector entry;
		StringTools::splitVector ( string(buf), ',', entry );

		if ( entry.size() != 4 )
			fthrow(IOException, "Parse error in " << resultsfn << "." );

		results.push_back ( pair<double, int> ( entry[2], (int)entry[1] ) );
	}

	ifs.close();
}

/** 
    print recall/precision curves etc 
*/
int main (int argc, char **argv)
{   
    std::set_terminate(__gnu_cxx::__verbose_terminate_handler);

	Config conf ( argc, argv );

	string resultssetting = conf.gS("main", "results", "results.txt" );
	int graphtype = conf.gI("main", "graphtype", 0 );
	string outfn = conf.gS("main", "out", "");
	bool displayGraph = conf.gB("main", "display", false );
	bool plainMode = conf.gB("main", "plain", false );

	vector<string> submatches;
	vector<string> resultsfiles;
	if ( StringTools::regexMatch ( resultssetting, "^list:(.+)", submatches ) )
	{
		if ( !plainMode ) 
			cerr << "reading file list" << endl;
		string listfn = submatches[1];
		ifstream ifs ( listfn.c_str(), ios::in );
		if ( ifs.bad() )
			fthrow(IOException, "Unable to open " << listfn << "." );
		char buf[1024];
		while ( !ifs.eof() )
		{
			ifs.getline(buf, 1024);
			if ( strlen(buf) >= 1 ) 
				resultsfiles.push_back ( StringTools::chomp(string(buf)) );
		}
		ifs.close();
	} else {
		StringTools::split ( resultssetting, ',', resultsfiles );
	}
	
	for ( vector<string>::const_iterator i = resultsfiles.begin(); 
		i != resultsfiles.end(); i++ )
	{
		vector<pair<double, int> > results;
		string resultsfn = *i;

		if ( !plainMode )
			fprintf (stderr, "file: %s\n", resultsfn.c_str() );
		readResults ( resultsfn, results );
		if ( !plainMode ) 
			fprintf (stderr, "results.size() = %d\n", (int)results.size() );
	  
		int count_positives = 0;
		for ( vector<pair<double, int> >::const_iterator i = results.begin();
			i != results.end(); i++ )
			if ( i->second == 1 ) count_positives++;
		
		vector<double> thresholds;
		vector<double> x;
		vector<double> y;

		LocalizationAnalysis la;
		double areaMeasure = 0.0;
		if ( (graphtype == 0) || (graphtype == 2) )
		{
			la.calcRecallPrecisionCurve ( results, count_positives, thresholds, x, y );
			fprintf (stderr, "average precision (11-point): %f\n", la.calcAveragePrecision( x, y ) );
			areaMeasure = la.calcAveragePrecisionPrecise( x, y );
			if ( !plainMode ) 
				fprintf (stderr, "average precision (precise): %f\n", areaMeasure );
		} else {
			la.calcROCCurve ( results, count_positives, results.size() - count_positives, thresholds, x, y );
			areaMeasure = la.calcAreaUnderROC( x, y );
			if ( !plainMode )
				fprintf (stderr, "area under the ROC curve: %f\n", areaMeasure );
		}
		cerr << resultsfn << " " << areaMeasure << endl;


		if ( displayGraph ) {
			Gnuplot gp;
			gp.set_xrange ( 0, 1 );
			gp.set_yrange ( 0, 1 );
			gp.set_style ( "lines" );
			if ( graphtype == 0 )
				gp.plot_xy ( x, y, "Recall/Precision");
			else if ( graphtype == 2 ) {
				for ( uint i = 0 ; i < x.size() ; i++ )
				{
					double recall = x[i];
					x[i] = 1.0 - y[i];
					y[i] = recall;
				}
				gp.plot_xy ( x, y, "1-Precision/Recall");
			} else
				gp.plot_xy ( x, y, "TP/FP (ROC)");

			getchar();
		}

		if ( outfn.size() > 0 ) {
			ofstream ofs ( outfn.c_str(), ios::out );
			if ( !ofs.good() )
				fthrow(IOException, "Unable to write graph results to " << outfn << "." );

			for ( unsigned int i = 0 ; i < x.size(); i++ )
				ofs << x[i] << " " << y[i] << " " << thresholds[i] << endl;
			
			ofs.close();
		}
	}
    
    return 0;
}