/**
* @file SemSegTools.cpp
* @brief tools for semantic segmentation
* @author Erik Rodner, Sven Sickert
* @date 03/19/2009

*/
#include <iostream>
#include <iomanip>

#include "core/basics/StringTools.h"
#include "SemSegTools.h"

using namespace OBJREC;

using namespace std;
using namespace NICE;

#undef DEBUG_LOCALIZATION
#undef DEBUG

void SemSegTools::segmentToOverlay (
        const NICE::Image *orig,
        const NICE::ColorImage & segment,
        NICE::ColorImage & result )
{
    int xsize = orig->width();
    int ysize = orig->height();

    result.resize( xsize, ysize );
    std::vector< NICE::MatrixT<double> > channelMat;

    double alpha = .3;

    for (int c = 0; c < 3; c++)
    {
        NICE::MatrixT<double> chan ( xsize, ysize );
        channelMat.push_back( chan );
    }

    for (int y = 0; y < ysize; y++)
        for (int x = 0; x < xsize; x++)
        {
            uchar val = orig->getPixelQuick(x,y);
            for (int c = 0; c < 3; c++)
                channelMat[c](x,y) = alpha*(double)val
                        + (1.0-alpha)*(double)segment.getPixel( x, y, c );
        }

    for (int y = 0; y < ysize; y++)
        for (int x = 0; x < xsize; x++)
            for (int c = 0; c < 3; c++)
            {
                int val = channelMat[c](x,y);
                result.setPixel( x, y, c, (uchar)val);
            }

}

void SemSegTools::updateConfusionMatrix(
        const ImageT<int> &img,
        const ImageT<int> &gt,
        Matrix &M,
        const std::set<int> &forbiddenClasses,
        map<int,int> & classMapping )
{
    double subsamplex = gt.width() / ( double ) img.width();
    double subsampley = gt.height() / ( double ) img.height();

    for ( int y = 0 ; y < gt.height() ; y++ )
        for ( int x = 0 ; x < gt.width() ; x++ )
        {
            int xx = ( int ) ( x / subsamplex );
            int yy = ( int ) ( y / subsampley );

            if ( xx < 0 ) xx = 0;

            if ( yy < 0 ) yy = 0;

            if ( xx > img.width() - 1 ) xx = img.width() - 1;

            if ( yy > img.height() - 1 ) yy = img.height() - 1;

            int cimg = img.getPixel ( xx, yy );
            int gimg = gt.getPixel ( x, y );

            if ( forbiddenClasses.find ( gimg ) == forbiddenClasses.end() )
            {
                M ( classMapping[gimg], classMapping[cimg] ) ++;
            }
        }
}

void SemSegTools::computeClassificationStatistics(
        Matrix &confMat,
        const ClassNames &classNames,
        const std::set<int> &forbiddenClasses,
        map<int,int> & classMappingInv )
{
    std::cout << "\nPERFORMANCE" << std::endl;
    std::cout << "###########\n" << std::endl;

    double accuracy = confMat.trace();
    double sumAll  = 0.0;

    // overall recognition rate
    for ( int r = 0; r < (int) confMat.rows(); r++ )
        for ( int c = 0; c < (int) confMat.cols(); c++ )
            sumAll += confMat( r, c );

    accuracy /= sumAll;

    double prec = 0.0, rec = 0.0, f1score = 0.0, iuScore = 0.0;

    // classification
    int normConst = classMappingInv.size();
    for ( int c = 0; c < classMappingInv.size(); c++ )
    {
        std::cout << "Class " << classNames.text( classMappingInv[c] ) << ":" << std::endl;

        double precBase = 0.0, recBase = 0.0;
        // row-wise sum
        for ( int r = 0; r < classMappingInv.size(); r++ )
            precBase += confMat(r,c);

        // column-wise sum
        for ( int cc = 0; cc < classMappingInv.size(); cc++ )
            recBase += confMat(c,cc);

        double precClass = 0, recClass = 0;

        if (precBase > 0) precClass = confMat(c,c) / precBase;

        if (recBase > 0) recClass = confMat(c,c) / recBase;

        std::cout << "  Precision: " << precClass << std::endl;
        std::cout << "  Recall:    " << recClass << std::endl;
        prec += precClass;
        rec += recClass;

        if (precBase > 0 && recBase > 0)
            iuScore += confMat(c,c) / (precBase+recBase-confMat(c,c));
        else
            normConst--;
    }

    prec /= (double)normConst;
    rec /= (double)normConst;
    iuScore /= (double)normConst;
    f1score = 2.0*(prec*rec)/(prec+rec);

    // row-wise normalization of confMat
    for ( int r = 0 ; r < (int) confMat.rows() ; r++ )
    {
        double sum = 0.0;

        for ( int c = 0 ; c < (int) confMat.cols() ; c++ )
            sum += confMat ( r, c );

        if ( std::fabs ( sum ) > 1e-4 )
            for ( int c = 0 ; c < (int) confMat.cols() ; c++ )
                confMat ( r, c ) /= sum;
    }

    // printing confusion matrix
    short int printWidth = 16;
    std::cout.precision(6);
    std::cout << std::setw(printWidth) << "";
    for (int r = 0; r < (int) confMat.rows(); r++)
    {
        int cl = classMappingInv[r];
        if ( classNames.existsClassno ( cl )
             && ( forbiddenClasses.find ( cl ) == forbiddenClasses.end() ) )
        {
            std::string cname = classNames.text ( cl );
            std::cout << std::setw(printWidth) << cname.c_str();
        }
    }
    std::cout << std::endl;
    for (int r = 0; r < (int) confMat.rows(); r++)
    {
        int cl = classMappingInv[r];
        if ( classNames.existsClassno ( cl )
             && ( forbiddenClasses.find ( cl ) == forbiddenClasses.end() ) )
        {
            std::string cname = classNames.text ( cl );
            std::cout << std::setw(printWidth) << cname.c_str();

            for (int c = 0; c < (int) confMat.cols(); c++)
                std::cout << std::setw(printWidth) << std::fixed << confMat (r, c);

            std::cout << std::endl;
        }
    }

    // print classification statistics
    std::cout << "\nAccuracy: " << accuracy;
    std::cout << "\nPrecision: " << prec;
    std::cout << "\nRecall: " << rec;
    std::cout << "\nF1Score: " << f1score;
    std::cout << "\nIU: " << iuScore;
    //std::cout << "\n\nAverage Recognition Rate: " << confMat.trace() / (double)classMappingInv.size();
    //std::cout << "\nLower Bound: " << 1.0 /(double)classMappingInv.size();
    std::cout << std::endl;
}

void SemSegTools::computeResourceStatistics (
        NICE::ResourceStatistics &rs )
{
    std::cout << "\nSTATISTICS" << std::endl;
    std::cout << "##########\n" << std::endl;

    long maxMemory;
    double userCPUTime, sysCPUTime;
    rs.getStatistics ( maxMemory, userCPUTime, sysCPUTime );
    std::cout << "Memory (max):    " << maxMemory << " KB" << std::endl;
    std::cout << "CPU Time (user): " << userCPUTime << " seconds" << std::endl;
    std::cout << "CPU Time (sys):  " << sysCPUTime << " seconds" << std::endl;
}

void SemSegTools::saveResultsToImageFile(
        const Config *conf,
        const string &section,
        const ColorImage &orig,
        const ColorImage &gtruth,
        const ColorImage &segment,
        const string &file,
        string & outStr )
{
    std::string resultDir = conf->gS ( section, "resultdir", "." );
    std::string outputType = conf->gS ( section, "output_type", "ppm" );
    std::string outputPostfix = conf->gS ( section, "output_postfix", "" );

    NICE::ColorImage overlaySegment, overlayGTruth;

    NICE::Image* origGrey = orig.getChannel(1);
    segmentToOverlay( origGrey, segment, overlaySegment );
    segmentToOverlay( origGrey, gtruth, overlayGTruth );

    std::stringstream out;
    out << resultDir << "/" << file << outputPostfix;

#ifdef DEBUG
    std::cout << "Writing to file " << out.str() << "_*." << outputType << std::endl;
#endif

    orig.write ( out.str() + "_orig." + outputType );
    segment.write ( out.str() + "_result." + outputType );
    gtruth.write ( out.str() + "_groundtruth." + outputType );
    overlaySegment.write ( out.str() + "_overlay_res." + outputType );
    overlayGTruth.write ( out.str() + "_overlay_gt." + outputType );

    outStr = out.str();
}

void SemSegTools::getDepthVector (
        const LabeledSet *Files,
        std::vector<int> & depthVec,
        const bool run3Dseg )
{
    std::string oldName;
    int zsize = 0;
    bool isInit = false;

    for (LabeledSet::const_iterator it = Files->begin(); it != Files->end(); it++)
    {
        for (std::vector<ImageInfo *>::const_iterator jt = it->second.begin();
             jt != it->second.end(); jt++)
        {
            ImageInfo & info = *(*jt);
            std::string file = info.img();

            std::vector< std::string > list;
            StringTools::split ( file, '/', list );
            std::string filename = list.back();
            uint found = filename.find_last_of ( "_" );
            if (run3Dseg && found < filename.size() && found-3 > 0 )
            {
                std::string curName = filename.substr ( found-3,3 );
                if ( !isInit )
                {
                    oldName = curName;
                    isInit = true;
                }
                if ( curName.compare ( oldName ) == 0 ) // if strings match up
                {
                    zsize++;
                }
                else
                {
                    depthVec.push_back ( zsize );
                    zsize = 1;
                    oldName = curName;
                }
            }
            else
            {
                zsize = 1;
                depthVec.push_back ( zsize );
            }

        }
    }
    depthVec.push_back ( zsize );
}

void SemSegTools::collectTrainingExamples (
        const Config * conf,
        const std::string & section,
        const LabeledSet & train,
        const ClassNames & cn,
        Examples & examples,
        vector<CachedExample *> & imgexamples,
        const bool run3Dseg )
{
    assert ( train.count() > 0 );
    examples.clear();
    imgexamples.clear();

    vector<int> zsizeVec;
    SemSegTools::getDepthVector ( &train, zsizeVec, run3Dseg );

    int grid_size_x = conf->gI(section, "grid_size_x", 5 );
    int grid_size_y = conf->gI(section, "grid_size_y", 5 );
    int grid_size_z = conf->gI(section, "grid_size_z", 5 );
    int grid_border_x = conf->gI(section, "grid_border_x", 20 );
    int grid_border_y = conf->gI(section, "grid_border_y", 20 );
    int grid_border_z = conf->gI(section, "grid_border_z", 20 );

    if (!run3Dseg)
    {
        grid_size_z = 1;
        grid_border_z = 0;
    }

    std::string selection = conf->gS(section, "train_selection" );

    set<int> classnoSelection;
    cn.getSelection ( selection, classnoSelection );

    bool useExcludedAsBG = conf->gB(section, "use_excluded_as_background", false );

    int backgroundClassNo = 0;

    if ( useExcludedAsBG )
    {
        backgroundClassNo = cn.classno("various");
        assert ( backgroundClassNo >= 0 );
    }

    int depthCount = 0;
    int imgCounter = 0;
    vector<std::string> filelist;
    NICE::MultiChannelImageT<int> pixelLabels;

    for (LabeledSet::const_iterator it = train.begin(); it != train.end(); it++)
    {
        for (std::vector<ImageInfo *>::const_iterator jt = it->second.begin();
             jt != it->second.end(); jt++)
        {
            ImageInfo & info = *(*jt);
            std::string file = info.img();
            filelist.push_back ( file );
            depthCount++;

            const LocalizationResult *locResult = info.localization();

            // getting groundtruth
            NICE::ImageT<int> pL;
            pL.resize ( locResult->xsize, locResult->ysize );
            pL.set ( 0 );
            locResult->calcLabeledImage ( pL, cn.getBackgroundClass() );
            pixelLabels.addChannel ( pL );

            if ( locResult->size() <= 0 ) {
                std::cerr << "WARNING: NO ground truth polygons found for "
                          << file << " !" << std::endl;
                continue;
            }

            std::cerr << "SemSegTools: Collecting pixel examples from localization info: "
                      << file << std::endl;

            int depthBoundary = 0;
            if ( run3Dseg )
            {
                depthBoundary = zsizeVec[imgCounter];
            }

            if ( depthCount < depthBoundary ) continue;

            int xsize, ysize, zsize;
            CachedExample *ce = new CachedExample ( filelist );
            ce->getImageSize3 ( xsize, ysize, zsize );
            imgexamples.push_back ( ce );

            // drawing actual examples
            Example pce ( ce, 0, 0, 0 );
            for ( int z = 0; z < zsize; z += grid_size_z )
                for ( int x = 0 ; x < xsize ; x += grid_size_x )
                    for ( int y = 0 ; y < ysize ; y += grid_size_y )
                        if ( ( x >= grid_border_x ) &&
                             ( y >= grid_border_y ) &&
                             ( z >= grid_border_z ) &&
                             ( x < xsize - grid_border_x ) &&
                             ( y < ysize - grid_border_y ) &&
                             ( z < zsize - grid_border_z ) )
                        {
                            pce.x = x; pce.y = y; pce.z = z;
                            int classno = pixelLabels.get(x,y,(unsigned int)z);

                            if ( classnoSelection.find(classno) != classnoSelection.end() )
                            {
                                examples.push_back (
                                            pair<int, Example> ( classno, pce ) );
                            } else if ( useExcludedAsBG )
                            {
                                examples.push_back (
                                            pair<int, Example> ( backgroundClassNo, pce ) );
                            }
                        }

            // prepare for new 3D image
            filelist.clear();
            pixelLabels.reInit ( 0,0,0 );
            depthCount = 0;
            imgCounter++;
        }
    }

    std::cerr << "total number of examples: " << (int)examples.size() << std::endl;
}