#include "SemSegContextTree3D.h"

#include <core/basics/FileName.h>
#include <core/basics/numerictools.h>
#include <core/basics/quadruplet.h>
#include <core/basics/StringTools.h>
#include <core/basics/Timer.h>
#include <core/basics/vectorio.h>
#include <core/image/Filter.h>
#include <core/image/FilterT.h>
#include <core/image/Morph.h>
#include <core/imagedisplay/ImageDisplay.h>

#include <vislearning/baselib/cc.h>
#include <vislearning/baselib/Globals.h>
#include <vislearning/baselib/ICETools.h>
#include <vislearning/cbaselib/CachedExample.h>
#include <vislearning/cbaselib/PascalResults.h>

#include <segmentation/RSGraphBased.h>
#include <segmentation/RSMeanShift.h>
#include <segmentation/RSSlic.h>

#include <omp.h>
#include <iostream>

#define VERBOSE
#undef DEBUG
#undef VISUALIZE
#undef WRITEREGIONS

using namespace OBJREC;
using namespace std;
using namespace NICE;


//###################### CONSTRUCTORS #########################//


SemSegContextTree3D::SemSegContextTree3D () : SemanticSegmentation ()
{
    this->lfcw                = NULL;
    this->firstiteration      = true;
    this->run3Dseg            = false;
    this->maxSamples          = 2000;
    this->minFeats            = 50;
    this->maxDepth            = 10;
    this->windowSize          = 15;
    this->contextMultiplier   = 3;
    this->featsPerSplit       = 200;
    this->useShannonEntropy   = true;
    this->nbTrees             = 10;
    this->randomTests         = 10;
    this->useAltTristimulus   = false;
    this->useGradient         = true;
    this->useWeijer           = false;
    this->useAdditionalLayer  = false;
    this->useHoiemFeatures    = false;
    this->useCategorization   = false;
    this->cndir               = "";
    this->fasthik             = NULL;
    this->saveLoadData        = false;
    this->fileLocation        = "tmp.txt";
    this->pixelWiseLabeling   = true;
    this->segmentation        = NULL;
    this->useFeat0            = true;
    this->useFeat1            = false;
    this->useFeat2            = true;
    this->useFeat3            = true;
    this->useFeat4            = false;
    this->labelIncrement      = 1;

    if (coarseMode)
        this->labelIncrement = 6;
}


SemSegContextTree3D::SemSegContextTree3D (
        const Config *conf,
        const ClassNames *classNames )
    : SemanticSegmentation ( conf, classNames )
{
    this->conf = conf;

    string section = "SSContextTree";
    string featsec = "Features";

    this->lfcw                = NULL;
    this->firstiteration      = true;
    this->run3Dseg            = conf->gB ( section, "run_3dseg", false );
    this->maxSamples          = conf->gI ( section, "max_samples", 2000 );
    this->minFeats            = conf->gI ( section, "min_feats", 50 );
    this->maxDepth            = conf->gI ( section, "max_depth", 10 );
    this->windowSize          = conf->gI ( section, "window_size", 15 );
    this->contextMultiplier   = conf->gI ( section, "context_multiplier", 3 );
    this->featsPerSplit       = conf->gI ( section, "feats_per_split", 200 );
    this->useShannonEntropy   = conf->gB ( section, "use_shannon_entropy", true );
    this->nbTrees             = conf->gI ( section, "amount_trees", 10 );
    this->randomTests         = conf->gI ( section, "random_tests", 10 );

    this->useAltTristimulus   = conf->gB ( featsec, "use_alt_trist", false );
    this->useGradient         = conf->gB ( featsec, "use_gradient", true );
    this->useWeijer           = conf->gB ( featsec, "use_weijer", true );
    this->useAdditionalLayer  = conf->gB ( featsec, "use_additional_layer", false );
    this->useHoiemFeatures    = conf->gB ( featsec, "use_hoiem_features", false );

    this->useCategorization   = conf->gB ( section, "use_categorization", false );
    this->cndir               = conf->gS ( "SSContextTree", "cndir", "" );

    this->saveLoadData        = conf->gB ( "debug", "save_load_data", false );
    this->fileLocation        = conf->gS ( "debug", "datafile", "tmp.txt" );

    this->pixelWiseLabeling   = conf->gB ( section, "pixelWiseLabeling", false );

    if (coarseMode)
        this->labelIncrement  = conf->gI ( section, "label_increment", 6 );
    else
        this->labelIncrement  = 1;

    if ( useCategorization && cndir == "" )
        this->fasthik = new GPHIKClassifierNICE ( conf );
    else
        this->fasthik = NULL;

    if ( useWeijer )
        this->lfcw    = new LocalFeatureColorWeijer ( conf );

    this->classnames = (*classNames);

    string forbidden_classes_s = conf->gS ( "analysis", "forbidden_classes", "" );
    classnames.getSelection ( forbidden_classes_s, forbidden_classes );

    // feature types
    this->useFeat0 = conf->gB ( section, "use_feat_0", true);     // pixel pair features
    this->useFeat1 = conf->gB ( section, "use_feat_1", false);    // region feature
    this->useFeat2 = conf->gB ( section, "use_feat_2", true);     // integral features
    this->useFeat3 = conf->gB ( section, "use_feat_3", true);     // integral contex features
    this->useFeat4 = conf->gB ( section, "use_feat_4", false);    // pixel pair context features

    string segmentationtype = conf->gS ( section, "segmentation_type", "none" );
    if ( segmentationtype == "meanshift" )
        this->segmentation = new RSMeanShift ( conf );
    else if ( segmentationtype == "felzenszwalb" )
        this->segmentation = new RSGraphBased ( conf );
    else if ( segmentationtype == "slic" )
        this->segmentation = new RSSlic ( conf );
    else if ( segmentationtype == "none" )
    {
        this->segmentation = NULL;
        this->pixelWiseLabeling = true;
        this->useFeat1 = false;
    }
    else
        throw ( "no valid segmenation_type\n please choose between none, meanshift, slic and felzenszwalb\n" );

    if ( useFeat0 )
        this->featTypes.push_back(0);
    if ( useFeat1 )
        this->featTypes.push_back(1);
    if ( useFeat2 )
        this->featTypes.push_back(2);
    if ( useFeat3 )
        this->featTypes.push_back(3);
    if ( useFeat4 )
        this->featTypes.push_back(4);

    this->initOperations();
}

//###################### DESTRUCTORS ##########################//

SemSegContextTree3D::~SemSegContextTree3D()
{
}


//#################### MEMBER FUNCTIONS #######################//

void SemSegContextTree3D::initOperations()
{
    string featsec = "Features";

    // operation prototypes
    vector<Operation3D*> tops0, tops1, tops2, tops3, tops4, tops5;

    if ( conf->gB ( featsec, "int", true ) )
    {
        tops2.push_back ( new IntegralOps3D() );
        Operation3D* o = new IntegralOps3D();
        o->setContext(true);
        tops3.push_back ( o );
    }
    if ( conf->gB ( featsec, "bi_int", true ) )
    {
        tops2.push_back ( new BiIntegralOps3D() );
        Operation3D* o = new BiIntegralOps3D();
        o->setContext(true);
        tops3.push_back ( o );
    }
    if ( conf->gB ( featsec, "bi_int_cent", true ) )
    {
        tops2.push_back ( new BiIntegralCenteredOps3D() );
        Operation3D* o = new BiIntegralCenteredOps3D();
        o->setContext(true);
        tops3.push_back ( o );
    }
    if ( conf->gB ( featsec, "int_cent", true ) )
    {
        tops2.push_back ( new IntegralCenteredOps3D() );
        Operation3D* o = new IntegralCenteredOps3D();
        o->setContext(true);
        tops3.push_back ( o );
    }
    if ( conf->gB ( featsec, "haar_horz", true ) )
    {
        tops2.push_back ( new HaarHorizontal3D() );
        Operation3D* o = new HaarHorizontal3D();
        o->setContext(true);
        tops3.push_back ( o );
    }
    if ( conf->gB ( featsec, "haar_vert", true ) )
    {
        tops2.push_back ( new HaarVertical3D() );
        Operation3D* o = new HaarVertical3D();
        o->setContext(true);
        tops3.push_back ( o );
    }
    if ( conf->gB ( featsec, "haar_stack", true ) )
    {
        tops2.push_back ( new HaarStacked3D() );
        Operation3D* o = new HaarStacked3D();
        o->setContext(true);
        tops3.push_back ( o );
    }
    if ( conf->gB ( featsec, "haar_diagxy", true ) )
    {
        tops2.push_back ( new HaarDiagXY3D() );
        Operation3D* o = new HaarDiagXY3D();
        o->setContext(true);
        tops3.push_back ( o );
    }
    if ( conf->gB ( featsec, "haar_diagxz", true ) )
    {
        tops2.push_back ( new HaarDiagXZ3D() );
        Operation3D* o = new HaarDiagXZ3D();
        o->setContext(true);
        tops3.push_back ( o );
    }
    if ( conf->gB ( featsec, "haar_diagyz", true ) )
    {
        tops2.push_back ( new HaarDiagYZ3D() );
        Operation3D* o = new HaarDiagYZ3D();
        o->setContext(true);
        tops3.push_back ( o );
    }
    if ( conf->gB ( featsec, "haar3_horz", true ) )
    {
        tops2.push_back ( new Haar3Horiz3D() );
        Operation3D* o = new Haar3Horiz3D();
        o->setContext(true);
        tops3.push_back ( o );
    }
    if ( conf->gB ( featsec, "haar3_vert", true ) )
    {
        tops2.push_back ( new Haar3Vert3D() );
        Operation3D* o = new Haar3Vert3D();
        o->setContext(true);
        tops3.push_back ( o );
    }
    if ( conf->gB ( featsec, "haar3_stack", true ) )
    {
        tops2.push_back ( new Haar3Stack3D() );
        Operation3D* o = new Haar3Stack3D();
        o->setContext(true);
        tops3.push_back ( o );
    }

    if ( conf->gB ( featsec, "minus", true ) )
    {
        tops0.push_back ( new Minus3D() );
        Operation3D* o = new Minus3D();
        o->setContext(true);
        tops4.push_back ( o );
    }
    if ( conf->gB ( featsec, "minus_abs", true ) )
    {
        tops0.push_back ( new MinusAbs3D() );
        Operation3D* o = new MinusAbs3D();
        o->setContext(true);
        tops4.push_back ( o );
    }
    if ( conf->gB ( featsec, "addition", true ) )
    {
        tops0.push_back ( new Addition3D() );
        Operation3D* o = new Addition3D();
        o->setContext(true);
        tops4.push_back ( o );
    }
    if ( conf->gB ( featsec, "only1", true ) )
    {
        tops0.push_back ( new Only13D() );
        Operation3D* o = new Only13D();
        o->setContext(true);
        tops4.push_back ( o );
    }
    if ( conf->gB ( featsec, "rel_x", true ) )
        tops0.push_back ( new RelativeXPosition3D() );
    if ( conf->gB ( featsec, "rel_y", true ) )
        tops0.push_back ( new RelativeYPosition3D() );
    if ( conf->gB ( featsec, "rel_z", true ) )
        tops0.push_back ( new RelativeZPosition3D() );

    this->ops.push_back ( tops0 );
    this->ops.push_back ( tops1 );
    this->ops.push_back ( tops2 );
    this->ops.push_back ( tops3 );
    this->ops.push_back ( tops4 );
}

double SemSegContextTree3D::getBestSplit (
        std::vector<NICE::MultiChannelImage3DT<double> > &feats,
        std::vector<NICE::MultiChannelImage3DT<unsigned short int> > &nodeIndices,
        const std::vector<NICE::MultiChannelImageT<int> > &labels,
        int node,
        Operation3D *&splitop,
        double &splitval,
        const int &tree,
        vector<vector<vector<double> > > &regionProbs )
{
    Timer t;
    t.start();
    int imgCount = 0;

    try
    {
        imgCount = ( int ) feats.size();
    }
    catch ( Exception )
    {
        cerr << "no features computed?" << endl;
    }

    double bestig = -numeric_limits< double >::max();

    splitop = NULL;
    splitval = -1.0;

    vector<quadruplet<int,int,int,int> > selFeats;
    map<int, int> e;
    int featcounter = forest[tree][node].featcounter;

    if ( featcounter < minFeats )
    {
        return 0.0;
    }

    vector<double> fraction ( a.size(), 0.0 );

    for ( uint i = 0; i < fraction.size(); i++ )
    {
        if ( forbidden_classes.find ( labelmapback[i] ) != forbidden_classes.end() )
            fraction[i] = 0;
        else
            fraction[i] = ( ( double ) maxSamples ) / ( ( double ) featcounter * a[i] * a.size() );
    }

    featcounter = 0;

    for ( int iCounter = 0; iCounter < imgCount; iCounter++ )
    {
        int xsize = ( int ) nodeIndices[iCounter].width();
        int ysize = ( int ) nodeIndices[iCounter].height();
        int zsize = ( int ) nodeIndices[iCounter].depth();

        for ( int x = 0; x < xsize; x++ )
            for ( int y = 0; y < ysize; y++ )
                for ( int z = 0; z < zsize; z++ )
                {
                    if ( nodeIndices[iCounter].get ( x, y, z, tree ) == node )
                    {
                        int cn = labels[iCounter].get ( x, y, ( uint ) z );
                        double randD = ( double ) rand() / ( double ) RAND_MAX;

                        if ( labelmap.find ( cn ) == labelmap.end() )
                            continue;

                        if ( randD < fraction[labelmap[cn]] )
                        {
                            quadruplet<int,int,int,int> quad( iCounter, x, y, z );
                            featcounter++;
                            selFeats.push_back ( quad );
                            e[cn]++;
                        }
                    }
                }
    }

    // global entropy
    double globent = 0.0;
    for ( map<int, int>::iterator mapit = e.begin() ; mapit != e.end(); mapit++ )
    {
        double p = ( double ) ( *mapit ).second / ( double ) featcounter;
        globent += p * log2 ( p );
    }
    globent = -globent;

    if ( globent < 0.5 )
        return 0.0;

    // pointers to all randomly chosen features
    std::vector<Operation3D*> featsel;

    for ( int i = 0; i < featsPerSplit; i++ )
    {
        int x1, x2, y1, y2, z1, z2, ft;

        do
        {
            ft = ( int ) ( rand() % featTypes.size() );
            ft = featTypes[ft];
        }
        while ( channelsPerType[ft].size() == 0 );

        int tmpws = windowSize;

        if ( ft == 3 || ft == 4 )
        {
            //use larger window size for context features
            tmpws *= contextMultiplier;
        }

        // use region feature only with reasonable pre-segmentation
        //    if ( ft == 1 && depth < 8 )
        //    {
        //      ft = 0;
        //    }

        /* random window positions */
        double z_ratio = conf->gB ( "SSContextTree", "z_ratio", 1.0 );
        int tmp_z =  ( int ) floor( (tmpws * z_ratio) + 0.5 );
        double y_ratio = conf->gB ( "SSContextTree", "y_ratio", 1.0 );
        int tmp_y =  ( int ) floor( (tmpws * y_ratio) + 0.5 );
        x1 = ( int ) ( rand() % tmpws ) - tmpws / 2 ;
        x2 = ( int ) ( rand() % tmpws ) - tmpws / 2 ;
        y1 = ( int ) ( rand() % tmp_y ) - tmp_y / 2 ;
        y2 = ( int ) ( rand() % tmp_y ) - tmp_y / 2 ;
        z1 = ( int ) ( rand() % tmp_z ) - tmp_z / 2 ;
        z2 = ( int ) ( rand() % tmp_z ) - tmp_z / 2 ;

        if (conf->gB ( "SSContextTree", "z_negative_only", false ))
        {
            z1 = abs(z1);
            z2 = abs(z2);
        }

        /* random feature maps (channels) */
        int f1, f2;
        f1 = ( int ) ( rand() % channelsPerType[ft].size() );
        if ( (rand() % 2) == 0 )
            f2 = ( int ) ( rand() % channelsPerType[ft].size() );
        else
            f2 = f1;
        f1 = channelsPerType[ft][f1];
        f2 = channelsPerType[ft][f2];

        if ( ft == 1 )
        {
            int classes = ( int ) regionProbs[0][0].size();
            f2 = ( int ) ( rand() % classes );
        }

        /* random extraction method (operation) */
        int o = ( int ) ( rand() % ops[ft].size() );

        Operation3D *op = ops[ft][o]->clone();
        op->set ( x1, y1, z1, x2, y2, z2, f1, f2, ft );
        op->setWSize(windowSize);

        if ( ft == 3 || ft == 4 )
            op->setContext ( true );
        else
            op->setContext ( false );

        featsel.push_back ( op );
    }

    // do actual split tests
    for ( int f = 0; f < featsPerSplit; f++ )
    {
        double l_bestig = -numeric_limits< double >::max();
        double l_splitval = -1.0;
        vector<double> vals;

        double maxval = -numeric_limits<double>::max();
        double minval = numeric_limits<double>::max();
        int counter = 0;
        for ( vector<quadruplet<int,int,int,int> >::const_iterator it = selFeats.begin();
              it != selFeats.end(); it++ )
        {
            Features feat;
            feat.feats = &feats[ ( *it ).first ];
            feat.rProbs = &regionProbs[ ( *it ).first ];

            assert ( forest.size() > ( uint ) tree );
            assert ( forest[tree][0].dist.size() > 0 );

            double val = 0.0;
            val = featsel[f]->getVal ( feat, ( *it ).second, ( *it ).third, ( *it ).fourth );
            if ( !isfinite ( val ) )
            {
#ifdef DEBUG
                cerr << "feat " << feat.feats->width() << " " << feat.feats->height() << " " << feat.feats->depth() << endl;
                cerr << "non finite value " << val << " for " << featsel[f]->writeInfos() <<  endl << (*it).second << " " <<  (*it).third << " " << (*it).fourth << endl;
#endif
                val = 0.0;
            }
            vals.push_back ( val );
            maxval = std::max ( val, maxval );
            minval = std::min ( val, minval );
        }

        if ( minval == maxval )
            continue;

        // split values
        for ( int run = 0 ; run < randomTests; run++ )
        {
            // choose threshold randomly
            double sval = 0.0;
            sval = ( (double) rand() / (double) RAND_MAX*(maxval-minval) ) + minval;

            map<int, int> eL, eR;
            int counterL = 0, counterR = 0;
            counter = 0;

            for ( vector<quadruplet<int,int,int,int> >::const_iterator it2 = selFeats.begin();
                  it2 != selFeats.end(); it2++, counter++ )
            {
                int cn = labels[ ( *it2 ).first ].get ( ( *it2 ).second, ( *it2 ).third, ( *it2 ).fourth );
                //cout << "vals[counter2] " << vals[counter2] << " val: " <<  val << endl;

                if ( vals[counter] < sval )
                {
                    //left entropie:
                    eL[cn] = eL[cn] + 1;
                    counterL++;
                }
                else
                {
                    //right entropie:
                    eR[cn] = eR[cn] + 1;
                    counterR++;
                }
            }

            double leftent = 0.0;
            for ( map<int, int>::iterator mapit = eL.begin() ; mapit != eL.end(); mapit++ )
            {
                double p = ( double ) ( *mapit ).second / ( double ) counterL;
                leftent -= p * log2 ( p );
            }

            double rightent = 0.0;
            for ( map<int, int>::iterator mapit = eR.begin() ; mapit != eR.end(); mapit++ )
            {
                double p = ( double ) ( *mapit ).second / ( double ) counterR;
                rightent -= p * log2 ( p );
            }

            //cout << "rightent: " << rightent << " leftent: " << leftent << endl;

            double pl = ( double ) counterL / ( double ) ( counterL + counterR );

            //information gain
            double ig = globent - ( 1.0 - pl ) * rightent - pl * leftent;

            //double ig = globent - rightent - leftent;

            if ( useShannonEntropy )
            {
                double esplit = - ( pl * log ( pl ) + ( 1 - pl ) * log ( 1 - pl ) );
                ig = 2 * ig / ( globent + esplit );
            }

            if ( ig > l_bestig )
            {
                l_bestig = ig;
                l_splitval = sval;
            }
        }

        if ( l_bestig > bestig )
        {
            bestig = l_bestig;
            splitop = featsel[f];
            splitval = l_splitval;
        }
    }

#ifdef DEBUG
    cout << "globent: " << globent <<  " bestig " << bestig << " splitval: " << splitval << endl;
#endif
    return bestig;
}

inline double SemSegContextTree3D::getMeanProb (
        const int &x,
        const int &y,
        const int &z,
        const int &channel,
        const MultiChannelImage3DT<unsigned short int> &nodeIndices )
{
    double val = 0.0;

    for ( int tree = 0; tree < nbTrees; tree++ )
    {
        val += forest[tree][nodeIndices.get ( x,y,z,tree ) ].dist[channel];
    }

    return val / ( double ) nbTrees;
}

void SemSegContextTree3D::updateProbabilityMaps (
        const NICE::MultiChannelImage3DT<unsigned short int> &nodeIndices,
        NICE::MultiChannelImage3DT<double> &feats,
        int firstChannel )
{
    int xsize = feats.width();
    int ysize = feats.height();
    int zsize = feats.depth();

    int classes = ( int ) labelmap.size();

    // integral images for context channels (probability maps for each class)
#pragma omp parallel for
    for ( int c = 0; c < classes; c++ )
    {
        for ( int z = 0; z < zsize; z++ )
            for ( int y = 0; y < ysize; y++ )
                for ( int x = 0; x < xsize; x++ )
                {
                    double val = getMeanProb ( x, y, z, c, nodeIndices );

                    if (useFeat3 || useFeat4)
                        feats ( x, y, z, firstChannel + c ) = val;
                }

        feats.calcIntegral ( firstChannel + c );
    }
}

inline double computeWeight ( const int &d, const int &dim )
{
    if (d == 0)
        return 0.0;
    else
        return 1.0 / ( pow ( 2, ( double ) ( dim - d + 1 ) ) );
}

void SemSegContextTree3D::train ( const MultiDataset *md )
{
    const LabeledSet *trainp = ( *md ) ["train"];

    if ( saveLoadData )
    {
        if ( FileMgt::fileExists ( fileLocation ) )
            read ( fileLocation );
        else
        {
            train ( trainp );
            write ( fileLocation );
        }
    }
    else
    {
        train ( trainp );
    }
}

void SemSegContextTree3D::train ( const LabeledSet * trainp )
{
    int shortsize = numeric_limits<short>::max();

    Timer timer;
    timer.start();

    vector<int> zsizeVec;
    getDepthVector ( trainp, zsizeVec, run3Dseg );

    //FIXME: memory usage
    vector<MultiChannelImage3DT<double> > allfeats;   // Feature Werte
    vector<MultiChannelImage3DT<unsigned short int> > nodeIndices;    // Zuordnung Knoten/Baum für jeden Pixel
    vector<MultiChannelImageT<int> > labels;

    // für externen Klassifikator
    vector<SparseVector*> globalCategorFeats;
    vector<map<int,int> > classesPerImage;

    vector<vector<int> > rSize;   // anzahl der pixel je region
    vector<int> amountRegionpI; // ANZAHL der regionen pro bild (von unsupervised segmentation)

    int imgCounter = 0;
    int amountPixels = 0;

    // How many channels of non-integral type do we have?

    if ( imagetype == IMAGETYPE_RGB )
        rawChannels = 3;
    else
        rawChannels = 1;

    if ( useGradient )
        rawChannels *= 3;   // gx, gy, gz

    if ( useWeijer )      // Weijer Colornames
        rawChannels += 11;

    if ( useHoiemFeatures )       // geometrische Kontextmerkmale
        rawChannels += 8;

    if ( useAdditionalLayer )     // beliebige Merkmale in extra Bilddateien
        rawChannels += 1;


    ///////////////////////////// read input data /////////////////////////////////
    ///////////////////////////////////////////////////////////////////////////////
    int depthCount = 0;
    vector< string > filelist;
    NICE::MultiChannelImageT<int> pixelLabels;
    std::map<int, bool> labelExist;

    for (LabeledSet::const_iterator it = trainp->begin(); it != trainp->end(); it++)
    {
        for (std::vector<ImageInfo *>::const_iterator jt = it->second.begin();
             jt != it->second.end(); jt++)
        {
            int classno = it->first;
            ImageInfo & info = *(*jt);
            std::string file = info.img();
            filelist.push_back ( file );
            depthCount++;

            const LocalizationResult *locResult = info.localization();

            // getting groundtruth
            NICE::ImageT<int> pL;
            pL.resize ( locResult->xsize, locResult->ysize );
            pL.set ( 0 );
            locResult->calcLabeledImage ( pL, ( *classNames ).getBackgroundClass() );
            pixelLabels.addChannel ( pL );

            if ( locResult->size() <= 0 )
            {
                fprintf ( stderr, "WARNING: NO ground truth polygons found for %s !\n",
                          file.c_str() );
                continue;
            }

            fprintf ( stderr, "SSContext: Collecting pixel examples from localization info: %s\n", file.c_str() );

            int depthBoundary = 0;
            if ( run3Dseg )
            {
                depthBoundary = zsizeVec[imgCounter];
            }

            if ( depthCount < depthBoundary ) continue;

            // all image slices collected -> make a 3d image
            NICE::MultiChannelImage3DT<double> imgData;
            make3DImage ( filelist, imgData );

            int xsize = imgData.width();
            int ysize = imgData.height();
            int zsize = imgData.depth();
            amountPixels += xsize * ysize * zsize;

            MultiChannelImageT<int> tmpMat ( xsize, ysize, ( uint ) zsize );
            labels.push_back ( tmpMat );

            nodeIndices.push_back ( MultiChannelImage3DT<unsigned short int> ( xsize, ysize, zsize, nbTrees ) );
            nodeIndices[imgCounter].setAll ( 0 );

            int amountRegions;
            // convert color to L*a*b, add selected feature channels
            addFeatureMaps ( imgData, filelist, amountRegions );
            allfeats.push_back(imgData);

            if ( useFeat1 )
            {
                amountRegionpI.push_back ( amountRegions );
                rSize.push_back ( vector<int> ( amountRegions, 0 ) );
            }

            if ( useCategorization )
            {
                globalCategorFeats.push_back ( new SparseVector() );
                classesPerImage.push_back ( map<int,int>() );
            }

            for ( int x = 0; x < xsize; x++ )
                for ( int y = 0; y < ysize; y++ )
                    for ( int z = 0; z < zsize; z++ )
                    {
                        if ( useFeat1 )
                            rSize[imgCounter][allfeats[imgCounter] ( x, y, z, rawChannels ) ]++;

                        if ( run3Dseg )
                            classno = pixelLabels ( x, y, ( uint ) z );
                        else
                            classno = pL.getPixelQuick ( x,y );

                        labels[imgCounter].set ( x, y, classno, ( uint ) z );

                        if ( forbidden_classes.find ( classno ) != forbidden_classes.end() )
                            continue;

                        labelExist[classno] = true;

                        if ( useCategorization )
                            classesPerImage[imgCounter][classno] = 1;
                    }

            filelist.clear();
            pixelLabels.reInit ( 0,0,0 );
            depthCount = 0;
            imgCounter++;
        }
    }

    int classes = 0;
    for ( map<int, bool>::const_iterator mapit = labelExist.begin();
          mapit != labelExist.end(); mapit++ )
    {
        labelmap[mapit->first] = classes;
        labelmapback[classes] = mapit->first;
        classes++;
    }

    ////////////////////////// channel type configuration /////////////////////////
    ///////////////////////////////////////////////////////////////////////////////

    unsigned char shift = 0;
    std::vector<int> rawChannelsIdx, numClassesIdx;
    int idx = 0;
    for ( int i = 0; i < rawChannels; i++, idx++ )
        rawChannelsIdx.push_back ( idx );

    for ( int i = 0; i < classes; i++, idx++ )
        numClassesIdx.push_back ( idx );


    /** Type 0: single pixel & pixel-comparison features on gray value channels */
    // actual values derived from integral values
    channelsPerType.push_back ( rawChannelsIdx );

    /** Type 1: region channel with unsupervised segmentation */
    if ( useFeat1 )
    {
        channelsPerType.push_back ( vector<int>(1, rawChannels) );
        shift = 1;
    }
    else
        channelsPerType.push_back ( vector<int>() );

    /** Type 2: rectangular and Haar-like features on gray value integral channels */
    if ( useFeat2 )
        channelsPerType.push_back ( rawChannelsIdx );
    else
        channelsPerType.push_back ( vector<int>() );

    /** Type 3: type 2 features on integral probability channels (context) */
    if ( useFeat3 )
        channelsPerType.push_back ( numClassesIdx );
    else
        channelsPerType.push_back ( vector<int>() );

    /** Type 4: type 0 features on probability channels (context) */
    // Type 4 channels are now INTEGRAL
    // This remains for compatibility reasons.
    if ( useFeat4 )
        channelsPerType.push_back ( numClassesIdx );
    else
        channelsPerType.push_back ( vector<int>() );

    ///////////////////////////////////////////////////////////////////////////////
    ///////////////////////////////////////////////////////////////////////////////

    vector<vector<vector<double> > > regionProbs;
    if ( useFeat1 )
    {
        for ( int i = 0; i < imgCounter; i++ )
        {
            regionProbs.push_back ( vector<vector<double> > ( amountRegionpI[i], vector<double> ( classes, 0.0 ) ) );
        }
    }

    //balancing
    a = vector<double> ( classes, 0.0 );
    int selectionCounter = 0;
    for ( int iCounter = 0; iCounter < imgCounter; iCounter++ )
    {
        int xsize = ( int ) nodeIndices[iCounter].width();
        int ysize = ( int ) nodeIndices[iCounter].height();
        int zsize = ( int ) nodeIndices[iCounter].depth();

        for ( int x = 0; x < xsize; x++ )
            for ( int y = 0; y < ysize; y++ )
                for ( int z = 0; z < zsize; z++ )
                {
                    int cn = labels[iCounter] ( x, y, ( uint ) z );
                    if ( labelmap.find ( cn ) == labelmap.end() )
                        continue;
                    a[labelmap[cn]] ++;
                    selectionCounter++;
                }
    }

    for ( int i = 0; i < ( int ) a.size(); i++ )
        a[i] /= ( double ) selectionCounter;

#ifdef VERBOSE
    cout << "\nDistribution:" << endl;
    for ( int i = 0; i < ( int ) a.size(); i++ )
        cout << "class '" << classNames->code(labelmapback[i]) << "': "
             << a[i] << endl;
#endif

    depth = 0;
    uniquenumber = 0;

    //initialize random forest
    for ( int t = 0; t < nbTrees; t++ )
    {
        vector<TreeNode> singletree;
        singletree.push_back ( TreeNode() );
        singletree[0].dist = vector<double> ( classes, 0.0 );
        singletree[0].depth = depth;
        singletree[0].featcounter = amountPixels;
        singletree[0].nodeNumber = uniquenumber;
        uniquenumber++;
        forest.push_back ( singletree );
    }

    vector<int> startnode ( nbTrees, 0 );

    bool noNewSplit = false;

    timer.stop();
    cout << "\nTime for Pre-Processing: " << timer.getLastAbsolute() << " seconds\n" << endl;

    //////////////////////////// train the classifier ///////////////////////////
    /////////////////////////////////////////////////////////////////////////////
    timer.start();
    while ( !noNewSplit && (depth < maxDepth) )
    {
        depth++;
#ifdef DEBUG
        cout << "depth: " << depth << endl;
#endif
        noNewSplit = true;
        vector<MultiChannelImage3DT<unsigned short int> > lastNodeIndices = nodeIndices;
        vector<vector<vector<double> > > lastRegionProbs = regionProbs;

        if ( useFeat1 )
            for ( int i = 0; i < imgCounter; i++ )
            {
                int numRegions = (int) regionProbs[i].size();
                for ( int r = 0; r < numRegions; r++ )
                    for ( int c = 0; c < classes; c++ )
                        regionProbs[i][r][c] = 0.0;
            }

        // initialize & update context channels
        for ( int i = 0; i < imgCounter; i++)
            if ( useFeat3 || useFeat4 )
                this->updateProbabilityMaps ( nodeIndices[i], allfeats[i], rawChannels + shift );

#ifdef VERBOSE
        Timer timerDepth;
        timerDepth.start();
#endif

        double weight = computeWeight ( depth, maxDepth )
                - computeWeight ( depth - 1, maxDepth );

#pragma omp parallel for
        // for each tree
        for ( int tree = 0; tree < nbTrees; tree++ )
        {
            const int t = ( int ) forest[tree].size();
            const int s = startnode[tree];
            startnode[tree] = t;
            double bestig;

            // for each node
            for ( int node = s; node < t; node++ )
            {
                if ( !forest[tree][node].isleaf && forest[tree][node].left < 0 )
                {
                    // find best split
                    Operation3D *splitfeat = NULL;
                    double splitval;
                    bestig = getBestSplit ( allfeats, lastNodeIndices, labels, node,
                                            splitfeat, splitval, tree, lastRegionProbs );

                    forest[tree][node].feat = splitfeat;
                    forest[tree][node].decision = splitval;

                    // split the node
                    if ( splitfeat != NULL )
                    {
                        noNewSplit = false;
                        int left;
#pragma omp critical
                        {
                            left = forest[tree].size();
                            forest[tree].push_back ( TreeNode() );
                            forest[tree].push_back ( TreeNode() );
                        }
                        int right = left + 1;
                        forest[tree][node].left = left;
                        forest[tree][node].right = right;
                        forest[tree][left].init( depth, classes, uniquenumber);
                        int leftu = uniquenumber;
                        uniquenumber++;
                        forest[tree][right].init( depth, classes, uniquenumber);
                        int rightu = uniquenumber;
                        uniquenumber++;

#pragma omp parallel for
                        for ( int i = 0; i < imgCounter; i++ )
                        {
                            int xsize = nodeIndices[i].width();
                            int ysize = nodeIndices[i].height();
                            int zsize = nodeIndices[i].depth();

                            for ( int x = 0; x < xsize; x++ )
                            {
                                for ( int y = 0; y < ysize; y++ )
                                {
                                    for ( int z = 0; z < zsize; z++ )
                                    {
                                        if ( nodeIndices[i].get ( x, y, z, tree ) == node )
                                        {
                                            // get feature value
                                            Features feat;
                                            feat.feats = &allfeats[i];
                                            feat.rProbs = &lastRegionProbs[i];
                                            double val = 0.0;
                                            val = splitfeat->getVal ( feat, x, y, z );
                                            if ( !isfinite ( val ) ) val = 0.0;

#pragma omp critical
                                            {
                                                int curLabel = labels[i] ( x, y, ( uint ) z );
                                                // traverse to left child
                                                if ( val < splitval )
                                                {
                                                    nodeIndices[i].set ( x, y, z, left, tree );
                                                    if ( labelmap.find ( curLabel ) != labelmap.end() )
                                                        forest[tree][left].dist[labelmap[curLabel]]++;
                                                    forest[tree][left].featcounter++;
                                                    if ( useCategorization && leftu < shortsize )
                                                        ( *globalCategorFeats[i] ) [leftu]+=weight;
                                                }
                                                // traverse to right child
                                                else
                                                {
                                                    nodeIndices[i].set ( x, y, z, right, tree );
                                                    if ( labelmap.find ( curLabel ) != labelmap.end() )
                                                        forest[tree][right].dist[labelmap[curLabel]]++;
                                                    forest[tree][right].featcounter++;

                                                    if ( useCategorization && rightu < shortsize )
                                                        ( *globalCategorFeats[i] ) [rightu]+=weight;
                                                }
                                            }
                                        }
                                    }
                                }
                            }
                        }

                        // normalize distributions in child leaves
                        double lcounter = 0.0, rcounter = 0.0;
                        for ( int c = 0; c < (int)forest[tree][left].dist.size(); c++ )
                        {
                            if ( forbidden_classes.find ( labelmapback[c] ) != forbidden_classes.end() )
                            {
                                forest[tree][left].dist[c] = 0;
                                forest[tree][right].dist[c] = 0;
                            }
                            else
                            {
                                forest[tree][left].dist[c] /= a[c];
                                lcounter += forest[tree][left].dist[c];
                                forest[tree][right].dist[c] /= a[c];
                                rcounter += forest[tree][right].dist[c];
                            }
                        }

                        assert ( lcounter > 0 && rcounter > 0 );

                        for ( int c = 0; c < classes; c++ )
                        {
                            forest[tree][left].dist[c] /= lcounter;
                            forest[tree][right].dist[c] /= rcounter;
                        }
                    }
                    else
                    {
                        forest[tree][node].isleaf = true;
                    }
                }
            }
        }


        if ( useFeat1 )
        {
            for ( int i = 0; i < imgCounter; i++ )
            {
                int xsize = nodeIndices[i].width();
                int ysize = nodeIndices[i].height();
                int zsize = nodeIndices[i].depth();

#pragma omp parallel for
                // set region probability distribution
                for ( int x = 0; x < xsize; x++ )
                {
                    for ( int y = 0; y < ysize; y++ )
                    {
                        for ( int z = 0; z < zsize; z++ )
                        {
                            for ( int tree = 0; tree < nbTrees; tree++ )
                            {
                                int node = nodeIndices[i].get ( x, y, z, tree );
                                for ( int c = 0; c < classes; c++ )
                                {
                                    int r = (int) ( allfeats[i] ( x, y, z, rawChannels ) );
                                    regionProbs[i][r][c] += forest[tree][node].dist[c];
                                }
                            }
                        }
                    }
                }

                // normalize distribution
                int numRegions = (int) regionProbs[i].size();
                for ( int r = 0; r < numRegions; r++ )
                {
                    for ( int c = 0; c < classes; c++ )
                    {
                        regionProbs[i][r][c] /= ( double ) ( rSize[i][r] );
                    }
                }
            }
        }

        if ( firstiteration ) firstiteration = false;

#ifdef VERBOSE
        timerDepth.stop();
        cout << "Depth " << depth << ": " << timerDepth.getLastAbsolute() << " seconds" <<endl;
#endif

        lastNodeIndices.clear();
        lastRegionProbs.clear();
    }

    timer.stop();
    cout << "Time for Learning: " << timer.getLastAbsolute() << " seconds\n" << endl;

    //////////////////////// classification using HIK ///////////////////////////
    /////////////////////////////////////////////////////////////////////////////

    if ( useCategorization && fasthik != NULL )
    {
        timer.start();
        uniquenumber = std::min ( shortsize, uniquenumber );
        for ( uint i = 0; i < globalCategorFeats.size(); i++ )
        {
            globalCategorFeats[i]->setDim ( uniquenumber );
            globalCategorFeats[i]->normalize();
        }
        map<int,Vector> ys;

        int cCounter = 0;
        for ( map<int,int>::const_iterator it = labelmap.begin();
              it != labelmap.end(); it++, cCounter++ )
        {
            ys[cCounter] = Vector ( globalCategorFeats.size() );
            for ( int i = 0; i < imgCounter; i++ )
            {
                if ( classesPerImage[i].find ( it->first ) != classesPerImage[i].end() )
                {
                    ys[cCounter][i] = 1;
                }
                else
                {
                    ys[cCounter][i] = -1;
                }
            }
        }

        fasthik->train( reinterpret_cast<vector<const NICE::SparseVector *>&>(globalCategorFeats), ys);

        timer.stop();
        cerr << "Time for Categorization: " << timer.getLastAbsolute() << " seconds\n" << endl;
    }

#ifdef VERBOSE
    cout << "\nFEATURE USAGE" << endl;
    cout << "#############\n" << endl;

    // amount of used features per feature type
    std::map<int, int> featTypeCounter;
    for ( int tree = 0; tree < nbTrees; tree++ )
    {
        int t = ( int ) forest[tree].size();

        for ( int node = 0; node < t; node++ )
        {
            if ( !forest[tree][node].isleaf && forest[tree][node].left != -1 )
            {
                featTypeCounter[ forest[tree][node].feat->getFeatType() ] += 1;
            }
        }
    }

    cout << "Types:" << endl;
    for ( map<int, int>::const_iterator it = featTypeCounter.begin(); it != featTypeCounter.end(); it++ )
        cout << it->first << ": " << it->second << endl;

    cout << "\nOperations - All:" << endl;
    // used operations
    vector<int> opOverview ( NBOPERATIONS, 0 );
    // relative use of context vs raw features per tree level
    vector<vector<double> > contextOverview ( maxDepth, vector<double> ( 2, 0.0 ) );
    for ( int tree = 0; tree < nbTrees; tree++ )
    {
        int t = ( int ) forest[tree].size();

        for ( int node = 0; node < t; node++ )
        {
#ifdef DEBUG
            printf ( "tree[%i]: left: %i, right: %i", node, forest[tree][node].left, forest[tree][node].right );
#endif

            if ( !forest[tree][node].isleaf && forest[tree][node].left != -1 )
            {
                cout <<  forest[tree][node].feat->writeInfos() << endl;
                opOverview[ forest[tree][node].feat->getOps() ]++;
                contextOverview[forest[tree][node].depth][ ( int ) forest[tree][node].feat->getContext() ]++;
            }
#ifdef DEBUG
            for ( int d = 0; d < ( int ) forest[tree][node].dist.size(); d++ )
            {
                cout << " " << forest[tree][node].dist[d];
            }
            cout << endl;
#endif
        }
    }

    // amount of used features per operation type
    cout << "\nOperations - Summary:" << endl;
    for ( int t = 0; t < ( int ) opOverview.size(); t++ )
    {
        cout << "Ops " << t << ": " << opOverview[ t ] << endl;
    }
    // ratio of used context features per depth level
    cout << "\nContext-Ratio:" << endl;
    for ( int d = 0; d < maxDepth; d++ )
    {
        double sum =  contextOverview[d][0] + contextOverview[d][1];
        if ( sum == 0 )
            sum = 1;

        contextOverview[d][0] /= sum;
        contextOverview[d][1] /= sum;

        cout << "Depth [" << d+1 << "] Normal: " << contextOverview[d][0] << " Context: " << contextOverview[d][1] << endl;
    }
#endif

}

void SemSegContextTree3D::addFeatureMaps (
        NICE::MultiChannelImage3DT<double> &imgData,
        const vector<string> &filelist,
        int &amountRegions )
{
    int xsize = imgData.width();
    int ysize = imgData.height();
    int zsize = imgData.depth();

    amountRegions = 0;

    // RGB to Lab
    if ( imagetype == IMAGETYPE_RGB )
    {
        for ( int z = 0; z < zsize; z++ )
            for ( int y = 0; y < ysize; y++ )
                for ( int x = 0; x < xsize; x++ )
                {
                    double R, G, B, X, Y, Z, L, a, b;
                    R = ( double )imgData.get( x, y, z, 0 ) / 255.0;
                    G = ( double )imgData.get( x, y, z, 1 ) / 255.0;
                    B = ( double )imgData.get( x, y, z, 2 ) / 255.0;

                    if ( useAltTristimulus )
                    {
                        ColorConversion::ccRGBtoXYZ( R, G, B, &X, &Y, &Z, 4 );
                        ColorConversion::ccXYZtoCIE_Lab( X, Y, Z, &L, &a, &b, 4 );
                    }
                    else
                    {
                        ColorConversion::ccRGBtoXYZ( R, G, B, &X, &Y, &Z, 0 );
                        ColorConversion::ccXYZtoCIE_Lab( X, Y, Z, &L, &a, &b, 0 );
                    }

                    imgData.set( x, y, z, L, 0 );
                    imgData.set( x, y, z, a, 1 );
                    imgData.set( x, y, z, b, 2 );
                }
    }

    // Gradient layers
    if ( useGradient )
    {
        int currentsize = imgData.channels();
        imgData.addChannel ( 2*currentsize );

        for ( int z = 0; z < zsize; z++ )
            for ( int c = 0; c < currentsize; c++ )
            {
                ImageT<double> tmp = imgData.getChannelT(z, c);
                ImageT<double> sobX( xsize, ysize );
                ImageT<double> sobY( xsize, ysize );
                NICE::FilterT<double, double, double>::sobelX ( tmp, sobX );
                NICE::FilterT<double, double, double>::sobelY ( tmp, sobY );
                for ( int y = 0; y < ysize; y++ )
                    for ( int x = 0; x < xsize; x++ )
                    {
                        imgData.set( x, y, z, sobX.getPixelQuick(x,y), c+currentsize );
                        imgData.set( x, y, z, sobY.getPixelQuick(x,y), c+(currentsize*2) );
                    }
            }
    }

    // Weijer color names
    if ( useWeijer )
    {
        if ( imagetype == IMAGETYPE_RGB )
        {
            int currentsize = imgData.channels();
            imgData.addChannel ( 11 );
            for ( int z = 0; z < zsize; z++ )
            {
                NICE::ColorImage img = imgData.getColor ( z );
                NICE::MultiChannelImageT<double> cfeats;
                lfcw->getFeats ( img, cfeats );
                for ( int c = 0; c < cfeats.channels(); c++)
                    for ( int y = 0; y < ysize; y++ )
                        for ( int x = 0; x < xsize; x++ )
                            imgData.set(x, y, z, cfeats.get(x,y,(uint)c), c+currentsize);
            }
        }
        else
        {
            cerr << "Can't compute weijer features of a grayscale image." << endl;
        }
    }

    // arbitrary additional layer as image
    if ( useAdditionalLayer )
    {
        int currentsize = imgData.channels();
        imgData.addChannel ( 1 );
        for ( int z = 0; z < zsize; z++ )
        {
            vector<string> list;
            StringTools::split ( filelist[z], '/', list );
            string layerPath = StringTools::trim ( filelist[z], list.back() ) + "addlayer/" + list.back();
            NICE::Image layer ( layerPath );
            for ( int y = 0; y < ysize; y++ )
                for ( int x = 0; x < xsize; x++ )
                    imgData.set(x, y, z, layer.getPixelQuick(x,y), currentsize);
        }
    }
    
    // read the geometric cues produced by Hoiem et al.
    if ( useHoiemFeatures )
    {
        string hoiemDirectory = conf->gS ( "Features", "hoiem_directory" );
        // we could also give the following set as a config option
        string hoiemClasses_s = "sky 000 090-045 090-090 090-135 090 090-por 090-sol";
        vector<string> hoiemClasses;
        StringTools::split ( hoiemClasses_s, ' ', hoiemClasses );

        int currentsize = imgData.channels();
        imgData.addChannel ( hoiemClasses.size() );
        for ( int z = 0; z < zsize; z++ )
        {
            FileName fn ( filelist[z] );
            fn.removeExtension();
            FileName fnBase = fn.extractFileName();

            for ( vector<string>::const_iterator i = hoiemClasses.begin(); i != hoiemClasses.end(); i++, currentsize++ )
            {
                string hoiemClass = *i;
                FileName fnConfidenceImage ( hoiemDirectory + fnBase.str() + "_c_" + hoiemClass + ".png" );
                if ( ! fnConfidenceImage.fileExists() )
                {
                    fthrow ( Exception, "Unable to read the Hoiem geometric confidence image: " << fnConfidenceImage.str() << " (original image is " << filelist[z] << ")" );
                }
                else
                {
                    Image confidenceImage ( fnConfidenceImage.str() );
                    if ( confidenceImage.width() != xsize || confidenceImage.height() != ysize )
                    {
                        fthrow ( Exception, "The size of the geometric confidence image does not match with the original image size: " << fnConfidenceImage.str() );
                    }

                    // copy standard image to double image
                    for ( int y = 0 ; y < confidenceImage.height(); y++ )
                        for ( int x = 0 ; x < confidenceImage.width(); x++ )
                            imgData ( x, y, z, currentsize ) = ( double ) confidenceImage ( x, y );

                    currentsize++;
                }
            }
        }
    }

    // region feature (unsupervised segmentation)
    int shift = 0;
    if ( useFeat1 )
    {
        shift = 1;
        MultiChannelImageT<int> regions;
        regions.reInit( xsize, ysize, zsize );
        amountRegions = segmentation->segRegions ( imgData, regions, imagetype );

        int currentsize = imgData.channels();
        imgData.addChannel ( 1 );

        for ( int z = 0; z < ( int ) regions.channels(); z++ )
            for ( int y = 0; y < regions.height(); y++ )
                for ( int x = 0; x < regions.width(); x++ )
                    imgData.set ( x, y, z, regions ( x, y, ( uint ) z ), currentsize );

    }

    // intergal images of raw channels
    if ( useFeat2 )
    {
#pragma omp parallel for
        for ( int i = 0; i < rawChannels; i++ )
            imgData.calcIntegral ( i );
    }

    int classes = classNames->numClasses() - forbidden_classes.size();

    if ( useFeat3 || useFeat4 )
        imgData.addChannel ( classes );

}

void SemSegContextTree3D::classify (
        const std::vector<std::string> & filelist,
        NICE::MultiChannelImageT<int> & segresult,
        NICE::MultiChannelImage3DT<double> & probabilities )
{
    ///////////////////////// build MCI3DT from files ///////////////////////////
    /////////////////////////////////////////////////////////////////////////////

    NICE::MultiChannelImage3DT<double> imgData;
    this->make3DImage( filelist, imgData );

    int xsize = imgData.width();
    int ysize = imgData.height();
    int zsize = imgData.depth();

    ////////////////////////// initialize variables /////////////////////////////
    /////////////////////////////////////////////////////////////////////////////

    firstiteration = true;
    depth = 0;

    // anytime classification ability
    int classificationDepth = conf->gI( "SSContextTree", "classification_depth", maxDepth );
    if (classificationDepth > maxDepth || classificationDepth < 1 )
        classificationDepth = maxDepth;

    Timer timer;
    timer.start();

    // classes occurred during training step
    int classes = labelmapback.size();
    // classes defined in config file
    int numClasses = classNames->numClasses();

    // class probabilities by pixel
    probabilities.reInit ( xsize, ysize, zsize, numClasses );
    probabilities.setAll ( 0 );

    // class probabilities by region
    vector<vector<double> > regionProbs;

    // affiliation: pixel <-> (tree,node)
    MultiChannelImage3DT<unsigned short int> nodeIndices ( xsize, ysize, zsize, nbTrees );
    nodeIndices.setAll ( 0 );

    // for categorization
    SparseVector *globalCategorFeat;
    globalCategorFeat = new SparseVector();

    /////////////////////////// get feature values //////////////////////////////
    /////////////////////////////////////////////////////////////////////////////

    // Basic Features
    int amountRegions;
    addFeatureMaps ( imgData, filelist, amountRegions );

    vector<int> rSize;
    int shift = 0;
    if ( useFeat1 )
    {
        shift = 1;
        regionProbs = vector<vector<double> > ( amountRegions, vector<double> ( classes, 0.0 ) );
        rSize = vector<int> ( amountRegions, 0 );
        for ( int z = 0; z < zsize; z++ )
        {
            for ( int y = 0; y < ysize; y++ )
            {
                for ( int x = 0; x < xsize; x++ )
                {
                    rSize[imgData ( x, y, z, rawChannels ) ]++;
                }
            }
        }
    }

    ////////////////// traverse image example through trees /////////////////////
    /////////////////////////////////////////////////////////////////////////////

    bool noNewSplit = false;
    for ( int d = 0; d < classificationDepth && !noNewSplit; d++ )
    {
        depth++;
        vector<vector<double> > lastRegionProbs = regionProbs;

        if ( useFeat1 )
        {
            int numRegions = ( int ) regionProbs.size();
            for ( int r = 0; r < numRegions; r++ )
                for ( int c = 0; c < classes; c++ )
                    regionProbs[r][c] = 0.0;
        }

        if ( depth < classificationDepth )
        {
            int firstChannel = rawChannels + shift;
            if ( useFeat3 || useFeat4 )
                this->updateProbabilityMaps ( nodeIndices, imgData, firstChannel );
        }

        double weight = computeWeight ( depth, maxDepth )
                - computeWeight ( depth - 1, maxDepth );

        noNewSplit = true;

        int tree;
#pragma omp parallel for private(tree)
        for ( tree = 0; tree < nbTrees; tree++ )
            for ( int x = 0; x < xsize; x=x+labelIncrement )
                for ( int y = 0; y < ysize; y=y+labelIncrement )
                    for ( int z = 0; z < zsize; z++ )
                    {
                        int node = nodeIndices.get ( x, y, z, tree );

                        if ( forest[tree][node].left > 0 )
                        {
                            noNewSplit = false;
                            Features feat;
                            feat.feats = &imgData;
                            feat.rProbs = &lastRegionProbs;

                            double val = forest[tree][node].feat->getVal ( feat, x, y, z );
                            if ( !isfinite ( val ) ) val = 0.0;

                            // traverse to left child
                            if ( val < forest[tree][node].decision )
                            {
                                int left = forest[tree][node].left;

                                for ( int n = 0; n < labelIncrement; n++ )
                                    for ( int m = 0; m < labelIncrement; m++ )
                                        if (x+m < xsize && y+n < ysize)
                                            nodeIndices.set ( x+m, y+n, z, left, tree );
#pragma omp critical
                                {
                                    if ( fasthik != NULL
                                         && useCategorization
                                         && forest[tree][left].nodeNumber < uniquenumber )
                                        ( *globalCategorFeat ) [forest[tree][left].nodeNumber] += weight;
                                }
                            }
                            // traverse to right child
                            else
                            {
                                int right = forest[tree][node].right;
                                for ( int n = 0; n < labelIncrement; n++ )
                                    for ( int m = 0; m < labelIncrement; m++ )
                                        if (x+m < xsize && y+n < ysize)
                                            nodeIndices.set ( x+m, y+n, z, right, tree );
#pragma omp critical
                                {
                                    if ( fasthik != NULL
                                         && useCategorization
                                         && forest[tree][right].nodeNumber < uniquenumber )
                                        ( *globalCategorFeat ) [forest[tree][right].nodeNumber] += weight;
                                }
                            }
                        }
                    }

        if ( useFeat1 )
        {
            int xsize = nodeIndices.width();
            int ysize = nodeIndices.height();
            int zsize = nodeIndices.depth();

#pragma omp parallel for
            for ( int x = 0; x < xsize; x++ )
                for ( int y = 0; y < ysize; y++ )
                    for ( int z = 0; z < zsize; z++ )
                        for ( int tree = 0; tree < nbTrees; tree++ )
                        {
                            int node = nodeIndices.get ( x, y, z, tree );
                            for ( uint c = 0; c < forest[tree][node].dist.size(); c++ )
                            {
                                int r = (int) imgData ( x, y, z, rawChannels );
                                regionProbs[r][c] += forest[tree][node].dist[c];
                            }
                        }

            int numRegions = (int) regionProbs.size();
            for ( int r = 0; r < numRegions; r++ )
                for ( int c = 0; c < (int) classes; c++ )
                    regionProbs[r][c] /= ( double ) ( rSize[r] );
        }

        if ( (depth < classificationDepth) && firstiteration ) firstiteration = false;
    }

    vector<int> classesInImg;

    if ( useCategorization )
    {
        if ( cndir != "" )
        {
            for ( int z = 0; z < zsize; z++ )
            {
                vector< string > list;
                StringTools::split ( filelist[z], '/', list );
                string orgname = list.back();

                ifstream infile ( ( cndir + "/" + orgname + ".dat" ).c_str() );
                while ( !infile.eof() && infile.good() )
                {
                    int tmp;
                    infile >> tmp;
                    assert ( tmp >= 0 && tmp < numClasses );
                    classesInImg.push_back ( tmp );
                }
            }
        }
        else
        {
            globalCategorFeat->setDim ( uniquenumber );
            globalCategorFeat->normalize();
            ClassificationResult cr = fasthik->classify( globalCategorFeat);
            for ( uint i = 0; i < ( uint ) classes; i++ )
            {
                cerr << cr.scores[i] << " ";
                if ( cr.scores[i] > 0.0/*-0.3*/ )
                {
                    classesInImg.push_back ( i );
                }
            }
        }
        cerr << "amount of classes: " << classes << " used classes: " << classesInImg.size() << endl;
    }

    if ( classesInImg.size() == 0 )
    {
        for ( uint i = 0; i < ( uint ) classes; i++ )
        {
            classesInImg.push_back ( i );
        }
    }

    // final labeling step
    if ( pixelWiseLabeling )
    {
        for ( int x = 0; x < xsize; x++ )
            for ( int y = 0; y < ysize; y++ )
                for ( int z = 0; z < zsize; z++ )
                {
                    double maxProb = - numeric_limits<double>::max();
                    int maxClass = 0;

                    for ( uint c = 0; c < classesInImg.size(); c++ )
                    {
                        int i = classesInImg[c];
                        double curProb = getMeanProb ( x, y, z, i, nodeIndices );
                        probabilities.set ( x, y, z, curProb, labelmapback[i] );

                        if ( curProb > maxProb )
                        {
                            maxProb = curProb;
                            maxClass = labelmapback[i];
                        }
                    }
                    assert(maxProb <= 1);

                    // copy pixel labeling into segresults (output)
                    segresult.set ( x, y, maxClass, ( uint ) z );
                }

#ifdef VISUALIZE
        getProbabilityMap( probabilities );
#endif
    }
    else
    {
        // labeling by region
        NICE::MultiChannelImageT<int> regions;
        int xsize = imgData.width();
        int ysize = imgData.height();
        int zsize = imgData.depth();
        regions.reInit ( xsize, ysize, zsize );

        if ( useFeat1 )
        {
            for ( int z = 0; z < zsize; z++ )
                for ( int y = 0; y < ysize; y++ )
                    for ( int x = 0; x < xsize; x++ )
                        regions.set ( x, y, imgData ( x, y, z, rawChannels ), ( uint ) z );
        }
        else
        {
            amountRegions = segmentation->segRegions ( imgData, regions, imagetype );

#ifdef DEBUG
            for ( unsigned int z = 0; z < ( uint ) zsize; z++ )
            {
                NICE::Matrix regmask;
                NICE::ColorImage colorimg ( xsize, ysize );
                NICE::ColorImage marked ( xsize, ysize );
                regmask.resize ( xsize, ysize );
                for ( int y = 0; y < ysize; y++ )
                {
                    for ( int x = 0; x < xsize; x++ )
                    {
                        regmask ( x,y ) = regions ( x,y,z );
                        colorimg.setPixelQuick ( x, y, 0, imgData.get ( x,y,z,0 ) );
                        colorimg.setPixelQuick ( x, y, 1, imgData.get ( x,y,z,0 ) );
                        colorimg.setPixelQuick ( x, y, 2, imgData.get ( x,y,z,0 ) );
                    }
                }
                vector<int> colorvals;
                colorvals.push_back ( 255 );
                colorvals.push_back ( 0 );
                colorvals.push_back ( 0 );
                segmentation->markContours ( colorimg, regmask, colorvals, marked );
                std::vector<string> list;
                StringTools::split ( filelist[z], '/', list );
                string savePath = StringTools::trim ( filelist[z], list.back() ) + "marked/" + list.back();
                marked.write ( savePath );
            }
#endif
        }

        regionProbs.clear();
        regionProbs = vector<vector<double> > ( amountRegions, vector<double> ( classes, 0.0 ) );
        vector<vector<double> > regionProbsCount ( amountRegions, vector<double> ( classes, 0.0 ) );

        vector<int> bestlabels ( amountRegions, labelmapback[classesInImg[0]] );
        for ( int z = 0; z < zsize; z++ )
        {
            for ( int y = 0; y < ysize; y++ )
            {
                for ( int x = 0; x < xsize; x++ )
                {
                    int r = regions ( x, y, ( uint ) z );
                    for ( uint i = 0; i < classesInImg.size(); i++ )
                    {
                        int c = classesInImg[i];
                        // get mean voting of all trees
                        regionProbs[r][c] += getMeanProb ( x, y, z, c, nodeIndices );
                        regionProbsCount[r][c]++;
                    }
                }
            }
        }

        for ( int r = 0; r < amountRegions; r++ )
            for ( int c = 0; c < classes; c++ )
                regionProbs[r][c] /= regionProbsCount[r][c];


        for ( int r = 0; r < amountRegions; r++ )
        {
            double maxProb = regionProbs[r][classesInImg[0]];
            bestlabels[r] = classesInImg[0];

            for ( int c = 1; c < classes; c++ )
                if ( maxProb < regionProbs[r][c] )
                {
                    maxProb = regionProbs[r][c];
                    bestlabels[r] = c;
                }

            bestlabels[r] = labelmapback[bestlabels[r]];
        }

        // copy region labeling into segresults (output)
        for ( int z = 0; z < zsize; z++ )
            for ( int y = 0; y < ysize; y++ )
                for ( int x = 0; x < xsize; x++ )
                {
                    int r = regions ( x,y, (uint) z );
                    int l = bestlabels[ r ];

                    segresult.set ( x, y, l, (uint) z );
                    for ( int c = 0; c < classes; c++ )
                    {
                        double curProb = regionProbs[r][c];
                        probabilities.set( x, y, z, curProb, c );
                    }
                }

#ifdef WRITEREGIONS
        for ( int z = 0; z < zsize; z++ )
        {
            RegionGraph rg;
            NICE::ColorImage img ( xsize,ysize );
            if ( imagetype == IMAGETYPE_RGB )
            {
                img = imgData.getColor ( z );
            }
            else
            {
                NICE::Image gray = imgData.getChannel ( z );
                for ( int y = 0; y < ysize; y++ )
                {
                    for ( int x = 0; x < xsize; x++ )
                    {
                        int val = gray.getPixelQuick ( x,y );
                        img.setPixelQuick ( x, y, val, val, val );
                    }
                }
            }

            Matrix regions_tmp ( xsize,ysize );
            for ( int y = 0; y < ysize; y++ )
            {
                for ( int x = 0; x < xsize; x++ )
                {
                    regions_tmp ( x,y ) = regions ( x,y, ( uint ) z );
                }
            }
            segmentation->getGraphRepresentation ( img, regions_tmp,  rg );
            for ( uint pos = 0; pos < regionProbs.size(); pos++ )
            {
                rg[pos]->setProbs ( regionProbs[pos] );
            }

            std::string s;
            std::stringstream out;
            std::vector< std::string > list;
            StringTools::split ( filelist[z], '/', list );

            out << "rgout/" << list.back() << ".graph";
            string writefile = out.str();
            rg.write ( writefile );
        }
#endif
    }

    timer.stop();
    cout << "\nTime for Classification: " << timer.getLastAbsolute() << endl;

    // CLEANING UP
    // TODO: operations in "forest"
    while( !ops.empty() )
    {
        vector<Operation3D*> &tops = ops.back();
        while ( !tops.empty() )
            tops.pop_back();

        ops.pop_back();
    }

    delete globalCategorFeat;
}

void SemSegContextTree3D::store ( std::ostream & os, int format ) const
{
    os.precision ( numeric_limits<double>::digits10 + 1 );
    os << nbTrees << endl;
    classnames.store ( os );

    map<int, int>::const_iterator it;

    os << labelmap.size() << endl;
    for ( it = labelmap.begin() ; it != labelmap.end(); it++ )
        os << ( *it ).first << " " << ( *it ).second << endl;

    os << labelmapback.size() << endl;
    for ( it = labelmapback.begin() ; it != labelmapback.end(); it++ )
        os << ( *it ).first << " " << ( *it ).second << endl;

    int trees = forest.size();
    os << trees << endl;

    for ( int t = 0; t < trees; t++ )
    {
        int nodes = forest[t].size();
        os << nodes << endl;
        for ( int n = 0; n < nodes; n++ )
        {
            os << forest[t][n].left << " " << forest[t][n].right << " " << forest[t][n].decision << " " << forest[t][n].isleaf << " " << forest[t][n].depth << " " << forest[t][n].featcounter << " " << forest[t][n].nodeNumber << endl;
            os << forest[t][n].dist << endl;

            if ( forest[t][n].feat == NULL )
                os << -1 << endl;
            else
            {
                os << forest[t][n].feat->getOps() << endl;
                forest[t][n].feat->store ( os );
            }
        }
    }

    vector<int> channelType;
    if ( useFeat0 )
        channelType.push_back(0);
    if ( useFeat1 )
        channelType.push_back(1);
    if ( useFeat2 )
        channelType.push_back(2);
    if ( useFeat3 )
        channelType.push_back(3);
    if ( useFeat4 )
        channelType.push_back(4);

    os << channelType.size() << endl;
    for ( int i = 0; i < ( int ) channelType.size(); i++ )
    {
        os << channelType[i] << " ";
    }
    os << endl;

    os << rawChannels << endl;

    os << uniquenumber << endl;
}

void SemSegContextTree3D::restore ( std::istream & is, int format )
{
    is >> nbTrees;

    classnames.restore ( is );

    int lsize;
    is >> lsize;

    labelmap.clear();
    for ( int l = 0; l < lsize; l++ )
    {
        int first, second;
        is >> first;
        is >> second;
        labelmap[first] = second;
    }

    is >> lsize;
    labelmapback.clear();
    for ( int l = 0; l < lsize; l++ )
    {
        int first, second;
        is >> first;
        is >> second;
        labelmapback[first] = second;
    }

    int trees;
    is >> trees;
    forest.clear();

    for ( int t = 0; t < trees; t++ )
    {
        vector<TreeNode> tmptree;
        forest.push_back ( tmptree );
        int nodes;
        is >> nodes;

        for ( int n = 0; n < nodes; n++ )
        {
            TreeNode tmpnode;
            forest[t].push_back ( tmpnode );
            is >> forest[t][n].left;
            is >> forest[t][n].right;
            is >> forest[t][n].decision;
            is >> forest[t][n].isleaf;
            is >> forest[t][n].depth;
            is >> forest[t][n].featcounter;
            is >> forest[t][n].nodeNumber;

            is >> forest[t][n].dist;

            int feattype;
            is >> feattype;
            assert ( feattype < NBOPERATIONS );
            forest[t][n].feat = NULL;

            if ( feattype >= 0 )
            {
                for ( uint o = 0; o < ops.size(); o++ )
                {
                    for ( uint o2 = 0; o2 < ops[o].size(); o2++ )
                    {
                        if ( forest[t][n].feat == NULL )
                        {
                            if ( ops[o][o2]->getOps() == feattype )
                            {
                                forest[t][n].feat = ops[o][o2]->clone();
                                break;
                            }
                        }
                    }
                }

                assert ( forest[t][n].feat != NULL );
                forest[t][n].feat->restore ( is );

            }
        }
    }

    // channel type configuration
    int ctsize;
    is >> ctsize;
    for ( int i = 0; i < ctsize; i++ )
    {
        int tmp;
        is >> tmp;
        switch (tmp)
        {
            case 0: useFeat0 = true; break;
            case 1: useFeat1 = true; break;
            case 2: useFeat2 = true; break;
            case 3: useFeat3 = true; break;
            case 4: useFeat4 = true; break;
        }
    }

    is >> rawChannels;

    is >> uniquenumber;
}