/** 
* @file ImageNetData.cpp
* @brief wrapper class for matlab IO with ImageNet data
* @author Erik Rodner
* @date 02/03/2012

*/

#ifdef NICE_USELIB_MATIO

#include <iostream>
#include <vector>

#include <core/basics/Exception.h>
#include <core/vector/VectorT.h>

#include "ImageNetData.h"

using namespace NICE;
using namespace std;


ImageNetData::ImageNetData( const string & imageNetRoot )
{
  this->imageNetRoot = imageNetRoot;
}

ImageNetData::~ImageNetData()
{
}

void ImageNetData::getBatchData ( sparse_t & data, Vector & y, const string & fileTag, const string & variableTag )
{
  string filename = imageNetRoot + "/demo." + fileTag + ".mat";
  string vn_data = variableTag + "_instance_matrix";
  string vn_y = variableTag + "_label_vector";
  MatFileIO matfileIO (filename,MAT_ACC_RDONLY);
  matfileIO.getSparseVariableViaName(data,vn_data);
  matfileIO.getVectorViaName(y,vn_y);
}

void ImageNetData::preloadData ( const string & fileTag, const string & variableTag )
{
  sparse_t m_XPreload;
  getBatchData ( m_XPreload, yPreload, fileTag, variableTag ); 

  XPreload.resize ( yPreload.size() );
  cerr << "ImageNetData: converting data ... " << yPreload.size() << " examples" << endl;
  for ( int i = 0; i < m_XPreload.njc-1; i++ ) //walk over dimensions
	{
		for ( int j = m_XPreload.jc[i]; j < m_XPreload.jc[i+1] && j < m_XPreload.ndata; j++ )
    {
      int exampleIndex = m_XPreload.ir[ j];
      if ( exampleIndex < 0 || exampleIndex >= (int)XPreload.size() )
        fthrow(Exception, "Label and data file seem to mismatch according the sizes: " << XPreload.size() << " vs. " << exampleIndex);
      XPreload[exampleIndex].insert ( pair<int, double> ( i, ((double *)m_XPreload.data)[j] ) );
    }
  }
  cerr << "ImageNetData: data conversion finished." << endl;
}

void ImageNetData::normalizeData ( const string & normTag ) 
{
  if ( normTag.compare("L1") == 0 )
  {
    for ( std::vector< SparseVector >::iterator it = XPreload.begin(); it != XPreload.end(); it++ )
    {
      it->normalize();
    }
    return;
  } 
  
  if ( normTag.compare("L2") == 0 )
  {
    double L2norm(0.0);
    NICE::SparseVector tmpVec;
    for ( std::vector< SparseVector >::iterator it = XPreload.begin(); it != XPreload.end(); it++ )
    {
      tmpVec = *it;
      tmpVec.multiply(*it);
      L2norm = tmpVec.sum();
      it->divide(L2norm);
    }
    return;
  }  
  
  cerr << "ImageNetData::normalizeData: invalid normTag... data was not normalized" << endl;
  
}

void ImageNetData::loadDataAsLabeledSetVector( OBJREC::LabeledSetVector & lsVector, const std::string & fileTag, const std::string & variableTag )
{
  sparse_t m_XPreload;
  
  //load raw data
  getBatchData ( m_XPreload, yPreload, fileTag, variableTag ); 
  
  //tmp storage 
  std::vector<NICE::Vector> dataTmp;
  dataTmp.resize(yPreload.size());
  
  
  //initialize every entries with zero
  NICE::Vector vZero (yPreload.size(), 0.0);
  for (uint i = 0; i < yPreload.size(); i++)
  {
    dataTmp[i] = vZero;
  }
 
  //set non-zero entries according to the stored values
  std::cerr << "ImageNetData: converting data ... " << yPreload.size() << " examples" << std::endl;
  for ( int i = 0; i < m_XPreload.njc-1; i++ ) //walk over dimensions
  {
    for ( int j = m_XPreload.jc[i]; j < m_XPreload.jc[i+1] && j < m_XPreload.ndata; j++ ) //and over every non-zero entry in this dimension
    {
      //what is the original index?
      int exampleIndex = m_XPreload.ir[ j];
      if ( exampleIndex < 0 || exampleIndex >= (int)yPreload.size() )
        fthrow(Exception, "Label and data file seem to mismatch according the sizes: " << yPreload.size() << " vs. " << exampleIndex);      
      
      //insert at the original index and the corresponding dimension
      dataTmp[exampleIndex][i] =  ((double *)m_XPreload.data)[j];
    }
  }

  std::cerr << "ImageNetData: data conversion finished." << std::endl;
  
  lsVector.clear();
  for (uint i = 0; i < yPreload.size(); i++)
  {
        lsVector.add( yPreload[i], dataTmp[i] );
  }
}

const SparseVector & ImageNetData::getPreloadedExample ( int index ) const
{
  if ( index >= (int)XPreload.size() || index < 0 )
    fthrow(Exception, "Invalid index!");
  return XPreload[index];
}

double ImageNetData::getPreloadedLabel ( int index ) const
{
  if ( index < 0 || index >= (int)yPreload.size() )
    fthrow(Exception, "Invalid index!");
  return yPreload[index];
}

int ImageNetData::getNumPreloadedExamples () const
{
  return yPreload.size();
}

void ImageNetData::loadExternalLabels ( const string & fn, int n )
{
  if ( n <= 0  && yPreload.size() == 0 ) {
    fthrow(Exception, "Please initialize with preloadData() first, or use the second optional argument to give the number of examples.");
  }
  if ( n >= 0 )
    yPreload.resize( n );

  ifstream ifs ( fn.c_str(), ios::in );
  if ( ! ifs.good() )
    fthrow(Exception, "Unable to read " << fn );

  int value;
  int i = 0;
  while ( (i < yPreload.size()) && (ifs >> value) ) 
    yPreload[i++] = value;

  ifs.close();

  if ( (XPreload.size() > 0) && (yPreload.size() != XPreload.size()) )
    fthrow(Exception, "Size of the label vector and the size of the data structure do not match.");
}

#endif