#include "vislearning/classifier/fpclassifier/logisticregression/FPCSMLR.h"
#include "vislearning/cbaselib/FeaturePool.h"

#include "core/image/ImageT.h"
//#include "core/imagedisplay/ImageDisplay.h"

#include <iostream>

using namespace OBJREC;
using namespace std;
using namespace NICE;

FPCSMLR::FPCSMLR ()
{
  inpic = false;
}

FPCSMLR::FPCSMLR ( const Config *_conf, string section ) : conf ( _conf )
{
  confsection = section;
  inpic = conf->gB ( section, "inpic", false );
}

FPCSMLR::~FPCSMLR()
{
  //clean up
}

ClassificationResult FPCSMLR::classify ( Example & pce )
{
  FullVector overall_distribution ( maxClassNo);

  double maxp = -numeric_limits<double>::max();
  int classno = 0;

  double sum  = 0.0;

  for ( int i = 0; i < maxClassNo; i++ )
  {
    overall_distribution[i] = classifiers[i].classify ( pce );

    sum += overall_distribution[i];

    if ( maxp < overall_distribution[i] )
    {
      classno = i;
      maxp = overall_distribution[i];
    }
  }
  for ( int i = 0; i < maxClassNo; i++ )
  {
    overall_distribution[i] /= sum;
  }

  return ClassificationResult ( classno, overall_distribution );
}

void FPCSMLR::train ( FeaturePool & _fp, Examples & examples )
{

  cout << "start train" << endl;
  fp = FeaturePool ( _fp );

  // Anzahl von Merkmalen
  int fanz = examples.size();

  maxClassNo = -1;
  for ( int i = 0; i < fanz; i++ )
  {
    maxClassNo = std::max ( maxClassNo, examples[i].first );
  }
  maxClassNo++;

  assert ( fanz >= maxClassNo );

  classifiers.resize ( maxClassNo );
  for ( int i = 0; i < maxClassNo; i++ )
  {
    cout << "classifier no " << i << " training starts" << endl;
    classifiers[i] = SLR ( conf, confsection );

    if ( inpic )
    {

      vector<bool> classinpic;

      for ( int j = 0; j < ( int ) examples.size(); j++ )
      {
        if ( examples[j].first == i )
        {
          if ( examples[j].second.position < ( int ) classinpic.size() )
            classinpic[examples[j].second.position] = true;
          else if ( examples[j].second.position == ( int ) classinpic.size() )
            classinpic.push_back ( true );
          else
          {
            while ( examples[j].second.position > ( int ) classinpic.size() )
            {
              classinpic.push_back ( false );
            }
            classinpic.push_back ( true );
          }
        }
      }

      Examples ex2;

      for ( int j = 0; j < ( int ) examples.size(); j++ )
      {
        if ( examples[j].second.position >= ( int ) classinpic.size() )
          continue;
        if ( classinpic[examples[j].second.position] )
        {
          Example e;
          e.svec = examples[j].second.svec;
          e.vec = examples[j].second.vec;
          ex2.push_back ( pair<int, Example> ( examples[j].first, e ) );
        }
      }
      cout << "examples for class " << i << ": " << ex2.size() << endl;

      if ( ex2.size() <= 2 )
        continue;

      classifiers[i].train ( _fp, ex2, i );

      for ( int j = 0; j < ( int ) ex2.size(); j++ )
      {
        ex2[j].second.svec = NULL;
      }
    }
    else
    {
      classifiers[i].train ( _fp, examples, i );
    }
  }

  cout << "end train" << endl;
}


void FPCSMLR::restore ( istream & is, int format )
{
  is >> maxClassNo;
  classifiers.resize ( maxClassNo );
  for ( int i = 0; i < maxClassNo; i++ )
  {
    classifiers[i].restore ( is, format );
  }
}

void FPCSMLR::store ( ostream & os, int format ) const
{

  if ( format != -9999 ) os << maxClassNo;

  for ( int i = 0; i < maxClassNo; i++ )
  {
    classifiers[i].store ( os, format );
  }
}

void FPCSMLR::clear ()
{
//TODO: einbauen
}

FeaturePoolClassifier *FPCSMLR::clone () const
{
  //TODO: wenn alle Variablen der Klasse bekannt sind, dann übergebe diese mit
  FPCSMLR *o = new FPCSMLR ( conf, confsection );

  o->maxClassNo = maxClassNo;

  return o;
}

void FPCSMLR::setComplexity ( int size )
{
  cerr << "FPCSMLR: no complexity to set" << endl;
}