#include "SemSegContextTree3D.h"
#include "vislearning/baselib/Globals.h"
#include "core/basics/StringTools.h"

#include "core/imagedisplay/ImageDisplay.h"

#include "vislearning/cbaselib/CachedExample.h"
#include "vislearning/cbaselib/PascalResults.h"
#include "vislearning/baselib/cc.h"
#include "segmentation/RSMeanShift.h"
#include "segmentation/RSGraphBased.h"
#include "segmentation/RSSlic.h"
#include "core/basics/numerictools.h"
#include "core/basics/StringTools.h"
#include "core/basics/FileName.h"
#include "vislearning/baselib/ICETools.h"

#include "core/basics/Timer.h"
#include "core/basics/vectorio.h"
#include "core/basics/quadruplet.h"
#include <core/image/Filter.h>
#include "core/image/FilterT.h"
#include <core/image/Morph.h>

#include <omp.h>
#include <iostream>

#define VERBOSE
#undef DEBUG
#undef VISUALIZE
#undef WRITEREGIONS

using namespace OBJREC;
using namespace std;
using namespace NICE;


//###################### CONSTRUCTORS #########################//


SemSegContextTree3D::SemSegContextTree3D () : SemanticSegmentation ()
{
  this->lfcw                = NULL;
  this->firstiteration      = true;
  this->run3Dseg            = false;
  this->maxSamples          = 2000;
  this->minFeats            = 50;
  this->maxDepth            = 10;
  this->windowSize          = 15;
  this->contextMultiplier   = 3;
  this->featsPerSplit       = 200;
  this->useShannonEntropy   = true;
  this->nbTrees             = 10;
  this->randomTests         = 10;
  this->useAltTristimulus   = false;
  this->useGradient         = true;
  this->useWeijer           = false;
  this->useAdditionalLayer  = false;
  this->useHoiemFeatures    = false;
  this->useCategorization   = false;
  this->cndir               = "";
  this->fasthik             = NULL;
  this->saveLoadData        = false;
  this->fileLocation        = "tmp.txt";
  this->pixelWiseLabeling   = true;
  this->segmentation        = NULL;
  this->useFeat0            = true;
  this->useFeat1            = false;
  this->useFeat2            = true;
  this->useFeat3            = true;
  this->useFeat4            = false;
  this->useFeat5            = false;
}


SemSegContextTree3D::SemSegContextTree3D (
    const Config *conf,
    const MultiDataset *md )
    : SemanticSegmentation ( conf, & ( md->getClassNames ( "train" ) ) )
{
  this->conf = conf;

  string section = "SSContextTree";
  string featsec = "Features";

  this->lfcw                = NULL;
  this->firstiteration      = true;
  this->run3Dseg            = conf->gB ( section, "run_3dseg", false );
  this->maxSamples          = conf->gI ( section, "max_samples", 2000 );
  this->minFeats            = conf->gI ( section, "min_feats", 50 );
  this->maxDepth            = conf->gI ( section, "max_depth", 10 );
  this->windowSize          = conf->gI ( section, "window_size", 15 );
  this->contextMultiplier   = conf->gI ( section, "context_multiplier", 3 );
  this->featsPerSplit       = conf->gI ( section, "feats_per_split", 200 );
  this->useShannonEntropy   = conf->gB ( section, "use_shannon_entropy", true );
  this->nbTrees             = conf->gI ( section, "amount_trees", 10 );
  this->randomTests         = conf->gI ( section, "random_tests", 10 );

  this->useAltTristimulus   = conf->gB ( featsec, "use_alt_trist", false );
  this->useGradient         = conf->gB ( featsec, "use_gradient", true );
  this->useWeijer           = conf->gB ( featsec, "use_weijer", true );
  this->useAdditionalLayer  = conf->gB ( featsec, "use_additional_layer", false );
  this->useHoiemFeatures    = conf->gB ( featsec, "use_hoiem_features", false );

  this->useCategorization   = conf->gB ( section, "use_categorization", false );
  this->cndir               = conf->gS ( "SSContextTree", "cndir", "" );

  this->saveLoadData        = conf->gB ( "debug", "save_load_data", false );
  this->fileLocation        = conf->gS ( "debug", "datafile", "tmp.txt" );

  this->pixelWiseLabeling   = conf->gB ( section, "pixelWiseLabeling", false );

  if ( useCategorization && cndir == "" )
    this->fasthik = new GPHIKClassifierNICE ( conf );
  else
    this->fasthik = NULL;

  if ( useWeijer )
    this->lfcw    = new LocalFeatureColorWeijer ( conf );

  this->classnames = md->getClassNames ( "train" );

  // feature types
  this->useFeat0 = conf->gB ( section, "use_feat_0", true);     // pixel pair features
  this->useFeat1 = conf->gB ( section, "use_feat_1", false);    // region feature
  this->useFeat2 = conf->gB ( section, "use_feat_2", true);     // integral features
  this->useFeat3 = conf->gB ( section, "use_feat_3", true);     // integral contex features
  this->useFeat4 = conf->gB ( section, "use_feat_4", false);    // pixel pair context features
  this->useFeat5 = conf->gB ( section, "use_feat_5", false);    // ray features

  string segmentationtype = conf->gS ( section, "segmentation_type", "slic" );
  if ( segmentationtype == "meanshift" )
    this->segmentation = new RSMeanShift ( conf );
  else if ( segmentationtype == "felzenszwalb" )
    this->segmentation = new RSGraphBased ( conf );
  else if ( segmentationtype == "slic" )
    this->segmentation = new RSSlic ( conf );
  else if ( segmentationtype == "none" )
  {
    this->segmentation = NULL;
    this->pixelWiseLabeling = true;
    this->useFeat1 = false;
  }
  else
    throw ( "no valid segmenation_type\n please choose between none, meanshift, slic and felzenszwalb\n" );

  if ( !useGradient || imagetype == IMAGETYPE_RGB)
    this->useFeat5 = false;

  if ( useFeat0 )
    this->featTypes.push_back(0);
  if ( useFeat1 )
    this->featTypes.push_back(1);
  if ( useFeat2 )
    this->featTypes.push_back(2);
  if ( useFeat3 )
    this->featTypes.push_back(3);
  if ( useFeat4 )
    this->featTypes.push_back(4);
  if ( useFeat5 )
    this->featTypes.push_back(5);

  this->initOperations();
}

//###################### DESTRUCTORS ##########################//

SemSegContextTree3D::~SemSegContextTree3D()
{
}


//#################### MEMBER FUNCTIONS #######################//

void SemSegContextTree3D::initOperations()
{
  string featsec = "Features";

  // operation prototypes
  vector<Operation3D*> tops0, tops1, tops2, tops3, tops4, tops5;

  if ( conf->gB ( featsec, "int", true ) )
  {
    tops2.push_back ( new IntegralOps() );
    Operation3D* o = new IntegralOps();
    o->setContext(true);
    tops3.push_back ( o );
  }
  if ( conf->gB ( featsec, "bi_int_cent", true ) )
  {
    tops2.push_back ( new BiIntegralCenteredOps() );
    Operation3D* o = new BiIntegralCenteredOps();
    o->setContext(true);
    tops3.push_back ( o );
  }
  if ( conf->gB ( featsec, "int_cent", true ) )
  {
    tops2.push_back ( new IntegralCenteredOps() );
    Operation3D* o = new IntegralCenteredOps();
    o->setContext(true);
    tops3.push_back ( o );
  }
  if ( conf->gB ( featsec, "haar_horz", true ) )
  {
    tops2.push_back ( new HaarHorizontal() );
    Operation3D* o = new HaarHorizontal();
    o->setContext(true);
    tops3.push_back ( o );
  }
  if ( conf->gB ( featsec, "haar_vert", true ) )
  {
    tops2.push_back ( new HaarVertical );
    Operation3D* o = new HaarVertical();
    o->setContext(true);
    tops3.push_back ( o );
  }
  if ( conf->gB ( featsec, "haar_stack", true ) )
  {
    tops2.push_back ( new HaarStacked() );
    Operation3D* o = new HaarStacked();
    o->setContext(true);
    tops3.push_back ( o );
  }
  if ( conf->gB ( featsec, "haar_diagxy", true ) )
  {
    tops2.push_back ( new HaarDiagXY() );
    Operation3D* o = new HaarDiagXY();
    o->setContext(true);
    tops3.push_back ( o );
  }
  if ( conf->gB ( featsec, "haar_diagxz", true ) )
  {
    tops2.push_back ( new HaarDiagXZ() );
    Operation3D* o = new HaarDiagXZ();
    o->setContext(true);
    tops3.push_back ( o );
  }
  if ( conf->gB ( featsec, "haar_diagyz", true ) )
  {
    tops2.push_back ( new HaarDiagYZ() );
    Operation3D* o = new HaarDiagYZ();
    o->setContext(true);
    tops3.push_back ( o );
  }
  if ( conf->gB ( featsec, "haar3_horz", true ) )
  {
    tops2.push_back ( new Haar3Horiz() );
    Operation3D* o = new Haar3Horiz();
    o->setContext(true);
    tops3.push_back ( o );
  }
  if ( conf->gB ( featsec, "haar3_vert", true ) )
  {
    tops2.push_back ( new Haar3Vert() );
    Operation3D* o = new Haar3Vert();
    o->setContext(true);
    tops3.push_back ( o );
  }
  if ( conf->gB ( featsec, "haar3_stack", true ) )
  {
    tops2.push_back ( new Haar3Stack() );
    Operation3D* o = new Haar3Stack();
    o->setContext(true);
    tops3.push_back ( o );
  }

  if ( conf->gB ( featsec, "minus", true ) )
  {
    tops0.push_back ( new Minus() );
    Operation3D* o = new Minus();
    o->setContext(true);
    tops4.push_back ( o );
  }
  if ( conf->gB ( featsec, "minus_abs", true ) )
  {
    tops0.push_back ( new MinusAbs() );
    Operation3D* o = new MinusAbs();
    o->setContext(true);
    tops4.push_back ( o );
  }
  if ( conf->gB ( featsec, "addition", true ) )
  {
    tops0.push_back ( new Addition() );
    Operation3D* o = new Addition();
    o->setContext(true);
    tops4.push_back ( o );
  }
  if ( conf->gB ( featsec, "only1", true ) )
  {
    tops0.push_back ( new Only1() );
    Operation3D* o = new Only1();
    o->setContext(true);
    tops4.push_back ( o );
  }
  if ( conf->gB ( featsec, "rel_x", true ) )
    tops0.push_back ( new RelativeXPosition() );
  if ( conf->gB ( featsec, "rel_y", true ) )
    tops0.push_back ( new RelativeYPosition() );
  if ( conf->gB ( featsec, "rel_z", true ) )
    tops0.push_back ( new RelativeZPosition() );

  if ( conf->gB ( featsec, "ray_diff", false ) )
    tops5.push_back ( new RayDiff() );
  if ( conf->gB ( featsec, "ray_dist", false ) )
    tops5.push_back ( new RayDist() );
  if ( conf->gB ( featsec, "ray_orient", false ) )
    tops5.push_back ( new RayOrient() );
  if ( conf->gB ( featsec, "ray_norm", false ) )
    tops5.push_back ( new RayNorm() );

  this->ops.push_back ( tops0 );
  this->ops.push_back ( tops1 );
  this->ops.push_back ( tops2 );
  this->ops.push_back ( tops3 );
  this->ops.push_back ( tops4 );
  this->ops.push_back ( tops5 );
}

double SemSegContextTree3D::getBestSplit (
    std::vector<NICE::MultiChannelImage3DT<double> > &feats,
    std::vector<NICE::MultiChannelImage3DT<unsigned short int> > &nodeIndices,
    const std::vector<NICE::MultiChannelImageT<int> > &labels,
    int node,
    Operation3D *&splitop,
    double &splitval,
    const int &tree,
    vector<vector<vector<double> > > &regionProbs )
{
  Timer t;
  t.start();
  int imgCount = 0;

  try
  {
    imgCount = ( int ) feats.size();
  }
  catch ( Exception )
  {
    cerr << "no features computed?" << endl;
  }

  double bestig = -numeric_limits< double >::max();

  splitop = NULL;
  splitval = -1.0;

  vector<quadruplet<int,int,int,int> > selFeats;
  map<int, int> e;
  int featcounter = forest[tree][node].featcounter;

  if ( featcounter < minFeats )
  {
    return 0.0;
  }

  vector<double> fraction ( a.size(), 0.0 );

  for ( uint i = 0; i < fraction.size(); i++ )
  {
    if ( forbidden_classes.find ( labelmapback[i] ) != forbidden_classes.end() )
      fraction[i] = 0;
    else
      fraction[i] = ( ( double ) maxSamples ) / ( ( double ) featcounter * a[i] * a.size() );
  }

  featcounter = 0;

  for ( int iCounter = 0; iCounter < imgCount; iCounter++ )
  {
    int xsize = ( int ) nodeIndices[iCounter].width();
    int ysize = ( int ) nodeIndices[iCounter].height();
    int zsize = ( int ) nodeIndices[iCounter].depth();

    for ( int x = 0; x < xsize; x++ )
      for ( int y = 0; y < ysize; y++ )
        for ( int z = 0; z < zsize; z++ )
        {
          if ( nodeIndices[iCounter].get ( x, y, z, tree ) == node )
          {
            int cn = labels[iCounter].get ( x, y, ( uint ) z );
            double randD = ( double ) rand() / ( double ) RAND_MAX;

            if ( labelmap.find ( cn ) == labelmap.end() )
              continue;

            if ( randD < fraction[labelmap[cn]] )
            {
              quadruplet<int,int,int,int> quad( iCounter, x, y, z );
              featcounter++;
              selFeats.push_back ( quad );
              e[cn]++;
            }
          }
        }
  }

  // global entropy
  double globent = 0.0;
  for ( map<int, int>::iterator mapit = e.begin() ; mapit != e.end(); mapit++ )
  {
    double p = ( double ) ( *mapit ).second / ( double ) featcounter;
    globent += p * log2 ( p );
  }
  globent = -globent;

  if ( globent < 0.5 )
    return 0.0;

  // pointers to all randomly chosen features
  std::vector<Operation3D*> featsel;

  for ( int i = 0; i < featsPerSplit; i++ )
  {
    int x1, x2, y1, y2, z1, z2, ft;

    do
    {
      ft = ( int ) ( rand() % featTypes.size() );
      ft = featTypes[ft];
    }
    while ( channelsPerType[ft].size() == 0 );

    int tmpws = windowSize;

    if ( ft == 2 || ft == 4 )
    {
      //use larger window size for context features
      tmpws *= contextMultiplier;
    }

    // use region feature only with reasonable pre-segmentation
//    if ( ft == 1 && depth < 8 )
//    {
//      ft = 0;
//    }

    /* random window positions */
    double z_ratio = conf->gB ( "SSContextTree", "z_ratio", 1.0 );
    int tmp_z =  ( int ) floor( (tmpws * z_ratio) + 0.5 );
    double y_ratio = conf->gB ( "SSContextTree", "y_ratio", 1.0 );
    int tmp_y =  ( int ) floor( (tmpws * y_ratio) + 0.5 );
    x1 = ( int ) ( rand() % tmpws ) - tmpws / 2 ;
    x2 = ( int ) ( rand() % tmpws ) - tmpws / 2 ;
    y1 = ( int ) ( rand() % tmp_y ) - tmp_y / 2 ;
    y2 = ( int ) ( rand() % tmp_y ) - tmp_y / 2 ;
    z1 = ( int ) ( rand() % tmp_z ) - tmp_z / 2 ;
    z2 = ( int ) ( rand() % tmp_z ) - tmp_z / 2 ;

    // use z1/z2 as directions (angles) in ray features
    if ( ft == 5 )
    {
      z1 = ( int ) ( rand() % 8 );
      z2 = ( int ) ( rand() % 8 );
    }

//    if (conf->gB ( "SSContextTree", "z_negative_only", false ))
//    {
//      z1 = -abs(z1);
//      z2 = -abs(z2);
//    }

    /* random feature maps (channels) */
    int f1, f2;
    f1 = ( int ) ( rand() % channelsPerType[ft].size() );
    if ( (rand() % 2) == 0 )
      f2 = ( int ) ( rand() % channelsPerType[ft].size() );
    else
      f2 = f1;
    f1 = channelsPerType[ft][f1];
    f2 = channelsPerType[ft][f2];

    if ( ft == 1 )
    {
      int classes = ( int ) regionProbs[0][0].size();
      f2 = ( int ) ( rand() % classes );
    }

    /* random extraction method (operation) */
    int o = ( int ) ( rand() % ops[ft].size() );

    Operation3D *op = ops[ft][o]->clone();
    op->set ( x1, y1, z1, x2, y2, z2, f1, f2, ft );

    if ( ft == 3 || ft == 4 )
      op->setContext ( true );
    else
      op->setContext ( false );

    featsel.push_back ( op );
  }

  // do actual split tests
  for ( int f = 0; f < featsPerSplit; f++ )
  {
    double l_bestig = -numeric_limits< double >::max();
    double l_splitval = -1.0;
    vector<double> vals;

    double maxval = -numeric_limits<double>::max();
    double minval = numeric_limits<double>::max();
    int counter = 0;
    for ( vector<quadruplet<int,int,int,int> >::const_iterator it = selFeats.begin();
         it != selFeats.end(); it++ )
    {
      Features feat;
      feat.feats = &feats[ ( *it ).first ];
      feat.rProbs = &regionProbs[ ( *it ).first ];

      assert ( forest.size() > ( uint ) tree );
      assert ( forest[tree][0].dist.size() > 0 );

      double val = 0.0;
      val = featsel[f]->getVal ( feat, ( *it ).second, ( *it ).third, ( *it ).fourth );
      if ( !isfinite ( val ) )
      {
#ifdef DEBUG
        cerr << "feat " << feat.feats->width() << " " << feat.feats->height() << " " << feat.feats->depth() << endl;
        cerr << "non finite value " << val << " for " << featsel[f]->writeInfos() <<  endl << (*it).second << " " <<  (*it).third << " " << (*it).fourth << endl;
#endif
        val = 0.0;
      }
      vals.push_back ( val );
      maxval = std::max ( val, maxval );
      minval = std::min ( val, minval );
    }

    if ( minval == maxval )
      continue;

    // split values
    for ( int run = 0 ; run < randomTests; run++ )
    {
      // choose threshold randomly
      double sval = 0.0;
      sval = ( (double) rand() / (double) RAND_MAX*(maxval-minval) ) + minval;

      map<int, int> eL, eR;
      int counterL = 0, counterR = 0;
      counter = 0;

      for ( vector<quadruplet<int,int,int,int> >::const_iterator it2 = selFeats.begin();
            it2 != selFeats.end(); it2++, counter++ )
      {
        int cn = labels[ ( *it2 ).first ].get ( ( *it2 ).second, ( *it2 ).third, ( *it2 ).fourth );
        //cout << "vals[counter2] " << vals[counter2] << " val: " <<  val << endl;

        if ( vals[counter] < sval )
        {
          //left entropie:
          eL[cn] = eL[cn] + 1;
          counterL++;
        }
        else
        {
          //right entropie:
          eR[cn] = eR[cn] + 1;
          counterR++;
        }
      }

      double leftent = 0.0;
      for ( map<int, int>::iterator mapit = eL.begin() ; mapit != eL.end(); mapit++ )
      {
        double p = ( double ) ( *mapit ).second / ( double ) counterL;
        leftent -= p * log2 ( p );
      }

      double rightent = 0.0;
      for ( map<int, int>::iterator mapit = eR.begin() ; mapit != eR.end(); mapit++ )
      {
        double p = ( double ) ( *mapit ).second / ( double ) counterR;
        rightent -= p * log2 ( p );
      }

      //cout << "rightent: " << rightent << " leftent: " << leftent << endl;

      double pl = ( double ) counterL / ( double ) ( counterL + counterR );

      //information gain
      double ig = globent - ( 1.0 - pl ) * rightent - pl * leftent;

      //double ig = globent - rightent - leftent;

      if ( useShannonEntropy )
      {
        double esplit = - ( pl * log ( pl ) + ( 1 - pl ) * log ( 1 - pl ) );
        ig = 2 * ig / ( globent + esplit );
      }

      if ( ig > l_bestig )
      {
        l_bestig = ig;
        l_splitval = sval;
      }
    }

    if ( l_bestig > bestig )
    {
      bestig = l_bestig;
      splitop = featsel[f];
      splitval = l_splitval;
    }
  }

#ifdef DEBUG
  cout << "globent: " << globent <<  " bestig " << bestig << " splitval: " << splitval << endl;
#endif
  return bestig;
}

inline double SemSegContextTree3D::getMeanProb (
    const int &x,
    const int &y,
    const int &z,
    const int &channel,
    const MultiChannelImage3DT<unsigned short int> &nodeIndices )
{
  double val = 0.0;

  for ( int tree = 0; tree < nbTrees; tree++ )
  {
    val += forest[tree][nodeIndices.get ( x,y,z,tree ) ].dist[channel];
  }

  return val / ( double ) nbTrees;
}

void SemSegContextTree3D::computeRayFeatImage (
    NICE::MultiChannelImage3DT<double> &feats,
    int firstChannel )
{
  int xsize = feats.width();
  int ysize = feats.height();
  int zsize = feats.depth();

  const int amountDirs = 8;

  // compute ray feature maps from canny image
  for ( int z = 0; z < zsize; z++)
  {
    // canny image from raw channel
    NICE::Image med (xsize,ysize);
    NICE::median ( feats.getChannel( z, 0 ), &med, 2);
    NICE::Image* can = NICE::canny( med, 5, 25);

    for ( int dir = 0; dir < amountDirs; dir++)
    {
      NICE::Matrix dist(xsize,ysize,0);
      NICE::Matrix norm(xsize,ysize,0);
      NICE::Matrix orient(xsize,ysize,0);

      for (int y = 0; y < ysize; y++)
        for ( int x = 0; x < xsize; x++)
        {
          int xo = 0, yo = 0; // offsets
          int theta = 0;

          switch (dir)
          {
            case 0: theta =   0; yo = -1; break;
            case 1: theta =  45; xo =  1; yo = -1; break;
            case 2: theta =  90; xo =  1; x = (xsize-1)-x; break;
            case 3: theta = 135; xo =  1; yo = -1; break;
            case 4: theta = 180; yo =  1; y = (ysize-1)-y; break;
            case 5: theta = 225; xo = -1; yo = 1; y = (ysize-1)-y; break;
            case 6: theta = 270; xo = -1; break;
            case 7: theta = 315; xo = -1; yo = -1; break;
            default: return;
          }

          if (can->getPixelQuick(x,y) != 0
              || x+xo < 0
              || x+xo >= xsize
              || y+yo < 0
              || y+yo >= ysize )
          {
            double gx = feats.get(x, y, z, 1);
            double gy = feats.get(x, y, z, 2);

            //double go = atan2 (gy, gx);

            norm(x, y)   = sqrt(gx*gx+gy*gy);
            orient(x, y) = ( gx*cos(theta)+gy*sin(theta) ) / norm(x,y);
            dist(x, y)   = 0;
          }
          else
          {
            orient(x, y) = orient(x+xo,y+yo);
            norm(x, y)   = norm(x+xo,y+yo);
            dist(x, y)   = dist(x+xo,y+yo) + 1;
          }
        }

      for (int y = 0; y < ysize; y++)
        for (int x = 0; x < xsize; x++)
        {
          // distance feature maps
          feats.set( x, y, z, dist(x,y), firstChannel + dir );
          // norm feature maps
          feats.set( x, y, z, norm(x,y), firstChannel + amountDirs + dir );
          // orientation feature maps
          feats.set( x, y, z, norm(x,y), firstChannel + (amountDirs*2) + dir );
        }
    }

    delete can;
  }
}

void SemSegContextTree3D::updateProbabilityMaps (
    const NICE::MultiChannelImage3DT<unsigned short int> &nodeIndices,
    NICE::MultiChannelImage3DT<double> &feats,
    int firstChannel )
{
  int xsize = feats.width();
  int ysize = feats.height();
  int zsize = feats.depth();

  int classes = ( int ) forest[0][0].dist.size();

  // integral images for context channels (probability maps for each class)
#pragma omp parallel for
    for ( int c = 0; c < classes; c++ )
    {
      for ( int z = 0; z < zsize; z++ )
      {
        for ( int y = 0; y < ysize; y++ )
        {
          for ( int x = 0; x < xsize; x++ )
          {
            double val = getMeanProb ( x, y, z, c, nodeIndices );

            if (useFeat3)
              feats ( x, y, z, firstChannel + c ) = val;

            if (useFeat4)
              feats ( x, y, z, firstChannel + classes + c ) = val;
          }
        }

        // Gaussian filter on probability maps
//        NICE::ImageT<double> img = feats.getChannelT( z, firstChannel+c );
//        NICE::ImageT<double> gF(xsize,ysize);
//        NICE::FilterT<double,double,double> filt;
//        filt.filterGaussSigmaApproximate( img, 2, &gF );
//        for ( int y = 0; y < ysize; y++ )
//          for ( int x = 0; x < xsize; x++ )
//            feats.set(x, y, z, gF.getPixelQuick(x,y), firstChannel+c);
      }

      feats.calcIntegral ( firstChannel + c );
    }
}

inline double computeWeight ( const int &d, const int &dim )
{
  if (d == 0)
    return 0.0;
  else
    return 1.0 / ( pow ( 2, ( double ) ( dim - d + 1 ) ) );
}

void SemSegContextTree3D::train ( const MultiDataset *md )
{
  const LabeledSet trainSet = * ( *md ) ["train"];
  const LabeledSet *trainp = &trainSet;
  
  if ( saveLoadData )
  {
    if ( FileMgt::fileExists ( fileLocation ) )
      read ( fileLocation );
    else
    {
      train ( trainp );
      write ( fileLocation );
    }
  }
  else
  {
    train ( trainp );
  }
}

void SemSegContextTree3D::train ( const LabeledSet * trainp )
{
  int shortsize = numeric_limits<short>::max();

  Timer timer;
  timer.start();
  
  vector<int> zsizeVec;
  getDepthVector ( trainp, zsizeVec, run3Dseg );

  //FIXME: memory usage
  vector<MultiChannelImage3DT<double> > allfeats;
  vector<MultiChannelImage3DT<unsigned short int> > nodeIndices;
  vector<MultiChannelImageT<int> > labels;

  vector<SparseVector*> globalCategorFeats;
  vector<map<int,int> > classesPerImage;

  vector<vector<int> > rSize;
  vector<int> amountRegionpI;

  std::string forbidden_classes_s = conf->gS ( "analysis", "forbidden_classes", "" );
  classnames.getSelection ( forbidden_classes_s, forbidden_classes );

  int imgCounter = 0;
  int amountPixels = 0;

  // How many channels of non-integral type do we have?

  if ( imagetype == IMAGETYPE_RGB )
    rawChannels = 3;
  else
    rawChannels = 1;

  if ( useGradient )
    rawChannels *= 3;

  if ( useWeijer )
    rawChannels += 11;

  if ( useHoiemFeatures )
    rawChannels += 8;

  if ( useAdditionalLayer )
    rawChannels += 1;


///////////////////////////// read input data /////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
  int depthCount = 0;
  vector< string > filelist;
  NICE::MultiChannelImageT<uchar> pixelLabels;

  for (LabeledSet::const_iterator it = trainp->begin(); it != trainp->end(); it++)
  {
    for (std::vector<ImageInfo *>::const_iterator jt = it->second.begin();
         jt != it->second.end(); jt++)
    {
      int classno = it->first;
      ImageInfo & info = *(*jt);
      std::string file = info.img();
      filelist.push_back ( file );
      depthCount++;

      const LocalizationResult *locResult = info.localization();

      // getting groundtruth
      NICE::Image pL;
      pL.resize ( locResult->xsize, locResult->ysize );
      pL.set ( 0 );
      locResult->calcLabeledImage ( pL, ( *classNames ).getBackgroundClass() );
      pixelLabels.addChannel ( pL );

      if ( locResult->size() <= 0 )
      {
        fprintf ( stderr, "WARNING: NO ground truth polygons found for %s !\n",
                  file.c_str() );
        continue;
      }

      fprintf ( stderr, "SSContext: Collecting pixel examples from localization info: %s\n", file.c_str() );

      int depthBoundary = 0;
      if ( run3Dseg )
      {
        depthBoundary = zsizeVec[imgCounter];
      }

      if ( depthCount < depthBoundary ) continue;

      // all image slices collected -> make a 3d image
      NICE::MultiChannelImage3DT<double> imgData;
      make3DImage ( filelist, imgData );

      int xsize = imgData.width();
      int ysize = imgData.height();
      int zsize = imgData.depth();
      amountPixels += xsize * ysize * zsize;

      MultiChannelImageT<int> tmpMat ( xsize, ysize, ( uint ) zsize );
      labels.push_back ( tmpMat );

      nodeIndices.push_back ( MultiChannelImage3DT<unsigned short int> ( xsize, ysize, zsize, nbTrees ) );
      nodeIndices[imgCounter].setAll ( 0 );

//      MultiChannelImage3DT<double> feats;
//      allfeats.push_back ( feats );

      int amountRegions;
      // convert color to L*a*b, add selected feature channels
      addFeatureMaps ( imgData, filelist, amountRegions );
      allfeats.push_back(imgData);

      if ( useFeat1 )
      {
        amountRegionpI.push_back ( amountRegions );
        rSize.push_back ( vector<int> ( amountRegions, 0 ) );
      }

      if ( useCategorization )
      {
        globalCategorFeats.push_back ( new SparseVector() );
        classesPerImage.push_back ( map<int,int>() );
      }

      for ( int x = 0; x < xsize; x++ )
      {
        for ( int y = 0; y < ysize; y++ )
        {
          for ( int z = 0; z < zsize; z++ )
          {
            if ( useFeat1 )
              rSize[imgCounter][allfeats[imgCounter] ( x, y, z, rawChannels ) ]++;

            if ( run3Dseg )
              classno = pixelLabels ( x, y, ( uint ) z );
            else
              classno = pL.getPixelQuick ( x,y );

            labels[imgCounter].set ( x, y, classno, ( uint ) z );

            if ( forbidden_classes.find ( classno ) != forbidden_classes.end() )
              continue;

            labelcounter[classno]++;

            if ( useCategorization )
              classesPerImage[imgCounter][classno] = 1;
          }
        }
      }

      filelist.clear();
      pixelLabels.reInit ( 0,0,0 );
      depthCount = 0;
      imgCounter++;
    }
  }

  int classes = 0;
  for ( map<int, int>::const_iterator mapit = labelcounter.begin();
        mapit != labelcounter.end(); mapit++ )
  {
    labelmap[mapit->first] = classes;
    labelmapback[classes] = mapit->first;
    classes++;
  }

////////////////////////// channel type configuration /////////////////////////
///////////////////////////////////////////////////////////////////////////////

  // Type 0: single pixel & pixel-comparison features on gray value channels
  for ( int i = 0; i < rawChannels; i++ )
    channelType.push_back ( 0 );

  // Type 1: region channel with unsupervised segmentation
  int shift = 0;
  if ( useFeat1 )
  {
    channelType.push_back ( 1 );
    shift++;
  }

  // Type 2: rectangular and Haar-like features on gray value integral channels
  if ( useFeat2 )
    for ( int i = 0; i < rawChannels; i++ )
      channelType.push_back ( 2 );

  // Type 3: type 2 features on context channels
  if ( useFeat3 )
    for ( int i = 0; i < classes; i++ )
      channelType.push_back ( 3 );

  // Type 4: type 0 features on context channels
  if ( useFeat4 )
    for ( int i = 0; i < classes; i++ )
      channelType.push_back ( 4 );

  // Type 5: ray features for shape modeling on canny-map
  if ( useFeat5 )
    for ( int i = 0; i < 24; i++ )
      channelType.push_back ( 5 );

  // 'amountTypes' sets upper bound for usable feature types
  int amountTypes = 6;
  channelsPerType = vector<vector<int> > ( amountTypes, vector<int>() );

  for ( int i = 0; i < ( int ) channelType.size(); i++ )
  {
    channelsPerType[channelType[i]].push_back ( i );
  }

///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////

  vector<vector<vector<double> > > regionProbs;
  if ( useFeat1 )
  {
    for ( int i = 0; i < imgCounter; i++ )
    {
      regionProbs.push_back ( vector<vector<double> > ( amountRegionpI[i], vector<double> ( classes, 0.0 ) ) );
    }
  }

  //balancing
  a = vector<double> ( classes, 0.0 );
  int featcounter = 0;
  for ( int iCounter = 0; iCounter < imgCounter; iCounter++ )
  {
    int xsize = ( int ) nodeIndices[iCounter].width();
    int ysize = ( int ) nodeIndices[iCounter].height();
    int zsize = ( int ) nodeIndices[iCounter].depth();

    for ( int x = 0; x < xsize; x++ )
    {
      for ( int y = 0; y < ysize; y++ )
      {
        for ( int z = 0; z < zsize; z++ )
        {
          featcounter++;
          int cn = labels[iCounter] ( x, y, ( uint ) z );
          if ( labelmap.find ( cn ) == labelmap.end() )
            continue;
          a[labelmap[cn]] ++;
        }
      }
    }
  }

  for ( int i = 0; i < ( int ) a.size(); i++ )
  {
    a[i] /= ( double ) featcounter;
  }

#ifdef VERBOSE
  cout << "\nDistribution:" << endl;
  for ( int i = 0; i < ( int ) a.size(); i++ )
    cout << "class " << i << ": " << a[i] << endl;
#endif

  depth = 0;
  uniquenumber = 0;

  //initialize random forest
  for ( int t = 0; t < nbTrees; t++ )
  {
    vector<TreeNode> singletree;
    singletree.push_back ( TreeNode() );
    singletree[0].dist = vector<double> ( classes, 0.0 );
    singletree[0].depth = depth;
    singletree[0].featcounter = amountPixels;
    singletree[0].nodeNumber = uniquenumber;
    uniquenumber++;
    forest.push_back ( singletree );
  }

  vector<int> startnode ( nbTrees, 0 );

  bool noNewSplit = false;

  timer.stop();
  cout << "\nTime for Pre-Processing: " << timer.getLastAbsolute() << " seconds\n" << endl;

  //////////////////////////// train the classifier ///////////////////////////
  /////////////////////////////////////////////////////////////////////////////
  timer.start();
  while ( !noNewSplit && (depth < maxDepth) )
  {
    depth++;
#ifdef DEBUG
    cout << "depth: " << depth << endl;
#endif
    noNewSplit = true;
    vector<MultiChannelImage3DT<unsigned short int> > lastNodeIndices = nodeIndices;
    vector<vector<vector<double> > > lastRegionProbs = regionProbs;

    if ( useFeat1 )
      for ( int i = 0; i < imgCounter; i++ )
      {
        int numRegions = (int) regionProbs[i].size();
        for ( int r = 0; r < numRegions; r++ )
          for ( int c = 0; c < classes; c++ )
            regionProbs[i][r][c] = 0.0;
      }

    // initialize & update context channels
    for ( int i = 0; i < imgCounter; i++)
      if ( useFeat3 || useFeat4 )
        this->updateProbabilityMaps ( nodeIndices[i], allfeats[i], rawChannels + shift );

#ifdef VERBOSE
    Timer timerDepth;
    timerDepth.start();
#endif

    double weight = computeWeight ( depth, maxDepth )
                    - computeWeight ( depth - 1, maxDepth );

#pragma omp parallel for
    // for each tree
    for ( int tree = 0; tree < nbTrees; tree++ )
    {
      const int t = ( int ) forest[tree].size();
      const int s = startnode[tree];
      startnode[tree] = t;
      double bestig;

      // for each node
      for ( int node = s; node < t; node++ )
      {
        if ( !forest[tree][node].isleaf && forest[tree][node].left < 0 )
        {
          // find best split
          Operation3D *splitfeat = NULL;
          double splitval;
          bestig = getBestSplit ( allfeats, lastNodeIndices, labels, node,
                                   splitfeat, splitval, tree, lastRegionProbs );

          forest[tree][node].feat = splitfeat;
          forest[tree][node].decision = splitval;

          // split the node
          if ( splitfeat != NULL )
          {
            noNewSplit = false;
            int left;
#pragma omp critical
            {
              left = forest[tree].size();
              forest[tree].push_back ( TreeNode() );
              forest[tree].push_back ( TreeNode() );
            }
            int right = left + 1;
            forest[tree][node].left = left;
            forest[tree][node].right = right;
            forest[tree][left].init( depth, classes, uniquenumber);
            int leftu = uniquenumber;
            uniquenumber++;
            forest[tree][right].init( depth, classes, uniquenumber);
            int rightu = uniquenumber;
            uniquenumber++;

#pragma omp parallel for
            for ( int i = 0; i < imgCounter; i++ )
            {
              int xsize = nodeIndices[i].width();
              int ysize = nodeIndices[i].height();
              int zsize = nodeIndices[i].depth();

              for ( int x = 0; x < xsize; x++ )
              {
                for ( int y = 0; y < ysize; y++ )
                {
                  for ( int z = 0; z < zsize; z++ )
                  {
                    if ( nodeIndices[i].get ( x, y, z, tree ) == node )
                    {
                      // get feature value
                      Features feat;
                      feat.feats = &allfeats[i];
                      feat.rProbs = &lastRegionProbs[i];
                      double val = 0.0;
                      val = splitfeat->getVal ( feat, x, y, z );
                      if ( !isfinite ( val ) ) val = 0.0;

#pragma omp critical
                      {
                        int curLabel = labels[i] ( x, y, ( uint ) z );
                        // traverse to left child
                        if ( val < splitval )
                        {
                          nodeIndices[i].set ( x, y, z, left, tree );
                          if ( labelmap.find ( curLabel ) != labelmap.end() )
                            forest[tree][left].dist[labelmap[curLabel]]++;
                          forest[tree][left].featcounter++;
                          if ( useCategorization && leftu < shortsize )
                            ( *globalCategorFeats[i] ) [leftu]+=weight;
                        }
                        // traverse to right child
                        else
                        {
                          nodeIndices[i].set ( x, y, z, right, tree );
                          if ( labelmap.find ( curLabel ) != labelmap.end() )
                            forest[tree][right].dist[labelmap[curLabel]]++;
                          forest[tree][right].featcounter++;

                          if ( useCategorization && rightu < shortsize )
                            ( *globalCategorFeats[i] ) [rightu]+=weight;
                        }
                      }
                    }
                  }
                }
              }
            }

            // normalize distributions in child leaves
            double lcounter = 0.0, rcounter = 0.0;
            for ( int c = 0; c < (int)forest[tree][left].dist.size(); c++ )
            {
              if ( forbidden_classes.find ( labelmapback[c] ) != forbidden_classes.end() )
              {
                forest[tree][left].dist[c] = 0;
                forest[tree][right].dist[c] = 0;
              }
              else
              {
                forest[tree][left].dist[c] /= a[c];
                lcounter += forest[tree][left].dist[c];
                forest[tree][right].dist[c] /= a[c];
                rcounter += forest[tree][right].dist[c];
              }
            }

            assert ( lcounter > 0 && rcounter > 0 );

//            if ( lcounter <= 0 || rcounter <= 0 )
//            {
//              cout << "lcounter : " << lcounter << " rcounter: " << rcounter << endl;
//              cout << "splitval: " << splitval << " splittype: " << splitfeat->writeInfos() << endl;
//              cout << "bestig: " << bestig << endl;

//              for ( int i = 0; i < imgCounter; i++ )
//              {
//                int xsize = nodeIndices[i].width();
//                int ysize = nodeIndices[i].height();
//                int zsize = nodeIndices[i].depth();
//                int counter = 0;

//                for ( int x = 0; x < xsize; x++ )
//                {
//                  for ( int y = 0; y < ysize; y++ )
//                  {
//                    for ( int z = 0; z < zsize; z++ )
//                    {
//                      if ( lastNodeIndices[i].get ( x, y, tree ) == node )
//                      {
//                        if ( ++counter > 30 )
//                          break;

//                        Features feat;
//                        feat.feats = &allfeats[i];
//                        feat.rProbs = &lastRegionProbs[i];

//                        double val = splitfeat->getVal ( feat, x, y, z );
//                        if ( !isfinite ( val ) ) val = 0.0;

//                        cout << "splitval: " << splitval << " val: " << val << endl;
//                      }
//                    }
//                  }
//                }
//              }

//              assert ( lcounter > 0 && rcounter > 0 );
//            }

            for ( int c = 0; c < classes; c++ )
            {
              forest[tree][left].dist[c] /= lcounter;
              forest[tree][right].dist[c] /= rcounter;
            }
          }
          else
          {
            forest[tree][node].isleaf = true;
          }
        }
      }
    }


    if ( useFeat1 )
    {
      for ( int i = 0; i < imgCounter; i++ )
      {
        int xsize = nodeIndices[i].width();
        int ysize = nodeIndices[i].height();
        int zsize = nodeIndices[i].depth();

#pragma omp parallel for
        // set region probability distribution
        for ( int x = 0; x < xsize; x++ )
        {
          for ( int y = 0; y < ysize; y++ )
          {
            for ( int z = 0; z < zsize; z++ )
            {
              for ( int tree = 0; tree < nbTrees; tree++ )
              {
                int node = nodeIndices[i].get ( x, y, z, tree );
                for ( int c = 0; c < classes; c++ )
                {
                  int r = (int) ( allfeats[i] ( x, y, z, rawChannels ) );
                  regionProbs[i][r][c] += forest[tree][node].dist[c];
                }
              }
            }
          }
        }

        // normalize distribution
        int numRegions = (int) regionProbs[i].size();
        for ( int r = 0; r < numRegions; r++ )
        {
          for ( int c = 0; c < classes; c++ )
          {
            regionProbs[i][r][c] /= ( double ) ( rSize[i][r] );
          }
        }
      }
    }

    if ( firstiteration ) firstiteration = false;

#ifdef VERBOSE
    timerDepth.stop();
    cout << "Depth " << depth << ": " << timerDepth.getLastAbsolute() << " seconds" <<endl;
#endif

    lastNodeIndices.clear();
    lastRegionProbs.clear();
  }

  timer.stop();
  cout << "Time for Learning: " << timer.getLastAbsolute() << " seconds\n" << endl;

  //////////////////////// classification using HIK ///////////////////////////
  /////////////////////////////////////////////////////////////////////////////

  if ( useCategorization && fasthik != NULL )
  {
    timer.start();
    uniquenumber = std::min ( shortsize, uniquenumber );
    for ( uint i = 0; i < globalCategorFeats.size(); i++ )
    {
      globalCategorFeats[i]->setDim ( uniquenumber );
      globalCategorFeats[i]->normalize();
    }
    map<int,Vector> ys;

    int cCounter = 0;
    for ( map<int,int>::const_iterator it = labelmap.begin();
          it != labelmap.end(); it++, cCounter++ )
    {
      ys[cCounter] = Vector ( globalCategorFeats.size() );
      for ( int i = 0; i < imgCounter; i++ )
      {
        if ( classesPerImage[i].find ( it->first ) != classesPerImage[i].end() )
        {
          ys[cCounter][i] = 1;
        }
        else
        {
          ys[cCounter][i] = -1;
        }
      }
    }

    fasthik->train( reinterpret_cast<vector<const NICE::SparseVector *>&>(globalCategorFeats), ys);

    timer.stop();
    cerr << "Time for Categorization: " << timer.getLastAbsolute() << " seconds\n" << endl;
  }

#ifdef VERBOSE
  cout << "\nFEATURE USAGE" << endl;
  cout << "#############\n" << endl;

  // amount of used features per feature type
  std::map<int, int> featTypeCounter;
  for ( int tree = 0; tree < nbTrees; tree++ )
  {
    int t = ( int ) forest[tree].size();

    for ( int node = 0; node < t; node++ )
    {
      if ( !forest[tree][node].isleaf && forest[tree][node].left != -1 )
      {
        featTypeCounter[ forest[tree][node].feat->getFeatType() ] += 1;
      }
    }
  }

  cout << "Types:" << endl;
  for ( map<int, int>::const_iterator it = featTypeCounter.begin(); it != featTypeCounter.end(); it++ )
    cout << it->first << ": " << it->second << endl;

  cout << "\nOperations - All:" << endl;
  // used operations
  vector<int> opOverview ( NBOPERATIONS, 0 );
  // relative use of context vs raw features per tree level
  vector<vector<double> > contextOverview ( maxDepth, vector<double> ( 2, 0.0 ) );
  for ( int tree = 0; tree < nbTrees; tree++ )
  {
    int t = ( int ) forest[tree].size();

    for ( int node = 0; node < t; node++ )
    {
#ifdef DEBUG
      printf ( "tree[%i]: left: %i, right: %i", node, forest[tree][node].left, forest[tree][node].right );
#endif

      if ( !forest[tree][node].isleaf && forest[tree][node].left != -1 )
      {
        cout <<  forest[tree][node].feat->writeInfos() << endl;
        opOverview[ forest[tree][node].feat->getOps() ]++;
        contextOverview[forest[tree][node].depth][ ( int ) forest[tree][node].feat->getContext() ]++;
      }
#ifdef DEBUG
      for ( int d = 0; d < ( int ) forest[tree][node].dist.size(); d++ )
      {
        cout << " " << forest[tree][node].dist[d];
      }
      cout << endl;
#endif
    }
  }

  // amount of used features per operation type
  cout << "\nOperations - Summary:" << endl;
  for ( int t = 0; t < ( int ) opOverview.size(); t++ )
  {
    cout << "Ops " << t << ": " << opOverview[ t ] << endl;
  }
  // ratio of used context features per depth level
  cout << "\nContext-Ratio:" << endl;
  for ( int d = 0; d < maxDepth; d++ )
  {
    double sum =  contextOverview[d][0] + contextOverview[d][1];
    if ( sum == 0 )
      sum = 1;

    contextOverview[d][0] /= sum;
    contextOverview[d][1] /= sum;

    cout << "Depth [" << d+1 << "] Normal: " << contextOverview[d][0] << " Context: " << contextOverview[d][1] << endl;
  }
#endif

}

void SemSegContextTree3D::addFeatureMaps (
    NICE::MultiChannelImage3DT<double> &imgData,
    const vector<string> &filelist,
    int &amountRegions )
{
  int xsize = imgData.width();
  int ysize = imgData.height();
  int zsize = imgData.depth();

  amountRegions = 0;

  // RGB to Lab
  if ( imagetype == IMAGETYPE_RGB )
  {
    for ( int z = 0; z < zsize; z++ )
      for ( int y = 0; y < ysize; y++ )
        for ( int x = 0; x < xsize; x++ )
        {
          double R, G, B, X, Y, Z, L, a, b;
          R = ( double )imgData.get( x, y, z, 0 ) / 255.0;
          G = ( double )imgData.get( x, y, z, 1 ) / 255.0;
          B = ( double )imgData.get( x, y, z, 2 ) / 255.0;

          if ( useAltTristimulus )
          {
            ColorConversion::ccRGBtoXYZ( R, G, B, &X, &Y, &Z, 4 );
            ColorConversion::ccXYZtoCIE_Lab( X, Y, Z, &L, &a, &b, 4 );
          }
          else
          {
            ColorConversion::ccRGBtoXYZ( R, G, B, &X, &Y, &Z, 0 );
            ColorConversion::ccXYZtoCIE_Lab( X, Y, Z, &L, &a, &b, 0 );
          }

          imgData.set( x, y, z, L, 0 );
          imgData.set( x, y, z, a, 1 );
          imgData.set( x, y, z, b, 2 );
        }
  }

  // Gradient layers
  if ( useGradient )
  {
    int currentsize = imgData.channels();
    imgData.addChannel ( 2*currentsize );

    for ( int z = 0; z < zsize; z++ )
      for ( int c = 0; c < currentsize; c++ )
      {
        ImageT<double> tmp = imgData.getChannelT(z, c);
        ImageT<double> sobX( xsize, ysize );
        ImageT<double> sobY( xsize, ysize );
        NICE::FilterT<double, double, double>::sobelX ( tmp, sobX );
        NICE::FilterT<double, double, double>::sobelY ( tmp, sobY );
        for ( int y = 0; y < ysize; y++ )
          for ( int x = 0; x < xsize; x++ )
          {
            imgData.set( x, y, z, sobX.getPixelQuick(x,y), c+currentsize );
            imgData.set( x, y, z, sobY.getPixelQuick(x,y), c+(currentsize*2) );
          }
      }
  }

  // Weijer color names
  if ( useWeijer )
  {
    if ( imagetype == IMAGETYPE_RGB )
    {
      int currentsize = imgData.channels();
      imgData.addChannel ( 11 );
      for ( int z = 0; z < zsize; z++ )
      {
        NICE::ColorImage img = imgData.getColor ( z );
        NICE::MultiChannelImageT<double> cfeats;
        lfcw->getFeats ( img, cfeats );
        for ( int c = 0; c < cfeats.channels(); c++)
          for ( int y = 0; y < ysize; y++ )
            for ( int x = 0; x < xsize; x++ )
              imgData.set(x, y, z, cfeats.get(x,y,(uint)c), c+currentsize);
      }
    }
    else
    {
      cerr << "Can't compute weijer features of a grayscale image." << endl;
    }
  }

  // arbitrary additional layer as image
  if ( useAdditionalLayer )
  {
    int currentsize = imgData.channels();
    imgData.addChannel ( 1 );
    for ( int z = 0; z < zsize; z++ )
    {
      vector<string> list;
      StringTools::split ( filelist[z], '/', list );
      string layerPath = StringTools::trim ( filelist[z], list.back() ) + "addlayer/" + list.back();
      NICE::Image layer ( layerPath );
      for ( int y = 0; y < ysize; y++ )
        for ( int x = 0; x < xsize; x++ )
          imgData.set(x, y, z, layer.getPixelQuick(x,y), currentsize);
    }
  }
    
  // read the geometric cues produced by Hoiem et al.
  if ( useHoiemFeatures )
  {
    string hoiemDirectory = conf->gS ( "Features", "hoiem_directory" );
    // we could also give the following set as a config option
    string hoiemClasses_s = "sky 000 090-045 090-090 090-135 090 090-por 090-sol";
    vector<string> hoiemClasses;
    StringTools::split ( hoiemClasses_s, ' ', hoiemClasses );

    int currentsize = imgData.channels();
    imgData.addChannel ( hoiemClasses.size() );
    for ( int z = 0; z < zsize; z++ )
    {
      FileName fn ( filelist[z] );
      fn.removeExtension();
      FileName fnBase = fn.extractFileName();

      for ( vector<string>::const_iterator i = hoiemClasses.begin(); i != hoiemClasses.end(); i++, currentsize++ )
      {
        string hoiemClass = *i;
        FileName fnConfidenceImage ( hoiemDirectory + fnBase.str() + "_c_" + hoiemClass + ".png" );
        if ( ! fnConfidenceImage.fileExists() )
        {
          fthrow ( Exception, "Unable to read the Hoiem geometric confidence image: " << fnConfidenceImage.str() << " (original image is " << filelist[z] << ")" );
        }
        else
        {
          Image confidenceImage ( fnConfidenceImage.str() );
          if ( confidenceImage.width() != xsize || confidenceImage.height() != ysize )
          {
            fthrow ( Exception, "The size of the geometric confidence image does not match with the original image size: " << fnConfidenceImage.str() );
          }

          // copy standard image to double image
          for ( int y = 0 ; y < confidenceImage.height(); y++ )
            for ( int x = 0 ; x < confidenceImage.width(); x++ )
              imgData ( x, y, z, currentsize ) = ( double ) confidenceImage ( x, y );

          currentsize++;
        }
      }
    }
  }

  // region feature (unsupervised segmentation)
  int shift = 0;
  if ( useFeat1 )
  {
    shift = 1;
    MultiChannelImageT<int> regions;
    regions.reInit( xsize, ysize, zsize );
    amountRegions = segmentation->segRegions ( imgData, regions, imagetype );

    int currentsize = imgData.channels();
    imgData.addChannel ( 1 );

    for ( int z = 0; z < ( int ) regions.channels(); z++ )
      for ( int y = 0; y < regions.height(); y++ )
        for ( int x = 0; x < regions.width(); x++ )
          imgData.set ( x, y, z, regions ( x, y, ( uint ) z ), currentsize );

  }

  // intergal images of raw channels
  if ( useFeat2 )
  {
    imgData.addChannel ( rawChannels );

#pragma omp parallel for
    for ( int i = 0; i < rawChannels; i++ )
    {
      int corg = i;
      int cint = i + rawChannels + shift;

      for ( int z = 0; z < zsize; z++ )
        for ( int y = 0; y < ysize; y++ )
          for ( int x = 0; x < xsize; x++ )
            imgData ( x, y, z, cint ) = imgData ( x, y, z, corg );

      imgData.calcIntegral ( cint );
    }
  }

  int classes = classNames->numClasses();

  if ( useFeat3 )
    imgData.addChannel ( classes );

  if ( useFeat4 )
    imgData.addChannel ( classes );

  if ( useFeat5 )
  {
    imgData.addChannel ( 24 );
    this->computeRayFeatImage( imgData, imgData.channels()-24);
  }

}

void SemSegContextTree3D::classify (
    NICE::MultiChannelImage3DT<double> & imgData,
    NICE::MultiChannelImageT<double> & segresult,
    NICE::MultiChannelImage3DT<double> & probabilities,
    const std::vector<std::string> & filelist )
{
  int xsize = imgData.width();
  int ysize = imgData.height();
  int zsize = imgData.depth();

  ////////////////////////// initialize variables /////////////////////////////
  /////////////////////////////////////////////////////////////////////////////

  firstiteration = true;
  depth = 0;

  Timer timer;
  timer.start();

  // classes occurred during training step
  int classes = labelmapback.size();
  // classes defined in config file
  int numClasses = classNames->numClasses();

  // class probabilities by pixel
  probabilities.reInit ( xsize, ysize, zsize, numClasses );
  probabilities.setAll ( 0 );

  // class probabilities by region
  vector<vector<double> > regionProbs;

  // affiliation: pixel <-> (tree,node)
  MultiChannelImage3DT<unsigned short int> nodeIndices ( xsize, ysize, zsize, nbTrees );
  nodeIndices.setAll ( 0 );

  // for categorization
  SparseVector *globalCategorFeat;
  globalCategorFeat = new SparseVector();

  /////////////////////////// get feature values //////////////////////////////
  /////////////////////////////////////////////////////////////////////////////

  // Basic Features
  int amountRegions;
  addFeatureMaps ( imgData, filelist, amountRegions );

  vector<int> rSize;
  int shift = 0;
  if ( useFeat1 )
  {
    shift = 1;
    regionProbs = vector<vector<double> > ( amountRegions, vector<double> ( classes, 0.0 ) );
    rSize = vector<int> ( amountRegions, 0 );
    for ( int z = 0; z < zsize; z++ )
    {
      for ( int y = 0; y < ysize; y++ )
      {
        for ( int x = 0; x < xsize; x++ )
        {
          rSize[imgData ( x, y, z, rawChannels ) ]++;
        }
      }
    }
  }

  ////////////////// traverse image example through trees /////////////////////
  /////////////////////////////////////////////////////////////////////////////

  bool noNewSplit = false;
  for ( int d = 0; d < maxDepth && !noNewSplit; d++ )
  {
    depth++;
    vector<vector<double> > lastRegionProbs = regionProbs;

    if ( useFeat1 )
    {
      int numRegions = ( int ) regionProbs.size();
      for ( int r = 0; r < numRegions; r++ )
        for ( int c = 0; c < classes; c++ )
          regionProbs[r][c] = 0.0;
    }

    if ( depth < maxDepth )
    {
      int firstChannel = rawChannels + shift;
      if ( useFeat3 || useFeat4 )
        this->updateProbabilityMaps ( nodeIndices, imgData, firstChannel );
    }

    double weight = computeWeight ( depth, maxDepth )
                    - computeWeight ( depth - 1, maxDepth );

    noNewSplit = true;

    int tree;
#pragma omp parallel for private(tree)
    for ( tree = 0; tree < nbTrees; tree++ )
    {
      for ( int x = 0; x < xsize; x++ )
      {
        for ( int y = 0; y < ysize; y++ )
        {
          for ( int z = 0; z < zsize; z++ )
          {
            int node = nodeIndices.get ( x, y, z, tree );

            if ( forest[tree][node].left > 0 )
            {
              noNewSplit = false;
              Features feat;
              feat.feats = &imgData;
              feat.rProbs = &lastRegionProbs;

              double val = forest[tree][node].feat->getVal ( feat, x, y, z );
              if ( !isfinite ( val ) ) val = 0.0;

              // traverse to left child
              if ( val < forest[tree][node].decision )
              {
                int left = forest[tree][node].left;
                nodeIndices.set ( x, y, z, left, tree );
#pragma omp critical
                {
                  if ( fasthik != NULL
                       && useCategorization
                       && forest[tree][left].nodeNumber < uniquenumber )
                    ( *globalCategorFeat ) [forest[tree][left].nodeNumber] += weight;
                }
              }
              // traverse to right child
              else
              {
                int right = forest[tree][node].right;
                nodeIndices.set ( x, y, z, right, tree );
#pragma omp critical
                {
                  if ( fasthik != NULL
                       && useCategorization
                       && forest[tree][right].nodeNumber < uniquenumber )
                    ( *globalCategorFeat ) [forest[tree][right].nodeNumber] += weight;
                }
              }
            }
          }
        }
      }
    }

    if ( useFeat1 )
    {
      int xsize = nodeIndices.width();
      int ysize = nodeIndices.height();
      int zsize = nodeIndices.depth();

#pragma omp parallel for
      for ( int x = 0; x < xsize; x++ )
      {
        for ( int y = 0; y < ysize; y++ )
        {
          for ( int z = 0; z < zsize; z++ )
          {
            for ( int tree = 0; tree < nbTrees; tree++ )
            {
              int node = nodeIndices.get ( x, y, z, tree );
              for ( uint c = 0; c < forest[tree][node].dist.size(); c++ )
              {
                int r = (int) imgData ( x, y, z, rawChannels );
                regionProbs[r][c] += forest[tree][node].dist[c];
              }
            }
          }
        }
      }

      int numRegions = (int) regionProbs.size();
      for ( int r = 0; r < numRegions; r++ )
      {
        for ( int c = 0; c < (int) classes; c++ )
        {
          regionProbs[r][c] /= ( double ) ( rSize[r] );
        }
      }
    }

    if ( (depth < maxDepth) && firstiteration ) firstiteration = false;
  }

  vector<int> classesInImg;

  if ( useCategorization )
  {
    if ( cndir != "" )
    {
      for ( int z = 0; z < zsize; z++ )
      {
        vector< string > list;
        StringTools::split ( filelist[z], '/', list );
        string orgname = list.back();

        ifstream infile ( ( cndir + "/" + orgname + ".dat" ).c_str() );
        while ( !infile.eof() && infile.good() )
        {
          int tmp;
          infile >> tmp;
          assert ( tmp >= 0 && tmp < numClasses );
          classesInImg.push_back ( tmp );
        }
      }
    }
    else
    {
      globalCategorFeat->setDim ( uniquenumber );
      globalCategorFeat->normalize();
      ClassificationResult cr = fasthik->classify( globalCategorFeat);
      for ( uint i = 0; i < ( uint ) classes; i++ )
      {
        cerr << cr.scores[i] << " ";
        if ( cr.scores[i] > 0.0/*-0.3*/ )
        {
          classesInImg.push_back ( i );
        }
      }
    }
    cerr << "amount of classes: " << classes << " used classes: " << classesInImg.size() << endl;
  }

  if ( classesInImg.size() == 0 )
  {
    for ( uint i = 0; i < ( uint ) classes; i++ )
    {
      classesInImg.push_back ( i );
    }
  }

  // final labeling step
  if ( pixelWiseLabeling )
  {
    for ( int x = 0; x < xsize; x++ )
    {
      for ( int y = 0; y < ysize; y++ )
      {
        for ( int z = 0; z < zsize; z++ )
        {
          //TODO by nodes instead of pixel?
          double maxProb = - numeric_limits<double>::max();
          int maxClass = 0;

          for ( uint c = 0; c < classesInImg.size(); c++ )
          {
            int i = classesInImg[c];
            double curProb = getMeanProb ( x, y, z, i, nodeIndices );
            probabilities ( x, y, z, labelmapback[i] ) = curProb;

            if ( curProb > maxProb )
            {
              maxProb = curProb;
              maxClass = labelmapback[i];
            }
          }
          assert(maxProb <= 1);

          // copy pixel labeling into segresults (output)
          segresult.set ( x, y, maxClass, ( uint ) z );
        }
      }
    }

#ifdef VISUALIZE
    getProbabilityMap( probabilities );
#endif
  }
  else
  {
    // labeling by region
    NICE::MultiChannelImageT<int> regions;
    int xsize = imgData.width();
    int ysize = imgData.height();
    int zsize = imgData.depth();
    regions.reInit ( xsize, ysize, zsize );

    if ( useFeat1 )
    {
      int rchannel = -1;
      for ( uint i = 0; i < channelType.size(); i++ )
      {
        if ( channelType[i] == 1 )
        {
          rchannel = i;
          break;
        }
      }

      assert ( rchannel > -1 );

      for ( int z = 0; z < zsize; z++ )
      {
        for ( int y = 0; y < ysize; y++ )
        {
          for ( int x = 0; x < xsize; x++ )
          {
            regions.set ( x, y, imgData ( x, y, z, rchannel ), ( uint ) z );
          }
        }
      }
    }
    else
    {
      amountRegions = segmentation->segRegions ( imgData, regions, imagetype );

#ifdef DEBUG
      for ( unsigned int z = 0; z < ( uint ) zsize; z++ )
      {
        NICE::Matrix regmask;
        NICE::ColorImage colorimg ( xsize, ysize );
        NICE::ColorImage marked ( xsize, ysize );
        regmask.resize ( xsize, ysize );
        for ( int y = 0; y < ysize; y++ )
        {
          for ( int x = 0; x < xsize; x++ )
          {
            regmask ( x,y ) = regions ( x,y,z );
            colorimg.setPixelQuick ( x, y, 0, imgData.get ( x,y,z,0 ) );
            colorimg.setPixelQuick ( x, y, 1, imgData.get ( x,y,z,0 ) );
            colorimg.setPixelQuick ( x, y, 2, imgData.get ( x,y,z,0 ) );
          }
        }
        vector<int> colorvals;
        colorvals.push_back ( 255 );
        colorvals.push_back ( 0 );
        colorvals.push_back ( 0 );
        segmentation->markContours ( colorimg, regmask, colorvals, marked );
        std::vector<string> list;
        StringTools::split ( filelist[z], '/', list );
        string savePath = StringTools::trim ( filelist[z], list.back() ) + "marked/" + list.back();
        marked.write ( savePath );
      }
#endif
    }

    regionProbs.clear();
    regionProbs = vector<vector<double> > ( amountRegions, vector<double> ( classes, 0.0 ) );

    vector<int> bestlabels ( amountRegions, labelmapback[classesInImg[0]] );
    for ( int z = 0; z < zsize; z++ )
    {
      for ( int y = 0; y < ysize; y++ )
      {
        for ( int x = 0; x < xsize; x++ )
        {
          int r = regions ( x, y, ( uint ) z );
          for ( uint i = 0; i < classesInImg.size(); i++ )
          {
            int c = classesInImg[i];
            // get mean voting of all trees
            regionProbs[r][c] += getMeanProb ( x, y, z, c, nodeIndices );
          }
        }
      }
    }

    for ( int r = 0; r < amountRegions; r++ )
    {
      double maxProb = regionProbs[r][classesInImg[0]];
      bestlabels[r] = classesInImg[0];

      for ( int c = 1; c < classes; c++ )
      {
        if ( maxProb < regionProbs[r][c] )
        {
          maxProb = regionProbs[r][c];
          bestlabels[r] = c;
        }
      }

      bestlabels[r] = labelmapback[bestlabels[r]];
    }

    // copy region labeling into segresults (output)
    for ( int z = 0; z < zsize; z++ )
    {
      for ( int y = 0; y < ysize; y++ )
      {
        for ( int x = 0; x < xsize; x++ )
        {
          segresult.set ( x, y, bestlabels[regions ( x,y, ( uint ) z ) ], ( uint ) z );
        }
      }
    }

#ifdef WRITEREGIONS
    for ( int z = 0; z < zsize; z++ )
    {
      RegionGraph rg;
      NICE::ColorImage img ( xsize,ysize );
      if ( imagetype == IMAGETYPE_RGB )
      {
        img = imgData.getColor ( z );
      }
      else
      {
        NICE::Image gray = imgData.getChannel ( z );
        for ( int y = 0; y < ysize; y++ )
        {
          for ( int x = 0; x < xsize; x++ )
          {
            int val = gray.getPixelQuick ( x,y );
            img.setPixelQuick ( x, y, val, val, val );
          }
        }
      }

      Matrix regions_tmp ( xsize,ysize );
      for ( int y = 0; y < ysize; y++ )
      {
        for ( int x = 0; x < xsize; x++ )
        {
          regions_tmp ( x,y ) = regions ( x,y, ( uint ) z );
        }
      }
      segmentation->getGraphRepresentation ( img, regions_tmp,  rg );
      for ( uint pos = 0; pos < regionProbs.size(); pos++ )
      {
        rg[pos]->setProbs ( regionProbs[pos] );
      }

      std::string s;
      std::stringstream out;
      std::vector< std::string > list;
      StringTools::split ( filelist[z], '/', list );

      out << "rgout/" << list.back() << ".graph";
      string writefile = out.str();
      rg.write ( writefile );
    }
#endif
  }

  timer.stop();
  cout << "\nTime for Classification: " << timer.getLastAbsolute() << endl;

  // CLEANING UP
  // TODO: operations in "forest"
  while( !ops.empty() )
  {
    vector<Operation3D*> &tops = ops.back();
    while ( !tops.empty() )
      tops.pop_back();

    ops.pop_back();
  }

  delete globalCategorFeat;
}

void SemSegContextTree3D::store ( std::ostream & os, int format ) const
{
  os.precision ( numeric_limits<double>::digits10 + 1 );
  os << nbTrees << endl;
  classnames.store ( os );

  map<int, int>::const_iterator it;

  os << labelmap.size() << endl;
  for ( it = labelmap.begin() ; it != labelmap.end(); it++ )
    os << ( *it ).first << " " << ( *it ).second << endl;

  os << labelmapback.size() << endl;
  for ( it = labelmapback.begin() ; it != labelmapback.end(); it++ )
    os << ( *it ).first << " " << ( *it ).second << endl;

  int trees = forest.size();
  os << trees << endl;

  for ( int t = 0; t < trees; t++ )
  {
    int nodes = forest[t].size();
    os << nodes << endl;
    for ( int n = 0; n < nodes; n++ )
    {
      os << forest[t][n].left << " " << forest[t][n].right << " " << forest[t][n].decision << " " << forest[t][n].isleaf << " " << forest[t][n].depth << " " << forest[t][n].featcounter << " " << forest[t][n].nodeNumber << endl;
      os << forest[t][n].dist << endl;

      if ( forest[t][n].feat == NULL )
        os << -1 << endl;
      else
      {
        os << forest[t][n].feat->getOps() << endl;
        forest[t][n].feat->store ( os );
      }
    }
  }

  os << channelType.size() << endl;
  for ( int i = 0; i < ( int ) channelType.size(); i++ )
  {
    os << channelType[i] << " ";
  }
  os << endl;

  os << rawChannels << endl;

  os << uniquenumber << endl;
}

void SemSegContextTree3D::restore ( std::istream & is, int format )
{
  is >> nbTrees;

  classnames.restore ( is );

  int lsize;
  is >> lsize;

  labelmap.clear();
  for ( int l = 0; l < lsize; l++ )
  {
    int first, second;
    is >> first;
    is >> second;
    labelmap[first] = second;
  }

  is >> lsize;
  labelmapback.clear();
  for ( int l = 0; l < lsize; l++ )
  {
    int first, second;
    is >> first;
    is >> second;
    labelmapback[first] = second;
  }

  int trees;
  is >> trees;
  forest.clear();

  for ( int t = 0; t < trees; t++ )
  {
    vector<TreeNode> tmptree;
    forest.push_back ( tmptree );
    int nodes;
    is >> nodes;

    for ( int n = 0; n < nodes; n++ )
    {
      TreeNode tmpnode;
      forest[t].push_back ( tmpnode );
      is >> forest[t][n].left;
      is >> forest[t][n].right;
      is >> forest[t][n].decision;
      is >> forest[t][n].isleaf;
      is >> forest[t][n].depth;
      is >> forest[t][n].featcounter;
      is >> forest[t][n].nodeNumber;

      is >> forest[t][n].dist;

      int feattype;
      is >> feattype;
      assert ( feattype < NBOPERATIONS );
      forest[t][n].feat = NULL;

      if ( feattype >= 0 )
      {
        for ( uint o = 0; o < ops.size(); o++ )
        {
          for ( uint o2 = 0; o2 < ops[o].size(); o2++ )
          {
            if ( forest[t][n].feat == NULL )
            {
              if ( ops[o][o2]->getOps() == feattype )
              {
                forest[t][n].feat = ops[o][o2]->clone();
                break;
              }
            }
          }
        }

        assert ( forest[t][n].feat != NULL );
        forest[t][n].feat->restore ( is );

      }
    }
  }

  channelType.clear();
  int ctsize;
  is >> ctsize;
  for ( int i = 0; i < ctsize; i++ )
  {
    int tmp;
    is >> tmp;
    switch (tmp)
    {
      case 0: useFeat0 = true; break;
      case 1: useFeat1 = true; break;
      case 2: useFeat2 = true; break;
      case 3: useFeat3 = true; break;
      case 4: useFeat4 = true; break;
      case 5: useFeat5 = true; break;
    }
    channelType.push_back ( tmp );
  }

  // integralMap is deprecated but kept in RESTORE
  // for downwards compatibility!
//  std::vector<std::pair<int, int> > integralMap;
//  integralMap.clear();
//  int iMapSize;
//  is >> iMapSize;
//  for ( int i = 0; i < iMapSize; i++ )
//  {
//    int first;
//    int second;
//    is >> first;
//    is >> second;
//    integralMap.push_back ( pair<int, int> ( first, second ) );
//  }

  is >> rawChannels;

  is >> uniquenumber;
}