#ifdef NICE_USELIB_OPENMP
#include <omp.h>
#endif

#include <time.h>

#include "vislearning/classifier/fpclassifier/logisticregression/SLR.h"
#include "vislearning/cbaselib/FeaturePool.h"

#include "core/image/ImageT.h"
//#include "core/imagedisplay/ImageDisplay.h"

#include <iostream>

#define SLRDEBUG
#define featnorm

using namespace OBJREC;
using namespace std;
using namespace NICE;

SLR::SLR ()
{
}

SLR::SLR ( const Config *_conf, string section ) : conf ( _conf )
{
  maxiter = conf->gI ( section, "maxiter", 20000 );
  resamp_decay = conf->gD ( section, "resamp_decay", 0.5 );
  convergence_tol = conf->gD ( section, "convergence_tol", 1e-7 );
  min_resamp =  conf->gD ( section, "min_resamp", 0.001 );
  lambda = conf->gD ( section, "lambda", 0.1 );
  samplesperclass = conf->gD ( section, "samplesperclass", 200.0 );
  weight.setDim ( 0 );
  fdim = 0;
}

SLR::~SLR()
{
  //clean up
}

double SLR::classify ( Example & pce )
{
  double result;
  SparseVector *svec;

  if ( weight.getDim() == 0 )
    return 0.0;

  bool newvec = false;

  if ( pce.svec != NULL )
  {
    svec = pce.svec;
  }
  else
  {
    Vector x;

    x = * ( pce.vec );

#ifdef featnorm
    for ( int m = 0; m < ( int ) x.size(); m++ )
    {
      x[m] = ( x[m] - minval[m] ) / ( maxval[m] - minval[m] );
    }
#endif
    svec = new SparseVector ( x );

    svec->setDim ( x.size() );

    newvec = true;

  }

  if ( weight.size() == 0 )
    result = 0.0;
  else
  {
    result = 1.0 / ( 1.0 + exp ( -svec->innerProduct ( weight ) ) );
  }

  if ( newvec )
    delete svec;

  return result;
}

void SLR::train ( FeaturePool & _fp, Examples & examples, int classno )
{

  cout << "start train" << endl;
  fp = FeaturePool ( _fp );

  // Anzahl von Merkmalen
  int fanz = examples.size();
  assert ( fanz >= 2 );

  // Merkmalsdimension bestimmen
  Vector x;
  fp.calcFeatureVector ( examples[0].second, x );
  fdim = x.size();
  assert ( fdim > 0 );

  stepwise_regression ( examples, classno );
  cout << "end train" << endl;
}

void SLR::restore ( istream & is, int format )
{
  weight.restore ( is, format );
  fdim = ( int ) weight.getDim();
  maxval.resize ( fdim );
  minval.resize ( fdim );
  for ( int i = 0; i < fdim; i++ )
  {
    is >> maxval[i];
    is >> minval[i];
  }
}

void SLR::store ( ostream & os, int format ) const

{
  if ( format != -9999 ) {
    weight.store ( os, format );
    for ( int i = 0; i < fdim; i++ )
    {
      os << maxval[i] << " " << minval[i] << endl;
    }
  } else {
    weight.store ( os, format );
  }
}

void SLR::clear ()
{
//TODO: einbauen
}

int SLR::stepwise_regression ( Examples &x, int classno )
{
  // initialize randomization
  srand ( time ( NULL ) );
  //srand (1);

  cout << "start regression" << endl;

  // get the number of features
  int fnum = x.size();

  // create Datamatrix
  GMSparseVectorMatrix X;

  GMSparseVectorMatrix Y;

  vector<int> count ( 2, 0 );

  maxval.resize ( fdim );
  minval.resize ( fdim );

  if ( x[0].second.svec == NULL ) //input normal vectors
  {
    Vector feat;

    for ( int i = 0; i < fdim; i++ )
    {
      maxval[i] = numeric_limits<double>::min();
      minval[i] = numeric_limits<double>::max();
    }

    for ( int i = 0; i < fnum; i++ )
    {
      int pos = 0;
      if ( x[i].first != classno )
        pos = 1;

      ++count[pos];

      fp.calcFeatureVector ( x[i].second, feat );
#ifdef featnorm
      for ( int m = 0; m < ( int ) feat.size(); m++ )
      {
        minval[m] = std::min ( minval[m], feat[m] );
        maxval[m] = std::max ( maxval[m], feat[m] );
      }
#endif
      fdim = feat.size();
    }
  }
  else //input Sparse Vectors
  {
    for ( int i = 0; i < fnum; i++ )
    {
      int pos = 0;
      if ( x[i].first != classno )
        pos = 1;

      ++count[pos];
    }
    fdim = x[0].second.svec->getDim();
  }

  if ( count[0] == 0 || count[1] == 0 )
  {
    cerr << "not enought samples " << count[0] << " : " << count[1] << endl;
    weight.setDim ( 0 );
    return -1;
  }

  double samples = std::min ( count[0], count[1] );
  samples = std::min ( samples, samplesperclass );
  //samples = std::max(samples, 200);
  //samples = samplesperclass;

  vector<double> rands ( 2, 0.0 );
  vector<double> allsizes ( 2, 0.0 );
  for ( int i = 0; i < 2; i++ )
  {
    rands[i] = 1.0;        //use all samples (default if samples>count[i])
    if ( samples > 0 && samples < 1 ) rands[i] = samples;    //use relative number of samples wrt. class size
    if ( samples > 1 && samples <= count[i] ) rands[i] = samples / ( double ) count[i]; //use (approximately) fixed absolute number of samples for each class
    allsizes[i] = count[i];
    count[i] = 0;
  }

  for ( int i = 0; i < fnum; i++ )
  {
    int pos = 0;
    if ( x[i].first != classno )
      pos = 1;


    double r = ( double ) rand() / ( double ) RAND_MAX;

    if ( r > rands[pos] )
      continue;
    ++count[pos];

    if ( x[0].second.svec == NULL )
    {
      Vector feat;
      fp.calcFeatureVector ( x[i].second, feat );
#ifdef featnorm
      for ( int m = 0; m < ( int ) feat.size(); m++ )
      {
        feat[m] = ( feat[m] - minval[m] ) / ( maxval[m] - minval[m] );
      }
#endif

      X.addRow ( feat );
    }
    else
    {
      X.addRow ( x[i].second.svec );
    }
    SparseVector *v = new SparseVector ( 2 );
    ( *v ) [pos] = 1.0;
    Y.addRow ( v );
  }
  Y.setDel();

  if ( x[0].second.svec == NULL )
    X.setDel();

  for ( int i = 0; i < 2; i++ )
  {
    cerr << "Examples for class " << i << ": " << count[i] << " out of " << allsizes[i] << " with p = " << rands[i] << endl;
  }

#undef NORMALIZATION
#ifdef NORMALIZATION
  GMSparseVectorMatrix Xred;
  Xred.resize ( X.rows(), X.cols() );

  for ( int r = 0; r < ( int ) Xred.rows(); r++ )
  {
    for ( int c = 0; c < ( int ) Xred.cols(); c++ )
    {
      double tmp = X[r].get ( c );

      if ( Y[r].get ( 0 ) == 1 )
        tmp *= count[0] / fnum;
      else
        tmp *= count[1] / fnum;

      if ( fabs ( tmp ) > 10e-7 )
        Xred[r][c] = tmp;
    }
  }
#endif

  fnum = X.rows();

  GMSparseVectorMatrix xY;
#ifdef NORMALIZATION
  Xred.mult ( Y, xY, true );
#else
  X.mult ( Y, xY, true );
#endif

  weight.setDim ( fdim );

  // for faster Computing Xw = X*w
  GMSparseVectorMatrix Xw;
  X.mult ( weight, Xw, false, true );
  SparseVector S ( fnum );

  for ( int r = 0; r < fnum; r++ )
  {
    S[r] = 0.0;
    for ( int c = 0; c < 2; c++ )
    {
      S[r] += exp ( Xw[r].get ( c ) );
    }
  }

  // for faster computing ac[i] = (maxClassNo-1)/(2*maxClassNo) * Sum_j (x_i (j))^2

  Vector ac ( fdim );
  Vector lm_2_ac ( fdim );

  for ( int f = 0; f < fdim; f++ )
  {
    ac[f] = 0.0;
    for ( int a = 0; a < fnum; a++ )
    {
      ac[f] += X[a].get ( f ) * X[a].get ( f );
    }
    ac[f] *= 0.25;
    lm_2_ac[f] = ( lambda / 2.0 ) / ac[f];
  }

  // initialize the iterative optimization
  double incr = numeric_limits<double>::max();
  long non_zero = 0;
  long wasted_basis = 0;
  long needed_basis = 0;

  // prob of resample each weight
  vector<double> p_resamp;
  p_resamp.resize ( fdim );

  // loop over cycles
  long cycle = 0;

  for ( cycle = 0; cycle < maxiter; cycle++ )
  {
#ifdef SLRDEBUG
    cerr << "iteration: " << cycle << " of " << maxiter << endl;
#endif
    // zero out the diffs for assessing change
    double sum2_w_diff = 0.0;
    double sum2_w_old = 0.0;
    wasted_basis = 0;
    if ( cycle == 1 )
      needed_basis = 0;

    // update each weight
//#pragma omp parallel for
    for ( int basis = 0; basis < fdim; basis++ ) // über alle Dimensionen
    {
      int changed = 0;
      // get the starting weight
      double w_old = weight.get ( basis );

      // set the p_resamp if it's the first cycle
      if ( cycle == 0 )
      {
        p_resamp[basis] = 1.0;
      }

      // see if we're gonna update
      double rval = ( double ) rand() / ( double ) RAND_MAX;

      if ( ( w_old != 0 ) || ( rval < p_resamp[basis] ) )
      {
        // calc the probability
        double XdotP = 0.0;
        for ( int i = 0; i < fnum; i++ )
        {
#ifdef NORMALIZATION
          double e = Xred[i].get ( basis ) * exp ( Xw[i].get ( 0 ) ) / S[i];
#else
          double e = X[i].get ( basis ) * exp ( Xw[i].get ( 0 ) ) / S[i];
#endif
		  if ( NICE::isFinite ( e ) )
            XdotP += e;
#ifdef SLRDEBUG
          else
            throw "numerical problems";
#endif
        }

        // get the gradient
        double grad = xY[basis].get ( 0 ) - XdotP;
        // set the new weight
        double w_new = w_old + grad / ac[basis];

        // test that we're within bounds
        if ( w_new > lm_2_ac[basis] )
        {
          // more towards bounds, but keep it
          w_new -= lm_2_ac[basis];
          changed = 1;

          // umark from being zero if necessary
          if ( w_old == 0.0 )
          {
            non_zero += 1;

            // reset the p_resample
            p_resamp[basis] = 1.0;

            // we needed the basis
            needed_basis += 1;
          }
        }
        else if ( w_new < -lm_2_ac[basis] )
        {
          // more towards bounds, but keep it
          w_new += lm_2_ac[basis];
          changed = 1;

          // umark from being zero if necessary
          if ( w_old == 0.0 )
          {
            non_zero += 1;

            // reset the p_resample
            p_resamp[basis] = 1.0;

            // we needed the basis
            needed_basis += 1;
          }
        }
        else
        {
          // gonna zero it out
          w_new = 0.0;

          // decrease the p_resamp
          p_resamp[basis] -= ( p_resamp[basis] - min_resamp ) * resamp_decay;

          // set the number of non-zero
          if ( w_old == 0.0 )
          {
            // we didn't change
            changed = 0;

            // and wasted a basis
            wasted_basis += 1;
          }
          else
          {
            // we changed
            changed = 1;

            // must update num non_zero
            non_zero -= 1;
          }
        }
        // process changes if necessary
        if ( changed == 1 )
        {
          // update the expected values
          double w_diff = w_new - w_old;
          for ( int i = 0; i < fnum; i++ )
          {
            double E_old = exp ( Xw[i].get ( 0 ) );
            double val = X[i].get ( basis ) * w_diff;
            if ( Xw[i].get ( 0 ) == 0.0 )
            {
              if ( fabs ( val ) > 10e-7 )
                Xw[i][0] = val;
            }
            else
              Xw[i][0] += X[i].get ( basis ) * w_diff;
            double E_new_m = exp ( Xw[i].get ( 0 ) );

            S[i] += E_new_m - E_old;
          }

          // update the weight
          if ( w_new == 0.0 )
          {
            if ( weight.get ( basis ) != 0.0 )
            {
              weight.erase ( basis );
            }
          }
          else
            weight[basis] = w_new;
          // keep track of the sqrt sum squared diffs
//#pragma omp critical
          sum2_w_diff += w_diff * w_diff;
        }
        // no matter what we keep track of the old
//#pragma omp critical
        sum2_w_old += w_old * w_old;
      }
    }

    // finished a cycle, assess convergence
    incr = sqrt ( sum2_w_diff ) / ( sqrt ( sum2_w_old ) + numeric_limits<double>::epsilon() );
#ifdef SLRDEBUG
    cout << "sum2_w_diff: " << sum2_w_diff << " sum2_w_old " << sum2_w_old << endl;
    cout << "convcrit: " << incr << " tol: " << convergence_tol << endl;
    //cout << "sum2_w_wold = " << sum2_w_old << " sum2_w_diff = " << sum2_w_diff << endl;
#endif
    if ( incr < convergence_tol )
    {
      // we converged!!!
      break;
    }
  }

  // finished updating weights
  // assess convergence
  cerr << "end regression after " << cycle << " steps"  << endl;
  return cycle;
}