#include "SemSegContextTree.h"
#include "vislearning/baselib/Globals.h"
#include "vislearning/baselib/ProgressBar.h"
#include "core/basics/StringTools.h"

#include "vislearning/cbaselib/CachedExample.h"
#include "vislearning/cbaselib/PascalResults.h"
#include "vislearning/baselib/ColorSpace.h"
#include "objrec/segmentation/RSMeanShift.h"
#include "objrec/segmentation/RSGraphBased.h"
#include "core/basics/numerictools.h"

#include "core/basics/Timer.h"
#include "core/basics/vectorio.h"

#include <omp.h>
#include <iostream>

#define BOUND(x,min,max) (((x)<(min))?(min):((x)>(max)?(max):(x)))
#undef LOCALFEATS
//#define LOCALFEATS

using namespace OBJREC;

using namespace std;

using namespace NICE;

class MCImageAccess: public ValueAccess
{

  public:
    virtual double getVal ( const Features &feats, const int &x, const int &y, const int &channel )
    {
      return feats.feats->get ( x, y, channel );
    }

    virtual string writeInfos()
    {
      return "raw";
    }

    virtual ValueTypes getType()
    {
      return RAWFEAT;
    }
};

class ClassificationResultAcess: public ValueAccess
{

  public:
    virtual double getVal ( const Features &feats, const int &x, const int &y, const int &channel )
    {
      return ( *feats.tree ) [feats.cfeats->get ( x,y,feats.cTree ) ].dist[channel];
    }

    virtual string writeInfos()
    {
      return "context";
    }

    virtual ValueTypes getType()
    {
      return CONTEXT;
    }
};

class SparseImageAcess: public ValueAccess
{
  private:
    double scale;

  public:
    virtual double getVal ( const Features &feats, const int &x, const int &y, const int &channel )
    {
      //MultiChannelImageT<SparseVectorInt> textonMap;
      //TODO: implement access
      return -1.0;
    }

    virtual string writeInfos()
    {
      return "context";
    }

    virtual ValueTypes getType()
    {
      return CONTEXT;
    }
};

void Operation::restore ( std::istream &is )
{
  is >> x1;
  is >> x2;
  is >> y1;
  is >> y2;
  is >> channel1;
  is >> channel2;

  int tmp;
  is >> tmp;

  if ( tmp >= 0 )
  {
    if ( tmp == RAWFEAT )
    {
      values = new MCImageAccess();
    }
    else if ( tmp == CONTEXT )
    {
      values = new ClassificationResultAcess();
    }
    else
    {
      throw ( "no valid ValueAccess" );
    }
  }
  else
  {
    values = NULL;
  }
}

std::string Operation::writeInfos()
{
  std::stringstream ss;
  ss << " x1: " << x1 << " y1: " << y1 << " x2: " << x2 << " y2: " << y2 << " c1: " << channel1 << " c2: " << channel2;
  return ss.str();
}

void Operation::set ( int ws, int c1size, int c2size, int c3size, bool useGaussian )
{
  int types = 1;
  if ( c2size > 0 )
  {
    types++;
  }
  if ( c3size > 0 )
  {
    types++;
  }
  
  types = std::min(types, maxtypes);

  int ft = ( int ) ( ( double ) rand() / ( double ) RAND_MAX * ( double ) types );

  if ( ft > 0 )
  {
    ws *= 4;
  }

  if ( useGaussian )
  {
    double sigma = ( double ) ws * 2.0;
    x1 = randGaussDouble ( sigma ) * ( double ) ws;
    x2 = randGaussDouble ( sigma ) * ( double ) ws;
    y1 = randGaussDouble ( sigma ) * ( double ) ws;
    y2 = randGaussDouble ( sigma ) * ( double ) ws;
  }
  else
  {
    x1 = ( int ) ( ( double ) rand() / ( double ) RAND_MAX * ( double ) ws ) - ws / 2;
    x2 = ( int ) ( ( double ) rand() / ( double ) RAND_MAX * ( double ) ws ) - ws / 2;
    y1 = ( int ) ( ( double ) rand() / ( double ) RAND_MAX * ( double ) ws ) - ws / 2;
    y2 = ( int ) ( ( double ) rand() / ( double ) RAND_MAX * ( double ) ws ) - ws / 2;
  }

  if ( ft == RAWFEAT )
  {
    values = new MCImageAccess();
  }
  else if ( ft == CONTEXT )
  {
    values = new ClassificationResultAcess();
  }
  else
  {
    values = new SparseImageAcess();
  }
}

class Minus: public Operation
{
  public:
    virtual double getVal ( const Features &feats, const int &x, const int &y )
    {
      int xsize, ysize;
      getXY ( feats, xsize, ysize );
      double v1 = values->getVal ( feats, BOUND ( x + x1, 0, xsize - 1 ), BOUND ( y + y1, 0, ysize - 1 ), channel1 );
      double v2 = values->getVal ( feats, BOUND ( x + x2, 0, xsize - 1 ), BOUND ( y + y2, 0, ysize - 1 ), channel2 );
      return v1 -v2;
    }

    virtual Operation* clone()
    {
      return new Minus();
    }

    virtual string writeInfos()
    {
      string out = "Minus";

      if ( values != NULL )
        out += values->writeInfos();

      return out + Operation::writeInfos();
    }

    virtual OperationTypes getOps()
    {
      return MINUS;
    }
};



class MinusAbs: public Operation
{

  public:
    virtual double getVal ( const Features &feats, const int &x, const int &y )
    {
      int xsize, ysize;
      getXY ( feats, xsize, ysize );
      double v1 = values->getVal ( feats, BOUND ( x + x1, 0, xsize - 1 ), BOUND ( y + y1, 0, ysize - 1 ), channel1 );
      double v2 = values->getVal ( feats, BOUND ( x + x2, 0, xsize - 1 ), BOUND ( y + y2, 0, ysize - 1 ), channel2 );
      return abs ( v1 -v2 );
    }

    virtual Operation* clone()
    {
      return new MinusAbs();
    };

    virtual string writeInfos()
    {
      string out = "MinusAbs";

      if ( values != NULL )
        out += values->writeInfos();

      return out;
    }

    virtual OperationTypes getOps()
    {
      return MINUSABS;
    }
};

class Addition: public Operation
{

  public:
    virtual double getVal ( const Features &feats, const int &x, const int &y )
    {
      int xsize, ysize;
      getXY ( feats, xsize, ysize );
      double v1 = values->getVal ( feats, BOUND ( x + x1, 0, xsize - 1 ), BOUND ( y + y1, 0, ysize - 1 ), channel1 );
      double v2 = values->getVal ( feats, BOUND ( x + x2, 0, xsize - 1 ), BOUND ( y + y2, 0, ysize -
                                   1 ), channel2 );
      return v1 + v2;
    }

    virtual Operation* clone()
    {
      return new Addition();
    }

    virtual string writeInfos()
    {
      string out = "Addition";

      if ( values != NULL )
        out += values->writeInfos();

      return out + Operation::writeInfos();
    }

    virtual OperationTypes getOps()
    {
      return ADDITION;
    }
};

class Only1: public Operation
{

  public:
    virtual double getVal ( const Features &feats, const int &x, const int &y )
    {
      int xsize, ysize;
      getXY ( feats, xsize, ysize );
      double v1 = values->getVal ( feats, BOUND ( x + x1, 0, xsize - 1 ), BOUND ( y + y1, 0, ysize - 1 ), channel1 );
      return v1;
    }

    virtual Operation* clone()
    {
      return new Only1();
    }

    virtual string writeInfos()
    {
      string out = "Only1";

      if ( values != NULL )
        out += values->writeInfos();

      return out + Operation::writeInfos();
    }

    virtual OperationTypes getOps()
    {
      return ONLY1;
    }
};

class RelativeXPosition: public Operation
{

  public:
    virtual double getVal ( const Features &feats, const int &x, const int &y )
    {
      int xsize, ysize;
      getXY ( feats, xsize, ysize );
      return ( double ) x / ( double ) xsize;
    }

    virtual Operation* clone()
    {
      return new RelativeXPosition();
    }

    virtual string writeInfos()
    {
      return "RelativeXPosition" + Operation::writeInfos();
    }

    virtual OperationTypes getOps()
    {
      return RELATIVEXPOSITION;
    }
};

class RelativeYPosition: public Operation
{

  public:
    virtual double getVal ( const Features &feats, const int &x, const int &y )
    {
      int xsize, ysize;
      getXY ( feats, xsize, ysize );
      return ( double ) x / ( double ) xsize;
    }

    virtual Operation* clone()
    {
      return new RelativeYPosition();
    }

    virtual string writeInfos()
    {
      return "RelativeYPosition" + Operation::writeInfos();
    }

    virtual OperationTypes getOps()
    {
      return RELATIVEYPOSITION;
    }
};

// uses mean of classification in window given by (x1,y1) (x2,y2)

class IntegralOps: public Operation
{

  public:
    virtual void set ( int _x1, int _y1, int _x2, int _y2, int _channel1, int _channel2, ValueAccess *_values )
    {
      x1 = min ( _x1, _x2 );
      y1 = min ( _y1, _y2 );
      x2 = max ( _x1, _x2 );
      y2 = max ( _y1, _y2 );
      channel1 = _channel1;
      channel2 = _channel2;
      values = _values;
    }

    virtual double getVal ( const Features &feats, const int &x, const int &y )
    {
      int xsize, ysize;
      getXY ( feats, xsize, ysize );
      return computeMean ( *feats.integralImg, BOUND ( x + x1, 0, xsize - 1 ), BOUND ( y + y1, 0, ysize - 1 ), BOUND ( x + x2, 0, xsize - 1 ), BOUND ( y + y2, 0, ysize - 1 ), channel1 );
    }

    inline double computeMean ( const NICE::MultiChannelImageT<double> &intImg, const int &uLx, const int &uLy, const int &lRx, const int &lRy, const int &chan )
    {
      double val1 = intImg.get ( uLx, uLy, chan );
      double val2 = intImg.get ( lRx, uLy, chan );
      double val3 = intImg.get ( uLx, lRy, chan );
      double val4 = intImg.get ( lRx, lRy, chan );
      double area = ( lRx - uLx ) * ( lRy - uLy );

      if ( area == 0 )
        return 0.0;

      return ( val1 + val4 - val2 - val3 ) / area;
    }

    virtual Operation* clone()
    {
      return new IntegralOps();
    }

    virtual string writeInfos()
    {
      return "IntegralOps" + Operation::writeInfos();
    }

    virtual OperationTypes getOps()
    {
      return INTEGRAL;
    }
};

//like a global bag of words to model the current appearance of classes in an image without local context

class GlobalFeats: public IntegralOps
{

  public:
    virtual double getVal ( const Features &feats, const int &x, const int &y )
    {
      int xsize, ysize;
      getXY ( feats, xsize, ysize );
      return computeMean ( *feats.integralImg, 0, 0, xsize - 1, ysize - 1, channel1 );
    }

    virtual Operation* clone()
    {
      return new GlobalFeats();
    }

    virtual string writeInfos()
    {
      return "GlobalFeats" + Operation::writeInfos();
    }

    virtual OperationTypes getOps()
    {
      return GLOBALFEATS;
    }
};

//uses mean of Integral image given by x1, y1 with current pixel as center

class IntegralCenteredOps: public IntegralOps
{

  public:
    virtual void set ( int _x1, int _y1, int _x2, int _y2, int _channel1, int _channel2, ValueAccess *_values )
    {
      x1 = abs ( _x1 );
      y1 = abs ( _y1 );
      x2 = abs ( _x2 );
      y2 = abs ( _y2 );
      channel1 = _channel1;
      channel2 = _channel2;
      values = _values;
    }

    virtual double getVal ( const Features &feats, const int &x, const int &y )
    {
      int xsize, ysize;
      getXY ( feats, xsize, ysize );
      return computeMean ( *feats.integralImg, BOUND ( x - x1, 0, xsize - 1 ), BOUND ( y - y1, 0, ysize - 1 ), BOUND ( x + x1, 0, xsize - 1 ), BOUND ( y + y1, 0, ysize - 1 ), channel1 );
    }

    virtual Operation* clone()
    {
      return new IntegralCenteredOps();
    }

    virtual string writeInfos()
    {
      return "IntegralCenteredOps" + Operation::writeInfos();
    }

    virtual OperationTypes getOps()
    {
      return INTEGRALCENT;
    }
};

//uses different of mean of Integral image given by two windows, where (x1,y1) is the width and height of window1 and (x2,y2) of window 2

class BiIntegralCenteredOps: public IntegralCenteredOps
{

  public:
    virtual void set ( int _x1, int _y1, int _x2, int _y2, int _channel1, int _channel2, ValueAccess *_values )
    {
      x1 = min ( abs ( _x1 ), abs ( _x2 ) );
      y1 = min ( abs ( _y1 ), abs ( _y2 ) );
      x2 = max ( abs ( _x1 ), abs ( _x2 ) );
      y2 = max ( abs ( _y1 ), abs ( _y2 ) );
      channel1 = _channel1;
      channel2 = _channel2;
      values = _values;
    }

    virtual double getVal ( const Features &feats, const int &x, const int &y )
    {
      int xsize, ysize;
      getXY ( feats, xsize, ysize );
      return computeMean ( *feats.integralImg, BOUND ( x - x1, 0, xsize - 1 ), BOUND ( y - y1, 0, ysize - 1 ), BOUND ( x + x1, 0, xsize - 1 ), BOUND ( y + y1, 0, ysize - 1 ), channel1 ) - computeMean ( *feats.integralImg, BOUND ( x - x2, 0, xsize - 1 ), BOUND ( y - y2, 0, ysize - 1 ), BOUND ( x + x2, 0, xsize - 1 ), BOUND ( y + y2, 0, ysize - 1 ), channel1 );
    }

    virtual Operation* clone()
    {
      return new BiIntegralCenteredOps();
    }

    virtual string writeInfos()
    {
      return "BiIntegralCenteredOps" + Operation::writeInfos();
    }

    virtual OperationTypes getOps()
    {
      return BIINTEGRALCENT;
    }
};

/** horizontal Haar features
 * ++
 * --
 */

class HaarHorizontal: public IntegralCenteredOps
{
    virtual double getVal ( const Features &feats, const int &x, const int &y )
    {
      int xsize, ysize;
      getXY ( feats, xsize, ysize );

      int tlx = BOUND ( x - x1, 0, xsize - 1 );
      int tly = BOUND ( y - y1, 0, ysize - 1 );
      int lrx = BOUND ( x + x1, 0, xsize - 1 );
      int lry = BOUND ( y + y1, 0, ysize - 1 );

      return computeMean ( *feats.integralImg, tlx, tly, lrx, y, channel1 ) - computeMean ( *feats.integralImg, tlx, y, lrx, lry, channel1 );
    }

    virtual Operation* clone()
    {
      return new HaarHorizontal();
    }

    virtual string writeInfos()
    {
      return "HaarHorizontal" + Operation::writeInfos();
    }

    virtual OperationTypes getOps()
    {
      return HAARHORIZ;
    }
};

/** vertical Haar features
 * +-
 * +-
 */

class HaarVertical: public IntegralCenteredOps
{
    virtual double getVal ( const Features &feats, const int &x, const int &y )
    {
      int xsize, ysize;
      getXY ( feats, xsize, ysize );

      int tlx = BOUND ( x - x1, 0, xsize - 1 );
      int tly = BOUND ( y - y1, 0, ysize - 1 );
      int lrx = BOUND ( x + x1, 0, xsize - 1 );
      int lry = BOUND ( y + y1, 0, ysize - 1 );

      return computeMean ( *feats.integralImg, tlx, tly, x, lry, channel1 ) - computeMean ( *feats.integralImg, x, tly, lrx, lry, channel1 );
    }

    virtual Operation* clone()
    {
      return new HaarVertical();
    }

    virtual string writeInfos()
    {
      return "HaarVertical" + Operation::writeInfos();
    }

    virtual OperationTypes getOps()
    {
      return HAARVERT;
    }
};

/** vertical Haar features
 * +-
 * -+
 */

class HaarDiag: public IntegralCenteredOps
{
    virtual double getVal ( const Features &feats, const int &x, const int &y )
    {
      int xsize, ysize;
      getXY ( feats, xsize, ysize );

      int tlx = BOUND ( x - x1, 0, xsize - 1 );
      int tly = BOUND ( y - y1, 0, ysize - 1 );
      int lrx = BOUND ( x + x1, 0, xsize - 1 );
      int lry = BOUND ( y + y1, 0, ysize - 1 );

      return computeMean ( *feats.integralImg, tlx, tly, x, y, channel1 ) + computeMean ( *feats.integralImg, x, y, lrx, lry, channel1 ) - computeMean ( *feats.integralImg, tlx, y, x, lry, channel1 ) - computeMean ( *feats.integralImg, x, tly, lrx, y, channel1 );
    }

    virtual Operation* clone()
    {
      return new HaarDiag();
    }

    virtual string writeInfos()
    {
      return "HaarDiag" + Operation::writeInfos();
    }

    virtual OperationTypes getOps()
    {
      return HAARDIAG;
    }
};

/** horizontal Haar features
 * +++
 * ---
 * +++
 */

class Haar3Horiz: public BiIntegralCenteredOps
{
    virtual double getVal ( const Features &feats, const int &x, const int &y )
    {
      int xsize, ysize;
      getXY ( feats, xsize, ysize );

      int tlx = BOUND ( x - x2, 0, xsize - 1 );
      int tly = BOUND ( y - y2, 0, ysize - 1 );
      int mtly = BOUND ( y - y1, 0, ysize - 1 );
      int mlry = BOUND ( y + y1, 0, ysize - 1 );
      int lrx = BOUND ( x + x2, 0, xsize - 1 );
      int lry = BOUND ( y + y2, 0, ysize - 1 );

      return computeMean ( *feats.integralImg, tlx, tly, lrx, mtly, channel1 ) - computeMean ( *feats.integralImg, tlx, mtly, lrx, mlry, channel1 ) + computeMean ( *feats.integralImg, tlx, mlry, lrx, lry, channel1 );
    }

    virtual Operation* clone()
    {
      return new Haar3Horiz();
    }

    virtual string writeInfos()
    {
      return "Haar3Horiz" + Operation::writeInfos();
    }

    virtual OperationTypes getOps()
    {
      return HAAR3HORIZ;
    }
};

/** vertical Haar features
 * +-+
 * +-+
 * +-+
 */

class Haar3Vert: public BiIntegralCenteredOps
{
    virtual double getVal ( const Features &feats, const int &x, const int &y )
    {
      int xsize, ysize;
      getXY ( feats, xsize, ysize );

      int tlx = BOUND ( x - x2, 0, xsize - 1 );
      int tly = BOUND ( y - y2, 0, ysize - 1 );
      int mtlx = BOUND ( x - x1, 0, xsize - 1 );
      int mlrx = BOUND ( x + x1, 0, xsize - 1 );
      int lrx = BOUND ( x + x2, 0, xsize - 1 );
      int lry = BOUND ( y + y2, 0, ysize - 1 );

      return computeMean ( *feats.integralImg, tlx, tly, mtlx, lry, channel1 ) - computeMean ( *feats.integralImg, mtlx, tly, mlrx, lry, channel1 ) + computeMean ( *feats.integralImg, mlrx, tly, lrx, lry, channel1 );
    }

    virtual Operation* clone()
    {
      return new Haar3Vert();
    }

    virtual string writeInfos()
    {
      return "Haar3Vert" + Operation::writeInfos();
    }

    virtual OperationTypes getOps()
    {
      return HAAR3VERT;
    }
};

SemSegContextTree::SemSegContextTree ( const Config *conf, const MultiDataset *md )
    : SemanticSegmentation ( conf, & ( md->getClassNames ( "train" ) ) )
{
  this->conf = conf;
  string section = "SSContextTree";
  lfcw = new LFColorWeijer ( conf );

  grid = conf->gI ( section, "grid", 10 );

  maxSamples = conf->gI ( section, "max_samples", 2000 );

  minFeats = conf->gI ( section, "min_feats", 50 );

  maxDepth = conf->gI ( section, "max_depth", 10 );

  windowSize = conf->gI ( section, "window_size", 16 );

  featsPerSplit = conf->gI ( section, "feats_per_split", 200 );

  useShannonEntropy = conf->gB ( section, "use_shannon_entropy", true );

  nbTrees = conf->gI ( section, "amount_trees", 1 );

  string segmentationtype = conf->gS ( section, "segmentation_type", "meanshift" );

  useGaussian = conf->gB ( section, "use_gaussian", true );

  randomTests = conf->gI ( section, "random_tests", 10 );

  bool saveLoadData = conf->gB ( "debug", "save_load_data", false );
  string fileLocation = conf->gS ( "debug", "datafile", "tmp.txt" );

  if ( useGaussian )
    throw ( "there something wrong with using gaussian! first fix it!" );

  pixelWiseLabeling = false;

  if ( segmentationtype == "meanshift" )
    segmentation = new RSMeanShift ( conf );
  else if ( segmentationtype == "none" )
  {
    segmentation = NULL;
    pixelWiseLabeling = true;
  }
  else if ( segmentationtype == "felzenszwalb" )
    segmentation = new RSGraphBased ( conf );
  else
    throw ( "no valid segmenation_type\n please choose between none, meanshift and felzenszwalb\n" );

  ftypes = conf->gI ( section, "features", 2 );;

  string featsec = "Features";

  if ( conf->gB ( featsec, "minus", true ) )
    ops.push_back ( new Minus() );
  if ( conf->gB ( featsec, "minus_abs", true ) )
    ops.push_back ( new MinusAbs() );
  if ( conf->gB ( featsec, "addition", true ) )
    ops.push_back ( new Addition() );
  if ( conf->gB ( featsec, "only1", true ) )
    ops.push_back ( new Only1() );
  if ( conf->gB ( featsec, "rel_x", true ) )
    ops.push_back ( new RelativeXPosition() );
  if ( conf->gB ( featsec, "rel_y", true ) )
    ops.push_back ( new RelativeYPosition() );

  if ( conf->gB ( featsec, "bi_int_cent", true ) )
    cops.push_back ( new BiIntegralCenteredOps() );
  if ( conf->gB ( featsec, "int_cent", true ) )
    cops.push_back ( new IntegralCenteredOps() );
  if ( conf->gB ( featsec, "int", true ) )
    cops.push_back ( new IntegralOps() );
  if ( conf->gB ( featsec, "haar_horz", true ) )
    cops.push_back ( new HaarHorizontal() );
  if ( conf->gB ( featsec, "haar_vert", true ) )
    cops.push_back ( new HaarVertical() );
  if ( conf->gB ( featsec, "haar_diag", true ) )
    cops.push_back ( new HaarDiag() );
  if ( conf->gB ( featsec, "haar3_horz", true ) )
    cops.push_back ( new Haar3Horiz() );
  if ( conf->gB ( featsec, "haar3_vert", true ) )
    cops.push_back ( new Haar3Vert() );
  if ( conf->gB ( featsec, "glob", true ) )
    cops.push_back ( new GlobalFeats() );

  opOverview = vector<int> ( NBOPERATIONS, 0 );
  contextOverview = vector<vector<double> > ( maxDepth, vector<double> ( 2, 0.0 ) );

  calcVal.push_back ( new MCImageAccess() );
  calcVal.push_back ( new ClassificationResultAcess() );

  classnames = md->getClassNames ( "train" );

  ///////////////////////////////////
  // Train Segmentation Context Trees
  ///////////////////////////////////

  if ( saveLoadData )
  {
    if ( FileMgt::fileExists ( fileLocation ) )
      read ( fileLocation );
    else
    {
      train ( md );
      write ( fileLocation );
    }
  }
  else
  {
    train ( md );
  }
}

SemSegContextTree::~SemSegContextTree()
{
}

double SemSegContextTree::getBestSplit ( std::vector<NICE::MultiChannelImageT<double> > &feats, std::vector<NICE::MultiChannelImageT<unsigned short int> > &currentfeats, std::vector<NICE::MultiChannelImageT<double> > &integralImgs, const std::vector<NICE::MatrixT<int> > &labels, int node, Operation *&splitop, double &splitval, const int &tree )
{
  Timer t;
  t.start();
  int imgCount = 0, featdim = 0;

  try
  {
    imgCount = ( int ) feats.size();
    featdim = feats[0].channels();
  }
  catch ( Exception )
  {
    cerr << "no features computed?" << endl;
  }

  double bestig = -numeric_limits< double >::max();

  splitop = NULL;
  splitval = -1.0;

  set<vector<int> >selFeats;
  map<int, int> e;
  int featcounter = forest[tree][node].featcounter;

  if ( featcounter < minFeats )
  {
    //cout << "only " << featcounter << " feats in current node -> it's a leaf" << endl;
    return 0.0;
  }

  vector<double> fraction ( a.size(), 0.0 );

  for ( uint i = 0; i < fraction.size(); i++ )
  {
    if ( forbidden_classes.find ( labelmapback[i] ) != forbidden_classes.end() )
      fraction[i] = 0;
    else
      fraction[i] = ( ( double ) maxSamples ) / ( ( double ) featcounter * a[i] * a.size() );

    //cout << "fraction["<<i<<"]: "<< fraction[i] << " a[" << i << "]: " << a[i] << endl;
  }

  featcounter = 0;

  for ( int iCounter = 0; iCounter < imgCount; iCounter++ )
  {
    int xsize = ( int ) currentfeats[iCounter].width();
    int ysize = ( int ) currentfeats[iCounter].height();

    for ( int x = 0; x < xsize; x++ )
    {
      for ( int y = 0; y < ysize; y++ )
      {
        if ( currentfeats[iCounter].get ( x, y, tree ) == node )
        {
          int cn = labels[iCounter] ( x, y );
          double randD = ( double ) rand() / ( double ) RAND_MAX;

          if ( labelmap.find ( cn ) == labelmap.end() )
            continue;

          if ( randD < fraction[labelmap[cn]] )
          {
            vector<int> tmp ( 3, 0 );
            tmp[0] = iCounter;
            tmp[1] = x;
            tmp[2] = y;
            featcounter++;
            selFeats.insert ( tmp );
            e[cn]++;
          }
        }
      }
    }
  }
  //cout << "size: " << selFeats.size() << endl;
  //getchar();

  map<int, int>::iterator mapit;

  double globent = 0.0;

  for ( mapit = e.begin() ; mapit != e.end(); mapit++ )
  {
    //cout << "class: " << mapit->first << ": " << mapit->second << endl;
    double p = ( double ) ( *mapit ).second / ( double ) featcounter;
    globent += p * log2 ( p );
  }

  globent = -globent;

  if ( globent < 0.5 )
  {
    //cout << "globent to small: " << globent << endl;
    return 0.0;
  }

  int classes = ( int ) forest[tree][0].dist.size();

  featsel.clear();

  for ( int i = 0; i < featsPerSplit; i++ )
  {
    int x1, x2, y1, y2;
    int ft = ( int ) ( ( double ) rand() / ( double ) RAND_MAX * ( double ) ftypes );

    int tmpws = windowSize;

    if ( integralImgs[0].width() == 0 )
      ft = 0;

    if ( ft > 0 )
    {
      tmpws *= 4;
    }

    if ( useGaussian )
    {
      double sigma = ( double ) tmpws * 2.0;
      x1 = randGaussDouble ( sigma ) * ( double ) tmpws;
      x2 = randGaussDouble ( sigma ) * ( double ) tmpws;
      y1 = randGaussDouble ( sigma ) * ( double ) tmpws;
      y2 = randGaussDouble ( sigma ) * ( double ) tmpws;
    }
    else
    {
      x1 = ( int ) ( ( double ) rand() / ( double ) RAND_MAX * ( double ) tmpws ) - tmpws / 2;
      x2 = ( int ) ( ( double ) rand() / ( double ) RAND_MAX * ( double ) tmpws ) - tmpws / 2;
      y1 = ( int ) ( ( double ) rand() / ( double ) RAND_MAX * ( double ) tmpws ) - tmpws / 2;
      y2 = ( int ) ( ( double ) rand() / ( double ) RAND_MAX * ( double ) tmpws ) - tmpws / 2;
    }

    if ( ft == 0 )
    {
      int f1 = ( int ) ( ( double ) rand() / ( double ) RAND_MAX * ( double ) featdim );
      int f2 = ( int ) ( ( double ) rand() / ( double ) RAND_MAX * ( double ) featdim );
      int o = ( int ) ( ( double ) rand() / ( double ) RAND_MAX * ( double ) ops.size() );
      Operation *op = ops[o]->clone();
      op->set ( x1, y1, x2, y2, f1, f2, calcVal[ft] );
      op->setContext ( false );
      featsel.push_back ( op );
    }
    else if ( ft == 1 )
    {

      int opssize = ( int ) ops.size();
      //opssize = 0;
      int o = ( int ) ( ( double ) rand() / ( double ) RAND_MAX * ( ( ( double ) cops.size() ) + ( double ) opssize ) );

      Operation *op;

      if ( o < opssize )
      {
        int chans = ( int ) forest[0][0].dist.size();
        int f1 = ( int ) ( ( double ) rand() / ( double ) RAND_MAX * ( double ) chans );
        int f2 = ( int ) ( ( double ) rand() / ( double ) RAND_MAX * ( double ) chans );
        op = ops[o]->clone();
        op->set ( x1, y1, x2, y2, f1, f2, calcVal[ft] );
        op->setContext ( true );
      }
      else
      {
        int chans = integralImgs[0].channels();
        int f1 = ( int ) ( ( double ) rand() / ( double ) RAND_MAX * ( double ) chans );
        int f2 = ( int ) ( ( double ) rand() / ( double ) RAND_MAX * ( double ) chans );

        o -= opssize;
        op = cops[o]->clone();
        op->set ( x1, y1, x2, y2, f1, f2, calcVal[ft] );
        if ( f1 < forest[0][0].dist.size() )
          op->setContext ( true );
        else
          op->setContext ( false );
      }

      featsel.push_back ( op );
    }
  }

#pragma omp parallel for private(mapit)
  for ( int f = 0; f < featsPerSplit; f++ )
  {
    double l_bestig = -numeric_limits< double >::max();
    double l_splitval = -1.0;
    set<vector<int> >::iterator it;
    vector<double> vals;

    double maxval = -numeric_limits<double>::max();
    double minval = numeric_limits<double>::max();
    for ( it = selFeats.begin() ; it != selFeats.end(); it++ )
    {
      Features feat;
      feat.feats = &feats[ ( *it ) [0]];
      feat.cfeats = &currentfeats[ ( *it ) [0]];
      feat.cTree = tree;
      feat.tree = &forest[tree];
      feat.integralImg = &integralImgs[ ( *it ) [0]];
      double val = featsel[f]->getVal ( feat, ( *it ) [1], ( *it ) [2] );
      vals.push_back ( val );
      maxval = std::max ( val, maxval );
      minval = std::min ( val, minval );
    }

    if ( minval == maxval )
      continue;

    double scale = maxval - minval;
    vector<double> splits;

    for ( int r = 0; r < randomTests; r++ )
    {
      splits.push_back ( ( ( double ) rand() / ( double ) RAND_MAX*scale ) + minval );
    }

    for ( int run = 0 ; run < randomTests; run++ )
    {
      set<vector<int> >::iterator it2;
      double val = splits[run];

      map<int, int> eL, eR;
      int counterL = 0, counterR = 0;
      int counter2 = 0;

      for ( it2 = selFeats.begin() ; it2 != selFeats.end(); it2++, counter2++ )
      {
        int cn = labels[ ( *it2 ) [0]] ( ( *it2 ) [1], ( *it2 ) [2] );
        //cout << "vals[counter2] " << vals[counter2] << " val: " <<  val << endl;

        if ( vals[counter2] < val )
        {
          //left entropie:
          eL[cn] = eL[cn] + 1;
          counterL++;
        }
        else
        {
          //right entropie:
          eR[cn] = eR[cn] + 1;
          counterR++;
        }
      }

      double leftent = 0.0;

      for ( mapit = eL.begin() ; mapit != eL.end(); mapit++ )
      {
        double p = ( double ) ( *mapit ).second / ( double ) counterL;
        leftent -= p * log2 ( p );
      }

      double rightent = 0.0;

      for ( mapit = eR.begin() ; mapit != eR.end(); mapit++ )
      {
        double p = ( double ) ( *mapit ).second / ( double ) counterR;
        rightent -= p * log2 ( p );
      }

      //cout << "rightent: " << rightent << " leftent: " << leftent << endl;

      double pl = ( double ) counterL / ( double ) ( counterL + counterR );

      double ig = globent - ( 1.0 - pl ) * rightent - pl * leftent;

      //double ig = globent - rightent - leftent;

      if ( useShannonEntropy )
      {
        double esplit = - ( pl * log ( pl ) + ( 1 - pl ) * log ( 1 - pl ) );
        ig = 2 * ig / ( globent + esplit );
      }

      if ( ig > l_bestig )
      {
        l_bestig = ig;
        l_splitval = val;
      }
    }

#pragma omp critical
    {
      //cout << "globent: " << globent <<  " bestig " << bestig << " splitfeat: " << splitfeat << " splitval: " << splitval << endl;
      //cout << "globent: " << globent <<  " l_bestig " << l_bestig << " f: " << p << " l_splitval: " << l_splitval << endl;
      //cout << "p: " << featsubset[f] << endl;

      if ( l_bestig > bestig )
      {
        bestig = l_bestig;
        splitop = featsel[f];
        splitval = l_splitval;
      }
    }
  }

  //getchar();
  //splitop->writeInfos();
  //cout<< "ig: " << bestig << endl;
  //FIXME: delete all features!
  /*for(int i = 0; i < featsPerSplit; i++)
  {
   if(featsel[i] != splitop)
    delete featsel[i];
  }*/


#ifdef debug
  cout << "globent: " << globent <<  " bestig " << bestig << " splitval: " << splitval << endl;

#endif
  return bestig;
}

inline double SemSegContextTree::getMeanProb ( const int &x, const int &y, const int &channel, const MultiChannelImageT<unsigned short int> &currentfeats )
{
  double val = 0.0;

  for ( int tree = 0; tree < nbTrees; tree++ )
  {
    val += forest[tree][currentfeats.get ( x,y,tree ) ].dist[channel];
  }

  return val / ( double ) nbTrees;
}

void SemSegContextTree::computeIntegralImage ( const NICE::MultiChannelImageT<SparseVectorInt> &infeats, NICE::MultiChannelImageT<SparseVectorInt> &integralImage )
{
  int xsize = infeats.width();
  int ysize = infeats.height();
  integralImage ( 0, 0 ).add ( infeats.get ( 0, 0 ) );

  //first column
  for ( int y = 1; y < ysize; y++ )
  {
    integralImage ( 0, y ).add ( infeats.get ( 0, y ) );
    integralImage ( 0, y ).add ( integralImage ( 0, y - 1 ) );
  }

  //first row
  for ( int x = 1; x < xsize; x++ )
  {
    integralImage ( x, 0 ).add ( infeats.get ( x, 0 ) );
    integralImage ( x, 0 ).add ( integralImage ( x - 1, 0 ) );
  }

  //rest
  for ( int y = 1; y < ysize; y++ )
  {
    for ( int x = 1; x < xsize; x++ )
    {
      integralImage ( x, y ).add ( infeats.get ( x, y ) );
      integralImage ( x, y ).add ( integralImage ( x, y - 1 ) );
      integralImage ( x, y ).add ( integralImage ( x - 1, y ) );
      integralImage ( x, y ).sub ( integralImage ( x - 1, y - 1 ) );
    }
  }
}

void SemSegContextTree::computeIntegralImage ( const NICE::MultiChannelImageT<unsigned short int> &currentfeats, const NICE::MultiChannelImageT<double> &lfeats, NICE::MultiChannelImageT<double> &integralImage )
{
  int xsize = currentfeats.width();
  int ysize = currentfeats.height();

  int channels = ( int ) forest[0][0].dist.size();
#pragma omp parallel for
  for ( int c = 0; c < channels; c++ )
  {
    integralImage.set ( 0, 0, getMeanProb ( 0, 0, c, currentfeats ), c );

    //first column

    for ( int y = 1; y < ysize; y++ )
    {
      integralImage.set ( 0, y, getMeanProb ( 0, y, c, currentfeats ) + integralImage.get ( 0, y - 1, c ), c );
    }

    //first row
    for ( int x = 1; x < xsize; x++ )
    {
      integralImage.set ( x, 0, getMeanProb ( x, 0, c, currentfeats ) + integralImage.get ( x - 1, 0, c ), c );
    }

    //rest
    for ( int y = 1; y < ysize; y++ )
    {
      for ( int x = 1; x < xsize; x++ )
      {
        double val = getMeanProb ( x, y, c, currentfeats ) + integralImage.get ( x, y - 1, c ) + integralImage.get ( x - 1, y, c ) - integralImage.get ( x - 1, y - 1, c );
        integralImage.set ( x, y, val, c );
      }
    }
  }

  int channels2 = ( int ) lfeats.channels();

  xsize = lfeats.width();
  ysize = lfeats.height();

  if ( integralImage.get ( xsize - 1, ysize - 1, channels ) == 0.0 )
  {
#pragma omp parallel for
    for ( int c1 = 0; c1 < channels2; c1++ )
    {
      int c = channels + c1;
      integralImage.set ( 0, 0, lfeats.get ( 0, 0, c1 ), c );

      //first column

      for ( int y = 1; y < ysize; y++ )
      {
        integralImage.set ( 0, y, lfeats.get ( 0, y, c1 ) + integralImage.get ( 0, y, c ), c );
      }

      //first row
      for ( int x = 1; x < xsize; x++ )
      {
        integralImage.set ( x, 0, lfeats.get ( x, 0, c1 ) + integralImage.get ( x, 0, c ), c );
      }

      //rest
      for ( int y = 1; y < ysize; y++ )
      {
        for ( int x = 1; x < xsize; x++ )
        {
          double val = lfeats.get ( x, y, c1 ) + integralImage.get ( x, y - 1, c ) + integralImage.get ( x - 1, y, c ) - integralImage.get ( x - 1, y - 1, c );
          integralImage.set ( x, y, val, c );
        }
      }
    }
  }
}

inline double computeWeight(const double &d, const double &dim)
{
    return 1.0/(pow(2,(double)(dim-d+1)));
}

void SemSegContextTree::train ( const MultiDataset *md )
{
  const LabeledSet train = * ( *md ) ["train"];
  const LabeledSet *trainp = &train;

  ProgressBar pb ( "compute feats" );
  pb.show();

  //TODO: Speichefresser!, lohnt sich sparse?
  vector<MultiChannelImageT<double> > allfeats;
  vector<MultiChannelImageT<unsigned short int> > currentfeats;
  vector<MatrixT<int> > labels;
  vector<MultiChannelImageT<SparseVectorInt> > textonMap;
  vector<MultiChannelImageT<SparseVectorInt> > integralTexton;


  std::string forbidden_classes_s = conf->gS ( "analysis", "donttrain", "" );

  if ( forbidden_classes_s == "" )
  {
    forbidden_classes_s = conf->gS ( "analysis", "forbidden_classes", "" );
  }

  classnames.getSelection ( forbidden_classes_s, forbidden_classes );

  int imgcounter = 0;

  int amountPixels = 0;

  LOOP_ALL_S ( *trainp )
  {
    EACH_INFO ( classno, info );

    NICE::ColorImage img;

    std::string currentFile = info.img();

    CachedExample *ce = new CachedExample ( currentFile );

    const LocalizationResult *locResult = info.localization();

    if ( locResult->size() <= 0 )
    {
      fprintf ( stderr, "WARNING: NO ground truth polygons found for %s !\n",
                currentFile.c_str() );
      continue;
    }

    fprintf ( stderr, "SemSegCsurka: Collecting pixel examples from localization info: %s\n", currentFile.c_str() );

    int xsize, ysize;
    ce->getImageSize ( xsize, ysize );
    amountPixels += xsize * ysize;

    MatrixT<int> tmpMat ( xsize, ysize );

    currentfeats.push_back ( MultiChannelImageT<unsigned short int> ( xsize, ysize, nbTrees ) );
    currentfeats[imgcounter].setAll ( 0 );
    textonMap.push_back ( MultiChannelImageT<SparseVectorInt> ( xsize / grid + 1, ysize / grid + 1, 1 ));
    integralTexton.push_back ( MultiChannelImageT<SparseVectorInt> ( xsize / grid + 1, ysize / grid + 1, 1 ));
    
    labels.push_back ( tmpMat );
    
    try {
      img = ColorImage ( currentFile );
    } catch ( Exception ) {
      cerr << "SemSeg: error opening image file <" << currentFile << ">" << endl;
      continue;
    }

    Globals::setCurrentImgFN ( currentFile );

    //TODO: resize image?!
    MultiChannelImageT<double> feats;
    allfeats.push_back ( feats );
#ifdef LOCALFEATS
    lfcw->getFeats ( img, allfeats[imgcounter] );
#else
    allfeats[imgcounter].reInit ( xsize, ysize, 3, true );

    for ( int x = 0; x < xsize; x++ )
    {
      for ( int y = 0; y < ysize; y++ )
      {
        for ( int r = 0; r < 3; r++ )
        {
          allfeats[imgcounter].set ( x, y, img.getPixel ( x, y, r ), r );
        }
      }
    }

    allfeats[imgcounter] = ColorSpace::rgbtolab ( allfeats[imgcounter] );
#endif

    // getting groundtruth
    NICE::Image pixelLabels ( xsize, ysize );

    pixelLabels.set ( 0 );

    locResult->calcLabeledImage ( pixelLabels, ( *classNames ).getBackgroundClass() );

    for ( int x = 0; x < xsize; x++ )
    {
      for ( int y = 0; y < ysize; y++ )
      {
        classno = pixelLabels.getPixel ( x, y );
        labels[imgcounter] ( x, y ) = classno;

        if ( forbidden_classes.find ( classno ) != forbidden_classes.end() )
          continue;

        labelcounter[classno]++;

      }
    }

    imgcounter++;

    pb.update ( trainp->count() );
    delete ce;
  }

  pb.hide();

  map<int, int>::iterator mapit;
  int classes = 0;

  for ( mapit = labelcounter.begin(); mapit != labelcounter.end(); mapit++ )
  {
    labelmap[mapit->first] = classes;
    labelmapback[classes] = mapit->first;
    classes++;
  }

  //balancing
  int featcounter = 0;

  a = vector<double> ( classes, 0.0 );

  for ( int iCounter = 0; iCounter < imgcounter; iCounter++ )
  {
    int xsize = ( int ) currentfeats[iCounter].width();
    int ysize = ( int ) currentfeats[iCounter].height();

    for ( int x = 0; x < xsize; x++ )
    {
      for ( int y = 0; y < ysize; y++ )
      {
        featcounter++;
        int cn = labels[iCounter] ( x, y );
        if ( labelmap.find ( cn ) == labelmap.end() )
          continue;
        a[labelmap[cn]] ++;
      }
    }
  }

  for ( int i = 0; i < ( int ) a.size(); i++ )
  {
    a[i] /= ( double ) featcounter;
  }

#ifdef DEBUG
  for ( int i = 0; i < ( int ) a.size(); i++ )
  {
    cout << "a[" << i << "]: " << a[i] << endl;
  }

  cout << "a.size: " << a.size() << endl;

#endif

  depth = 0;

  int uniquenumber = 0;

  for ( int t = 0; t < nbTrees; t++ )
  {
    vector<TreeNode> tree;
    tree.push_back ( TreeNode() );
    tree[0].dist = vector<double> ( classes, 0.0 );
    tree[0].depth = depth;
    tree[0].featcounter = amountPixels;
    tree[0].nodeNumber = uniquenumber;
    uniquenumber++;
    forest.push_back ( tree );
  }

  vector<int> startnode ( nbTrees, 0 );

  bool allleaf = false;
  //int baseFeatSize = allfeats[0].size();

  vector<MultiChannelImageT<double> > integralImgs ( imgcounter, MultiChannelImageT<double>() );
  
  while ( !allleaf && depth < maxDepth )
  {
    depth++;
#ifdef DEBUG
    cout << "depth: " << depth << endl;
#endif
    allleaf = true;
    vector<MultiChannelImageT<unsigned short int> > lastfeats = currentfeats;

#if 1
    Timer timer;
    timer.start();
#endif

    double weight = computeWeight(depth,maxDepth) - computeWeight(depth-1,maxDepth);
    
    if(depth == 1)
    {
      weight = computeWeight(1,maxDepth);
    }
    
    for ( int tree = 0; tree < nbTrees; tree++ )
    {
      int t = ( int ) forest[tree].size();
      int s = startnode[tree];
      startnode[tree] = t;
      //TODO vielleicht parallel wenn nächste schleife trotzdem noch parallelsiert würde, die hat mehr gewicht
      //#pragma omp parallel for
      for ( int i = s; i < t; i++ )
      {
        if ( !forest[tree][i].isleaf && forest[tree][i].left < 0 )
        {
#if 0
          timer.stop();
          cout << "time 1: " << timer.getLast() << endl;
          timer.start();
#endif
          Operation *splitfeat = NULL;
          double splitval;
          double bestig = getBestSplit ( allfeats, lastfeats, integralImgs, labels, i, splitfeat, splitval, tree );
#if 0
          timer.stop();
          double tl = timer.getLast();

          if ( tl > 10.0 )
          {
            cout << "time 2: " << tl << endl;
            cout << "slow split: " << splitfeat->writeInfos() << endl;
            getchar();
          }
          timer.start();
#endif
          forest[tree][i].feat = splitfeat;
          forest[tree][i].decision = splitval;

          if ( splitfeat != NULL )
          {
            allleaf = false;
            int left = forest[tree].size();
            forest[tree].push_back ( TreeNode() );
            forest[tree].push_back ( TreeNode() );
            int right = left + 1;
            forest[tree][i].left = left;
            forest[tree][i].right = right;
            forest[tree][left].dist = vector<double> ( classes, 0.0 );
            forest[tree][right].dist = vector<double> ( classes, 0.0 );
            forest[tree][left].depth = depth;
            forest[tree][right].depth = depth;
            forest[tree][left].featcounter = 0;
            forest[tree][right].featcounter = 0;
            forest[tree][left].nodeNumber = uniquenumber;
            int leftu = uniquenumber;
            uniquenumber++;
            forest[tree][right].nodeNumber = uniquenumber;
            int rightu = uniquenumber;
            uniquenumber++;
            forest[tree][right].featcounter = 0;

#if 0
            timer.stop();
            cout << "time 3: " << timer.getLast() << endl;
            timer.start();
#endif

#pragma omp parallel for
            for ( int iCounter = 0; iCounter < imgcounter; iCounter++ )
            {
              int xsize = currentfeats[iCounter].width();
              int ysize = currentfeats[iCounter].height();

              for ( int x = 0; x < xsize; x++ )
              {
                for ( int y = 0; y < ysize; y++ )
                {
                  if ( currentfeats[iCounter].get ( x, y, tree ) == i )
                  {
                    Features feat;
                    feat.feats = &allfeats[iCounter];
                    feat.cfeats = &lastfeats[iCounter];
                    feat.cTree = tree;
                    feat.tree = &forest[tree];
                    feat.integralImg = &integralImgs[iCounter];
                    double val = splitfeat->getVal ( feat, x, y );

                    int subx = x / grid;
                    int suby = y / grid;

#pragma omp critical
                    if ( val < splitval )
                    {
                      currentfeats[iCounter].set ( x, y, left, tree );
                      if ( labelmap.find ( labels[iCounter] ( x, y ) ) != labelmap.end() )
                        forest[tree][left].dist[labelmap[labels[iCounter] ( x, y ) ]]++;
                      forest[tree][left].featcounter++;
                      SparseVectorInt v;
                      v.insert ( pair<int, double> ( leftu, weight ) );
                      textonMap[iCounter] ( subx, suby ).add ( v );
                    }
                    else
                    {
                      currentfeats[iCounter].set ( x, y, right, tree );
                      if ( labelmap.find ( labels[iCounter] ( x, y ) ) != labelmap.end() )
                        forest[tree][right].dist[labelmap[labels[iCounter] ( x, y ) ]]++;
                      forest[tree][right].featcounter++;
                      //feld im subsampled finden und in diesem rechts hochzählen
                      SparseVectorInt v;
                      v.insert ( pair<int, double> ( rightu, weight ) );
                      textonMap[iCounter] ( subx, suby ).add ( v );
                    }
                  }
                }
              }
            }
#if 0
            timer.stop();
            cout << "time 4: " << timer.getLast() << endl;
            timer.start();
#endif
//            forest[tree][right].featcounter = forest[tree][i].featcounter - forest[tree][left].featcounter;

            double lcounter = 0.0, rcounter = 0.0;

            for ( uint d = 0; d < forest[tree][left].dist.size(); d++ )
            {
              if ( forbidden_classes.find ( labelmapback[d] ) != forbidden_classes.end() )
              {
                forest[tree][left].dist[d] = 0;
                forest[tree][right].dist[d] = 0;
              }
              else
              {
                forest[tree][left].dist[d] /= a[d];
                lcounter += forest[tree][left].dist[d];
                forest[tree][right].dist[d] /= a[d];
                rcounter += forest[tree][right].dist[d];
              }
            }
#if 0
            timer.stop();
            cout << "time 5: " << timer.getLast() << endl;
            timer.start();
#endif
            if ( lcounter <= 0 || rcounter <= 0 )
            {
              cout << "lcounter : " << lcounter << " rcounter: " << rcounter << endl;
              cout << "splitval: " << splitval << " splittype: " << splitfeat->writeInfos() << endl;
              cout << "bestig: " << bestig << endl;

              for ( int iCounter = 0; iCounter < imgcounter; iCounter++ )
              {
                int xsize = currentfeats[iCounter].width();
                int ysize = currentfeats[iCounter].height();
                int counter = 0;

                for ( int x = 0; x < xsize; x++ )
                {
                  for ( int y = 0; y < ysize; y++ )
                  {
                    if ( lastfeats[iCounter].get ( x, y, tree ) == i )
                    {
                      if ( ++counter > 30 )
                        break;

                      Features feat;

                      feat.feats = &allfeats[iCounter];
                      feat.cfeats = &lastfeats[iCounter];
                      feat.cTree = tree;
                      feat.tree = &forest[tree];
                      feat.integralImg = &integralImgs[iCounter];

                      double val = splitfeat->getVal ( feat, x, y );

                      cout << "splitval: " << splitval << " val: " << val << endl;
                    }
                  }
                }
              }

              assert ( lcounter > 0 && rcounter > 0 );
            }

            for ( uint d = 0; d < forest[tree][left].dist.size(); d++ )
            {
              forest[tree][left].dist[d] /= lcounter;
              forest[tree][right].dist[d] /= rcounter;
            }
          }
          else
          {
            forest[tree][i].isleaf = true;
          }
        }
      }
#if 0
      timer.stop();
      cout << "time after tree: " << timer.getLast() << endl;
      timer.start();
#endif
    }
    
    //compute integral image
    int channels = classes + allfeats[0].channels();
#if 0
    timer.stop();
    cout << "time for part0: " << timer.getLast() << endl;
    timer.start();
#endif

    if ( integralImgs[0].width() == 0 )
    {
      for ( int i = 0; i < imgcounter; i++ )
      {
        int xsize = allfeats[i].width();
        int ysize = allfeats[i].height();
        integralImgs[i].reInit ( xsize, ysize, channels );
        integralImgs[i].setAll ( 0.0 );
      }
    }
#if 0
    timer.stop();
    cout << "time for part1: " << timer.getLast() << endl;
    timer.start();
#endif

#pragma omp parallel for
    for ( int i = 0; i < imgcounter; i++ )
    {
      computeIntegralImage ( currentfeats[i], allfeats[i], integralImgs[i] );
      computeIntegralImage ( textonMap[i], integralTexton[i] );
    }

#if 1
    timer.stop();

    cout << "time for depth " << depth << ": " << timer.getLast() << endl;
#endif
  }

#define WRITEGLOB
#ifdef WRITEGLOB
  ofstream outstream("globtrain.feat");
  
  for(int i = 0; i < textonMap.size(); i++)
  {
    set<int> usedclasses;
    for ( uint x = 0; x < labels[i].rows(); x++ )
    {
      for ( uint y = 0; y < labels[i].cols(); y++ )
      {
        int classno = labels[i] ( x, y );

        if ( forbidden_classes.find ( classno ) != forbidden_classes.end() )
          continue;

        usedclasses.insert(classno);
      }
    }
    
    cout << "labels.cols: " << labels[i].cols() << " labels.rows " << labels[i].rows() << endl;
    cout << "currentfeats : " << allfeats[i].width() << " allfeats[i].height(); " << allfeats[i].height() << endl;
    
    set<int>::iterator it;
    for ( it=usedclasses.begin() ; it != usedclasses.end(); it++ )
      outstream << *it << " ";
    outstream << endl;
    integralTexton[i](integralTexton[i].width()-1, integralTexton[i].height()-1).store(outstream);
  }
  
  outstream.close();
#endif
    cout << "uniquenumber " << uniquenumber << endl;
    //getchar();
#ifdef DEBUG
  for ( int tree = 0; tree < nbTrees; tree++ )
  {
    int t = ( int ) forest[tree].size();
    
    for ( int i = 0; i < t; i++ )
    {
      printf ( "tree[%i]: left: %i, right: %i", i, forest[tree][i].left, forest[tree][i].right );

      if ( !forest[tree][i].isleaf && forest[tree][i].left != -1 )
      {
        cout <<  ", feat: " << forest[tree][i].feat->writeInfos() << " ";
        opOverview[forest[tree][i].feat->getOps() ]++;
        contextOverview[forest[tree][i].depth][ ( int ) forest[tree][i].feat->getContext() ]++;
      }

      for ( int d = 0; d < ( int ) forest[tree][i].dist.size(); d++ )
      {
        cout << " " << forest[tree][i].dist[d];
      }

      cout << endl;
    }
  }

  for ( uint c = 0; c < ops.size(); c++ )
  {
    cout << ops[c]->writeInfos() << ": " << opOverview[ops[c]->getOps() ] << endl;
  }

  for ( uint c = 0; c < cops.size(); c++ )
  {
    cout << cops[c]->writeInfos() << ": " << opOverview[cops[c]->getOps() ] << endl;
  }

  for ( int d = 0; d < maxDepth; d++ )
  {
    double sum =  contextOverview[d][0] + contextOverview[d][1];

    contextOverview[d][0] /= sum;
    contextOverview[d][1] /= sum;

    cout << "depth: " << d << " woContext: " << contextOverview[d][0] << " wContext: " << contextOverview[d][1] << endl;
  }

#endif
}

void SemSegContextTree::semanticseg ( CachedExample *ce, NICE::Image & segresult, NICE::MultiChannelImageT<double> & probabilities )
{
  int xpos = 8;
  //int xpos = 15;
  int ypos = 78;

  int xsize;
  int ysize;
  ce->getImageSize ( xsize, ysize );

  int numClasses = classNames->numClasses();

  fprintf ( stderr, "ContextTree classification !\n" );

  probabilities.reInit ( xsize, ysize, numClasses, true );
  probabilities.setAll ( 0 );

  NICE::ColorImage img;

  std::string currentFile = Globals::getCurrentImgFN();

  try {
    img = ColorImage ( currentFile );
  } catch ( Exception ) {
    cerr << "SemSeg: error opening image file <" << currentFile << ">" << endl;
    return;
  }

  //TODO: resize image?!

  MultiChannelImageT<double> feats;
  MultiChannelImageT<SparseVectorInt> textonMap ( xsize / grid + 1, ysize / grid + 1, 1 );
  MultiChannelImageT<SparseVectorInt> integralTexton ( xsize / grid + 1, ysize / grid + 1, 1 );

#ifdef LOCALFEATS
  lfcw->getFeats ( img, feats );

#else
  feats.reInit ( xsize, ysize, 3, true );

  for ( int x = 0; x < xsize; x++ )
  {
    for ( int y = 0; y < ysize; y++ )
    {
      for ( int r = 0; r < 3; r++ )
      {
        feats.set ( x, y, img.getPixel ( x, y, r ), r );
      }
    }
  }

  feats = ColorSpace::rgbtolab ( feats );
#endif

  bool allleaf = false;

  MultiChannelImageT<double> integralImg;

  MultiChannelImageT<unsigned short int> currentfeats ( xsize, ysize, nbTrees );

  currentfeats.setAll ( 0 );

  depth = 0;

  for ( int d = 0; d < maxDepth && !allleaf; d++ )
  {
    depth++;
    
    double weight = computeWeight(depth,maxDepth) - computeWeight(depth-1,maxDepth);
    
    if(depth == 1)
    {
      weight = computeWeight(1,maxDepth);
    }
    
    allleaf = true;
    
    MultiChannelImageT<unsigned short int> lastfeats = currentfeats;

    for ( int tree = 0; tree < nbTrees; tree++ )
    {
      for ( int x = 0; x < xsize; x++ )
      {
        for ( int y = 0; y < ysize; y++ )
        {
          int t = currentfeats.get ( x, y, tree );

          if ( forest[tree][t].left > 0 )
          {
            allleaf = false;
            Features feat;
            feat.feats = &feats;
            feat.cfeats = &lastfeats;
            feat.cTree = tree;
            feat.tree = &forest[tree];
            feat.integralImg = &integralImg;

            double val = forest[tree][t].feat->getVal ( feat, x, y );

            int subx = x / grid;
            int suby = y / grid;

            if ( val < forest[tree][t].decision )
            {
              currentfeats.set ( x, y, forest[tree][t].left, tree );
              SparseVectorInt v;
              v.insert ( pair<int, double> ( forest[tree][forest[tree][t].left].nodeNumber, weight ) );
              textonMap ( subx, suby ).add ( v );
            }
            else
            {
              currentfeats.set ( x, y, forest[tree][t].right, tree );
              SparseVectorInt v;
              v.insert ( pair<int, double> ( forest[tree][forest[tree][t].right].nodeNumber, weight ) );
              textonMap ( subx, suby ).add ( v );
            }

            if ( x == xpos && y == ypos )
            {
              cout << "val: " << val << " decision: " << forest[tree][t].decision << " details: " << forest[tree][t].feat->writeInfos() << endl;

            }
          }
        }
      }

      //compute integral image
      int channels = ( int ) labelmap.size() + feats.channels();

      if ( integralImg.width() == 0 )
      {
        int xsize = feats.width();
        int ysize = feats.height();

        integralImg.reInit ( xsize, ysize, channels );
        integralImg.setAll ( 0.0 );
      }
    }

    computeIntegralImage ( currentfeats, feats, integralImg );
    computeIntegralImage ( textonMap, integralTexton );
  }

  cout << forest[0][currentfeats.get ( xpos, ypos, 0 ) ].dist << endl;

#ifdef WRITEGLOB
  ofstream outstream("globtest.feat",ofstream::app);
  outstream << 0 << endl;
  integralTexton(integralTexton.width()-1, integralTexton.height()-1).store(outstream);
  outstream.close();
#endif
  
  string cndir = conf->gS ( "SSContextTree", "cndir", "" );
  int classes = ( int ) probabilities.numChannels;
  vector<int> useclass ( classes, 1 );

  std::vector< std::string > list;
  StringTools::split ( currentFile, '/', list );

  string orgname = list.back();
#ifdef WRITEGLOB
  ofstream outstream("filelist.txt",ofstream::app);
  outstream << orgname << ".dat" << endl;
#endif
  if ( cndir != "" )
  {
    useclass = vector<int> ( classes, 0 );
    ifstream infile ( ( cndir + "/" + orgname + ".dat" ).c_str() );
    while ( !infile.eof() && infile.good() )
    {
      int tmp;
      infile >> tmp;
      assert(tmp >= 0 && tmp < classes);
      useclass[tmp] = 1;
    }
    
    for(int c = 0; c < classes; c++)
    {
      if(useclass[c] == 0)
      {
        probabilities.set(-numeric_limits< double >::max(), c);
      }
    }
  }
  
  if ( pixelWiseLabeling )
  {
    //finales labeln:
    long int offset = 0;

    for ( int x = 0; x < xsize; x++ )
    {
      for ( int y = 0; y < ysize; y++, offset++ )
      {
        double maxvalue = - numeric_limits<double>::max(); //TODO: das kann auch nur pro knoten gemacht werden, nicht pro pixel
        int maxindex = 0;
        uint s = forest[0][0].dist.size();

        for ( uint i = 0; i < s; i++ )
        {
          int currentclass = labelmapback[i];
          probabilities.data[currentclass][offset] = getMeanProb ( x, y, i, currentfeats );

          if ( probabilities.data[currentclass][offset] > maxvalue )
          {
            maxvalue = probabilities.data[currentclass][offset];
            maxindex = labelmapback[i];
          }

          segresult.setPixel ( x, y, maxindex );
        }

        if ( maxvalue > 1 )
          cout << "maxvalue: " << maxvalue << endl;
      }
    }
  }
  else
  {
    //final labeling using segmentation
    Matrix regions;
    //showImage(img);
    int regionNumber = segmentation->segRegions ( img, regions );
    cout << "regions: " << regionNumber << endl;

    int dSize = forest[0][0].dist.size();
    vector<vector<double> > regionProbs ( regionNumber, vector<double> ( dSize, 0.0 ) );
    vector<int> bestlabels ( regionNumber, 0 );

    /*
    for(int r = 0; r < regionNumber; r++)
    {
     Image over(img.width(), img.height());
     for(int y = 0; y < img.height(); y++)
     {
      for(int x = 0; x < img.width(); x++)
      {
       if(((int)regions(x,y)) == r)
        over.setPixel(x,y,1);
       else
        over.setPixel(x,y,0);
      }
     }
     cout << "r: " << r << endl;
     showImageOverlay(img, over);
    }
    */

    for ( int y = 0; y < img.height(); y++ )
    {
      for ( int x = 0; x < img.width(); x++ )
      {
        int cregion = regions ( x, y );

        for ( int d = 0; d < dSize; d++ )
        {
          regionProbs[cregion][d] += getMeanProb ( x, y, d, currentfeats );
        }
      }
    }

    for ( int r = 0; r < regionNumber; r++ )
    {
      double maxval = regionProbs[r][0];
      bestlabels[r] = 0;

      for ( int d = 1; d < dSize; d++ )
      {
        if ( maxval < regionProbs[r][d] )
        {
          maxval = regionProbs[r][d];
          bestlabels[r] = d;
        }
      }

      bestlabels[r] = labelmapback[bestlabels[r]];
    }

    for ( int y = 0; y < img.height(); y++ )
    {
      for ( int x = 0; x < img.width(); x++ )
      {

        segresult.setPixel ( x, y, bestlabels[regions ( x,y ) ] );
      }
    }
  }

  cout << "segmentation finished" << endl;
}

void SemSegContextTree::store ( std::ostream & os, int format ) const
{
  os << nbTrees << endl;
  classnames.store ( os );

  map<int, int>::const_iterator it;

  os << labelmap.size() << endl;
  for ( it = labelmap.begin() ; it != labelmap.end(); it++ )
    os << ( *it ).first << " " << ( *it ).second << endl;

  os << labelmapback.size() << endl;
  for ( it = labelmapback.begin() ; it != labelmapback.end(); it++ )
    os << ( *it ).first << " " << ( *it ).second << endl;

  int trees = forest.size();
  os << trees << endl;

  for ( int t = 0; t < trees; t++ )
  {
    int nodes = forest[t].size();
    os << nodes << endl;
    for ( int n = 0; n < nodes; n++ )
    {
      os << forest[t][n].left << " " << forest[t][n].right << " " << forest[t][n].decision << " " << forest[t][n].isleaf << " " << forest[t][n].depth << " " << forest[t][n].featcounter << " " << forest[t][n].nodeNumber << endl;
      os << forest[t][n].dist << endl;

      if ( forest[t][n].feat == NULL )
        os << -1 << endl;
      else
      {
        os << forest[t][n].feat->getOps() << endl;
        forest[t][n].feat->store ( os );
      }
    }
  }
}

void SemSegContextTree::restore ( std::istream & is, int format )
{
  is >> nbTrees;

  classnames.restore ( is );

  int lsize;
  is >> lsize;

  labelmap.clear();
  for ( int l = 0; l < lsize; l++ )
  {
    int first, second;
    is >> first;
    is >> second;
    labelmap[first] = second;
  }

  is >> lsize;
  labelmapback.clear();
  for ( int l = 0; l < lsize; l++ )
  {
    int first, second;
    is >> first;
    is >> second;
    labelmapback[first] = second;
  }

  int trees;
  is >> trees;
  forest.clear();

  for ( int t = 0; t < trees; t++ )
  {
    vector<TreeNode> tmptree;
    forest.push_back ( tmptree );
    int nodes;
    is >> nodes;
    //cout << "nodes: " << nodes << endl;
    for ( int n = 0; n < nodes; n++ )
    {
      TreeNode tmpnode;
      forest[t].push_back ( tmpnode );
      is >> forest[t][n].left;
      is >> forest[t][n].right;
      is >> forest[t][n].decision;
      is >> forest[t][n].isleaf;
      is >> forest[t][n].depth;
      is >> forest[t][n].featcounter;

      is >> forest[t][n].nodeNumber;
      is >> forest[t][n].dist;

      int feattype;
      is >> feattype;
      assert ( feattype < NBOPERATIONS );
      forest[t][n].feat = NULL;
      if ( feattype >= 0 )
      {
        for ( uint o = 0; o < ops.size(); o++ )
        {
          if ( ops[o]->getOps() == feattype )
          {
            forest[t][n].feat = ops[o]->clone();
            break;
          }
        }

        if ( forest[t][n].feat == NULL )
        {
          for ( uint o = 0; o < cops.size(); o++ )
          {
            if ( cops[o]->getOps() == feattype )
            {
              forest[t][n].feat = cops[o]->clone();
              break;
            }
          }
        }
        assert ( forest[t][n].feat != NULL );
        forest[t][n].feat->restore ( is );
      }
    }
  }
}