/** 
* @file LabeledSet.cpp
* @brief Labeled set of vectors
* @author Erik Rodner
* @date 07.09.2007

*/
#ifndef LABELEDSETTCCINCLUDE
#define LABELEDSETTCCINCLUDE

#include "core/image/ImageT.h"
#include "core/vector/VectorT.h"
#include "core/vector/MatrixT.h"

#include <iostream>

#include "vislearning/cbaselib/LabeledSet.h"
#include "core/basics/StringTools.h"

using namespace OBJREC;

using namespace std;
using namespace NICE;


LabeledSet::LabeledSet ( bool _selection ) : selection(_selection)
{
}
	
LabeledSet::~LabeledSet ()
{
    //This is a big problem when using selections
    //clear();
    //fprintf (stderr, "LabeledSet: destructor (FIXME: memory leak)\n");
}

int LabeledSet::count ( int classno ) const
{
    const_iterator i = find(classno);
    return ( i == end() ) ? 0 : i->second.size();
}

int LabeledSet::count () const
{
    int mycount = 0;
    for ( const_iterator i = begin() ; i != end() ; i++ )
    {
		mycount += i->second.size();
    }
    return mycount;
}

void LabeledSet::clear ()
{
    if ( !selection ) 
    {
		for ( Permutation::const_iterator i  = insertOrder.begin(); 
						  i != insertOrder.end();
						  i++ )
		{
			const ImageInfo *s = i->second;
			delete s;
		}
    }
    
    std::map< int, vector<ImageInfo *> >::clear();
}

void LabeledSet::add ( int classno, ImageInfo *x )
{
    if ( selection ) {
		fprintf (stderr, "Operation not available for selections !\n");
		exit(-1);
    }

    iterator i = find(classno);
    if ( i == end() ) {
		operator[](classno) = vector<ImageInfo *>();
		i = find(classno);
    }
    i->second.push_back ( x ); 
    insertOrder.push_back ( ElementPointer ( classno, x ) );
}
	
void LabeledSet::getPermutation ( Permutation & permutation ) const
{
    permutation = Permutation ( insertOrder );
}

void LabeledSet::add_reference ( int classno, ImageInfo *pointer )
{
    iterator i = find(classno);
    if ( i == end() ) {
		operator[](classno) = vector<ImageInfo *>();
		i = find(classno);
    }
    i->second.push_back ( pointer ); 
    insertOrder.push_back ( ElementPointer ( classno, pointer ) );
}

void LabeledSet::getClasses ( std::vector<int> & classes ) const
{
    for ( const_iterator i = begin(); i != end(); i++ )
		classes.push_back ( i->first );
}

void LabeledSet::printInformation () const
{
    for ( const_iterator i = begin(); i != end(); i++ )
    {
	cerr << "class " << i->first << ": " << i->second.size() << endl;
    }
}


/************************************
    LabeledSetVector
*************************************/


LabeledSetVector::LabeledSetVector (bool _selection) : selection(_selection)
{}
	
LabeledSetVector::~LabeledSetVector ()
{
    // FIXME: THIS is a big problem with selections !!!
    //clear();
}

int LabeledSetVector::dimension () const
{
    if ( insertOrder.size() <= 0 ) return -1;
    return (*(begin()->second.begin()))->size();
    //insertOrder[0].second->size();
}

void LabeledSetVector::restore (istream & is, int format)
{
    if ( format == FILEFORMAT_RAW )
		restoreRAW ( is );
    else
		restoreASCII ( is, format );
}

void LabeledSetVector::restoreASCII (istream & is, int format)
{
    const int bufsize = 1024*1024;
    char *buf = new char[bufsize];
    std::string buf_s;

    vector<string> elements;
    vector<string> pair;

	// maximal dimension of all feature vectors;
	int dataset_dimension = -numeric_limits<int>::max();

    while (! is.eof())
    {
		elements.clear();
		int classno;
		
		if ( ! (is >> classno) ) {
			break;
		}

		is.get ( buf, bufsize );
		buf_s = buf;

		if ( buf_s.size() <= 0 ) 
			break;

		StringTools::split ( buf_s, ' ', elements );

		if ( elements.size() <= 1 )
			break;
		
		int dimension = - numeric_limits<int>::max();
		if ( format == FILEFORMAT_INDEX_SPARSE_ONE ) 
		{
			// in this format we have to determine the maximum index
			for ( vector<string>::const_iterator i  = elements.begin()+1; 
							 i != elements.end();
							 i++ ) 
			{
				pair.clear();
				StringTools::split ( *i, ':', pair );
				if ( pair.size() != 2 ) continue;

				int index = atoi(pair[0].c_str());
				
				if ( index > dimension )
					dimension = index;
			}

			if ( dimension > dataset_dimension )
				dataset_dimension = dimension;

			
		} else {
			// skip first element because of white space
			dimension = elements.size()-1;
		}


		NICE::Vector vec ( dimension, 0.0 );
		size_t l = 0;

		// skip first element because of white space
		for ( vector<string>::const_iterator i  = elements.begin()+1; 
							 i != elements.end();
							 i++, l++ )
		{
			if ( format == FILEFORMAT_INDEX ) 
			{
				pair.clear();
				StringTools::split ( *i, ':', pair );
				if ( pair.size() == 2 ) {
					double val = atof ( pair[1].c_str() );
					vec[l] = val;
				}
			} else if ( format == FILEFORMAT_INDEX_SPARSE_ONE ) 
			{
				pair.clear();
				StringTools::split ( *i, ':', pair );
				if ( pair.size() == 2 ) {
					double val = atof ( pair[1].c_str() );
					int index = atoi ( pair[0].c_str() ) - 1;
					vec[index] = val;
				}
			} else {
				vec[l] = atof( i->c_str() );
			}
		}
		add( classno, vec );
    }
    delete [] buf;

	if ( format == FILEFORMAT_INDEX_SPARSE_ONE ) {
		// we have to resize all feature vectors of the dataset to dataset_dimension
		for ( LabeledSetVector::iterator iLOOP_ALL = begin() ; iLOOP_ALL != end() ; iLOOP_ALL++)
			for ( vector<NICE::Vector *>::iterator jLOOP_ALL = iLOOP_ALL->second.begin(); 
									 jLOOP_ALL != iLOOP_ALL->second.end(); 
									 jLOOP_ALL++ )
			{
				NICE::Vector *x = (*jLOOP_ALL);

				uint old_dimension = x->size();

				// resize the vector to the dataset dimension
				x->resize(dataset_dimension);
				
				// set all elements to zero, which are new after the resize operation
				for ( uint k = old_dimension; k < x->size(); k++ )
					(*x)[k] = 0.0;
			}
	}
}

void LabeledSetVector::store (ostream & os, int format) const
{
    for ( Permutation::const_iterator i  = insertOrder.begin();
				      i != insertOrder.end();
				      i++ )
    {
		int classno = i->first;
		const NICE::Vector & x = *(i->second);
		
		storeElement ( os, classno, x, format );
    }
}

void LabeledSetVector::storeElement ( ostream & os, int classno, const NICE::Vector & x, int format )
{
    if ( format != FILEFORMAT_RAW ) {
		os << classno << " ";
		for ( size_t k = 0 ; k < x.size() ; k++ )
		{
			if ( format == FILEFORMAT_INDEX )
				os << k+1 << ":" << x[k];
			else if ( format == FILEFORMAT_NOINDEX )  
				os << x[k];
			else if ( format == FILEFORMAT_INDEX_SPARSE_ONE )  {
				if ( x[k] != 0.0 ) 
					os << k+1 << ":" << x[k];
			}

			if ( k != x.size() )
				os << " ";
		}
		os << endl;
	} else {
		const double *data = x.getDataPointer();
		int dimension = x.size();

		os.write ( (char *)&classno, sizeof(int) );
		os.write ( (char *)&dimension, sizeof(int) );
		os.write ( (char *)data, sizeof(double)*dimension );
    }
}

void LabeledSetVector::restoreRAW (istream & is)
{
    while (! is.eof())
    {
		int classno;
		int dimension;

		is.read ( (char *)&classno, sizeof(int) );
		if ( is.gcount() != sizeof(int) )
			return;

		is.read ( (char *)&dimension, sizeof(int) );
		if ( is.gcount() != sizeof(int) )
			return;

		NICE::Vector vec;

		try {
			vec.resize(dimension);
		} catch ( std::bad_alloc ) {
			fthrow(IOException, "Unable to allocate a vector with size " << dimension << "." << endl 
					<< "(debug: class " << classno << " ; " << "sizeof(int) = " << 8*sizeof(int) << " Bit ; " << endl
					<< "elements read = " << count() << " )" << endl );
		}
		double *data = vec.getDataPointer();

		is.read ( (char *)data, sizeof(double)*dimension );
		if ( (int)is.gcount() != (int)sizeof(double)*dimension )
			return;

		for ( int k = 0 ; k < dimension ; k++ )
			if ( isnan(data[k]) ) {
				cerr << "WARNING: nan's found !!" << endl;
				data[k] = 0.0;
			}
		
		add( classno, vec );
    }
}

LabeledSetVector::ElementPointer LabeledSetVector::pickRandomSample () const
{
    if ( insertOrder.size() <= 0 ) {
		fprintf (stderr, "LabeledSet::pickRandomSample: failure !\n");
		exit(-1);
    }
    
    int selection = rand() % insertOrder.size();
    return insertOrder[selection];
}

int LabeledSetVector::count ( int classno ) const
{
    const_iterator i = find(classno);
    return ( i == end() ) ? 0 : i->second.size();
}

int LabeledSetVector::count () const
{
    int mycount = 0;
    for ( const_iterator i = begin() ; i != end() ; i++ )
		mycount += i->second.size();
    return mycount;
}

int LabeledSetVector::pickRandomSample ( int classno, ElementPointer & i ) const
{
    const_iterator j = find(classno);
    if ( j == end() ) return -1;

    const vector<Vector *> & l = j->second;
    int num = rand() % l.size();

    i.first = classno;
    i.second = l[num];

    return classno;
}

void LabeledSetVector::clear ()
{
    if ( ! selection ) {
		for ( Permutation::const_iterator i  = insertOrder.begin(); 
						  i != insertOrder.end();
						  i++ )
		{
			const NICE::Vector *s = i->second;
			delete s;
		}
		insertOrder.clear();
    }

    std::map< int, vector<Vector *> >::clear();
}

void LabeledSetVector::add ( int classno, const NICE::Vector & x )
{
    if ( selection ) {
		fprintf (stderr, "Add operation not available for selections !\n");
		exit(-1);
    }

    iterator i = find(classno);
    if ( i == end() ) {
		operator[](classno) = vector<Vector *>();
		i = find(classno);
    }
    NICE::Vector *xp = new Vector(x);

    i->second.push_back ( xp ); 
    insertOrder.push_back ( ElementPointer ( classno, xp ) );
}
    
void LabeledSetVector::getPermutation ( Permutation & permutation ) const
{
    permutation = Permutation ( insertOrder );
}

void LabeledSetVector::add_reference ( int classno, NICE::Vector *pointer )
{
    iterator i = find(classno);
    if ( i == end() ) {
		operator[](classno) = vector<Vector *>();
		i = find(classno);
    }
    i->second.push_back ( pointer ); 
    insertOrder.push_back ( ElementPointer ( classno, pointer ) );
}
	
void LabeledSetVector::getClasses ( std::vector<int> & classes ) const
{
    for ( const_iterator i = begin(); i != end(); i++ )
	classes.push_back ( i->first );
}
	
void LabeledSetVector::printInformation () const
{
    for ( const_iterator i = begin(); i != end(); i++ )
    {
		cerr << "class " << i->first << ": " << i->second.size() << endl;
    }
}
	
int LabeledSetVector::getMaxClassno() const
{
    int maxclassno = 0;

    for ( const_iterator i = begin(); i != end(); i++ )
		if ( i->first > maxclassno ) 
			maxclassno = i->first;

    return maxclassno;
}

void LabeledSetVector::getFlatRepresentation ( VVector & vecSet, NICE::Vector & vecSetLabels ) const
{
	int k = 0;
	vecSetLabels.resize(count());
	for ( LabeledSetVector::const_iterator iLOOP_ALL = begin() ; iLOOP_ALL != end() ; iLOOP_ALL++)
		for ( vector<NICE::Vector *>::const_iterator jLOOP_ALL = iLOOP_ALL->second.begin(); 
								 jLOOP_ALL != iLOOP_ALL->second.end(); 
								 jLOOP_ALL++,k++ )
		{
			const NICE::Vector & (x) = *(*jLOOP_ALL);
			vecSet.push_back ( x );
			vecSetLabels[k] = iLOOP_ALL->first;
		}

}

#endif