/** 
* @file KernelClassifier.cpp
* @brief classifier interface for kernel based methods
* @author Erik Rodner
* @date 12/02/2009

*/
#include <iostream>

#include "KernelClassifier.h"

using namespace std;
using namespace NICE;
using namespace OBJREC;

#undef DEBUG

KernelClassifier::KernelClassifier( const Config *_conf, Kernel *kernelFunction, int normalizationType ) : conf(*_conf)
{
	this->kernelFunction = kernelFunction;
	this->normalizationType = normalizationType;
}

KernelClassifier::KernelClassifier ( const KernelClassifier & src ) : VecClassifier(), conf(src.conf)
{
	if ( src.kernelFunction != NULL )
		this->kernelFunction = src.kernelFunction->clone();
	else
		this->kernelFunction = NULL;
	
	this->vecSetLabels = src.vecSetLabels;
	this->normalizationType = src.normalizationType;
	this->vecSet = src.vecSet;

}

KernelClassifier::~KernelClassifier()
{
}

ClassificationResult KernelClassifier::classify ( const NICE::Vector & x ) const
{
	if ( kernelFunction == NULL )
		fthrow( Exception, "KernelClassifier::classify: To use this function, you have to specify a kernel function using the constructor" );

	NICE::Vector kstar;

	// nothing happens, we should create a copy of x and then
	// use normalization
	//if ( this->normalizationType == KERNELCLASSIFIER_NORMALIZATION_EUCLIDEAN )
	//	x.normL2();

	kernelFunction->calcKernelVector ( vecSet, x, kstar );	

#ifdef DEBUG
	if ( kstar.normL2() == 0.0 )
		cerr << "The kernel vector k_* has zero norm, you might want to adjust kernel parameters with e.g. Kernel:log_rbf_gamma." << endl;
#endif


	double kernelSelf = kernelFunction->K(x,x);
	
	return classifyKernel (kstar, kernelSelf);
}

void KernelClassifier::teach ( const LabeledSetVector & teachSet )
{
	if ( kernelFunction == NULL )
		fthrow( Exception, "KernelClassifier::teach: To use this function, you have to specify a kernel function using the constructor" );

	teachSet.getFlatRepresentation ( vecSet, vecSetLabels );
	
	if ( this->normalizationType == KERNELCLASSIFIER_NORMALIZATION_EUCLIDEAN )
	{
		for ( VVector::iterator i = vecSet.begin(); i != vecSet.end(); i++ )
		{
			Vector & v = *i;
			v.normL2();
		}
	}

	KernelData *kernelData = new KernelData ( &conf );

	kernelFunction->calcKernelData ( vecSet, kernelData );

	teach ( kernelData, vecSetLabels );
	
	// maybe problems with other kernel classifiers
	delete kernelData;
}


void KernelClassifier::restore(std::istream& ifs, int type)
{
	vecSet.restore(ifs);
	
	//TODO: read and write kernel
}

void KernelClassifier::store(std::ostream& ofs, int type) const
{
	vecSet.store(ofs);
}