/**
* @file SemSegLocal.cpp
* @brief semantic segmentation using image patches only
* @author Erik Rodner
* @date 05/08/2008

*/
#include <iostream>

#include "SemSegLocal.h"
#include "vislearning/cbaselib/CachedExample.h"
#include "vislearning/classifier/fpclassifier/randomforest/FPCRandomForests.h"
#include "vislearning/features/fpfeatures/PixelPairFeature.h"

#include "SemSegTools.h"

using namespace OBJREC;

using namespace std;
using namespace NICE;


SemSegLocal::SemSegLocal ( const Config *conf,
                           const MultiDataset *md )
    : SemanticSegmentation ( conf, & ( md->getClassNames ( "train" ) ) )
{
  save_cache = conf->gB ( "FPCPixel", "save_cache", true );
  read_cache = conf->gB ( "FPCPixel", "read_cache", false );
  cache = conf->gS ( "FPCPixel", "cache", "fpc.data" );
  fpc = new FPCRandomForests ( conf, "FPCPixel" );
  fpc->setMaxClassNo ( classNames->getMaxClassno() );


  if ( read_cache ) {
    fprintf ( stderr, "LocSSimpleFP:: Reading classifier data from %s\n", cache.c_str() );
    fpc->read ( cache );
    fprintf ( stderr, "LocSSimpleFP:: successfully read\n" );
  } else {
    train ( conf, md );
  }
}

void SemSegLocal::train ( const Config *conf, const MultiDataset *md )
{
  Examples examples;
  vector<CachedExample *> imgexamples;

  SemSegTools::collectTrainingExamples (
    conf,
    "FPCPixel", // config section for grid settings
    * ( ( *md ) ["train"] ),
    *classNames,
    examples,
    imgexamples );

  assert ( examples.size() > 0 );

  FeaturePool fp;
  PixelPairFeature hf ( conf );
  hf.explode ( fp );

  fpc->train ( fp, examples );

  // clean up memory !!
  for ( vector<CachedExample *>::iterator i = imgexamples.begin();
        i != imgexamples.end();
        i++ )
    delete ( *i );

  if ( save_cache ) {
    fpc->save ( cache );
  }

  fp.destroy();
}


SemSegLocal::~SemSegLocal()
{
  if ( fpc != NULL )
    delete fpc;
}


void SemSegLocal::semanticseg ( CachedExample *ce,
                                NICE::Image & segresult,
                                NICE::MultiChannelImageT<double> & probabilities )
{
  // for speed optimization
  FPCRandomForests *fpcrf = dynamic_cast<FPCRandomForests *> ( fpc );
  int xsize, ysize;
  ce->getImageSize ( xsize, ysize );
  probabilities.reInit ( xsize, ysize, classNames->getMaxClassno() + 1 );
  segresult.resize ( xsize, ysize );

  Example pce ( ce, 0, 0 );
  long int offset = 0;
  for ( int y = 0 ; y < ysize ; y++ )
    for ( int x = 0 ; x < xsize ; x++, offset++ )
    {
      pce.x = x ;
      pce.y = y;
      ClassificationResult r = fpcrf->classify ( pce );
      segresult.setPixel ( x, y, r.classno );
      for ( int i = 0 ; i < ( int ) probabilities.channels(); i++ )
        probabilities[i](x,y) = r.scores[i];
    }
}