/** 
* @file VCOneVsAll.cpp
* @author Erik Rodner
* @date 10/25/2007

*/
#include <iostream>

#include "core/image/ImageT.h"
#include "core/imagedisplay/ImageDisplay.h"

#include "vislearning/classifier/vclassifier/VCOneVsAll.h"


using namespace std;
using namespace NICE;
using namespace OBJREC;


VCOneVsAll::VCOneVsAll ( const Config *_conf, const VecClassifier *_prototype ) 
    : VecClassifier ( _conf ), prototype(_prototype)
{
}

VCOneVsAll::~VCOneVsAll()
{
}

ClassificationResult VCOneVsAll::classify ( const NICE::Vector & x ) const
{
    FullVector scores ( maxClassNo+1 );
	vector<bool> exists(maxClassNo+1, false);
	scores.set(0.0);
	
	double minval = numeric_limits<double>::max();
	
    for ( vector< pair<int, VecClassifier *> >::const_iterator i = 
	    classifiers.begin(); i != classifiers.end(); i++ )
    {
		int classno = i->first;
		exists[classno] = true;
		VecClassifier *classifier = i->second;
		ClassificationResult r = classifier->classify(x);
		scores[classno] += r.scores[1];
		minval = std::min(minval, scores[classno]);
    }

	for(int i = 0; i <= maxClassNo; i++)
	{
		if(!exists[i])
		{
			scores[i] = minval-numeric_limits<double>::epsilon();
		}
	}
	
    return ClassificationResult ( scores.maxElement(), scores ); 
}

void VCOneVsAll::teach ( const LabeledSetVector & _teachSet )
{
    if ( _teachSet.count() <= 0 )
		fthrow(Exception, "Number of training examples is zero!\n");

    maxClassNo = _teachSet.getMaxClassno();
    classifiers.clear();

    for ( int i = 0 ; i <= maxClassNo ; i++ )
    {
		LabeledSetVector binarySubSet (true);
		LabeledSetVector::const_iterator exiv = _teachSet.find(i);
		if ( exiv == _teachSet.end() )
		{
			// a test example might be classified as this class
			// if we do not use probability scores
			cerr << "Class " << i << " does not have any training examples; skipping training." << endl;
			continue;
		}
				
		int poscount = _teachSet.count(i);
		int negcount = _teachSet.count() - poscount;
		int mincount = std::min(poscount, negcount);		
		
		
		int c = 0;
		for ( vector<Vector *>::const_iterator exi = exiv->second.begin();
					exi != exiv->second.end(); exi++, c++ )
		{
			binarySubSet.add_reference ( 1, *exi );
			if( c >= mincount)
				break;
		}

		c = 0;
		
		for ( LabeledSetVector::const_iterator exjv = _teachSet.begin(); 
			exjv != _teachSet.end(); exjv++ , c++)
		{
			if ( exjv == exiv ) continue;
			for ( vector<Vector *>::const_iterator exj = exjv->second.begin();
					exj != exjv->second.end(); exj++ )
			binarySubSet.add_reference ( 0, *exj );
			if( c >= mincount)
				break;
		}
		VecClassifier *classifier;
		classifier = prototype->clone();
		
		fprintf (stderr, "Training classifier: class %d <-> remainder\n", i );
		classifier->teach ( binarySubSet );
		classifier->finishTeaching();

		classifiers.push_back ( pair<int, VecClassifier*> (i, classifier) );
    } 
}

void VCOneVsAll::finishTeaching()
{
}

VecClassifier *VCOneVsAll::clone(void) const
{
	VCOneVsAll *classifier = new VCOneVsAll( *this );
	
	return classifier;
}

VCOneVsAll::VCOneVsAll( const VCOneVsAll &vcova ): VecClassifier() 
{
	prototype = vcova.prototype->clone();
	for(int i = 0; i < (int)vcova.classifiers.size(); i++)
	{
		classifiers.push_back(pair<int, VecClassifier*>(vcova.classifiers[i].first,vcova.classifiers[i].second->clone()));
	}
}