/** 
* @file VCOneVsOne.cpp
* @author Erik Rodner
* @date 10/25/2007

*/
#include <iostream>

#include "core/basics/StringTools.h"
#include "vislearning/classifier/vclassifier/VCOneVsOne.h"


using namespace OBJREC;
using namespace std;
using namespace NICE;


VCOneVsOne::VCOneVsOne ( const Config *conf, VecClassifier *_prototype ) 
    : VecClassifier ( conf ), prototype ( _prototype )
{
    use_weighted_voting = conf->gB("VCOneVsOne", "use_weighted_voting", false);
}

VCOneVsOne::~VCOneVsOne()
{
}

ClassificationResult VCOneVsOne::classify ( const NICE::Vector & x ) const
{
    FullVector scores ( maxClassNo+1 );
    scores.set(0);
    for ( vector< triplet<int, int, VecClassifier *> >::const_iterator i = 
	    classifiers.begin(); i != classifiers.end(); i++ )
    {
		VecClassifier *classifier = i->third;
		ClassificationResult r = classifier->classify(x);
		int classi = i->first;
		int classj = i->second;

		if ( use_weighted_voting )
		{
			if ( r.classno == 0 )
				scores[classi]-=r.scores[1];
			else
				scores[classj]+=r.scores[1];
		} else {
			if ( r.classno == 0 )
				scores[classi]++;
			else
				scores[classj]++;
		}
    }
    scores.normalize();

    return ClassificationResult ( scores.maxElement(), scores ); 
}

void VCOneVsOne::teach ( const LabeledSetVector & _teachSet )
{
    maxClassNo = _teachSet.getMaxClassno();
    classifiers.clear();
    assert ( maxClassNo+1 == _teachSet.numClasses() );

    for ( int i = 0 ; i <= maxClassNo ; i++ )
    {
		for ( int j = i+1 ; j <= maxClassNo ; j++ )
		{
			LabeledSetVector binarySubSet (true);
			LabeledSetVector::const_iterator exiv = _teachSet.find(i);
			for ( vector<Vector *>::const_iterator exi = exiv->second.begin();
					exi != exiv->second.end(); exi++ )
			binarySubSet.add_reference ( 0, *exi );

			LabeledSetVector::const_iterator exjv = _teachSet.find(j);
			for ( vector<Vector *>::const_iterator exj = exjv->second.begin();
					exj != exjv->second.end(); exj++ )
			binarySubSet.add_reference ( 1, *exj );
		
			VecClassifier *classifier;

			classifier = prototype->clone();
				   
			fprintf (stderr, "Training classifier: class %d <-> class %d\n", i, j );
			classifier->teach ( binarySubSet );
			classifier->finishTeaching();

			classifiers.push_back ( triplet<int, int, VecClassifier*> (i,j,classifier) );
		}
    }
}

void VCOneVsOne::finishTeaching()
{
}


void VCOneVsOne::read (const string& s, int format)
{
    ifstream ifs ( s.c_str(), ios::in );
    ifs >> maxClassNo;
    ifs.close();

    for ( int i = 0 ; i <= maxClassNo ; i++ )
    {
		for ( int j = i+1 ; j <= maxClassNo ; j++ )
		{
			VecClassifier *classifier;

			classifier = prototype->clone();
				   
			string classifiercache = s + ".onevsone." + StringTools::convertToString<int> ( i ) + "." + StringTools::convertToString<int> ( j );
			fprintf (stderr, "Loading classifier: class %d <-> class %d\n", i, j );

			classifier->read ( classifiercache, format );
			
			classifiers.push_back ( triplet<int, int, VecClassifier*> (i,j,classifier) );
		}
    }
}

void VCOneVsOne::save (const string& s, int format) const
{
    ofstream ofs ( s.c_str(), ios::out );
    ofs << maxClassNo << endl;
    ofs.close();

    for ( vector< triplet<int, int, VecClassifier *> >::const_iterator i = 
	    classifiers.begin(); i != classifiers.end(); i++ )
    {
	int classi = i->first;
	int classj = i->second;
	VecClassifier *classifier = i->third;

	string classifiercache = s + ".onevsone." + StringTools::convertToString<int> ( classi ) + "." + StringTools::convertToString<int> ( classj );

	classifier->save ( classifiercache, format );
    }
}

void VCOneVsOne::store ( std::ostream & os, int format ) const
{
    fprintf (stderr, "VCOneVsOne: unable to write to stream! please use read()\n");
}

void VCOneVsOne::restore ( std::istream & is, int format )
{
    fprintf (stderr, "VCOneVsOne: unable to read from stream! please use save()\n");
    exit (-1);
}