// Beispielhafter Aufruf: BUILD_x86_64/progs/testActiveSemanticSegmentation -config <CONFIGFILE>

/**
* @file testActiveSemanticSegmentation.cpp
* @brief test semantic segmentation routines with actively selecting regions for labeling
* @author Alexander Freytag
* @date 27-02-2013
*/

#ifdef NICE_USELIB_OPENMP
#include <omp.h>
#endif

#include "core/basics/Config.h"
#include "core/basics/StringTools.h"
#include <vislearning/baselib/ICETools.h>

#include <semseg/semseg/SemanticSegmentation.h>
#include <semseg/semseg/SemSegLocal.h>
#include <semseg/semseg/SemSegCsurka.h>
#include <semseg/semseg/SemSegNovelty.h>
#include <semseg/semseg/SemSegNoveltyBinary.h>
#include <semseg/semseg/SemSegContextTree.h>

#include "core/image/FilterT.h"

#include <core/basics/ResourceStatistics.h>

#include <fstream>

using namespace OBJREC;

using namespace NICE;

using namespace std;

void updateMatrix( const NICE::ImageT<int> & img, const NICE::ImageT<int> & gt,
                   NICE::Matrix & M, const set<int> & forbidden_classes )
{
  double subsamplex = gt.width() / ( double )img.width();
  double subsampley = gt.height() / ( double )img.height();

  for ( int y = 0 ; y < gt.height() ; y++ )
    for ( int x = 0 ; x < gt.width() ; x++ )
    {
      int xx = ( int )( x / subsamplex );
      int yy = ( int )( y / subsampley );

      if ( xx < 0 ) xx = 0;

      if ( yy < 0 ) yy = 0;

      if ( xx > img.width() - 1 ) xx = img.width() - 1;

      if ( yy > img.height() - 1 ) yy = img.height() - 1;

      int cimg = img.getPixel( xx, yy );

      int gimg = gt.getPixel( x, y );

      if ( forbidden_classes.find( gimg ) == forbidden_classes.end() )
      {
        M( gimg, cimg )++;
      }
    }
}

/**
 test semantic segmentation routines
*/
int main( int argc, char **argv )
{
  std::set_terminate( __gnu_cxx::__verbose_terminate_handler );

  Config conf( argc, argv );
  
  ResourceStatistics rs;
  
  bool show_result = conf.gB( "debug", "show_results", false );

  bool write_results = conf.gB( "debug", "write_results", false );

  bool write_results_pascal = conf.gB( "debug", "write_results_pascal", false );

  std::string resultdir = conf.gS( "debug", "resultdir", "." );
  
  //how often do we want to iterate between sem-seg and active query?
  int activeIterations = conf.gI("main", "activeIterations", 1 );
    
  if ( write_results )
  {
    cerr << "Writing Results to " << resultdir << endl;
  }

  MultiDataset md( &conf );

  const ClassNames & classNames = md.getClassNames( "train" );

  string method = conf.gS( "main", "method", "SSCsurka" );

  //currently, we only allow SemSegNovelty, because it implements addNovelExamples()
  SemanticSegmentation *semseg = NULL;
  
      Timer timer;
      timer.start();
  if ( method == "SSCsurka" )
  {
    semseg = new SemSegCsurka( &conf, &md );
  }
  else if ( method == "SSContext" )
  {
    semseg = new SemSegContextTree( &conf, &md );
  }
  else if( method == "SSNovelty" )
  {
    semseg = new SemSegNovelty( &conf, &md );
  }
  else if( method == "SSNoveltyBinary" )
  {
    semseg = new SemSegNoveltyBinary( &conf, &md );
  }  
  timer.stop();
  std::cerr << "AL time for training: " << timer.getLast() << std::endl;

  const LabeledSet *testFiles = md["test"];

  NICE::Matrix M( classNames.getMaxClassno() + 1, classNames.getMaxClassno() + 1 );

  M.set( 0 );

  std::set<int> forbidden_classes;
  std::string forbidden_classes_s = conf.gS( "analysis", "forbidden_classesTrain", "" );
  classNames.getSelection( forbidden_classes_s, forbidden_classes );
  
  std::set<int> forbidden_classesForActiveLearning;
  std::string forbidden_classesForActiveLearning_s = conf.gS( "analysis", "forbidden_classesForActiveLearning", "" );
  classNames.getSelection( forbidden_classesForActiveLearning_s, forbidden_classesForActiveLearning );
  

  for (int iterationCount = 0; iterationCount < activeIterations; iterationCount++)
  {
      //TODO shouldn't we clean the confusion matrix at the beginning of each iteration?
    
    std::cerr << "SemSeg AL Iteration: " << iterationCount << std::endl;
    semseg->setIterationCountSuffix(iterationCount);
    
//     ProgressBar pb( "Semantic Segmentation Analysis" );
// 
//     pb.show();

    int fileno = 0;

    std::cerr << "start looping over all files" << std::endl;
    LOOP_ALL_S( *testFiles )
    {
      EACH_INFO( classno, info );
      std::string file = info.img();

      NICE::ImageT<int> lm;
      NICE::MultiChannelImageT<double> probabilities;

      if ( info.hasLocalizationInfo() )
      {
        const LocalizationResult *l_gt = info.localization();

        lm.resize( l_gt->xsize, l_gt->ysize );
        //lm.set( 0 );
        l_gt->calcLabeledImage( lm, classNames.getBackgroundClass() );
      }

      semseg->semanticseg( file, lm, probabilities );

      fprintf( stderr, "testSemanticSegmentation: Segmentation finished !\n" );

      //ground truth image, needed for updating the confusion matrix
      //TODO check whether this is really needed, since we computed such a label image already within SemSegNovelty
      NICE::ImageT<int> lm_gt;

      if ( info.hasLocalizationInfo() )
      {
        const LocalizationResult *l_gt = info.localization();

        lm_gt.resize( l_gt->xsize, l_gt->ysize );
        lm_gt.set( 0 );

        fprintf( stderr, "testSemanticSegmentation: Generating Labeled NICE::Image (Ground-Truth)\n" );
        l_gt->calcLabeledImage( lm_gt, classNames.getBackgroundClass() );
      }
// // // 
// // //       std::string fname = StringTools::baseName( file, false );
// // // 
// // //       if ( write_results_pascal )
// // //       {
// // // 
// // //         NICE::Image pascal_lm( lm.width(), lm.height() );
// // //         int backgroundClass = classNames.getBackgroundClass();
// // // 
// // //         for ( int y = 0 ; y < lm.height(); y++ )
// // //           for ( int x = 0 ; x < lm.width(); x++ )
// // //           {
// // //             int v = lm.getPixel( x, y );
// // // 
// // //             if ( v == backgroundClass )
// // //               pascal_lm.setPixel( x, y, 255 );
// // //             else
// // //               pascal_lm.setPixel( x, y, 255 - v - 1 );
// // //           }
// // // 
// // //         char filename[1024];
// // // 
// // //         char *format = ( char * )"pgm";
// // //         sprintf( filename, "%s/%s.%s", resultdir.c_str(), fname.c_str(), format );
// // // 
// // //         pascal_lm.write( filename );
// // //       }
// // // 
      if ( show_result || write_results )
      {
        NICE::ColorImage orig( file );
        NICE::ColorImage rgb;
        NICE::ColorImage rgb_gt;

        classNames.labelToRGB( lm, rgb );

        classNames.labelToRGB( lm_gt, rgb_gt );

        if ( write_results )
        {
  //         char filename[1024];
  //         char *format = ( char * )"ppm";
  //         sprintf( filename, "%06d.%s", fileno, format );
  //         std::string origfilename = resultdir + "/orig_" + string( filename );
  //         cerr << "Writing to file " << origfilename << endl;
  //         orig.write( origfilename );
  //         rgb.write( resultdir + "/result_" + string( filename ) );
  //         rgb_gt.write( resultdir + "/groundtruth_" + string( filename ) );
          
          std::stringstream out;       
          std::vector< std::string > myList;
          StringTools::split ( Globals::getCurrentImgFN (), '/', myList );
          out << resultdir << "/" << myList.back();
          cerr << "Writing to file " << resultdir << "/"<< myList.back() << endl;
          
          std::string noveltyMethodString = conf.gS( "SemSegNovelty",  "noveltyMethod", "gp-variance");
          orig.write ( out.str() + "_orig.ppm" );
          rgb.write ( out.str() + "_" + noveltyMethodString + "_result_run_" + NICE::intToString(iterationCount) + ".ppm" );
          rgb_gt.write ( out.str() + "_groundtruth.ppm" );
        }

        if ( show_result )
        {
  #ifndef NOVISUAL
          showImage( rgb, "Result" );
          showImage( rgb_gt, "Groundtruth" );
          showImage( orig, "Input" );
  #endif
        }
      }

  //#pragma omp critical
      updateMatrix( lm, lm_gt, M, forbidden_classes );

      std::cerr << M << std::endl;

      fileno++;

//       pb.update( testFiles->count() );
    } //Loop over all test images

//     pb.hide();

    //**********************************************
    //                  EVALUATION 
    //   COMPUTE CONFUSION MAT AND FINAL SCORES
    //**********************************************
    timer.start();
    
    long maxMemory;
    rs.getMaximumMemory(maxMemory);
    cerr << "Maximum memory used: " << maxMemory << " KB" << endl;
    
    double overall = 0.0;
    double sumall = 0.0;

    for ( int r = 0; r < ( int )M.rows(); r++ )
    {
      for ( int c = 0; c < ( int )M.cols(); c++ )
      {
        if ( r == c )
          overall += M( r, c );

        sumall += M( r, c );
      }
    }

    overall /= sumall;

    // normalizing M using rows

    for ( int r = 0 ; r < ( int )M.rows() ; r++ )
    {
      double sum = 0.0;

      for ( int c = 0 ; c < ( int )M.cols() ; c++ )
        sum += M( r, c );

      if ( fabs( sum ) > 1e-4 )
        for ( int c = 0 ; c < ( int )M.cols() ; c++ )
          M( r, c ) /= sum;
    }

    std::cerr << M << std::endl;

    double avg_perf = 0.0;
    int classes_trained = 0;

    for ( int r = 0 ; r < ( int )M.rows() ; r++ )
    {
      if (( classNames.existsClassno( r ) ) && ( forbidden_classes.find( r ) == forbidden_classes.end() ) )
      {
        avg_perf += M( r, r );
        double lsum = 0.0;
        for(int r2 = 0; r2 < ( int )M.rows(); r2++)
        {
          lsum += M(r,r2);
        }
        if(lsum != 0.0)
        {
          classes_trained++;
        }
      }
    }

    if ( write_results )
    {
      ofstream fout(( resultdir + "/res.txt" ).c_str(), ios::out );
      fout <<  "overall: " << overall << endl;
      fout << "Average Performance " << avg_perf / ( classes_trained ) << endl;
      fout << "Lower Bound " << 1.0  / classes_trained << endl;

      for ( int r = 0 ; r < ( int )M.rows() ; r++ )
      {
        if (( classNames.existsClassno( r ) ) && ( forbidden_classes.find( r ) == forbidden_classes.end() ) )
        {
          std::string classname = classNames.text( r );
          fout << classname.c_str() << ": " << M( r, r ) << endl;
        }
      }

      fout.close();
    }

    fprintf( stderr, "overall: %f\n", overall );

    fprintf( stderr, "Average Performance %f\n", avg_perf / ( classes_trained ) );
    //fprintf(stderr, "Lower Bound %f\n", 1.0 / classes_trained);

    for ( int r = 0 ; r < ( int )M.rows() ; r++ )
    {
      if (( classNames.existsClassno( r ) ) && ( forbidden_classes.find( r ) == forbidden_classes.end() ) )
      {
        std::string classname = classNames.text( r );
        fprintf( stderr, "%s: %f\n", classname.c_str(), M( r, r ) );
      }
    }
    
    timer.stop();
    std::cout << "AL time for evaluation: " << timer.getLastAbsolute() << std::endl;
    
    //**********************************************
    //          READ QUERY SCORE IMAGES
    //   AND SELECT THE REGION TO BE LABELED
    //**********************************************
    //NOTE this is not needed anymore, since we store everything within SemSegNovelty
    //However, it is still needed if we use the NN-classifier for the feature learning approach
    
//     string alSection = "SemSegNovelty";
//     std::string noveltyMethodString = conf.gS( alSection,  "noveltyMethod", "gp-variance");
//     std::string uncertdir = conf.gS("debug", "resultdir", "result");
//     int testWSize = conf.gI(alSection, "test_window_size", 10);   
//     
//     float maxVal(0);
//     int maxValX(0);
//     int maxValY(0);
//     std::vector<ImageInfo *>::const_iterator maxValInfoIt = testFiles->begin()->second.begin();
//     
//     
//     for(LabeledSet::const_iterator outerIt = testFiles->begin() ; outerIt != testFiles->end() ; outerIt++)
//     {
//       for ( std::vector<ImageInfo *>::const_iterator imageIt = outerIt->second.begin(); imageIt != outerIt->second.end(); imageIt++ )    
//       {
//         const ImageInfo & (info) = *(*imageIt);
//         
//         std::string file = info.img();
//         
//         std::stringstream dest;
//         std::vector< std::string > list2;
//         StringTools::split ( file, '/', list2 );
//         dest << uncertdir << "/" << list2.back();      
//         
//         FloatImage noveltyImage;
//         noveltyImage.readRaw(dest.str() + "_run_" +  NICE::intToString(iterationCount) + "_" + noveltyMethodString+".rawfloat");
//         
//         int xsize ( noveltyImage.width() );
//         int ysize ( noveltyImage.height() );
//         
//         //compute the GT-image to ensure that we only query "useful" new features, i.e., not query background or similar "forbidden" stuff
//         NICE::Image lm_gt;
//         if ( (*maxValInfoIt)->hasLocalizationInfo() )
//         {
//           const LocalizationResult *l_gt = (*maxValInfoIt)->localization();
// 
//           lm_gt.resize( l_gt->xsize, l_gt->ysize );
//           lm_gt.set( 0 );
// 
//           l_gt->calcLabeledImage( lm_gt, classNames.getBackgroundClass() );
//         }                
//         
//         for ( int y = 0; y < ysize; y += testWSize )
//         {
//           for ( int x = 0; x < xsize; x += testWSize)
//           {
//             if ( (noveltyImage ( x, y ) > maxVal) && (  forbidden_classesForActiveLearning.find ( lm_gt(x, y) ) == forbidden_classesForActiveLearning.end() ) )
//             {
//               maxVal =  noveltyImage ( x, y );
//               maxValX = x;
//               maxValY = y;
//               maxValInfoIt = imageIt;
//             }
//           }
//         }
//         
//       }//iterate over inner loop
//     }//iterate over testFiles
// 
//     
//       std::cerr << "maxVal: " << maxVal << " maxValX: " << maxValX << " maxValY: " << maxValY << " maxValInfo: " << (*maxValInfoIt)->img() << std::endl;
    
    //**********************************************
    //          INCLUDE THE NEW INFORMATION
    //           AND UPDATE THE CLASSIFIER
    //**********************************************    
      
     timer.start();
     semseg->addNovelExamples(); 
     
     timer.stop();
     std::cout << "AL time for incremental update: " << timer.getLastAbsolute() << std::endl;
     //alternatively, we could call the destructor of semseg, and create it again, which does the same thing 
     // (add new features, save the classifier, re-read it after initialization)
     //BUT this would not setup the forbidden and known classes properly!!! We should fix that!
     
     const Examples * novelExamples = semseg->getNovelExamples(); 
//      std::cerr << " ==================================== " << std::endl;
//      std::cerr << "new examples to be added: " << std::endl;
//      for ( uint i = 0 ; i < novelExamples->size() ; i++ )
//      {
//         std::cerr << (*novelExamples)[i].first << " "; (*novelExamples)[i].second.store(std::cerr);
//      }
//      std::cerr << " ==================================== " << std::endl;
     
    //check which classes will be added using the features from the novel region
    std::set<int> newClassNumbers;
    newClassNumbers.clear(); //just to be sure  
    for ( uint i = 0 ; i < novelExamples->size() ; i++ )
    {
      if (newClassNumbers.find( (*novelExamples)[i].first /* classNumber*/) == newClassNumbers.end() )
      {
        newClassNumbers.insert( (*novelExamples)[i].first );
      }
    }

    //accept the new classes as valid information
    for (std::set<int>::const_iterator clNoIt = newClassNumbers.begin(); clNoIt != newClassNumbers.end(); clNoIt++)
    {
      if ( forbidden_classes.find ( *clNoIt ) != forbidden_classes.end() )
      {
        forbidden_classes.erase(*clNoIt);
      }
    }       
      
    //NOTE Below comes the old version:
    // it is not needed anymore, since we store everything within SemSegNovelty
    //However, it is still needed if we use the NN-classifier for the feature learning approach      
//     //  ----------------------------------------------------
//     //  therefore, we first recompute the features for the whole image and
//     //take the one which we desire
//       
//     //this is NOT efficient, but a nice and easy first step
//       
//     NICE::ColorImage img ( (*maxValInfoIt)->img() );
//     
//     MultiChannelImageT<double> feats;
// 
//     // extract features
//     LFColorWeijer * featExtract = new LFColorWeijer ( &conf );
//     featExtract->getFeats ( img, feats );
//     int featdim = feats.channels();
//     feats.addChannel(featdim);
// 
//     for (int c = 0; c < featdim; c++)
//     {
//       ImageT<double> tmp = feats[c];
//       ImageT<double> tmp2 = feats[c+featdim];
// 
//       NICE::FilterT<double, double, double>::gradientStrength (tmp, tmp2);
//     }
//     featdim += featdim;
// 
//     // compute integral images
//     for ( int c = 0; c < featdim; c++ )
//     {
//       feats.calcIntegral ( c );
//     }    
//     
//     //  ----------------------------------------------------
//     //now take the feature
//     NICE::Vector newFeature(featdim);
//     for ( int f = 0; f < featdim; f++ )
//     {
//       double val = feats.getIntegralValue ( maxValX - testWSize, maxValY - testWSize, maxValX + testWSize, maxValY + testWSize, f );
//       newFeature[f] = val;
//     }
//     newFeature.normalizeL1();    
//     
//     NICE::Image lm_gt;
//     // take the gt class number as well    
//     if ( (*maxValInfoIt)->hasLocalizationInfo() )
//     {
//       const LocalizationResult *l_gt = (*maxValInfoIt)->localization();
// 
//       lm_gt.resize( l_gt->xsize, l_gt->ysize );
//       lm_gt.set( 0 );
// 
//       l_gt->calcLabeledImage( lm_gt, classNames.getBackgroundClass() );
//     }
//     int classNoGT = lm_gt(maxValX, maxValY);
//     std::cerr << "class number GT: " << classNoGT << std::endl;
//     
//     
//     semseg->addNewExample(newFeature, classNoGT);
//     
//     //accept the new class as valid information
//     if ( forbidden_classes.find ( classNoGT ) != forbidden_classes.end() )
//     {
//       forbidden_classes.erase(classNoGT);
//     }    
    
    std::cerr << "iteration finished - start the next round" << std::endl;
    
  } //iterationCount

  delete semseg;

  return 0;
}