/** 
* @file testClassifier.cpp
* @brief main program for classifier evaluation
* @author Erik Rodner
* @date 2007-10-12
*/

#include <objrec/nice_nonvis.h>

#include <fstream>
#include <iostream>

#include <objrec/cbaselib/MultiDataset.h>
#include <objrec/iclassifier/icgeneric/CSGeneric.h>
#include <objrec/cbaselib/ClassificationResults.h>
#include <objrec/iclassifier/codebook/MutualInformation.h>

#include "objrec/classifier/classifierbase/FeaturePoolClassifier.h"
#include <objrec/classifier/fpclassifier/randomforest/FPCRandomForestTransfer.h>
#include <objrec/classifier/classifierinterfaces/VCFeaturePool.h>

#include <objrec/baselib/Config.h>
#include <objrec/baselib/Preprocess.h>
#include <objrec/baselib/StringTools.h>

#include "objrec/math/cluster/GMM.h"

#undef DEBUG

using namespace OBJREC;

using namespace NICE;
using namespace std;

void binarizeVector ( NICE::Vector & xout, const NICE::Vector & x, const NICE::Vector & thresholds )
{
    xout.resize(x.size());
    for ( size_t i = 0 ; i < x.size() ; i++ )
	if ( fabs(x[i]) > thresholds[i] )
	    xout[i] = 1.0;
	else
	    xout[i] = 0.0;
}

void binarizeSet ( LabeledSetVector & dst, const LabeledSetVector & src, const NICE::Vector & thresholds )
{
    LOOP_ALL(src)
    {
		EACH(classno,x);
		NICE::Vector dstv;
		binarizeVector ( dstv, x, thresholds );
		dst.add ( classno, dstv );
    }
}

int main (int argc, char **argv)
{  
    fprintf (stderr, "testClassifier: init\n");

    std::set_terminate(__gnu_cxx::__verbose_terminate_handler);

    Config conf ( argc, argv );

    string wekafile = conf.gS("main", "weka", "");
    string trainfn = conf.gS("main", "train", "train.vec");
	string testfn = conf.gS("main", "test", "test.vec");
	int format = conf.gI("main", "format", 0 );
	bool binarize = conf.gB("main", "binarize", false );
	int wekaclass = conf.gI("main", "wekaclass", 1 );
	string classifier_cache = conf.gS("main", "classifiercache", "");
	string classifier_cache_in = conf.gS("main", "classifierin", "");
	int numRuns = conf.gI("main", "runs", 1);
	
	// classno:text,classno:text,...
	string classes = conf.gS("main", "classes", "");
	int classesnb = conf.gI("main", "classes", 0);
	string classesconf = conf.gS("main", "classesconf", "");

    fprintf (stderr, "testClassifier: reading config\n");
    Preprocess::Init ( &conf );
    
    fprintf (stderr, "testClassifier: reading multi dataset\n");
    int testMaxClassNo;
    int trainMaxClassNo;
    

    ClassNames *classNames;
	if(classes.size() == 0 && classesnb != 0)
	{
		classNames = new ClassNames ();
		for ( int classno = 0 ; classno < classesnb ; classno++ )
		{
			classNames->addClass ( classno, StringTools::convertToString<int> ( classno ), StringTools::convertToString<int> (classno) );
		}
		trainMaxClassNo = classNames->getMaxClassno();
		testMaxClassNo = trainMaxClassNo;
	}
	else
    if ( classes.size() > 0 ) 
    {
		classNames = new ClassNames ();
		
		vector<string> classes_sub;
		StringTools::split ( string(classes), ',', classes_sub );

		for ( vector<string>::const_iterator i = classes_sub.begin();
				i != classes_sub.end(); i++ )
		{
			vector<string> desc;
			StringTools::split ( *i, ':', desc);
			if ( desc.size() != 2 ) 
				break;
			int classno = StringTools::convert<int> ( desc[0] );
			classNames->addClass ( classno, desc[1], desc[1] );
		}

		trainMaxClassNo = classNames->getMaxClassno();
		testMaxClassNo = trainMaxClassNo;

		classNames->store(cout);
    } 
	else if ( classesconf.size() > 0 ) {
		classNames = new ClassNames ();
		Config cConf ( classesconf );
		classNames->readFromConfig ( cConf, "*" );
		trainMaxClassNo = classNames->getMaxClassno();
		testMaxClassNo = trainMaxClassNo;
	}
	else 
	{
		MultiDataset md ( &conf );
		classNames = new ClassNames ( md.getClassNames("train"), "*" );
		testMaxClassNo = md.getClassNames("test").getMaxClassno();
		trainMaxClassNo = md.getClassNames("train").getMaxClassno();
    }
        
    LabeledSetVector train;
    if ( classifier_cache_in.size() <= 0 )
    {
		fprintf (stderr, "testClassifier: Reading training dataset from %s\n", trainfn.c_str() );
		train.read ( trainfn, format );
		train.printInformation();
    } else {
		fprintf (stderr, "testClassifier: skipping training set %s\n", trainfn.c_str() );
    }

    LabeledSetVector test;
    fprintf (stderr, "testClassifier: Reading test dataset from %s\n", testfn.c_str() );
    test.read ( testfn, format );

    GMM *gmm = NULL;
    int nbgmm = conf.gI("main", "gmm", 0);
    if(nbgmm > 0)
    {
	gmm = new GMM(&conf, nbgmm);
	VVector vset;
	Vector l;
	train.getFlatRepresentation(vset,l);
	gmm->computeMixture(vset);
	
	map<int, vector<NICE::Vector *> >::iterator iter;
    	for( iter = train.begin(); iter != train.end(); ++iter ) 
	{
		for(uint i = 0; i < iter->second.size(); ++i)
		{
			gmm->getProbs(*(iter->second[i]),*(iter->second[i]));
		}
	}
	
	for( iter = test.begin(); iter != test.end(); ++iter ) 
	{
		for(uint i = 0; i < iter->second.size(); ++i)
		{
			gmm->getProbs(*(iter->second[i]),*(iter->second[i]));
		}
	}
    }
    
	ClassificationResults cresults;
	

    for (int runs = 0 ; runs < numRuns ; runs++ ) {
		VecClassifier *vec_classifier = NULL;

		if ( conf.gS("main", "classifier") == "random_forest_transfer" ) 
		{
			FeaturePoolClassifier *fpc = new FPCRandomForestTransfer ( &conf, classNames );
			vec_classifier = new VCFeaturePool ( &conf, fpc );
		} else {
			vec_classifier = CSGeneric::selectVecClassifier ( &conf, "main" );
		}

		NICE::Vector thresholds;

		if ( classifier_cache_in.size() <= 0 )
		{
			if ( binarize ) {
				LabeledSetVector trainbin;
				NICE::Vector mis;
				MutualInformation mi;
				fprintf (stderr, "testClassifier: computing mutual information\n");
				mi.computeThresholdsOverall ( train, thresholds, mis );
				fprintf (stderr, "testClassifier: done!\n");
				binarizeSet ( trainbin, train, thresholds );
				vec_classifier->teach ( trainbin );
			} else {

				vec_classifier->teach ( train );

			}

			vec_classifier->finishTeaching();

		if ( classifier_cache.size() > 0 )
				vec_classifier->save ( classifier_cache );
		} else {
			vec_classifier->setMaxClassNo ( classNames->getMaxClassno() );
			vec_classifier->read ( classifier_cache_in );
		}

		ProgressBar pb ("Classification");

		pb.show();

		std::vector<int> count   ( testMaxClassNo+1, 0 );

		std::vector<int> correct ( testMaxClassNo+1, 0 );

		MatrixT<int> confusionMatrix ( testMaxClassNo+1, trainMaxClassNo+1, 0 );

		int n = test.count();
		LOOP_ALL(test)
		{
			EACH(classno,v);
			pb.update ( n );

			fprintf (stderr, "\tclassification\n" );
			ClassificationResult r;

			if ( binarize ) 
			{
				NICE::Vector vout;
				binarizeVector ( vout, v, thresholds );
				r = vec_classifier->classify ( vout );
 			} else {
				r = vec_classifier->classify ( v );
			}
			r.classno_groundtruth = classno;
			r.classname = classNames->text( r.classno );

#ifdef DEBUG
			if ( r.classno == classno )
				fprintf (stderr, "+ classification %d (\"%s\") <-> %d (\"%s\") score=%f\n", classno, 
					classNames->text(classno).c_str(), r.classno, r.classname.c_str(), r.scores[r.classno]);
			else
				fprintf (stderr, "- classification %d (\"%s\") <-> %d (\"%s\") score=%f\n", classno, 
					classNames->text(classno).c_str(), r.classno, r.classname.c_str(), r.scores[r.classno] );
#endif

			r.scores.store ( cerr );
			if ( r.classno >= 0 )
			{
				if ( classno == r.classno ) correct[classno]++;

				count[classno]++;

				if ( r.ok() ) {
					confusionMatrix(classno, r.classno)++;
				}
				cresults.push_back ( r );
			}
		}
		pb.hide();

		if ( wekafile.size() > 0 )
		{
			string wekafile_s = wekafile;
			if ( numRuns > 1 )
				wekafile_s = wekafile_s + "." + StringTools::convertToString<int>(runs) + ".txt";
			cresults.writeWEKA ( wekafile_s, wekaclass );
		}
			
		int count_total = 0;
		int correct_total = 0;
		int classes_tested = 0;
		double avg_recognition = 0.0;
		for ( size_t classno = 0; classno < correct.size(); classno++ )
		{
			if ( count[classno] == 0 ) {
				fprintf (stdout, "class %d not tested !!\n", (int)classno);
			} else {
				fprintf (stdout, "classification result class %d (\"%s\") : %5.2f %%\n",
					(int)classno, classNames->text(classno).c_str(), correct[classno]*100.0/count[classno] );
				avg_recognition += correct[classno]/(double)count[classno];
				classes_tested++;
			}

			count_total += count[classno];
			correct_total += correct[classno];
		}
		avg_recognition /= classes_tested;


		fprintf (stdout, "overall recognition rate : %-5.3f %%\n", correct_total*100.0/count_total );
		fprintf (stdout, "average recognition rate : %-5.3f %%\n", avg_recognition*100 );
		fprintf (stdout, "total:%d misclassified:%d\n", count_total, count_total - correct_total );

		int max_count = *(max_element( count.begin(), count.end() ));
		fprintf (stdout, "no of classes : %d\n", classNames->numClasses() );
		fprintf (stdout, "lower bound 1 : %f\n", 100.0/(classNames->numClasses()));
		fprintf (stdout, "lower bound 2 : %f\n", max_count * 100.0 / (double) count_total);

		cout << confusionMatrix << endl;
		
		delete vec_classifier;
    }

    delete classNames;

    return 0;
}