/** 
* @file KernelUtils.cpp
* @brief some utilities to select kernel sub matrices etc.
* @author Erik Rodner
* @date 03/01/2010

*/
#include <iostream>
#include <set>

#include "KernelUtils.h"
#include "core/basics/StringTools.h"

using namespace std;
using namespace NICE;
using namespace OBJREC;


void KernelUtils::selectExamples ( const Config *conf, const Vector & labels,
	vector<int> & trainSelection, vector<int> & testSelection )
{
	string selectionType = conf->gS("main", "selection_type");
	map<int, int> trainExamplesCount;
	map<int, int> testExamplesCount;

	if ( selectionType == "seq" ) {
		int trainExamples = conf->gI("main", "selection_examples" );
		if ( ((int)labels.size() < trainExamples) || (trainExamples <= 0) )
			fthrow(Exception, "Unable to select " << trainExamples << " from " << labels.size() << ".");
		for ( uint i = 0 ; i < (uint)trainExamples; i++ )
		{
			int classno = (int)labels[i];
			trainSelection.push_back( i );
			trainExamplesCount[classno] ++;
		}
		
		for ( uint i = (uint)trainExamples ; i < labels.size(); i++ )
		{
			int classno = (int)labels[i];
			testSelection.push_back( i );
			testExamplesCount[classno] ++;
		}

		
	} else if ( selectionType == "seq_class" ) 
	{
		int trainExamplesForEachClassSingle = conf->gI("main", "selection_examples_class", -1 );

		Vector trainExamplesForEachClass;
		if ( trainExamplesForEachClassSingle <=0 ) {
			string trainExamplesForEachClass_s = conf->gS("main", "selection_examples_class" );
			StringTools::splitVector ( trainExamplesForEachClass_s, ',', trainExamplesForEachClass );
		}

		for ( uint i = 0 ; i < labels.size() ; i++ )
		{
			int classno = (int)labels[i];
			if ( (trainExamplesForEachClassSingle <= 0) && (classno >= (int)trainExamplesForEachClass.size()) )
				fthrow(Exception, "-selection_examples_class <n0>,<n1>,..." << endl << "Missing data in selection_examples_class!" );

			int limit = trainExamplesForEachClassSingle;
			if ( limit <= 0 )
				limit = trainExamplesForEachClass[classno];
			if ( trainExamplesCount[classno] < limit )
			{
				trainSelection.push_back ( i );
				trainExamplesCount[classno] ++;
			} else {
				testSelection.push_back ( i );
				testExamplesCount[classno] ++;
			}
		}
	} else if ( selectionType == "random_class_doaa" ) 
	{
		if ( labels.size() != 224*5 )
			fthrow(Exception, "This selection only works with the Jena-Range-02 database!\n");
		int k = conf->gI("main", "selection_instances", 3 );

		map<int, set<int> > trainInstances;
		// loop through all classes
		for ( int i = 0 ; i < 5 ; i++ ) 
		{
			trainInstances.insert ( pair<int, set<int> > ( i, set<int> () ) );
			for ( int j = 0 ; j < k ; j++ )
			{
				int inst;
				do {
					inst = randInt ( 6 ) + 1;
				} while ( trainInstances[i].find(inst) != trainInstances[i].end() );
				trainInstances[i].insert ( inst );
			}
		}

		for ( uint i = 0 ; i < labels.size() ; i++ )
		{
			int classno = (int)labels[i];
			int instance = (i % 224) / 32;

			if ( instance == 0 )
				continue;
//			cerr << i << " " << "inst " << instance << " " << trainInstances[classno].size() << endl;

			if ( trainInstances[classno].find(instance) != trainInstances[classno].end() ) {
				trainSelection.push_back ( i );
				trainExamplesCount[classno] ++;
			} else {
				testSelection.push_back ( i );
				testExamplesCount[classno] ++;
			}
		}

		for ( int i = 0 ; i < 5 ; i++ ) 
			if ( trainExamplesCount[i] != k*32 ) {
				fthrow(Exception, "Something is wrong here: training examples of class " << i << " = " << trainExamplesCount[i] << " != " << k*32 );
			}
	} else if ( selectionType == "seq_class_doaa" ) 
	{
		int trainExamplesForEachClassSingle = 100;
		int testExamplesForEachClassSingle = 60;

		for ( uint i = 0 ; i < labels.size() ; i++ )
		{
			int classno = (int)labels[i];

			if ( trainExamplesCount[classno] < trainExamplesForEachClassSingle )
			{
				trainSelection.push_back ( i );
				trainExamplesCount[classno] ++;
			} else if ( testExamplesCount[classno] < testExamplesForEachClassSingle )  {
				testSelection.push_back ( i );
				testExamplesCount[classno] ++;
			}
		}

	} else if ( selectionType == "random_class" ) 
	{
		int trainExamplesForEachClassSingle = conf->gI("main", "selection_examples_class", -1 );

		Vector trainExamplesForEachClass;
		if ( trainExamplesForEachClassSingle <=0 ) {
			string trainExamplesForEachClass_s = conf->gS("main", "selection_examples_class" );
			StringTools::splitVector ( trainExamplesForEachClass_s, ',', trainExamplesForEachClass );
		}

		map<uint, uint> counts;
		for ( uint j = 0 ; j < labels.size(); j++ )
		{
			uint classno = (uint)labels[j];
			map<uint, uint>::iterator i = counts.find( classno );
			if ( i == counts.end() )
				counts.insert ( pair<uint, uint> ( classno, 1 ) );
			else
				i->second += 1;
		}

		set<int> memory;
		for ( map<uint, uint>::const_iterator k = counts.begin(); k != counts.end(); k++ )
		{
			uint count = k->second;
			uint classno = k->first;
			if ( (trainExamplesForEachClassSingle <= 0) && (classno >= trainExamplesForEachClass.size()) )
				fthrow(Exception, "-selection_examples_class <n0>,<n1>,..." << endl << "Missing data in selection_examples_class!" );

			int limit = trainExamplesForEachClassSingle;
			if ( limit <= 0 )
				limit = trainExamplesForEachClass[classno];
			if ( limit > (int)count )
			{
				cerr << "Class " << classno << " has not enough examples, we will use all of them (" << count << ") !" << endl;
				limit = count;
			}

			for ( int j = 0 ; j < limit ; j++ )
			{
				int k;
				// inefficient random selection 
				do {
					k = rand() % labels.size();
				} while ( (memory.find(k) != memory.end()) || ((uint)labels[k] != classno) );

				memory.insert(k);
				trainSelection.push_back ( k );	
			}

			cerr << classno << " -> " << limit << endl;
		}

		// put the remainder to the test sets
		for ( uint i = 0 ; i < labels.size(); i++ )
		{
			if ( memory.find(i) == memory.end() ) {
				testSelection.push_back ( i );
			}
		}

	} else {
		fthrow(Exception, "Selection type " << selectionType << " is unknown." );
	}


	cerr << "Learning" << endl;
	for ( map<int, int>::const_iterator j = trainExamplesCount.begin();
		j != trainExamplesCount.end(); j++ )
		cerr << "class " << j->first << ": " << j->second << endl;
		
	cerr << "Testing" << endl;
	for ( map<int, int>::const_iterator j = testExamplesCount.begin();
		j != testExamplesCount.end(); j++ )
		cerr << "class " << j->first << ": " << j->second << endl;
}
		
void KernelUtils::getKernelMatrix ( const vector<int> & trainSelection, 
	const Matrix & kernelMatrix, const Vector & labels, 
	Matrix & kernelMatrixTrain, Vector & labelsTrain )
{
	kernelMatrixTrain.resize ( trainSelection.size(), trainSelection.size() );
	labelsTrain.resize ( trainSelection.size() );

	int ik = 0;
	for ( vector<int>::const_iterator i = trainSelection.begin(); 
		i != trainSelection.end(); i++,ik++ )
	{
		int index_i = *i;
		labelsTrain[ik] = labels[index_i];

		int jk = 0;
		for ( vector<int>::const_iterator j = trainSelection.begin(); 
			j != trainSelection.end(); j++,jk++ )
		{
			int index_j = *j;
			kernelMatrixTrain(ik,jk) = kernelMatrix(index_i,index_j);
		}
	}
}
			
void KernelUtils::getKernelVector ( const vector<int> & trainSelection, 
	const Matrix & kernelMatrix, uint index, Vector & kernelVector )
{
	kernelVector.resize ( trainSelection.size() );
	int ik = 0;
	for ( vector<int>::const_iterator i = trainSelection.begin();
		i != trainSelection.end(); i++,ik++ )
	{
		int index_i = *i;
		kernelVector(ik) = kernelMatrix(index_i,index);
	}
}