/**
 * @file SemSegConvolutionalTree.h
 * @brief Semantic Segmentation using Covolutional Trees
 * @author Sven Sickert
 * @date 10/17/2014

*/

#include <iostream>

#include "SemSegConvolutionalTree.h"
#include "SemSegTools.h"

#include "vislearning/classifier/fpclassifier/randomforest/FPCRandomForests.h"
#include "vislearning/features/fpfeatures/ConvolutionFeature.h"
#include "vislearning/baselib/cc.h"

using namespace OBJREC;

using namespace std;
using namespace NICE;

//###################### CONSTRUCTORS #########################//

SemSegConvolutionalTree::SemSegConvolutionalTree () : SemanticSegmentation ()
{
    conf = NULL;

    saveLoadData = false;
    fileLocation = "classifier.data";

    fpc = new FPCRandomForests ();
}

SemSegConvolutionalTree::SemSegConvolutionalTree (
        const Config *conf,
        const ClassNames *classNames )
    : SemanticSegmentation( conf, classNames )
{
    initFromConfig( conf );
}

//###################### DESTRUCTORS ##########################//

SemSegConvolutionalTree::~SemSegConvolutionalTree ()
{
    if ( fpc != NULL )
        delete fpc;
}

//#################### MEMBER FUNCTIONS #######################//

void SemSegConvolutionalTree::convertRGBToHSV ( CachedExample *ce ) const
{
    assert( imgHSV->channels() == 0 );

    NICE::MultiChannelImageT<int> * img = NULL;
    NICE::MultiChannelImageT<double> * imgHSV = NULL;
    img     = & ce->getIChannel( CachedExample::I_COLOR );
    imgHSV  = & ce->getDChannel( CachedExample::D_EOH );

    if ( img->channels() == 3 )
    {
        imgHSV->reInit ( img->width(), img->height(), 3 );

        for ( int y = 0; y < img->height(); y++ )
            for ( int x = 0; x < img->width(); x++ )
            {
                double h,s,v;
                double r = (double)img->get( x, y, 0);
                double g = (double)img->get( x, y, 1);
                double b = (double)img->get( x, y, 2);

                ColorConversion::ccRGBtoHSV(r, g, b, &h, &s, &v);

                imgHSV->set( x, y, h, 0);
                imgHSV->set( x, y, h, 1);
                imgHSV->set( x, y, h, 2);
            }

        // remove r,g,b (integer) channels
        img->freeData();
        img = NULL;
    }
    else
    {
        imgHSV = NULL;
    }
}


void SemSegConvolutionalTree::initFromConfig( const Config *_conf,
                                         const string &s_confSection )
{
    conf = _conf;
    saveLoadData = conf->gB ( s_confSection, "save_load_data", false );
    fileLocation = conf->gS ( s_confSection, "datafile", "classifier.data" );

    fpc = new FPCRandomForests ( _conf, "FPCRandomForests" );
    fpc->setMaxClassNo( classNames->getMaxClassno() );
}

/** training function */
void SemSegConvolutionalTree::train ( const MultiDataset *md )
{
    if ( saveLoadData && FileMgt::fileExists( fileLocation ) )
    {
        read( fileLocation );
    }
    else
    {
        Examples examples;

        // image storage
        vector<CachedExample *> imgexamples;

        // create pixel-wise training examples
        SemSegTools::collectTrainingExamples (
          conf,
          "FPCRandomForests",
          * ( ( *md ) ["train"] ),
          *classNames,
          examples,
          imgexamples );

        assert ( examples.size() > 0 );

        for ( vector<CachedExample *>::iterator cei = imgexamples.begin();
              cei != imgexamples.end(); cei++ )
            convertRGBToHSV ( *cei );

        FeaturePool fp;
        ConvolutionFeature cf ( conf );
        cf.explode( fp );

        // start training using random forests
        fpc->train( fp, examples);

        // save trained classifier to file
        if (saveLoadData) save( fileLocation );

        // Cleaning up
        for ( vector<CachedExample *>::iterator i = imgexamples.begin();
              i != imgexamples.end();
              i++ )
            delete ( *i );

        fp.destroy();
    }
}

/** classification function */
void SemSegConvolutionalTree::semanticseg(
        CachedExample *ce,
        Image &segresult,
        NICE::MultiChannelImageT<double> &probabilities )
{
    // for speed optimization
    FPCRandomForests *fpcrf = dynamic_cast<FPCRandomForests *> ( fpc );

    int xsize, ysize;
    ce->getImageSize ( xsize, ysize );
    probabilities.reInit ( xsize, ysize, classNames->getMaxClassno() + 1 );
    segresult.resize ( xsize, ysize );

    convertRGBToHSV(ce);

    Example pce ( ce, 0, 0 );
    for ( int y = 0 ; y < ysize ; y++ )
      for ( int x = 0 ; x < xsize ; x++ )
      {
        pce.x = x ;
        pce.y = y;
        ClassificationResult r = fpcrf->classify ( pce );
        segresult.setPixel ( x, y, r.classno );
        for ( int i = 0 ; i < ( int ) probabilities.channels(); i++ )
          probabilities[i](x,y) = r.scores[i];
      }
}

///////////////////// INTERFACE PERSISTENT /////////////////////
// interface specific methods for store and restore
///////////////////// INTERFACE PERSISTENT /////////////////////

void SemSegConvolutionalTree::restore( istream &is, int format )
{
    //dirty solution to circumvent the const-flag
    const_cast<ClassNames*>(this->classNames)->restore ( is, format );

    fpc->restore( is, format );
}

void SemSegConvolutionalTree::store ( ostream &os, int format ) const
{
    classNames->store( os, format );
    fpc->store( os, format );
}

void SemSegConvolutionalTree::clear ( )
{
    fpc->clear();
}