#include "SemSegContextTree.h"
#include "vislearning/baselib/Globals.h"
#include "vislearning/baselib/ProgressBar.h"
#include "core/basics/StringTools.h"

#include "vislearning/cbaselib/CachedExample.h"
#include "vislearning/cbaselib/PascalResults.h"
#include "vislearning/baselib/ColorSpace.h"
#include "objrec/segmentation/RSMeanShift.h"
#include "objrec/segmentation/RSGraphBased.h"
#include "core/basics/numerictools.h"
#include "core/basics/StringTools.h"
#include "core/basics/FileName.h"
#include "vislearning/baselib/ICETools.h"

#include "core/basics/Timer.h"
#include "core/basics/vectorio.h"
#include "core/image/FilterT.h"

#include <omp.h>
#include <iostream>

#undef WRITEGLOB
#undef TEXTONMAP

#define DEBUG

using namespace OBJREC;
using namespace std;
using namespace NICE;

SemSegContextTree::SemSegContextTree (const Config *conf, const MultiDataset *md)
    : SemanticSegmentation (conf, & (md->getClassNames ("train")))
{
  this->conf = conf;
  string section = "SSContextTree";
  lfcw = new LFColorWeijer (conf);
  firstiteration = true;

  grid = conf->gI (section, "grid", 10);

  maxSamples = conf->gI (section, "max_samples", 2000);

  minFeats = conf->gI (section, "min_feats", 50);

  maxDepth = conf->gI (section, "max_depth", 10);

  windowSize = conf->gI (section, "window_size", 16);

  featsPerSplit = conf->gI (section, "feats_per_split", 200);

  useShannonEntropy = conf->gB (section, "use_shannon_entropy", true);

  nbTrees = conf->gI (section, "amount_trees", 1);

  string segmentationtype = conf->gS (section, "segmentation_type", "meanshift");

  randomTests = conf->gI (section, "random_tests", 10);

  bool saveLoadData = conf->gB ("debug", "save_load_data", false);
  string fileLocation = conf->gS ("debug", "datafile", "tmp.txt");

  pixelWiseLabeling = false;

  useRegionFeature = conf->gB (section, "use_reagion_feat", true);
  if (segmentationtype == "meanshift")
    segmentation = new RSMeanShift (conf);
  else if (segmentationtype == "none")
  {
    segmentation = NULL;
    pixelWiseLabeling = true;
    useRegionFeature = false;
  }
  else if (segmentationtype == "felzenszwalb")
    segmentation = new RSGraphBased (conf);
  else
    throw ("no valid segmenation_type\n please choose between none, meanshift and felzenszwalb\n");

  ftypes = conf->gI (section, "features", 100);;

  string featsec = "Features";

  vector<Operation*> tops;

  if (conf->gB (featsec, "minus", true))
    tops.push_back (new Minus());
  if (conf->gB (featsec, "minus_abs", true))
    tops.push_back (new MinusAbs());
  if (conf->gB (featsec, "addition", true))
    tops.push_back (new Addition());
  if (conf->gB (featsec, "only1", true))
    tops.push_back (new Only1());
  if (conf->gB (featsec, "rel_x", true))
    tops.push_back (new RelativeXPosition());
  if (conf->gB (featsec, "rel_y", true))
    tops.push_back (new RelativeYPosition());

  ops.push_back (tops);

  tops.clear();
  tops.push_back (new RegionFeat());
  ops.push_back (tops);

  tops.clear();
  if (conf->gB (featsec, "int", true))
    tops.push_back (new IntegralOps());
  if (conf->gB (featsec, "bi_int_cent", true))
    tops.push_back (new BiIntegralCenteredOps());
  if (conf->gB (featsec, "int_cent", true))
    tops.push_back (new IntegralCenteredOps());
  if (conf->gB (featsec, "haar_horz", true))
    tops.push_back (new HaarHorizontal());
  if (conf->gB (featsec, "haar_vert", true))
    tops.push_back (new HaarVertical());
  if (conf->gB (featsec, "haar_diag", true))
    tops.push_back (new HaarDiag());
  if (conf->gB (featsec, "haar3_horz", true))
    tops.push_back (new Haar3Horiz());
  if (conf->gB (featsec, "haar3_vert", true))
    tops.push_back (new Haar3Vert());
  if (conf->gB (featsec, "glob", true))
    tops.push_back (new GlobalFeats());

  ops.push_back (tops);
  ops.push_back (tops);

  tops.clear();
  if (conf->gB (featsec, "minus", true))
    tops.push_back (new Minus());
  if (conf->gB (featsec, "minus_abs", true))
    tops.push_back (new MinusAbs());
  if (conf->gB (featsec, "addition", true))
    tops.push_back (new Addition());
  if (conf->gB (featsec, "only1", true))
    tops.push_back (new Only1());
  if (conf->gB (featsec, "rel_x", true))
    tops.push_back (new RelativeXPosition());
  if (conf->gB (featsec, "rel_y", true))
    tops.push_back (new RelativeYPosition());

  ops.push_back (tops);

  useGradient = conf->gB (featsec, "use_gradient", true);

  useWeijer = conf->gB (featsec, "use_weijer", true);

  // geometric features of hoiem
  useHoiemFeatures = conf->gB (featsec, "use_hoiem_features", false);
  if (useHoiemFeatures)
  {
    hoiemDirectory = conf->gS (featsec, "hoiem_directory");
  }

  opOverview = vector<int> (NBOPERATIONS, 0);
  contextOverview = vector<vector<double> > (maxDepth, vector<double> (2, 0.0));

  calcVal.push_back (new MCImageAccess());
  calcVal.push_back (new MCImageAccess());
  calcVal.push_back (new MCImageAccess());
  calcVal.push_back (new MCImageAccess());
  calcVal.push_back (new ClassificationResultAccess());


  classnames = md->getClassNames ("train");

  ///////////////////////////////////
  // Train Segmentation Context Trees
  ///////////////////////////////////

  if (saveLoadData)
  {
    if (FileMgt::fileExists (fileLocation))
      read (fileLocation);
    else
    {
      train (md);
      write (fileLocation);
    }
  }
  else
  {
    train (md);
  }
}

SemSegContextTree::~SemSegContextTree()
{
}

double SemSegContextTree::getBestSplit (std::vector<NICE::MultiChannelImageT<double> > &feats, std::vector<NICE::MultiChannelImageT<unsigned short int> > &currentfeats, const std::vector<NICE::MatrixT<int> > &labels, int node, Operation *&splitop, double &splitval, const int &tree, vector<vector<vector<double> > > &regionProbs)
{
  Timer t;
  t.start();
  int imgCount = 0;

  try
  {
    imgCount = (int)feats.size();
  }
  catch (Exception)
  {
    cerr << "no features computed?" << endl;
  }

  double bestig = -numeric_limits< double >::max();

  splitop = NULL;
  splitval = -1.0;

  set<vector<int> >selFeats;
  map<int, int> e;
  int featcounter = forest[tree][node].featcounter;

  if (featcounter < minFeats)
  {
    return 0.0;
  }

  vector<double> fraction (a.size(), 0.0);

  for (uint i = 0; i < fraction.size(); i++)
  {
    if (forbidden_classes.find (labelmapback[i]) != forbidden_classes.end())
      fraction[i] = 0;
    else
      fraction[i] = ((double)maxSamples) / ((double)featcounter * a[i] * a.size());
  }

  featcounter = 0;

  for (int iCounter = 0; iCounter < imgCount; iCounter++)
  {
    int xsize = (int)currentfeats[iCounter].width();
    int ysize = (int)currentfeats[iCounter].height();

    for (int x = 0; x < xsize; x++)
    {
      for (int y = 0; y < ysize; y++)
      {
        if (currentfeats[iCounter].get (x, y, tree) == node)
        {
          int cn = labels[iCounter] (x, y);
          double randD = (double)rand() / (double)RAND_MAX;

          if (labelmap.find (cn) == labelmap.end())
            continue;

          if (randD < fraction[labelmap[cn]])
          {
            vector<int> tmp (3, 0);
            tmp[0] = iCounter;
            tmp[1] = x;
            tmp[2] = y;
            featcounter++;
            selFeats.insert (tmp);
            e[cn]++;
          }
        }
      }
    }
  }

  map<int, int>::iterator mapit;

  double globent = 0.0;

  for (mapit = e.begin() ; mapit != e.end(); mapit++)
  {
    double p = (double)(*mapit).second / (double)featcounter;
    globent += p * log2 (p);
  }

  globent = -globent;

  if (globent < 0.5)
  {
    return 0.0;
  }

  /** vector of all possible features */
  std::vector<Operation*> featsel;

  for (int i = 0; i < featsPerSplit; i++)
  {
    int x1, x2, y1, y2;
    int ft = (int)((double)rand() / (double)RAND_MAX * (double)ftypes);

    int tmpws = windowSize;

    if (firstiteration)
      ft = 0;

    if (channelsPerType[ft].size() == 0)
    {
      ft = 0;
    }

    if (ft > 1)
    {
      //use larger window size for context features
      tmpws *= 4;
    }

    x1 = (int)((double)rand() / (double)RAND_MAX * (double)tmpws) - tmpws / 2;
    x2 = (int)((double)rand() / (double)RAND_MAX * (double)tmpws) - tmpws / 2;
    y1 = (int)((double)rand() / (double)RAND_MAX * (double)tmpws) - tmpws / 2;
    y2 = (int)((double)rand() / (double)RAND_MAX * (double)tmpws) - tmpws / 2;

    int f1 = (int)((double)rand() / (double)RAND_MAX * (double)channelsPerType[ft].size());
    int f2 = f1;
    if ((double)rand() / (double)RAND_MAX > 0.5)
      f2 = (int)((double)rand() / (double)RAND_MAX * (double)channelsPerType[ft].size());
    int o = (int)((double)rand() / (double)RAND_MAX * (double)ops[ft].size());

    f1 = channelsPerType[ft][f1];
    f2 = channelsPerType[ft][f2];
    if(ft == 1)
    {
      int classes = (int)regionProbs[0][0].size();
      f2 = (int)((double)rand() / (double)RAND_MAX * (double)classes);
    }
    
    Operation *op = ops[ft][o]->clone();

    op->set(x1, y1, x2, y2, f1, f2, calcVal[ft]);
    op->setFeatType(ft);

    if (ft == 3 || ft == 4)
      op->setContext(true);
    else
      op->setContext(false);

    featsel.push_back (op);
  }

  for (int f = 0; f < featsPerSplit; f++)
  {
    double l_bestig = -numeric_limits< double >::max();
    double l_splitval = -1.0;
    set<vector<int> >::iterator it;
    vector<double> vals;

    double maxval = -numeric_limits<double>::max();
    double minval = numeric_limits<double>::max();
    for (it = selFeats.begin() ; it != selFeats.end(); it++)
    {
      Features feat;
      feat.feats = &feats[ (*it) [0]];
      feat.cfeats = &currentfeats[ (*it) [0]];
      feat.cTree = tree;
      feat.tree = &forest[tree];
      feat.rProbs = &regionProbs[(*it) [0]];
      
      double val = featsel[f]->getVal (feat, (*it) [1], (*it) [2]);
      vals.push_back (val);
      maxval = std::max (val, maxval);
      minval = std::min (val, minval);
    }

    if (minval == maxval)
      continue;

    double scale = maxval - minval;
    vector<double> splits;

    for (int r = 0; r < randomTests; r++)
    {
      splits.push_back (((double)rand() / (double)RAND_MAX*scale) + minval);
    }

    for (int run = 0 ; run < randomTests; run++)
    {
      set<vector<int> >::iterator it2;
      double val = splits[run];

      map<int, int> eL, eR;
      int counterL = 0, counterR = 0;
      int counter2 = 0;

      for (it2 = selFeats.begin() ; it2 != selFeats.end(); it2++, counter2++)
      {
        int cn = labels[ (*it2) [0]] ((*it2) [1], (*it2) [2]);
        //cout << "vals[counter2] " << vals[counter2] << " val: " <<  val << endl;

        if (vals[counter2] < val)
        {
          //left entropie:
          eL[cn] = eL[cn] + 1;
          counterL++;
        }
        else
        {
          //right entropie:
          eR[cn] = eR[cn] + 1;
          counterR++;
        }
      }

      double leftent = 0.0;

      for (mapit = eL.begin() ; mapit != eL.end(); mapit++)
      {
        double p = (double)(*mapit).second / (double)counterL;
        leftent -= p * log2 (p);
      }

      double rightent = 0.0;

      for (mapit = eR.begin() ; mapit != eR.end(); mapit++)
      {
        double p = (double)(*mapit).second / (double)counterR;
        rightent -= p * log2 (p);
      }

      //cout << "rightent: " << rightent << " leftent: " << leftent << endl;

      double pl = (double)counterL / (double)(counterL + counterR);

      double ig = globent - (1.0 - pl) * rightent - pl * leftent;

      //double ig = globent - rightent - leftent;

      if (useShannonEntropy)
      {
        double esplit = - (pl * log (pl) + (1 - pl) * log (1 - pl));
        ig = 2 * ig / (globent + esplit);
      }

      if (ig > l_bestig)
      {
        l_bestig = ig;
        l_splitval = val;
      }
    }

    if (l_bestig > bestig)
    {
      bestig = l_bestig;
      splitop = featsel[f];
      splitval = l_splitval;
    }
  }

  //FIXME: delete all features!
  /*for(int i = 0; i < featsPerSplit; i++)
  {
   if(featsel[i] != splitop)
    delete featsel[i];
  }*/


#ifdef DEBUG
  //cout << "globent: " << globent <<  " bestig " << bestig << " splitval: " << splitval << endl;
#endif
  return bestig;
}

inline double SemSegContextTree::getMeanProb (const int &x, const int &y, const int &channel, const MultiChannelImageT<unsigned short int> &currentfeats)
{
  double val = 0.0;

  for (int tree = 0; tree < nbTrees; tree++)
  {
    val += forest[tree][currentfeats.get (x,y,tree) ].dist[channel];
  }

  return val / (double)nbTrees;
}

void SemSegContextTree::computeIntegralImage (const NICE::MultiChannelImageT<SparseVectorInt> &infeats, NICE::MultiChannelImageT<SparseVectorInt> &integralImage)
{
  int xsize = infeats.width();
  int ysize = infeats.height();
  integralImage (0, 0).add (infeats.get (0, 0));

  //first column
  for (int y = 1; y < ysize; y++)
  {
    integralImage (0, y).add (infeats.get (0, y));
    integralImage (0, y).add (integralImage (0, y - 1));
  }

  //first row
  for (int x = 1; x < xsize; x++)
  {
    integralImage (x, 0).add (infeats.get (x, 0));
    integralImage (x, 0).add (integralImage (x - 1, 0));
  }

  //rest
  for (int y = 1; y < ysize; y++)
  {
    for (int x = 1; x < xsize; x++)
    {
      integralImage (x, y).add (infeats.get (x, y));
      integralImage (x, y).add (integralImage (x, y - 1));
      integralImage (x, y).add (integralImage (x - 1, y));
      integralImage (x, y).sub (integralImage (x - 1, y - 1));
    }
  }
}

void SemSegContextTree::computeIntegralImage (const NICE::MultiChannelImageT<unsigned short int> &currentfeats, NICE::MultiChannelImageT<double> &feats, int firstChannel)
{
  int xsize = currentfeats.width();
  int ysize = currentfeats.height();

  xsize = feats.width();
  ysize = feats.height();

  if (firstiteration)
  {
#pragma omp parallel for
    for (int it = 0; it < (int)integralMap.size(); it++)
    {
      int corg = integralMap[it].first;
      int cint = integralMap[it].second;

      for (int y = 0; y < ysize; y++)
      {
        for (int x = 0; x < xsize; x++)
        {
          feats(x, y, cint) = feats(x, y, corg);
        }
      }
      feats.calcIntegral(cint);
    }
  }

  int channels = (int)forest[0][0].dist.size();

#pragma omp parallel for
  for (int c = 0; c < channels; c++)
  {

    feats (0, 0, firstChannel + c) = getMeanProb (0, 0, c, currentfeats);

    //first column
    for (int y = 1; y < ysize; y++)
    {
      feats (0, y, firstChannel + c) = getMeanProb (0, y, c, currentfeats)
                                       + feats (0, y - 1, firstChannel + c);
    }

    //first row
    for (int x = 1; x < xsize; x++)
    {
      feats (x, 0, firstChannel + c) = getMeanProb (x, 0, c, currentfeats)
                                       + feats (x - 1, 0, firstChannel + c);
    }

    //rest
    for (int y = 1; y < ysize; y++)
    {
      for (int x = 1; x < xsize; x++)
      {
        feats (x, y, firstChannel + c) = getMeanProb (x, y, c, currentfeats)
                                         + feats (x, y - 1, firstChannel + c)
                                         + feats (x - 1, y, firstChannel + c)
                                         - feats (x - 1, y - 1, firstChannel + c);
      }
    }
  }
}

inline double computeWeight (const double &d, const double &dim)
{
  return 1.0 / (pow (2, (double)(dim - d + 1)));
}

void SemSegContextTree::train (const MultiDataset *md)
{
  Timer timer;
  timer.start();
  const LabeledSet train = * (*md) ["train"];
  const LabeledSet *trainp = &train;

  ProgressBar pb ("compute feats");
  pb.show();

  //TODO: Speichefresser!, lohnt sich sparse?
  vector<MultiChannelImageT<double> > allfeats;
  vector<MultiChannelImageT<unsigned short int> > currentfeats;
  vector<MatrixT<int> > labels;
#ifdef TEXTONMAP
  vector<MultiChannelImageT<SparseVectorInt> > textonMap;
  vector<MultiChannelImageT<SparseVectorInt> > integralTexton;
#endif

  std::string forbidden_classes_s = conf->gS ("analysis", "donttrain", "");

  vector<vector<vector<double> > > regionProbs;
  vector<vector<int> > rSize;
  vector<int> amountRegionpI;

  if (forbidden_classes_s == "")
  {
    forbidden_classes_s = conf->gS ("analysis", "forbidden_classes", "");
  }

  classnames.getSelection (forbidden_classes_s, forbidden_classes);

  int imgcounter = 0;

  int amountPixels = 0;

  ////////////////////////////////////////////////////
  //define which featurextraction methods should be used for each channel
  rawChannels = 3;

  // how many channels without integral image
  int shift = 0;

  if (useGradient)
    rawChannels *= 2;

  if (useWeijer)
    rawChannels += 11;

  if (useHoiemFeatures)
    rawChannels += 8;

  // gray value images
  for (int i = 0; i < rawChannels; i++)
  {
    channelType.push_back (0);
  }

  // regions
  if (useRegionFeature)
  {
    channelType.push_back (1);
    shift++;
  }

///////////////////////////////////////////////////////////////////

  LOOP_ALL_S (*trainp)
  {
    EACH_INFO (classno, info);

    NICE::ColorImage img;

    std::string currentFile = info.img();

    CachedExample *ce = new CachedExample (currentFile);

    const LocalizationResult *locResult = info.localization();

    if (locResult->size() <= 0)
    {
      fprintf (stderr, "WARNING: NO ground truth polygons found for %s !\n",
               currentFile.c_str());
      continue;
    }

    fprintf (stderr, "SemSegCsurka: Collecting pixel examples from localization info: %s\n", currentFile.c_str());

    int xsize, ysize;
    ce->getImageSize (xsize, ysize);
    amountPixels += xsize * ysize;

    MatrixT<int> tmpMat (xsize, ysize);

    currentfeats.push_back (MultiChannelImageT<unsigned short int> (xsize, ysize, nbTrees));
    currentfeats[imgcounter].setAll (0);
#ifdef TEXTONMAP
    textonMap.push_back (MultiChannelImageT<SparseVectorInt> (xsize / grid + 1, ysize / grid + 1, 1));
    integralTexton.push_back (MultiChannelImageT<SparseVectorInt> (xsize / grid + 1, ysize / grid + 1, 1));
#endif

    labels.push_back (tmpMat);

    try {
      img = ColorImage (currentFile);
    } catch (Exception) {
      cerr << "SemSeg: error opening image file <" << currentFile << ">" << endl;
      continue;
    }

    Globals::setCurrentImgFN (currentFile);

    //TODO: resize image?!
    MultiChannelImageT<double> feats;
    allfeats.push_back (feats);

    int amountRegions;
    // read image and do some simple transformations
    extractBasicFeatures (allfeats[imgcounter], img, currentFile, amountRegions);

    if (useRegionFeature)
    {
      amountRegionpI.push_back(amountRegions);
      rSize.push_back(vector<int>(amountRegions, 0));
      for (int y = 0; y < ysize; y++)
      {
        for (int x = 0; x < xsize; x++)
        {
          rSize[imgcounter][allfeats[imgcounter](x, y, rawChannels)]++;
        }
      }
    }

    // getting groundtruth
    NICE::Image pixelLabels (xsize, ysize);

    pixelLabels.set (0);

    locResult->calcLabeledImage (pixelLabels, (*classNames).getBackgroundClass());

    for (int x = 0; x < xsize; x++)
    {
      for (int y = 0; y < ysize; y++)
      {
        classno = pixelLabels.getPixel (x, y);
        labels[imgcounter] (x, y) = classno;

        if (forbidden_classes.find (classno) != forbidden_classes.end())
          continue;

        labelcounter[classno]++;

      }
    }

    imgcounter++;

    pb.update (trainp->count());
    delete ce;
  }

  pb.hide();

  map<int, int>::iterator mapit;
  int classes = 0;

  for (mapit = labelcounter.begin(); mapit != labelcounter.end(); mapit++)
  {
    labelmap[mapit->first] = classes;
    labelmapback[classes] = mapit->first;
    classes++;
  }

///////////////////////////////////////////////////////////////////
  for (int i = 0; i < rawChannels; i++)
  {
    channelType.push_back (2);
  }

  // integral images
  for (int i = 0; i < classes; i++)
  {
    channelType.push_back (3);
  }

  integralMap.clear();
  int integralImageAmount = rawChannels;
  for (int ii = 0; ii < integralImageAmount; ii++)
  {
    integralMap.push_back (pair<int, int> (ii, ii + integralImageAmount + shift));
  }

  int amountTypes = 5;

  channelsPerType = vector<vector<int> > (amountTypes, vector<int>());

  for (int i = 0; i < (int)channelType.size(); i++)
  {
    channelsPerType[channelType[i]].push_back (i);
  }

  for (int i = 0; i < classes; i++)
  {
    channelsPerType[channelsPerType.size()-1].push_back (i);
  }

  ftypes = std::min (amountTypes, ftypes);

////////////////////////////////////////////////////

  if (useRegionFeature)
  {
    for (int a = 0; a < (int)amountRegionpI.size(); a++)
    {
      regionProbs.push_back(vector<vector<double> > (amountRegionpI[a], vector<double> (classes, 0.0)));
    }
  }

  //balancing
  int featcounter = 0;

  a = vector<double> (classes, 0.0);

  for (int iCounter = 0; iCounter < imgcounter; iCounter++)
  {
    int xsize = (int)currentfeats[iCounter].width();
    int ysize = (int)currentfeats[iCounter].height();

    for (int x = 0; x < xsize; x++)
    {
      for (int y = 0; y < ysize; y++)
      {
        featcounter++;
        int cn = labels[iCounter] (x, y);
        if (labelmap.find (cn) == labelmap.end())
          continue;
        a[labelmap[cn]] ++;
      }
    }
  }

  for (int i = 0; i < (int)a.size(); i++)
  {
    a[i] /= (double)featcounter;
  }

#ifdef DEBUG
  for (int i = 0; i < (int)a.size(); i++)
  {
    cout << "a[" << i << "]: " << a[i] << endl;
  }

  cout << "a.size: " << a.size() << endl;

#endif

  depth = 0;

  int uniquenumber = 0;

  for (int t = 0; t < nbTrees; t++)
  {
    vector<TreeNode> tree;
    tree.push_back (TreeNode());
    tree[0].dist = vector<double> (classes, 0.0);
    tree[0].depth = depth;
    tree[0].featcounter = amountPixels;
    tree[0].nodeNumber = uniquenumber;
    uniquenumber++;
    forest.push_back (tree);
  }

  vector<int> startnode (nbTrees, 0);

  bool allleaf = false;
  //int baseFeatSize = allfeats[0].size();

  timer.stop();
  cerr << "preprocessing finished in: " << timer.getLastAbsolute() << " seconds" << endl;
  timer.start();

  while (!allleaf && depth < maxDepth)
  {
    depth++;
#ifdef DEBUG
    cout << "depth: " << depth << endl;
#endif
    allleaf = true;
    vector<MultiChannelImageT<unsigned short int> > lastfeats = currentfeats;
    vector<vector<vector<double> > > lastRegionProbs = regionProbs;

    if (useRegionFeature)
    {
      int rSize = (int)regionProbs.size();
      for (int a = 0; a < rSize; a++)
      {
        int rSize2 = (int)regionProbs[a].size();
        for (int b = 0; b < rSize2; b++)
        {
          int rSize3 = (int)regionProbs[a][b].size();
          for (int c = 0; c < rSize3; c++)
          {
            regionProbs[a][b][c] = 0.0;
          }
        }
      }
    }

#if 1
    Timer timerDepth;
    timerDepth.start();
#endif

    double weight = computeWeight (depth, maxDepth) - computeWeight (depth - 1, maxDepth);

    if (depth == 1)
    {
      weight = computeWeight (1, maxDepth);
    }

//   omp_set_dynamic(0);
//#pragma omp parallel for
    for (int tree = 0; tree < nbTrees; tree++)
    {
      const int t = (int)forest[tree].size();
      const int s = startnode[tree];
      startnode[tree] = t;
#pragma omp parallel for
      for (int i = s; i < t; i++)
      {
        if (!forest[tree][i].isleaf && forest[tree][i].left < 0)
        {
          Operation *splitfeat = NULL;
          double splitval;
          double bestig = getBestSplit (allfeats, lastfeats, labels, i, splitfeat, splitval, tree, lastRegionProbs);

          for (int ii = 0; ii < (int)lastfeats.size(); ii++)
          {
            for (int c = 0; c < lastfeats[ii].channels(); c++)
            {
              short unsigned int minv, maxv;
              lastfeats[ii].statistics (minv, maxv, c);
            }
          }

          forest[tree][i].feat = splitfeat;
          forest[tree][i].decision = splitval;

          if (splitfeat != NULL)
          {
            allleaf = false;
            int left;
#pragma omp critical
            {
              left = forest[tree].size();
              forest[tree].push_back (TreeNode());
              forest[tree].push_back (TreeNode());
            }
            int right = left + 1;
            forest[tree][i].left = left;
            forest[tree][i].right = right;
            forest[tree][left].dist = vector<double> (classes, 0.0);
            forest[tree][right].dist = vector<double> (classes, 0.0);
            forest[tree][left].depth = depth;
            forest[tree][right].depth = depth;
            forest[tree][left].featcounter = 0;
            forest[tree][right].featcounter = 0;
            forest[tree][left].nodeNumber = uniquenumber;
            int leftu = uniquenumber;
            uniquenumber++;
            forest[tree][right].nodeNumber = uniquenumber;
            int rightu = uniquenumber;
            uniquenumber++;
            forest[tree][right].featcounter = 0;

#pragma omp parallel for
            for (int iCounter = 0; iCounter < imgcounter; iCounter++)
            {
              int xsize = currentfeats[iCounter].width();
              int ysize = currentfeats[iCounter].height();

              for (int x = 0; x < xsize; x++)
              {
                for (int y = 0; y < ysize; y++)
                {
                  if (currentfeats[iCounter].get (x, y, tree) == i)
                  {
                    Features feat;
                    feat.feats = &allfeats[iCounter];
                    feat.cfeats = &lastfeats[iCounter];
                    feat.cTree = tree;
                    feat.tree = &forest[tree];
                    feat.rProbs = &lastRegionProbs[iCounter];
                    double val = splitfeat->getVal (feat, x, y);

                    //int subx = x / grid;
                    //int suby = y / grid;
#pragma omp critical
                    if (val < splitval)
                    {
                      currentfeats[iCounter].set (x, y, left, tree);
                      if (labelmap.find (labels[iCounter] (x, y)) != labelmap.end())
                        forest[tree][left].dist[labelmap[labels[iCounter] (x, y) ]]++;
                      forest[tree][left].featcounter++;
                      SparseVectorInt v;
                      v.insert (pair<int, double> (leftu, weight));
#ifdef TEXTONMAP
                      textonMap[iCounter] (subx, suby).add (v);
#endif
                    }
                    else
                    {
                      currentfeats[iCounter].set (x, y, right, tree);
                      if (labelmap.find (labels[iCounter] (x, y)) != labelmap.end())
                        forest[tree][right].dist[labelmap[labels[iCounter] (x, y) ]]++;
                      forest[tree][right].featcounter++;
                      //feld im subsampled finden und in diesem rechts hochzählen
                      SparseVectorInt v;
                      v.insert (pair<int, double> (rightu, weight));
#ifdef TEXTONMAP
                      textonMap[iCounter] (subx, suby).add (v);
#endif
                    }
                  }
                }
              }
            }

            double lcounter = 0.0, rcounter = 0.0;

            for (uint d = 0; d < forest[tree][left].dist.size(); d++)
            {
              if (forbidden_classes.find (labelmapback[d]) != forbidden_classes.end())
              {
                forest[tree][left].dist[d] = 0;
                forest[tree][right].dist[d] = 0;
              }
              else
              {
                forest[tree][left].dist[d] /= a[d];
                lcounter += forest[tree][left].dist[d];
                forest[tree][right].dist[d] /= a[d];
                rcounter += forest[tree][right].dist[d];
              }
            }

            if (lcounter <= 0 || rcounter <= 0)
            {
              cout << "lcounter : " << lcounter << " rcounter: " << rcounter << endl;
              cout << "splitval: " << splitval << " splittype: " << splitfeat->writeInfos() << endl;
              cout << "bestig: " << bestig << endl;

              for (int iCounter = 0; iCounter < imgcounter; iCounter++)
              {
                int xsize = currentfeats[iCounter].width();
                int ysize = currentfeats[iCounter].height();
                int counter = 0;

                for (int x = 0; x < xsize; x++)
                {
                  for (int y = 0; y < ysize; y++)
                  {
                    if (lastfeats[iCounter].get (x, y, tree) == i)
                    {
                      if (++counter > 30)
                        break;

                      Features feat;

                      feat.feats = &allfeats[iCounter];
                      feat.cfeats = &lastfeats[iCounter];
                      feat.cTree = tree;
                      feat.tree = &forest[tree];
                      feat.rProbs = &lastRegionProbs[iCounter];

                      double val = splitfeat->getVal (feat, x, y);

                      cout << "splitval: " << splitval << " val: " << val << endl;
                    }
                  }
                }
              }

              assert (lcounter > 0 && rcounter > 0);
            }

            for (uint d = 0; d < forest[tree][left].dist.size(); d++)
            {
              forest[tree][left].dist[d] /= lcounter;
              forest[tree][right].dist[d] /= rcounter;
            }
          }
          else
          {
            forest[tree][i].isleaf = true;
          }
        }
      }
    }


    if (useRegionFeature)
    {
      for (int iCounter = 0; iCounter < imgcounter; iCounter++)
      {
        int xsize = currentfeats[iCounter].width();
        int ysize = currentfeats[iCounter].height();
        int counter = 0;

#pragma omp parallel for
        for (int x = 0; x < xsize; x++)
        {
          for (int y = 0; y < ysize; y++)
          {
            for (int tree = 0; tree < nbTrees; tree++)
            {
              int node = currentfeats[iCounter].get(x, y, tree);
              for (uint d = 0; d < forest[tree][node].dist.size(); d++)
              {
                regionProbs[iCounter][(int)(allfeats[iCounter](x, y, rawChannels))][d] += forest[tree][node].dist[d];
              }
            }
          }
        }
      }

      int rSize1 = (int)regionProbs.size();
      for (int a = 0; a < rSize1; a++)
      {
        int rSize2 = (int)regionProbs[a].size();
        for (int b = 0; b < rSize2; b++)
        {
          int rSize3 = (int)regionProbs[a][b].size();
          for (int c = 0; c < rSize3; c++)
          {
            regionProbs[a][b][c] /= (double)(rSize[a][b]);
          }
        }
      }
    }

    //compute integral images
    if (firstiteration)
    {
      for (int i = 0; i < imgcounter; i++)
      {
        allfeats[i].addChannel ((int)(classes + rawChannels));
      }
    }

    for (int i = 0; i < imgcounter; i++)
    {
      computeIntegralImage (currentfeats[i], allfeats[i], channelType.size() - classes);
#ifdef TEXTONMAP
      computeIntegralImage (textonMap[i], integralTexton[i]);
#endif
    }

    if (firstiteration)
    {
      firstiteration = false;
    }

#if 1
    timerDepth.stop();

    cout << "time for depth " << depth << ": " << timerDepth.getLastAbsolute() << endl;
#endif
    
    lastfeats.clear();
    lastRegionProbs.clear();
  }

  timer.stop();
  cerr << "learning finished in: " << timer.getLastAbsolute() << " seconds" << endl;
  timer.start();

#ifdef WRITEGLOB
  ofstream outstream ("globtrain.feat");

  for (int i = 0; i < textonMap.size(); i++)
  {
    set<int> usedclasses;
    for (uint x = 0; x < labels[i].rows(); x++)
    {
      for (uint y = 0; y < labels[i].cols(); y++)
      {
        int classno = labels[i] (x, y);

        if (forbidden_classes.find (classno) != forbidden_classes.end())
          continue;

        usedclasses.insert (classno);
      }
    }

    cout << "labels.cols: " << labels[i].cols() << " labels.rows " << labels[i].rows() << endl;
    cout << "currentfeats : " << allfeats[i].width() << " allfeats[i].height(); " << allfeats[i].height() << endl;

    set<int>::iterator it;
    for (it = usedclasses.begin() ; it != usedclasses.end(); it++)
      outstream << *it << " ";
    outstream << endl;
    integralTexton[i] (integralTexton[i].width() - 1, integralTexton[i].height() - 1).store (outstream);
  }

  outstream.close();
#endif
  cout << "uniquenumber " << uniquenumber << endl;
  //getchar();
#ifdef DEBUG
  for (int tree = 0; tree < nbTrees; tree++)
  {
    int t = (int)forest[tree].size();

    for (int i = 0; i < t; i++)
    {
      printf ("tree[%i]: left: %i, right: %i", i, forest[tree][i].left, forest[tree][i].right);

      if (!forest[tree][i].isleaf && forest[tree][i].left != -1)
      {
        cout <<  ", feat: " << forest[tree][i].feat->writeInfos() << " ";
        opOverview[forest[tree][i].feat->getOps() ]++;
        contextOverview[forest[tree][i].depth][ (int)forest[tree][i].feat->getContext() ]++;
      }

      for (int d = 0; d < (int)forest[tree][i].dist.size(); d++)
      {
        cout << " " << forest[tree][i].dist[d];
      }

      cout << endl;
    }
  }

  std::map<int, int> featTypeCounter;

  for (int tree = 0; tree < nbTrees; tree++)
  {
    int t = (int)forest[tree].size();

    for (int i = 0; i < t; i++)
    {
      if (!forest[tree][i].isleaf && forest[tree][i].left != -1)
      {
        featTypeCounter[forest[tree][i].feat->getFeatType()] += 1;
      }
    }
  }

  cout << "evaluation of featuretypes" << endl;
  for (map<int, int>::const_iterator it = featTypeCounter.begin(); it != featTypeCounter.end(); it++)
  {
    cerr << it->first << ": " << it->second << endl;
  }

  for (uint c = 0; c < ops.size(); c++)
  {
    for (int t = 0; t < ops[c].size(); t++)
    {
      cout << ops[c][t]->writeInfos() << ": " << opOverview[ops[c][t]->getOps() ] << endl;
    }
  }

  for (int d = 0; d < maxDepth; d++)
  {
    double sum =  contextOverview[d][0] + contextOverview[d][1];

    contextOverview[d][0] /= sum;
    contextOverview[d][1] /= sum;

    cout << "depth: " << d << " woContext: " << contextOverview[d][0] << " wContext: " << contextOverview[d][1] << endl;
  }
#endif

  timer.stop();
  cerr << "rest finished in: " << timer.getLastAbsolute() << " seconds" << endl;
  timer.start();
}

void SemSegContextTree::extractBasicFeatures (NICE::MultiChannelImageT<double> &feats, const ColorImage &img, const string &currentFile, int &amountRegions)
{
  int xsize = img.width();
  int ysize = img.height();
  //TODO: resize image?!

  feats.reInit (xsize, ysize, 3);

  for (int x = 0; x < xsize; x++)
  {
    for (int y = 0; y < ysize; y++)
    {
      for (int r = 0; r < 3; r++)
      {
        feats.set (x, y, img.getPixel (x, y, r), r);
      }
    }
  }

  feats = ColorSpace::rgbtolab (feats);

  if (useGradient)
  {
    int currentsize = feats.channels();
    feats.addChannel (currentsize);

    for (int c = 0; c < currentsize; c++)
    {
      ImageT<double> tmp = feats[c];
      ImageT<double> tmp2 = feats[c+currentsize];

      NICE::FilterT<double, double, double>::gradientStrength (tmp, tmp2);
    }
  }

  if (useWeijer)
  {
    NICE::MultiChannelImageT<double> cfeats;
    lfcw->getFeats (img, cfeats);
    feats.addChannel (cfeats);
  }

  // read the geometric cues produced by Hoiem et al.
  if (useHoiemFeatures)
  {
    // we could also give the following set as a config option
    string hoiemClasses_s = "sky 000 090-045 090-090 090-135 090 090-por 090-sol";
    vector<string> hoiemClasses;
    StringTools::split (hoiemClasses_s, ' ', hoiemClasses);

    // Now we have to do some fancy regular expressions :)
    // Original image filename: basel_000083.jpg
    // hoiem result: basel_000083_c_sky.png

    // Fancy class of Ferid which supports string handling especially for filenames
    FileName fn (currentFile);
    fn.removeExtension();
    FileName fnBase = fn.extractFileName();

    // counter for the channel index, starts with the current size of the destination multi-channel image
    int currentChannel = feats.channels();

    // add a channel for each feature in advance
    feats.addChannel (hoiemClasses.size());

    // loop through all geometric categories and add the images
    for (vector<string>::const_iterator i = hoiemClasses.begin(); i != hoiemClasses.end(); i++, currentChannel++)
    {
      string hoiemClass = *i;
      FileName fnConfidenceImage (hoiemDirectory + fnBase.str() + "_c_" + hoiemClass + ".png");
      if (! fnConfidenceImage.fileExists())
      {
        fthrow (Exception, "Unable to read the Hoiem geometric confidence image: " << fnConfidenceImage.str() << " (original image is " << currentFile << ")");
      } else {
        Image confidenceImage (fnConfidenceImage.str());
        // check whether the image size is consistent
        if (confidenceImage.width() != feats.width() || confidenceImage.height() != feats.height())
        {
          fthrow (Exception, "The size of the geometric confidence image does not match with the original image size: " << fnConfidenceImage.str());
        }
        ImageT<double> dst = feats[currentChannel];

        // copy standard image to double image
        for (uint y = 0 ; y < (uint) confidenceImage.height(); y++)
          for (uint x = 0 ; x < (uint) confidenceImage.width(); x++)
            feats (x, y, currentChannel) = (double)confidenceImage (x, y);
      }
    }
  }

  if (useRegionFeature)
  {
    //using segmentation
    Matrix regions;
    amountRegions = segmentation->segRegions (img, regions);

    int cchannel = feats.channels();
    feats.addChannel(1);

    for (int y = 0; y < regions.cols(); y++)
    {
      for (int x = 0; x < regions.rows(); x++)
      {
        feats(x, y, cchannel) = regions(x, y);
      }
    }
  }
  else
  {
    amountRegions = -1;
  }
}

void SemSegContextTree::semanticseg (CachedExample *ce, NICE::Image & segresult, NICE::MultiChannelImageT<double> & probabilities)
{
  int xsize;
  int ysize;
  ce->getImageSize (xsize, ysize);
  firstiteration = true;

  int classes = labelmapback.size();

  int numClasses = classNames->numClasses();

  fprintf (stderr, "ContextTree classification !\n");

  probabilities.reInit (xsize, ysize, numClasses);
  probabilities.setAll (0);

#ifdef TEXTONMAP
  MultiChannelImageT<SparseVectorInt> textonMap (xsize / grid + 1, ysize / grid + 1, 1);
  MultiChannelImageT<SparseVectorInt> integralTexton (xsize / grid + 1, ysize / grid + 1, 1);
#endif

  std::string currentFile = Globals::getCurrentImgFN();
  MultiChannelImageT<double> feats;

  NICE::ColorImage img;

  try {
    img = ColorImage (currentFile);
  } catch (Exception) {
    cerr << "SemSeg: error opening image file <" << currentFile << ">" << endl;
    return;
  }

  //TODO add to features!
  int amountRegions;
  extractBasicFeatures (feats, img, currentFile, amountRegions); //read image and do some simple transformations

  vector<int> rSize;
  if (useRegionFeature)
  {
    rSize = vector<int>(amountRegions, 0);
    for (int y = 0; y < ysize; y++)
    {
      for (int x = 0; x < xsize; x++)
      {
        rSize[feats(x, y, rawChannels)]++;
      }
    }
  }

  bool allleaf = false;

  MultiChannelImageT<unsigned short int> currentfeats (xsize, ysize, nbTrees);

  currentfeats.setAll (0);

  depth = 0;

  vector<vector<double> > regionProbs;
  if (useRegionFeature)
  {
    regionProbs = vector<vector<double> > (amountRegions, vector<double> (classes, 0.0));
  }

  for (int d = 0; d < maxDepth && !allleaf; d++)
  {
    depth++;
    vector<vector<double> > lastRegionProbs = regionProbs;
    if (useRegionFeature)
    {
      int rSize2 = (int)regionProbs.size();
      for (int b = 0; b < rSize2; b++)
      {
        int rSize3 = (int)regionProbs[b].size();
        for (int c = 0; c < rSize3; c++)
        {
          regionProbs[b][c] = 0.0;
        }
      }
    }

#ifdef TEXTONMAP
    double weight = computeWeight (depth, maxDepth) - computeWeight (depth - 1, maxDepth);

    if (depth == 1)
    {
      weight = computeWeight (1, maxDepth);
    }
#endif

    allleaf = true;

    MultiChannelImageT<unsigned short int> lastfeats = currentfeats;

    int tree;
#pragma omp parallel for private(tree)
    for (tree = 0; tree < nbTrees; tree++)
    {
      for (int x = 0; x < xsize; x++)
      {
        for (int y = 0; y < ysize; y++)
        {
          int t = currentfeats.get (x, y, tree);

          if (forest[tree][t].left > 0)
          {
            allleaf = false;
            Features feat;
            feat.feats = &feats;
            feat.cfeats = &lastfeats;
            feat.cTree = tree;
            feat.tree = &forest[tree];
            feat.rProbs = &lastRegionProbs;

            double val = forest[tree][t].feat->getVal (feat, x, y);

            //int subx = x / grid;
            //int suby = y / grid;
            if (val < forest[tree][t].decision)
            {
              currentfeats.set (x, y, forest[tree][t].left, tree);
#ifdef TEXTONMAP
#pragma omp critical
              {
                SparseVectorInt v;
                v.insert (pair<int, double> (forest[tree][forest[tree][t].left].nodeNumber, weight));
                textonMap (subx, suby).add (v);
              }
#endif
            }
            else
            {
              currentfeats.set (x, y, forest[tree][t].right, tree);
#ifdef TEXTONMAP
#pragma omp critical
              {
                SparseVectorInt v;
                v.insert (pair<int, double> (forest[tree][forest[tree][t].right].nodeNumber, weight));

                textonMap (subx, suby).add (v);
              }
#endif
            }
          }
        }
      }
    }

    if (useRegionFeature)
    {
      int xsize = currentfeats.width();
      int ysize = currentfeats.height();
      int counter = 0;

#pragma omp parallel for
      for (int x = 0; x < xsize; x++)
      {
        for (int y = 0; y < ysize; y++)
        {
          for (int tree = 0; tree < nbTrees; tree++)
          {
            int node = currentfeats.get(x, y, tree);
            for (uint d = 0; d < forest[tree][node].dist.size(); d++)
            {
              regionProbs[(int)(feats(x, y, rawChannels))][d] += forest[tree][node].dist[d];
            }
          }
        }
      }


      int rSize2 = (int)regionProbs.size();
      for (int b = 0; b < rSize2; b++)
      {
        int rSize3 = (int)regionProbs[b].size();
        for (int c = 0; c < rSize3; c++)
        {
          regionProbs[b][c] /= (double)(rSize[b]);
        }
      }
    }

    if (depth < maxDepth)
    {
      //compute integral images
      if (firstiteration)
      {
        feats.addChannel (classes + rawChannels);
      }
      computeIntegralImage (currentfeats, feats, channelType.size() - classes);
#ifdef TEXTONMAP
      computeIntegralImage (textonMap, integralTexton);
#endif
      if (firstiteration)
      {
        firstiteration = false;
      }
    }
  }

#ifdef WRITEGLOB
  ofstream outstream ("globtest.feat", ofstream::app);
  outstream << 0 << endl;
  integralTexton (integralTexton.width() - 1, integralTexton.height() - 1).store (outstream);
  outstream.close();
#endif

  string cndir = conf->gS ("SSContextTree", "cndir", "");

  int allClasses = (int)probabilities.channels();
  vector<int> useclass (allClasses, 1);
#ifdef WRITEGLOB


  std::vector< std::string > list;
  StringTools::split (currentFile, '/', list);

  string orgname = list.back();

  ofstream ostream ("filelist.txt", ofstream::app);
  ostream << orgname << ".dat" << endl;
  ostream.close();

  if (cndir != "")
  {
    useclass = vector<int> (allClasses, 0);
    ifstream infile ((cndir + "/" + orgname + ".dat").c_str());

#undef OLD
#ifdef OLD
    while (!infile.eof() && infile.good())
    {
      int tmp;
      infile >> tmp;
      assert (tmp >= 0 && tmp < allClasses);
      useclass[tmp] = 1;
    }
#else
    int c = 0;
    vector<double> probs (allClasses, 0.0);

    while (!infile.eof() && infile.good())
    {
      infile >> probs[c];
      c++;
    }

    vector<double> sorted = probs;
    sort (sorted.begin(), sorted.end());

    double thr = sorted[10];

    if (thr < 0.0)
      thr = 0.0;

    for (int c = 0; c < allClasses; c++)
    {
      if (probs[c] < thr)
      {
        useclass[c] = 1;
      }
    }

#endif

    for (int c = 0; c < allClasses; c++)
    {
      if (useclass[c] == 0)
      {
        probabilities.set (-numeric_limits< double >::max(), c);
      }
    }
  }
#endif

  if (pixelWiseLabeling)
  {
    //finales labeln:
    //long int offset = 0;

    for (int x = 0; x < xsize; x++)
    {
      for (int y = 0; y < ysize; y++)
      {
        double maxvalue = - numeric_limits<double>::max(); //TODO: das kann auch nur pro knoten gemacht werden, nicht pro pixel
        int maxindex = 0;

        for (uint i = 0; i < classes; i++)
        {
          int currentclass = labelmapback[i];
          if (useclass[currentclass])
          {
            probabilities (x, y, currentclass) = getMeanProb (x, y, i, currentfeats);

            if (probabilities (x, y, currentclass) > maxvalue)
            {
              maxvalue = probabilities (x, y, currentclass);
              maxindex = currentclass;
            }
          }
        }
        segresult.setPixel (x, y, maxindex);
        if (maxvalue > 1)
          cout << "maxvalue: " << maxvalue << endl;
      }
    }
#undef VISUALIZE
#ifdef VISUALIZE
    for (int j = 0 ; j < (int)probabilities.numChannels; j++)
    {
      //cout << "class: " << j << endl;//" " << cn.text (j) << endl;

      NICE::Matrix tmp (probabilities.height(), probabilities.width());
      double maxval = -numeric_limits<double>::max();
      double minval = numeric_limits<double>::max();


      for (int y = 0; y < probabilities.height(); y++)
        for (int x = 0; x < probabilities.width(); x++)
        {
          double val = probabilities (x, y, j);
          tmp (y, x) = val;
          maxval = std::max (val, maxval);
          minval = std::min (val, minval);
        }
      tmp (0, 0) = 1.0;
      tmp (0, 1) = 0.0;

      NICE::ColorImage imgrgb (probabilities.width(), probabilities.height());
      ICETools::convertToRGB (tmp, imgrgb);

      cout << "maxval = " << maxval << " minval: " << minval << " for class " << j << endl; //cn.text (j) << endl;

      std::string s;
      std::stringstream out;
      out << "tmpprebmap" << j << ".ppm";
      s = out.str();
      imgrgb.write (s);
      //showImage(imgrgb, "Ergebnis");
      //getchar();
    }
    cout << "fertsch" << endl;
    getchar();
    cout << "weiter gehtsch" << endl;
#endif
  }
  else
  {
    //using segmentation
    Matrix regions;

    if (useRegionFeature)
    {
      int rchannel = -1;
      for (uint i = 0; i < channelType.size(); i++)
      {
        if (channelType[i] == 1)
        {
          rchannel = i;
          break;
        }
      }

      assert(rchannel > -1);

      int xsize = feats.width();
      int ysize = feats.height();
      regions.resize(xsize, ysize);
      for (int y = 0; y < ysize; y++)
      {
        for (int x = 0; x < xsize; x++)
        {
          regions(x, y) = feats(x, y, rchannel);
        }
      }
    }
    else
    {
      amountRegions = segmentation->segRegions (img, regions);
    }

    regionProbs.clear();
    regionProbs = vector<vector<double> >(amountRegions, vector<double> (classes, 0.0));

    vector<int> bestlabels (amountRegions, 0);

    for (int y = 0; y < img.height(); y++)
    {
      for (int x = 0; x < img.width(); x++)
      {
        int cregion = regions (x, y);

        for (int d = 0; d < classes; d++)
        {
          regionProbs[cregion][d] += getMeanProb (x, y, d, currentfeats);
        }
      }
    }

    for (int r = 0; r < amountRegions; r++)
    {
      double maxval = regionProbs[r][0];
      bestlabels[r] = 0;

      for (int d = 1; d < classes; d++)
      {
        if (maxval < regionProbs[r][d])
        {
          maxval = regionProbs[r][d];
          bestlabels[r] = d;
        }
      }

      bestlabels[r] = labelmapback[bestlabels[r]];
    }

    for (int y = 0; y < img.height(); y++)
    {
      for (int x = 0; x < img.width(); x++)
      {

        segresult.setPixel (x, y, bestlabels[regions (x,y) ]);
      }
    }

#define WRITEREGIONS
#ifdef WRITEREGIONS
    RegionGraph rg;
    segmentation->getGraphRepresentation (img, regions,  rg);
    for (uint pos = 0; pos < regionProbs.size(); pos++)
    {
      rg[pos]->setProbs (regionProbs[pos]);
    }

    std::string s;
    std::stringstream out;
    std::vector< std::string > list;
    StringTools::split (Globals::getCurrentImgFN (), '/', list);

    out << "rgout/" << list.back() << ".graph";
    string writefile = out.str();
    rg.write (writefile);
#endif
  }

  cout << "segmentation finished" << endl;
}

void SemSegContextTree::store (std::ostream & os, int format) const
{
  os.precision (numeric_limits<double>::digits10 + 1);
  os << nbTrees << endl;
  classnames.store (os);

  map<int, int>::const_iterator it;

  os << labelmap.size() << endl;
  for (it = labelmap.begin() ; it != labelmap.end(); it++)
    os << (*it).first << " " << (*it).second << endl;

  os << labelmapback.size() << endl;
  for (it = labelmapback.begin() ; it != labelmapback.end(); it++)
    os << (*it).first << " " << (*it).second << endl;

  int trees = forest.size();
  os << trees << endl;

  for (int t = 0; t < trees; t++)
  {
    int nodes = forest[t].size();
    os << nodes << endl;
    for (int n = 0; n < nodes; n++)
    {
      os << forest[t][n].left << " " << forest[t][n].right << " " << forest[t][n].decision << " " << forest[t][n].isleaf << " " << forest[t][n].depth << " " << forest[t][n].featcounter << " " << forest[t][n].nodeNumber << endl;
      os << forest[t][n].dist << endl;

      if (forest[t][n].feat == NULL)
        os << -1 << endl;
      else
      {
        os << forest[t][n].feat->getOps() << endl;
        forest[t][n].feat->store (os);
      }
    }
  }

  os << channelType.size() << endl;
  for (int i = 0; i < (int)channelType.size(); i++)
  {
    os << channelType[i] << " ";
  }
  os << endl;

  os << integralMap.size() << endl;
  for (int i = 0; i < (int)integralMap.size(); i++)
  {
    os << integralMap[i].first << " " << integralMap[i].second << endl;
  }

  os << rawChannels << endl;
}

void SemSegContextTree::restore (std::istream & is, int format)
{
  is >> nbTrees;

  classnames.restore (is);

  int lsize;
  is >> lsize;

  labelmap.clear();
  for (int l = 0; l < lsize; l++)
  {
    int first, second;
    is >> first;
    is >> second;
    labelmap[first] = second;
  }

  is >> lsize;
  labelmapback.clear();
  for (int l = 0; l < lsize; l++)
  {
    int first, second;
    is >> first;
    is >> second;
    labelmapback[first] = second;
  }

  int trees;
  is >> trees;
  forest.clear();

  for (int t = 0; t < trees; t++)
  {
    vector<TreeNode> tmptree;
    forest.push_back (tmptree);
    int nodes;
    is >> nodes;
    //cout << "nodes: " << nodes << endl;
    for (int n = 0; n < nodes; n++)
    {
      TreeNode tmpnode;
      forest[t].push_back (tmpnode);
      is >> forest[t][n].left;
      is >> forest[t][n].right;
      is >> forest[t][n].decision;
      is >> forest[t][n].isleaf;
      is >> forest[t][n].depth;
      is >> forest[t][n].featcounter;

      is >> forest[t][n].nodeNumber;
      is >> forest[t][n].dist;

      int feattype;
      is >> feattype;
      assert (feattype < NBOPERATIONS);
      forest[t][n].feat = NULL;
      if (feattype >= 0)
      {
        for (uint o = 0; o < ops.size(); o++)
        {
          for (uint o2 = 0; o2 < ops[o].size(); o2++)
          {
            if (forest[t][n].feat == NULL)
            {
              for (uint c = 0; c < ops[o].size(); c++)
              {
                if (ops[o][o2]->getOps() == feattype)
                {
                  forest[t][n].feat = ops[o][o2]->clone();
                  break;
                }
              }
            }
          }
        }

        assert (forest[t][n].feat != NULL);
        forest[t][n].feat->restore (is);
      }
    }
  }

  channelType.clear();
  int ctsize;
  is >> ctsize;
  for (int i = 0; i < ctsize; i++)
  {
    int tmp;
    is >> tmp;
    channelType.push_back (tmp);
  }

  integralMap.clear();
  int iMapSize;
  is >> iMapSize;
  for (int i = 0; i < iMapSize; i++)
  {
    int first;
    int second;
    is >> first;
    is >> second;
    integralMap.push_back (pair<int, int> (first, second));
  }

  is >> rawChannels;
}