/**
* @file MultiDataset.cpp
* @brief multiple datasets
* @author Erik Rodner
* @date 02/08/2008

*/
#include <iostream>

#include <sys/stat.h>
#include <sys/types.h>

#include "vislearning/cbaselib/ClassNames.h"

#include "core/basics/StringTools.h"
#include "core/basics/FileMgt.h"

#include "vislearning/cbaselib/MultiDataset.h"

using namespace OBJREC;

using namespace std;
using namespace NICE;

#undef DEBUG_MultiDataset

void MultiDataset::selectExamples ( const std::string & examples_command,
                                    const LabeledSet & base,
                                    LabeledSet & positives,
                                    LabeledSet & negatives,
                                    const ClassNames & cn ) const
{
  vector<string> examples;
  StringTools::split ( examples_command, ';', examples );
  set<int> processed_classes;

  for ( vector<string>::const_iterator i  = examples.begin();
        i != examples.end();
        i++ )
  {
    const std::string & cmd = *i;
    vector<string> parts;
    StringTools::split ( cmd, ' ', parts );

    if ( (parts.size() != 3) && ((parts.size() != 2) || (parts[0] != "all")) )
      fthrow( Exception, "Syntax error " << examples_command );

    const std::string & mode = parts[0];
    const std::string & csel = parts[1];
    double parameter = (parts.size() == 3 ) ? atof(parts[2].c_str()) : 0.0;
    map<int, int> fpe;

    set<int> selection;
    cn.getSelection ( csel, selection );
    for ( set<int>::const_iterator j  = selection.begin();
          j != selection.end();
          j++ )
    {
      int classno = *j;
      if ( processed_classes.find(classno) == processed_classes.end() )
      {
#ifdef DEBUG_MultiDataset
        fprintf (stderr, "class %s: %s %d\n", cn.text(classno).c_str(),
                 mode.c_str(), (int)parameter );
#endif
        fpe[*j] = (int)parameter;
        processed_classes.insert(classno);
      } else {
        if ( csel != "*" ) {
          fthrow ( Exception, "Example selection method for class %s has multiple specifications" << cn.text(classno) );
        }
      }
    }

    if ( mode == "seq" ) {
      LabeledSetSelection<LabeledSet>::selectSequential (
        fpe, base, positives, negatives );
#ifdef DEBUG_MultiDataset
      fprintf (stderr, "MultiDataset: after special seq selection: %d\n", positives.count() );
#endif
    } else if ( mode == "step" ) {
      LabeledSetSelection<LabeledSet>::selectSequentialStep (
        fpe, base, positives, negatives );
#ifdef DEBUG_MultiDataset
      fprintf (stderr, "MultiDataset: after special step selection: %d\n", positives.count() );
#endif
    } else if ( mode == "random" ) {
      LabeledSetSelection<LabeledSet>::selectRandom (
        fpe, base, positives, negatives );
#ifdef DEBUG_MultiDataset
      fprintf (stderr, "MultiDataset: after special random selection: %d\n", positives.count() );
#endif
    } else if ( mode == "all" ) {
      if ( (int)selection.size() == cn.numClasses() )
      {
        // preserve permutation
        LabeledSet::Permutation permutation;
        base.getPermutation ( permutation );
        for ( LabeledSet::Permutation::iterator i = permutation.begin(); i != permutation.end(); i++ )
        {
          int classno = i->first;
          ImageInfo *element = const_cast< ImageInfo * > ( i->second );
          positives.add_reference ( classno, element );
        }
      } else {
        LabeledSetSelection<LabeledSet>::selectClasses ( selection, base, positives, negatives );
      }
#ifdef DEBUG_MultiDataset
      fprintf (stderr, "MultiDataset: after special class selection: %d\n", positives.count() );
#endif
    } else {
      fthrow ( Exception, "Wrong value for parameter example\n");
    }
  }

#ifdef DEBUG_MultiDataset
  fprintf (stderr, "MultiDataset: after special selection operations: %d\n", positives.count() );
#endif

  set<int> allclasses;
  cn.getSelection ( "*", allclasses );

  set<int> allnegative_classes;

  // add all examples from allclasses \setminus processed_classes
  set_difference(allclasses.begin(), allclasses.end(), processed_classes.begin(), processed_classes.end(),
                 inserter(allnegative_classes, allnegative_classes.end()));

  LabeledSet dummy;
  LabeledSetSelection<LabeledSet>::selectClasses ( allnegative_classes,
      base, negatives, dummy );

}

/** MultiDataset ------- constructor */
MultiDataset::MultiDataset( const Config *conf , LabeledSetFactory *pSetFactory)
{
  std::set<string> blocks;
  conf->getAllBlocks ( blocks );

  lfl.setFactory( pSetFactory );

  map<string, Config> dsconfs;
  map<string, string> dirs;
  for ( set<string>::iterator i = blocks.begin();
        i != blocks.end();  )
  {
    if ( conf->gB(*i, "disable", false) )
    {
      i++;
      continue;
    }

    std::string dataset = conf->gS( *i, "dataset", "unknown" );
    if ( dataset == "unknown" )
      blocks.erase(i++);
    else {
#ifdef DEBUG_MultiDataset
      fprintf (stderr, "Reading dataset config for block [%s]\n", i->c_str() );
#endif
      Config dsconf ( (dataset + "/dataset.conf").c_str() );

      dirs[*i] = dataset;
      dsconfs[*i] = dsconf;
      i++;
    }
  }

  if ( blocks.find("traintest") != blocks.end() )
  {
    LabeledSet ls_base;
    LabeledSet ls_train (true);
    LabeledSet ls_nontrain (true);
    LabeledSet ls_test (true);
    LabeledSet dummy (true);
    LabeledSet temp (true);


    bool localizationInfoDisabled = conf->gB("traintest", "disable_localization_info", false );

    std::string classselection_train = conf->gS("traintest", "classselection_train", "*");
    std::string classselection_test = conf->gS("traintest", "classselection_test", "*");
    classnames["traintest"] = ClassNames();

    std::string classNamesTxt = dirs["traintest"] + "/classnames.txt";
    if ( FileMgt::fileExists ( classNamesTxt ) )
    {
        classnames["traintest"].read ( classNamesTxt );
    } else {
        classnames["traintest"].readFromConfig ( dsconfs["traintest"], classselection_train );
    }

    lfl.get ( dirs["traintest"], dsconfs["traintest"], classnames["traintest"], ls_base,
        localizationInfoDisabled, conf->gB("traintest", "debug_dataset", false ) );

    std::string examples_train =  conf->gS("traintest", "examples_train" );
    selectExamples ( examples_train, ls_base, ls_train, ls_nontrain, classnames["traintest"] );

    set<int> selection_test;
    classnames["traintest"].getSelection ( classselection_test, selection_test );

    std::string examples_test =  conf->gS("traintest", "examples_test" );
    if ( examples_test == "reclassification" )
    {
      LabeledSetSelection<LabeledSet>::selectClasses
      ( selection_test, ls_train, ls_test, dummy );

    } else {
      selectExamples ( examples_test, ls_nontrain, temp, dummy, classnames["traintest"] );
      LabeledSetSelection<LabeledSet>::selectClasses
      ( selection_test, temp, ls_test, dummy );
    }

    classnames["train"] = classnames["traintest"];
    classnames["test"] = ClassNames ( classnames["traintest"], classselection_test );
    datasets["test"] = ls_test;
    datasets["train"] = ls_train;
  }

  for ( set<string>::const_iterator i = blocks.begin();
        i != blocks.end();
        i++ )
  {
    std::string name = *i;
    if ( classnames.find(name) != classnames.end() )
      continue;

    if ( conf->gB(name, "disable", false) == true )
      continue;

    if ( dsconfs.find(name) == dsconfs.end() )
      continue;

    LabeledSet ls_base;
    LabeledSet ls (true);
    LabeledSet dummy (true);
    LabeledSet temp (true);

    bool localizationInfoDisabled = conf->gB(name, "disable_localization_info", false );

    std::string classselection = conf->gS(name, "classselection", "*");
    classnames[name] = ClassNames();

    std::string classNamesTxt = dirs[name] + "/classnames.txt";
    if ( FileMgt::fileExists ( classNamesTxt ) )
    {
#ifdef DEBUG_MultiDataset
      fprintf (stderr, "MultiDataset: reading class names from %s\n", classNamesTxt.c_str() );
#endif
      classnames[name].read ( classNamesTxt );
    } else {
#ifdef DEBUG_MultiDataset
      fprintf (stderr, "MultiDataset: reading class names from dataset config file\n" );
#endif
        classnames[name].readFromConfig ( dsconfs[name], classselection );
    }
		

    lfl.get (   dirs[name],
                dsconfs[name],
                classnames[name],
                ls_base,
                localizationInfoDisabled,
                conf->gB(name, "debug_dataset", false ) );


#ifdef DEBUG_MultiDataset
    fprintf (stderr, "MultiDataset: class names -->\n" );
    classnames[name].store ( cerr );
    fprintf (stderr, "MultiDataset: all information about %s set obtained ! (size %d)\n", name.c_str(), ls_base.count() );
#endif

    std::string examples = conf->gS(name, "examples", "all *" );
    selectExamples ( examples, ls_base, ls, dummy, classnames[name] );

#ifdef DEBUG_MultiDataset
    fprintf (stderr, "MultiDataset: size after selection %d\n", ls.count() );
#endif

    datasets[name] = ls;
  }

  bool dumpSelections = conf->gB("datasets", "dump_selection", false);
  if ( dumpSelections )
  {
    for ( map<string, LabeledSet>::const_iterator i = datasets.begin();
          i != datasets.end(); i++ )
    {
      const std::string & name = i->first;
      const LabeledSet & ls = i->second;
      const ClassNames & classNames = classnames[name];

      mkdir ( name.c_str(), 0755 );

      std::string filelist = name + "/files.txt";
      ofstream olist ( filelist.c_str(), ios::out );
      if ( !olist.good() )
        fthrow (IOException, "Unable to dump selections to " << filelist );

      LOOP_ALL_S(ls)
      {
        EACH_S(classno, file);
        std::string classtext = classNames.code(classno);
        olist << classtext << " " << file << endl;
      }
      olist.close();

      std::string datasetconf = name + "/dataset.conf";
      ofstream oconf ( datasetconf.c_str(), ios::out );
      if ( !oconf.good() )
        fthrow (IOException, "Unable to dump selections to " << datasetconf );

      set<int> classnos;
      classNames.getSelection ( "*", classnos);

      oconf << "[main]" << endl;
      oconf << "filelist = \"files.txt\"" << endl << endl;

      oconf << "[classnames]" << endl;
      for ( set<int>::const_iterator i = classnos.begin();
            i != classnos.end(); i++ )
      {
        const std::string & code = classNames.code(*i);
        const std::string & text = classNames.text(*i);
        oconf << code << "     =     \"" << text << "\"" << endl;
      }
      oconf.close();

      classNames.save ( name + "/classnames.txt" );
    }
  }

}

MultiDataset::~MultiDataset()
{
}

const ClassNames & MultiDataset::getClassNames ( const std::string & key ) const
{
  map<string, ClassNames>::const_iterator i = classnames.find(key);
  if ( i == classnames.end() )
  {
    fprintf (stderr, "MultiDataSet::getClassNames() FATAL ERROR: dataset <%s> not found !\n", key.c_str() );
    exit(-1);
  }
  return (i->second);

}

const LabeledSet *MultiDataset::operator[] ( const std::string & key ) const
{
  map<string, LabeledSet>::const_iterator i = datasets.find(key);
  if ( i == datasets.end() )
  {
    fprintf (stderr, "MultiDataSet: FATAL ERROR: dataset <%s> not found !\n", key.c_str() );
    exit(-1);
  }
  return &(i->second);
}

const LabeledSet *MultiDataset::at ( const std::string & key ) const
{
  map<string, LabeledSet>::const_iterator i = datasets.find(key);
  if ( i == datasets.end() )
    return NULL;
  else
    return &(i->second);
}