// Beispielhafter Aufruf: BUILD_x86_64/progs/testSemanticSegmentation -config <CONFIGFILE>

/**
* @file testSemanticSegmentation.cpp
* @brief test semantic segmentation routines
* @author Erik Rodner
* @date 03/20/2008
*/

#ifdef NICE_USELIB_OPENMP
#include <omp.h>
#endif

// STL includes
#include <fstream>

// nice-core includes
#include <core/basics/Config.h>
#include <core/basics/StringTools.h>
#include <core/basics/ResourceStatistics.h>

// nice-vislearning includes
#include <vislearning/baselib/ICETools.h>

// nice-semseg includes
#include <semseg/semseg/SemanticSegmentation.h>
#include <semseg/semseg/SemSegLocal.h>
#include <semseg/semseg/SemSegCsurka.h>
#include <semseg/semseg/SemSegNovelty.h>
#include <semseg/semseg/SemSegContextTree.h>





using namespace OBJREC;

using namespace NICE;

using namespace std;

void updateMatrix( const NICE::ImageT<int> & img, const NICE::ImageT<int> & gt,
                   NICE::Matrix & M, const set<int> & forbidden_classes )
{
  double subsamplex = gt.width() / ( double )img.width();
  double subsampley = gt.height() / ( double )img.height();

  for ( int y = 0 ; y < gt.height() ; y++ )
    for ( int x = 0 ; x < gt.width() ; x++ )
    {
      int xx = ( int )( x / subsamplex );
      int yy = ( int )( y / subsampley );

      if ( xx < 0 ) xx = 0;

      if ( yy < 0 ) yy = 0;

      if ( xx > img.width() - 1 ) xx = img.width() - 1;

      if ( yy > img.height() - 1 ) yy = img.height() - 1;

      int cimg = img.getPixel( xx, yy );

      int gimg = gt.getPixel( x, y );

      if ( forbidden_classes.find( gimg ) == forbidden_classes.end() )
      {
        M( gimg, cimg )++;
      }
    }
}

/**
 test semantic segmentation routines
*/
int main( int argc, char **argv )
{
  std::set_terminate( __gnu_cxx::__verbose_terminate_handler );

  Config conf( argc, argv );
  
  ResourceStatistics rs;
  
  bool show_result = conf.gB( "debug", "show_results", false );

  bool write_results = conf.gB( "debug", "write_results", false );

  bool write_results_pascal = conf.gB( "debug", "write_results_pascal", false );

  std::string resultdir = conf.gS( "debug", "resultdir", "." );

  if ( write_results )
  {
    cerr << "Writing Results to " << resultdir << endl;
  }

  MultiDataset md( &conf );

  const ClassNames & classNames = md.getClassNames( "train" );

  string method = conf.gS( "main", "method", "SSCsurka" );

  SemanticSegmentation *semseg = NULL;

  if ( method == "SSCsurka" )
  {
    semseg = new SemSegCsurka( &conf, &md );
  }
  else if ( method == "SSContext" )
  {
    semseg = new SemSegContextTree( &conf, &md );
  }
  else if( method == "SSNovelty" )
  {
    semseg = new SemSegNovelty( &conf, &md );
  }

  //SemanticSegmentation *semseg = new SemSegLocal ( &conf, &md );
  //SemanticSegmentation *semseg = new SemSegSTF ( &conf, &md );
  //SemanticSegmentation *semseg = new SemSegRegionBased(&conf, &md);

  const LabeledSet *testFiles = md["test"];

  NICE::Matrix M( classNames.getMaxClassno() + 1, classNames.getMaxClassno() + 1 );

  M.set( 0 );

  set<int> forbidden_classes;

  std::string forbidden_classes_s = conf.gS( "analysis", "forbidden_classes", "" );

  classNames.getSelection( forbidden_classes_s, forbidden_classes );

  ProgressBar pb( "Semantic Segmentation Analysis" );

  pb.show();

  int fileno = 0;

  LOOP_ALL_S( *testFiles )
  {
    EACH_INFO( classno, info );
    std::string file = info.img();

    NICE::ImageT<int> lm;
    NICE::MultiChannelImageT<double> probabilities;

    if ( info.hasLocalizationInfo() )
    {
      const LocalizationResult *l_gt = info.localization();

      lm.resize( l_gt->xsize, l_gt->ysize );
      //lm.set( 0 );
      l_gt->calcLabeledImage( lm, classNames.getBackgroundClass() );
    }

    semseg->semanticseg( file, lm, probabilities );

    fprintf( stderr, "testSemanticSegmentation: Segmentation finished !\n" );

    NICE::ImageT<int> lm_gt;

    if ( info.hasLocalizationInfo() )
    {
      const LocalizationResult *l_gt = info.localization();

      lm_gt.resize( l_gt->xsize, l_gt->ysize );
      lm_gt.set( 0 );

      fprintf( stderr, "testSemanticSegmentation: Generating Labeled NICE::Image (Ground-Truth)\n" );
      l_gt->calcLabeledImage( lm_gt, classNames.getBackgroundClass() );
    }

    std::string fname = StringTools::baseName( file, false );

    if ( write_results_pascal )
    {

      NICE::Image pascal_lm( lm.width(), lm.height() );
      int backgroundClass = classNames.getBackgroundClass();

      for ( int y = 0 ; y < lm.height(); y++ )
        for ( int x = 0 ; x < lm.width(); x++ )
        {
          int v = lm.getPixel( x, y );

          if ( v == backgroundClass )
            pascal_lm.setPixel( x, y, 255 );
          else
            pascal_lm.setPixel( x, y, 255 - v - 1 );
        }

      char filename[1024];

      char *format = ( char * )"pgm";
      sprintf( filename, "%s/%s.%s", resultdir.c_str(), fname.c_str(), format );

      pascal_lm.write( filename );
    }

    if ( show_result || write_results )
    {
      NICE::ColorImage orig( file );
      NICE::ColorImage rgb;
      NICE::ColorImage rgb_gt;

      classNames.labelToRGB( lm, rgb );

      classNames.labelToRGB( lm_gt, rgb_gt );

      if ( write_results )
      {
        std::stringstream out;       
        std::vector< std::string > myList;
        StringTools::split ( Globals::getCurrentImgFN (), '/', myList );
        out << resultdir << "/" << myList.back();
        cerr << "Writing to file " << resultdir << "/"<< myList.back() << endl;
        orig.write ( out.str() + "_orig.jpg" );
        rgb.write ( out.str() + "_result.png" );
        rgb_gt.write ( out.str() + "_groundtruth.png" );
      }

      if ( show_result )
      {
#ifndef NOVISUAL
        showImage( rgb, "Result" );
        showImage( rgb_gt, "Groundtruth" );
        showImage( orig, "Input" );
#endif
      }
    }

//#pragma omp critical
    updateMatrix( lm, lm_gt, M, forbidden_classes );

    cerr << M << endl;

    fileno++;

    pb.update( testFiles->count() );
  }

  pb.hide();

  long maxMemory;
  rs.getMaximumMemory(maxMemory);
  cerr << "Maximum memory used: " << maxMemory << " KB" << endl;
  
  double overall = 0.0;
  double sumall = 0.0;

  for ( int r = 0; r < ( int )M.rows(); r++ )
  {
    for ( int c = 0; c < ( int )M.cols(); c++ )
    {
      if ( r == c )
        overall += M( r, c );

      sumall += M( r, c );
    }
  }

  overall /= sumall;

  // normalizing M using rows

  for ( int r = 0 ; r < ( int )M.rows() ; r++ )
  {
    double sum = 0.0;

    for ( int c = 0 ; c < ( int )M.cols() ; c++ )
      sum += M( r, c );

    if ( fabs( sum ) > 1e-4 )
      for ( int c = 0 ; c < ( int )M.cols() ; c++ )
        M( r, c ) /= sum;
  }

  cerr << M << endl;

  double avg_perf = 0.0;
  int classes_trained = 0;

  for ( int r = 0 ; r < ( int )M.rows() ; r++ )
  {
    if (( classNames.existsClassno( r ) ) && ( forbidden_classes.find( r ) == forbidden_classes.end() ) )
    {
      avg_perf += M( r, r );
      double lsum = 0.0;
      for(int r2 = 0; r2 < ( int )M.rows(); r2++)
      {
        lsum += M(r,r2);
      }
      if(lsum != 0.0)
      {
        classes_trained++;
      }
    }
  }

  if ( write_results )
  {
    ofstream fout(( resultdir + "/res.txt" ).c_str(), ios::out );
    fout <<  "overall: " << overall << endl;
    fout << "Average Performance " << avg_perf / ( classes_trained ) << endl;
    fout << "Lower Bound " << 1.0  / classes_trained << endl;

    for ( int r = 0 ; r < ( int )M.rows() ; r++ )
    {
      if (( classNames.existsClassno( r ) ) && ( forbidden_classes.find( r ) == forbidden_classes.end() ) )
      {
        std::string classname = classNames.text( r );
        fout << classname.c_str() << ": " << M( r, r ) << endl;
      }
    }

    fout.close();
  }

  fprintf( stderr, "overall: %f\n", overall );

  fprintf( stderr, "Average Performance %f\n", avg_perf / ( classes_trained ) );
  //fprintf(stderr, "Lower Bound %f\n", 1.0 / classes_trained);

  for ( int r = 0 ; r < ( int )M.rows() ; r++ )
  {
    if (( classNames.existsClassno( r ) ) && ( forbidden_classes.find( r ) == forbidden_classes.end() ) )
    {
      std::string classname = classNames.text( r );
      fprintf( stderr, "%s: %f\n", classname.c_str(), M( r, r ) );
    }
  }

  delete semseg;

  return 0;
}