/** 
* @file ClassificationResults.cpp
// refactor-nice.pl: check this substitution
// old: * @brief vector of ClassificationResult
* @brief std::vector of ClassificationResult
* @author Erik Rodner
* @date 02/13/2008

*/
#include "core/image/ImageT.h"
#include "core/vector/VectorT.h"
#include "core/vector/MatrixT.h"

#include <iostream>
#include <fstream>
#include <iomanip>

#include "vislearning/cbaselib/ClassificationResults.h"
#include "vislearning/cbaselib/LocalizationAnalysis.h"

using namespace OBJREC;

using namespace std;
using namespace NICE;



ClassificationResults::ClassificationResults()
{
}

ClassificationResults::~ClassificationResults()
{
}

void ClassificationResults::writeWEKA ( const std::string & filename, int classno ) const
{
    ofstream ofs ( filename.c_str(), ios::out );
    int instno = 0;
    for ( const_iterator i = begin(); i != end() ; i++, instno++ )
    {
		const ClassificationResult & r = *i;
		double confidence = r.scores.get(classno);

		ofs << instno << ", " << r.classno_groundtruth << ", " <<setiosflags(ios::fixed)<< setprecision(20)<<confidence << ", " << r.classno << endl;
    }
    ofs.close();
}

double ClassificationResults::getBinaryClassPerformance ( int type ) const
{
	LocalizationAnalysis la;
	vector< pair<double, int> > resultsFlat;
	uint countPositives = 0;
	uint countNegatives = 0;
	for ( const_iterator i = begin(); i != end(); i++ )
	{
		const ClassificationResult & r = *i;
		double confidence = r.scores.get(1);
		uint classno_groundtruth = r.classno_groundtruth;

		resultsFlat.push_back ( pair<double, int> ( confidence, classno_groundtruth ) );

		if ( classno_groundtruth == 1 )
			countPositives++;
		if ( classno_groundtruth == 0 )
			countNegatives++;
	}

	if ( countPositives <= 0 )
		fthrow(Exception, "No positive ground truth examples");
	if ( countNegatives <= 0 )
		fthrow(Exception, "No negative ground truth examples");

	vector<double> thresholds, x, y;

	if ( type == PERF_AUC )  
		la.calcROCCurve ( resultsFlat, countPositives, countNegatives, thresholds, x, y );
	else
		la.calcRecallPrecisionCurve ( resultsFlat, countPositives, thresholds, x, y );

	if ( type == PERF_AUC )
		return la.calcAreaUnderROC ( x, y );
	else if ( type == PERF_AVG_PRECISION_11_POINT )
		return la.calcAveragePrecision ( x, y );
	else
		return la.calcAveragePrecisionPrecise ( x, y );
}

double ClassificationResults::getAverageRecognitionRate() const
{
  const_iterator i = begin();
  NICE::Matrix confusion ( i->scores.size(),i->scores.size(),0.0 );

  for ( ; i != end(); i++ )
  {
    const ClassificationResult & r = *i;
    uint classno_estimated = r.classno;
    uint classno_groundtruth = r.classno_groundtruth;
    confusion( classno_estimated, classno_groundtruth ) += 1;
  }
  confusion.normalizeColumnsL1();
  return confusion.trace()/confusion.rows();  
}

double ClassificationResults::getOverallRecognitionRate() const
{
  const_iterator i = begin();
  NICE::Matrix confusion ( i->scores.size(),i->scores.size(),0.0 );

  for ( ; i != end(); i++ )
  {
    const ClassificationResult & r = *i;
    uint classno_estimated = r.classno;
    uint classno_groundtruth = r.classno_groundtruth;
    confusion( classno_estimated, classno_groundtruth ) += 1;
  }
  return confusion.trace()/size();  
}