/**
* @file LabeledSet.cpp
* @brief Labeled set of vectors
* @author Erik Rodner
* @date 07.09.2007

*/
#ifndef LABELEDSETTCCINCLUDE
#define LABELEDSETTCCINCLUDE

#include "core/image/ImageT.h"
#include "core/vector/VectorT.h"
#include "core/vector/MatrixT.h"

#include <iostream>

#include "vislearning/cbaselib/LabeledSet.h"
#include "core/basics/StringTools.h"

using namespace OBJREC;

using namespace std;
using namespace NICE;


LabeledSet::LabeledSet ( bool _selection ) : selection ( _selection )
{
}

LabeledSet::~LabeledSet ()
{
  //This is a big problem when using selections
  //clear();
  //fprintf (stderr, "LabeledSet: destructor (FIXME: memory leak)\n");
}

int LabeledSet::count ( int classno ) const
{
  const_iterator i = find ( classno );
  return ( i == end() ) ? 0 : i->second.size();
}

int LabeledSet::count () const
{
  int mycount = 0;
  for ( const_iterator i = begin() ; i != end() ; i++ )
  {
    mycount += i->second.size();
  }
  return mycount;
}

void LabeledSet::clear ()
{
  if ( !selection )
  {
    for ( Permutation::const_iterator i  = insertOrder.begin();
          i != insertOrder.end();
          i++ )
    {
      const ImageInfo *s = i->second;
      delete s;
    }
  }

  std::map< int, vector<ImageInfo *> >::clear();
}

void LabeledSet::add ( int classno, ImageInfo *x )
{
  if ( selection ) {
    fprintf ( stderr, "Operation not available for selections !\n" );
    exit ( -1 );
  }

  iterator i = find ( classno );
  if ( i == end() ) {
    operator[] ( classno ) = vector<ImageInfo *>();
    i = find ( classno );
  }
  i->second.push_back ( x );
  insertOrder.push_back ( ElementPointer ( classno, x ) );
}

void LabeledSet::getPermutation ( Permutation & permutation ) const
{
  permutation = Permutation ( insertOrder );
}

void LabeledSet::add_reference ( int classno, ImageInfo *pointer )
{
  iterator i = find ( classno );
  if ( i == end() ) {
    operator[] ( classno ) = vector<ImageInfo *>();
    i = find ( classno );
  }
  i->second.push_back ( pointer );
  insertOrder.push_back ( ElementPointer ( classno, pointer ) );
}

void LabeledSet::getClasses ( std::vector<int> & classes ) const
{
  for ( const_iterator i = begin(); i != end(); i++ )
    classes.push_back ( i->first );
}

void LabeledSet::printInformation () const
{
  for ( const_iterator i = begin(); i != end(); i++ )
  {
    cerr << "class " << i->first << ": " << i->second.size() << endl;
  }
}


/************************************
    LabeledSetVector
*************************************/


LabeledSetVector::LabeledSetVector ( bool _selection ) : selection ( _selection )
{}

LabeledSetVector::~LabeledSetVector ()
{
  // FIXME: THIS is a big problem with selections !!!
  //clear();
}

int LabeledSetVector::dimension () const
{
  if ( insertOrder.size() <= 0 ) return -1;
  return ( * ( begin()->second.begin() ) )->size();
  //insertOrder[0].second->size();
}

void LabeledSetVector::restore ( istream & is, int format )
{
  if ( format == FILEFORMAT_RAW )
    restoreRAW ( is );
  else
    restoreASCII ( is, format );
}

void LabeledSetVector::restoreASCII ( istream & is, int format )
{
  const int bufsize = 1024 * 1024;
  char *buf = new char[bufsize];
  std::string buf_s;

  vector<string> elements;
  vector<string> pair;

  // maximal dimension of all feature vectors;
  int dataset_dimension = -numeric_limits<int>::max();

  while ( ! is.eof() )
  {
    elements.clear();
    int classno;

    if ( ! ( is >> classno ) ) {
      break;
    }

    is.get ( buf, bufsize );
    buf_s = buf;

    if ( buf_s.size() <= 0 )
      break;

    StringTools::split ( buf_s, ' ', elements );

    if ( elements.size() <= 1 )
      break;

    int dimension = - numeric_limits<int>::max();
    if ( format == FILEFORMAT_INDEX_SPARSE_ONE )
    {
      // in this format we have to determine the maximum index
      for ( vector<string>::const_iterator i  = elements.begin() + 1;
            i != elements.end();
            i++ )
      {
        pair.clear();
        StringTools::split ( *i, ':', pair );
        if ( pair.size() != 2 ) continue;

        int index = atoi ( pair[0].c_str() );

        if ( index > dimension )
          dimension = index;
      }

      if ( dimension > dataset_dimension )
        dataset_dimension = dimension;


    } else {
      // skip first element because of white space
      dimension = elements.size() - 1;
    }


    NICE::Vector vec ( dimension, 0.0 );
    size_t l = 0;

    // skip first element because of white space
    for ( vector<string>::const_iterator i  = elements.begin() + 1;
          i != elements.end();
          i++, l++ )
    {
      if ( format == FILEFORMAT_INDEX )
      {
        pair.clear();
        StringTools::split ( *i, ':', pair );
        if ( pair.size() == 2 ) {
          double val = atof ( pair[1].c_str() );
          vec[l] = val;
        }
      } else if ( format == FILEFORMAT_INDEX_SPARSE_ONE )
      {
        pair.clear();
        StringTools::split ( *i, ':', pair );
        if ( pair.size() == 2 ) {
          double val = atof ( pair[1].c_str() );
          int index = atoi ( pair[0].c_str() ) - 1;
          vec[index] = val;
        }
      } else {
        vec[l] = atof ( i->c_str() );
      }
    }
    add ( classno, vec );
  }
  delete [] buf;

  if ( format == FILEFORMAT_INDEX_SPARSE_ONE ) {
    // we have to resize all feature vectors of the dataset to dataset_dimension
    for ( LabeledSetVector::iterator iLOOP_ALL = begin() ; iLOOP_ALL != end() ; iLOOP_ALL++ )
      for ( vector<NICE::Vector *>::iterator jLOOP_ALL = iLOOP_ALL->second.begin();
            jLOOP_ALL != iLOOP_ALL->second.end();
            jLOOP_ALL++ )
      {
        NICE::Vector *x = ( *jLOOP_ALL );

        uint old_dimension = x->size();

        // resize the vector to the dataset dimension
        x->resize ( dataset_dimension );

        // set all elements to zero, which are new after the resize operation
        for ( uint k = old_dimension; k < x->size(); k++ )
          ( *x ) [k] = 0.0;
      }
  }
}

void LabeledSetVector::store ( ostream & os, int format ) const
{
  for ( Permutation::const_iterator i  = insertOrder.begin();
        i != insertOrder.end();
        i++ )
  {
    int classno = i->first;
    const NICE::Vector & x = * ( i->second );

    storeElement ( os, classno, x, format );
  }
}

void LabeledSetVector::storeElement ( ostream & os, int classno, const NICE::Vector & x, int format )
{
  if ( format != FILEFORMAT_RAW ) {
    os << classno << " ";
    for ( size_t k = 0 ; k < x.size() ; k++ )
    {
      if ( format == FILEFORMAT_INDEX )
        os << k + 1 << ":" << x[k];
      else if ( format == FILEFORMAT_NOINDEX )
        os << x[k];
      else if ( format == FILEFORMAT_INDEX_SPARSE_ONE )  {
        if ( x[k] != 0.0 )
          os << k + 1 << ":" << x[k];
      }

      if ( k != x.size() )
        os << " ";
    }
    os << endl;
  } else {
    const double *data = x.getDataPointer();
    int dimension = x.size();

    os.write ( ( char * ) &classno, sizeof ( int ) );
    os.write ( ( char * ) &dimension, sizeof ( int ) );
    os.write ( ( char * ) data, sizeof ( double ) *dimension );
  }
}

void LabeledSetVector::restoreRAW ( istream & is )
{
  while ( ! is.eof() )
  {
    int classno;
    int dimension;

    is.read ( ( char * ) &classno, sizeof ( int ) );
    if ( is.gcount() != sizeof ( int ) )
      return;

    is.read ( ( char * ) &dimension, sizeof ( int ) );
    if ( is.gcount() != sizeof ( int ) )
      return;

    NICE::Vector vec;

    try {
      vec.resize ( dimension );
    } catch ( std::bad_alloc ) {
      fthrow ( IOException, "Unable to allocate a vector with size " << dimension << "." << endl
               << "(debug: class " << classno << " ; " << "sizeof(int) = " << 8*sizeof ( int ) << " Bit ; " << endl
               << "elements read = " << count() << " )" << endl );
    }
    double *data = vec.getDataPointer();

    is.read ( ( char * ) data, sizeof ( double ) *dimension );
    if ( ( int ) is.gcount() != ( int ) sizeof ( double ) *dimension )
      return;

    for ( int k = 0 ; k < dimension ; k++ )
      if ( isnan ( data[k] ) ) {
        cerr << "WARNING: nan's found !!" << endl;
        data[k] = 0.0;
      }

    add ( classno, vec );
  }
}

LabeledSetVector::ElementPointer LabeledSetVector::pickRandomSample () const
{
  if ( insertOrder.size() <= 0 ) {
    fprintf ( stderr, "LabeledSet::pickRandomSample: failure !\n" );
    exit ( -1 );
  }

  int selection = rand() % insertOrder.size();
  return insertOrder[selection];
}

int LabeledSetVector::count ( int classno ) const
{
  const_iterator i = find ( classno );
  return ( i == end() ) ? 0 : i->second.size();
}

int LabeledSetVector::count () const
{
  int mycount = 0;
  for ( const_iterator i = begin() ; i != end() ; i++ )
    mycount += i->second.size();
  return mycount;
}

int LabeledSetVector::pickRandomSample ( int classno, ElementPointer & i ) const
{
  const_iterator j = find ( classno );
  if ( j == end() ) return -1;

  const vector<Vector *> & l = j->second;
  int num = rand() % l.size();

  i.first = classno;
  i.second = l[num];

  return classno;
}

void LabeledSetVector::clear ()
{
  if ( ! selection ) {
    for ( Permutation::const_iterator i  = insertOrder.begin();
          i != insertOrder.end();
          i++ )
    {
      const NICE::Vector *s = i->second;
      delete s;
    }
    insertOrder.clear();
  }

  std::map< int, vector<Vector *> >::clear();
}

void LabeledSetVector::add ( int classno, const NICE::Vector & x )
{
  if ( selection ) {
    fprintf ( stderr, "Add operation not available for selections !\n" );
    exit ( -1 );
  }

  iterator i = find ( classno );
  if ( i == end() ) {
    operator[] ( classno ) = vector<Vector *>();
    i = find ( classno );
  }
  NICE::Vector *xp = new Vector ( x );

  i->second.push_back ( xp );
  insertOrder.push_back ( ElementPointer ( classno, xp ) );
}

void LabeledSetVector::getPermutation ( Permutation & permutation ) const
{
  permutation = Permutation ( insertOrder );
}

void LabeledSetVector::add_reference ( int classno, NICE::Vector *pointer )
{
  iterator i = find ( classno );
  if ( i == end() ) {
    operator[] ( classno ) = vector<Vector *>();
    i = find ( classno );
  }
  i->second.push_back ( pointer );
  insertOrder.push_back ( ElementPointer ( classno, pointer ) );
}

void LabeledSetVector::getClasses ( std::vector<int> & classes ) const
{
  for ( const_iterator i = begin(); i != end(); i++ )
    classes.push_back ( i->first );
}

void LabeledSetVector::printInformation () const
{
  for ( const_iterator i = begin(); i != end(); i++ )
  {
    cerr << "class " << i->first << ": " << i->second.size() << endl;
  }
}

int LabeledSetVector::getMaxClassno() const
{
  int maxclassno = 0;

  for ( const_iterator i = begin(); i != end(); i++ )
    if ( i->first > maxclassno )
      maxclassno = i->first;

  return maxclassno;
}

void LabeledSetVector::getFlatRepresentation ( VVector & vecSet, NICE::Vector & vecSetLabels ) const
{
  int k = 0;
  vecSetLabels.resize ( count() );
  for ( LabeledSetVector::const_iterator iLOOP_ALL = begin() ; iLOOP_ALL != end() ; iLOOP_ALL++ )
    for ( vector<NICE::Vector *>::const_iterator jLOOP_ALL = iLOOP_ALL->second.begin();
          jLOOP_ALL != iLOOP_ALL->second.end();
          jLOOP_ALL++, k++ )
    {
      const NICE::Vector & ( x ) = * ( *jLOOP_ALL );
      vecSet.push_back ( x );
      vecSetLabels[k] = iLOOP_ALL->first;
    }

}

#endif