/**
* @file testClassifier.cpp
* @brief main program for classifier evaluation
* @author Erik Rodner
* @date 2007-10-12
*/

#include <fstream>
#include <iostream>

#include <vislearning/cbaselib/MultiDataset.h>
#include "vislearning/classifier/genericClassifierSelection.h"
#include <vislearning/cbaselib/ClassificationResults.h>
#include <vislearning/cbaselib/MutualInformation.h>

#include "vislearning/classifier/classifierbase/FeaturePoolClassifier.h"
#include <vislearning/classifier/fpclassifier/randomforest/FPCRandomForestTransfer.h>
#include <vislearning/classifier/classifierinterfaces/VCFeaturePool.h>

#include "core/basics/Config.h"
#include <vislearning/baselib/Preprocess.h>
#include <core/basics/StringTools.h>

#undef DEBUG

using namespace OBJREC;

using namespace NICE;

using namespace std;

void binarizeVector( NICE::Vector & xout, const NICE::Vector & x, const NICE::Vector & thresholds )
{
  xout.resize( x.size() );

  for ( size_t i = 0 ; i < x.size() ; i++ )
    if ( fabs( x[i] ) > thresholds[i] )
      xout[i] = 1.0;
    else
      xout[i] = 0.0;
}

void binarizeSet( LabeledSetVector & dst, const LabeledSetVector & src, const NICE::Vector & thresholds )
{
  LOOP_ALL( src )
  {
    EACH( classno, x );
    NICE::Vector dstv;
    binarizeVector( dstv, x, thresholds );
    dst.add( classno, dstv );
  }
}

int main( int argc, char **argv )
{
  fprintf( stderr, "testClassifier: init\n" );

  std::set_terminate( __gnu_cxx::__verbose_terminate_handler );

  Config conf( argc, argv );

  string wekafile = conf.gS( "main", "weka", "" );
  string trainfn = conf.gS( "main", "train", "train.vec" );
  string testfn = conf.gS( "main", "test", "test.vec" );
  int format = conf.gI( "main", "format", 0 );
  bool binarize = conf.gB( "main", "binarize", false );
  int wekaclass = conf.gI( "main", "wekaclass", 1 );
  string classifier_cache = conf.gS( "main", "classifiercache", "" );
  string classifier_cache_in = conf.gS( "main", "classifierin", "" );
  int numRuns = conf.gI( "main", "runs", 1 );
  string writeImgNet = conf.gS( "main", "imgnet", "" );

  // classno:text,classno:text,...
  string classes = conf.gS( "main", "classes", "" );
  int classesnb = conf.gI( "main", "classes", 0 );
  string classesconf = conf.gS( "main", "classesconf", "" );

  fprintf( stderr, "testClassifier: reading config\n" );
  Preprocess::Init( &conf );

  fprintf( stderr, "testClassifier: reading multi dataset\n" );
  int testMaxClassNo;
  int trainMaxClassNo;

  ClassNames *classNames;

  if ( classes.size() == 0 && classesnb != 0 )
  {
    classNames = new ClassNames();

    for ( int classno = 0 ; classno < classesnb ; classno++ )
    {
      classNames->addClass( classno, StringTools::convertToString<int> ( classno ), StringTools::convertToString<int> ( classno ) );
    }

    trainMaxClassNo = classNames->getMaxClassno();

    testMaxClassNo = trainMaxClassNo;
  }
  else
    if ( classes.size() > 0 )
    {
      classNames = new ClassNames();

      vector<string> classes_sub;
      StringTools::split( string( classes ), ',', classes_sub );

      for ( vector<string>::const_iterator i = classes_sub.begin();
            i != classes_sub.end(); i++ )
      {
        vector<string> desc;
        StringTools::split( *i, ':', desc );

        if ( desc.size() != 2 )
          break;

        int classno = StringTools::convert<int> ( desc[0] );

        classNames->addClass( classno, desc[1], desc[1] );
      }

      trainMaxClassNo = classNames->getMaxClassno();

      testMaxClassNo = trainMaxClassNo;

      classNames->store( cout );
    }
    else if ( classesconf.size() > 0 ) {
      classNames = new ClassNames();
      Config cConf( classesconf );
      classNames->readFromConfig( cConf, "*" );
      trainMaxClassNo = classNames->getMaxClassno();
      testMaxClassNo = trainMaxClassNo;
    }
    else
    {
      MultiDataset md( &conf );
      classNames = new ClassNames( md.getClassNames( "train" ), "*" );
      testMaxClassNo = md.getClassNames( "test" ).getMaxClassno();
      trainMaxClassNo = md.getClassNames( "train" ).getMaxClassno();
    }

  LabeledSetVector train;

  if ( classifier_cache_in.size() <= 0 )
  {
    fprintf( stderr, "testClassifier: Reading training dataset from %s\n", trainfn.c_str() );
    train.read( trainfn, format );
    train.printInformation();
  } else {
    fprintf( stderr, "testClassifier: skipping training set %s\n", trainfn.c_str() );
  }

  LabeledSetVector test;

  fprintf( stderr, "testClassifier: Reading test dataset from %s\n", testfn.c_str() );
  test.read( testfn, format );

  ClassificationResults cresults;

  ofstream outinet;

  if ( writeImgNet.length() > 0 )
  {
    outinet.open( writeImgNet.c_str() );
  }

  for ( int runs = 0 ; runs < numRuns ; runs++ ) {
    VecClassifier *vec_classifier = NULL;

    if ( conf.gS( "main", "classifier" ) == "random_forest_transfer" )
    {
      FeaturePoolClassifier *fpc = new FPCRandomForestTransfer( &conf, classNames );
      vec_classifier = new VCFeaturePool( &conf, fpc );
    } else {
      string classifierselection = conf.gS("main","classifier");
      vec_classifier = GenericClassifierSelection::selectVecClassifier( &conf, classifierselection );
    }

    NICE::Vector thresholds;

    if ( classifier_cache_in.size() <= 0 )
    {
      if ( binarize ) {
        LabeledSetVector trainbin;
        NICE::Vector mis;
        MutualInformation mi;
        fprintf( stderr, "testClassifier: computing mutual information\n" );
        mi.computeThresholdsOverall( train, thresholds, mis );
        fprintf( stderr, "testClassifier: done!\n" );
        binarizeSet( trainbin, train, thresholds );
        vec_classifier->teach( trainbin );
      } else {

        vec_classifier->teach( train );

      }

      vec_classifier->finishTeaching();

      if ( classifier_cache.size() > 0 )
        vec_classifier->save( classifier_cache );
    } else {
      vec_classifier->setMaxClassNo( classNames->getMaxClassno() );
      vec_classifier->read( classifier_cache_in );
    }

    ProgressBar pb( "Classification" );

    pb.show();

    std::vector<int> count( testMaxClassNo + 1, 0 );

    std::vector<int> correct( testMaxClassNo + 1, 0 );

    MatrixT<int> confusionMatrix( testMaxClassNo + 1, trainMaxClassNo + 1, 0 );

    int n = test.count();
    LOOP_ALL( test )
    {
      EACH( classno, v );
      pb.update( n );
#ifdef DEBUG
      fprintf( stderr, "\tclassification\n" );
#endif
      ClassificationResult r;

      if ( binarize )
      {
        NICE::Vector vout;
        binarizeVector( vout, v, thresholds );
        r = vec_classifier->classify( vout );
      } else {
        r = vec_classifier->classify( v );
      }

      r.classno_groundtruth = classno;

      r.classname = classNames->text( r.classno );

#ifdef DEBUG

      if ( r.classno == classno )
        fprintf( stderr, "+ classification %d (\"%s\") <-> %d (\"%s\") score=%f\n", classno,
                 classNames->text( classno ).c_str(), r.classno, r.classname.c_str(), r.scores[r.classno] );
      else
        fprintf( stderr, "- classification %d (\"%s\") <-> %d (\"%s\") score=%f\n", classno,
                 classNames->text( classno ).c_str(), r.classno, r.classname.c_str(), r.scores[r.classno] );

      r.scores.store( cerr );

#endif

      if ( writeImgNet.length() > 0 )
      {
        for ( int z = 1; z < r.scores.size() - 1; z++ )
        {
          outinet << r.scores[z] << " ";
        }

        outinet << r.scores[r.scores.size()-1] << endl;
      }

      if ( r.classno >= 0 )
      {
        if ( classno == r.classno ) correct[classno]++;

        count[classno]++;

        if ( r.ok() ) {
          confusionMatrix( classno, r.classno )++;
        }

        cresults.push_back( r );
      }
    }

    pb.hide();

    if ( wekafile.size() > 0 )
    {
      string wekafile_s = wekafile;

      if ( numRuns > 1 )
        wekafile_s = wekafile_s + "." + StringTools::convertToString<int>( runs ) + ".txt";

      cresults.writeWEKA( wekafile_s, wekaclass );
    }

    int count_total = 0;

    int correct_total = 0;
    int classes_tested = 0;
    double avg_recognition = 0.0;

    for ( size_t classno = 0; classno < correct.size(); classno++ )
    {
      if ( count[classno] == 0 ) {
        fprintf( stdout, "class %d not tested !!\n", ( int )classno );
      } else {
        fprintf( stdout, "classification result class %d (\"%s\") : %5.2f %%\n",
                 ( int )classno, classNames->text( classno ).c_str(), correct[classno]*100.0 / count[classno] );
        avg_recognition += correct[classno] / ( double )count[classno];
        classes_tested++;
      }

      count_total += count[classno];

      correct_total += correct[classno];
    }

    avg_recognition /= classes_tested;


    fprintf( stdout, "overall recognition rate : %-5.3f %%\n", correct_total*100.0 / count_total );
    fprintf( stdout, "average recognition rate : %-5.3f %%\n", avg_recognition*100 );
    fprintf( stdout, "total:%d misclassified:%d\n", count_total, count_total - correct_total );

    int max_count = *( max_element( count.begin(), count.end() ) );
    fprintf( stdout, "no of classes : %d\n", classNames->numClasses() );
    fprintf( stdout, "lower bound 1 : %f\n", 100.0 / ( classNames->numClasses() ) );
    fprintf( stdout, "lower bound 2 : %f\n", max_count * 100.0 / ( double ) count_total );

    cout << confusionMatrix << endl;

    delete vec_classifier;
  }

  delete classNames;

  return 0;
}