/**
* @file ClassNames.cpp
* @brief simple interface for class name confusion
* @author Erik Rodner
* @date 02/08/2008

*/
#include "core/image/ImageT.h"
#include "core/vector/VectorT.h"
#include "core/vector/MatrixT.h"

#include <iostream>
#include <algorithm>
#include <string>
#include <functional>
#include <assert.h>

#include "vislearning/cbaselib/ClassNames.h"
#include "core/basics/StringTools.h"
#include "core/image/ImageTools.h"
#include "vislearning/baselib/ICETools.h"

using namespace OBJREC;
using namespace std;
using namespace NICE;


ClassNames::ClassNames()
{
  maxClassNo = 0;
}

ClassNames::ClassNames ( const ClassNames & cn,
                         const std::string & classselection )
{
  std::set<int> selection;
  cn.getSelection ( classselection, selection );
  maxClassNo = 0;

  store ( cerr );

  for ( map<string, string>::const_iterator i  = cn.tbl_code_text.begin();
        i != cn.tbl_code_text.end();
        i++ )
  {
    const std::string & classname = i->second;
    std::string code = i->first;

    if ( cn.tbl_code_classno.find ( code ) == cn.tbl_code_classno.end() )
    {
      fprintf ( stderr, "class %s excluded in base classnames\n", code.c_str() );
      continue;
    }

    int classno = cn.classno ( code );
    if ( selection.find ( classno ) != selection.end() )
    {
      addClass ( classno, code, classname );
      if ( classno > maxClassNo ) maxClassNo = classno;
#ifdef DEBUG_ClassNames
      fprintf ( stderr, "class %s (%d) inherited\n", code.c_str(), classno );
#endif
    } else {
#ifdef DEBUG_ClassNames
      fprintf ( stderr, "class %s (%d) excluded in selection\n", code.c_str(), classno );
#endif
    }
  }
}

ClassNames::ClassNames ( const ClassNames & cn )
    : tbl_code_text ( cn.tbl_code_text ), tbl_text_code ( cn.tbl_text_code ),
    tbl_classno_code ( cn.tbl_classno_code ), tbl_code_classno ( cn.tbl_code_classno ),
    tbl_color_classno ( cn.tbl_color_classno ), tbl_classno_color ( cn.tbl_classno_color ), maxClassNo ( cn.maxClassNo )
{
}

ClassNames::~ClassNames()
{
}


int ClassNames::classnoFromText ( std::string text ) const
{
  map<string, string>::const_iterator j = tbl_text_code.find ( text );
  if ( j == tbl_text_code.end() ) return -1;

  map<string, int>::const_iterator jj = tbl_code_classno.find ( j->second );
  if ( jj == tbl_code_classno.end() ) return -1;

  return jj->second;
}

void ClassNames::getSelection ( const std::string & classselection,
                                std::set<int> & classnos ) const
{
  if ( classselection.size() <= 0 ) return;

  std::vector<string> classlist;
  StringTools::split ( classselection, ',', classlist );

  if ( classlist.size() <= 0 )
  {
    fprintf ( stderr, "FATAL ERROR: wrong format for classselection\n" );
    exit ( -1 );
  } else if ( classlist[0] == "*" )
  {
    map<string, bool> forbidden_classes;
    for ( size_t k = 1 ; k < classlist.size() ; k++ )
      if ( classlist[k].substr ( 0, 1 ) == "-" )
      {
        std::string f_class = classlist[k].substr ( 1 );
#if defined DEBUG_ClassNames
        fprintf ( stderr, "ClassNames: class %s excluded !\n", f_class.c_str() );
#endif
        forbidden_classes[ f_class ] = true;
      } else {
        fprintf ( stderr, "FATAL ERROR: wrong format for classselection: *,-class0,class1,...\n" );
        exit ( -1 );
      }

    for ( map<int, string>::const_iterator i  = tbl_classno_code.begin();
          i != tbl_classno_code.end();
          i++ )
    {
      int myclassno = i->first;
      const std::string & classname = text ( myclassno );
      if ( forbidden_classes.find ( classname ) != forbidden_classes.end() )
        continue;

      classnos.insert ( myclassno );
    }
  } else {
    for ( vector<string>::const_iterator i  = classlist.begin();
          i != classlist.end();
          i++ )
    {
      const std::string & classname = *i;
      map<string, string>::const_iterator j = tbl_text_code.find ( classname );

      if ( j == tbl_text_code.end() )
      {
        fprintf ( stderr, "ClassNames: FATAL ERROR This is not a selection of a subset: %s [%s]\n",
                  classname.c_str(), classselection.c_str() );
        exit ( -1 );
      }

      const std::string & code = j->second;
      int myclassno = classno ( code );

      if ( myclassno < 0 ) {
        fprintf ( stderr, "ClassNames: FATAL ERROR This is not a selection of a subset\n" );
        exit ( -1 );
      }
      classnos.insert ( myclassno );
    }
  }

}

std::string ClassNames::text ( int classno ) const
{
  map<string, string>::const_iterator i =
    tbl_code_text.find ( code ( classno ) );

  if ( i == tbl_code_text.end() )
  {
    fprintf ( stderr, "ClassNames: no name found for classno %d\n", classno );
    return "unknown";
  } else {
    return i->second;
  }

}

std::string ClassNames::code ( int classno ) const
{
  map<int, string>::const_iterator i =
    tbl_classno_code.find ( classno );

  if ( i == tbl_classno_code.end() )
  {
    fprintf ( stderr, "ClassNames: no code found for classno %d\n", classno );
    return "unknown";
  } else {
    return i->second;
  }

}

int ClassNames::classno ( std::string code ) const
{
  map<string, int>::const_iterator i =
    tbl_code_classno.find ( code );

  if ( i == tbl_code_classno.end() )
  {
    fthrow ( Exception, "no classno found for code <" << code << ">" );
  } else {
    return i->second;
  }
}

int ClassNames::numClasses () const
{
  return tbl_classno_code.size();
}


void ClassNames::addClass ( int classno, const std::string & code,
                            const std::string & text )
{
  tbl_classno_code[classno] = code;
  tbl_text_code[text]    = code;
  tbl_code_text[code]    = text;
  tbl_code_classno[code] = classno;

  if ( classno > maxClassNo ) maxClassNo = classno;
}

bool ClassNames::existsClassno ( int classno ) const
{
  return ( tbl_classno_code.find ( classno ) != tbl_classno_code.end() );
}

// refactor-nice.pl: check this substitution
// old: bool ClassNames::existsClassCode ( const string & classcode ) const
bool ClassNames::existsClassCode ( const std::string & classcode ) const
{
  return ( tbl_code_classno.find ( classcode ) != tbl_code_classno.end() );
}

bool ClassNames::readFromConfig ( const Config & datasetconf,
                                  // refactor-nice.pl: check this substitution
                                  // old: const string & classselection )
                                  const std::string & classselection )
{

  datasetconf.getAllS ( "classnames", tbl_code_text );

  if ( tbl_code_text.size() <= 0 ) {
    fprintf ( stderr, "ClassNames: no classnames specified\n" );
    return false;
  }

  // reverse map and lower case
  for ( map<string, string>::const_iterator i  = tbl_code_text.begin();
        i != tbl_code_text.end(); i++ )
    tbl_text_code [ i->second ] = i->first;

#if defined DEBUG_ClassNames
  cerr << "ClassNames::read: selection = " << classselection << endl;
#endif

  std::vector<string> classlist;
  StringTools::split ( classselection, ',', classlist );

  if ( classlist.size() <= 0 )
  {
    fprintf ( stderr, "FATAL ERROR: wrong format for classselection\n" );
    exit ( -1 );
  } else if ( classlist[0] == "*" )
  {
    map<string, bool> forbidden_classes;
    for ( size_t k = 1 ; k < classlist.size() ; k++ )
      if ( classlist[k].substr ( 0, 1 ) == "-" )
      {
        // refactor-nice.pl: check this substitution
        // old: string f_class = classlist[k].substr(1);
        std::string f_class = classlist[k].substr ( 1 );
#if defined DEBUG_ClassNames
        fprintf ( stderr, "ClassNames: class %s excluded !\n", f_class.c_str() );
#endif
        forbidden_classes[ f_class ] = true;
      } else {
        fprintf ( stderr, "FATAL ERROR: wrong format for classselection: *,-class0,class1,...\n" );
        exit ( -1 );
      }

    int classno_seq = 0;
    for ( map<string, string>::const_iterator i  = tbl_code_text.begin();
          i != tbl_code_text.end();
          i++, classno_seq++ )
    {
      const std::string & classname = i->second;
      if ( forbidden_classes.find ( classname ) != forbidden_classes.end() )
        continue;

      // refactor-nice.pl: check this substitution
      // old: string code = tbl_text_code [ i->second ];
      std::string code = tbl_text_code [ i->second ];
      int classno;
      classno = classno_seq;
      tbl_classno_code[classno] = code;
      tbl_code_classno[code] = classno;
      if ( classno > maxClassNo ) maxClassNo = classno;

#if defined DEBUG_ClassNames
      fprintf ( stderr, "classno %d class code %s class text %s\n", classno, code.c_str(), classname.c_str() );
#endif
    }
  } else {

#if defined DEBUG_ClassNames
    cerr << "list : " << classlist.size() << endl;
#endif
    for ( size_t classno_seq = 0 ; classno_seq < classlist.size() ; classno_seq++ )
    {
      std::string classname = classlist[classno_seq];

      if ( tbl_text_code.find ( classname ) != tbl_text_code.end() )
      {
        std::string code = tbl_text_code [ classname ];
        int classno;
        classno = classno_seq;
        tbl_classno_code[classno] = code;
        tbl_code_classno[code] = classno;
        if ( classno > maxClassNo ) maxClassNo = classno;

#if defined DEBUG_ClassNames
        fprintf ( stderr, "classno %d class code %s class text %s\n", ( int ) classno, code.c_str(), classname.c_str() );
#endif
      } else {
        fprintf ( stderr, "ClassNames::ClassNames: FATAL ERROR class >%s< not found in data set\n", classname.c_str() );
        exit ( -1 );
      }
    }
  }

  /****** after all, try to read color coding *******/
  map<string, string> list;
  datasetconf.getAllS ( "colors", list );

  if ( list.size() > 0 ) {
    for ( map<string, string>::const_iterator i = list.begin();
          i != list.end();
          i++ )
    {
      std::string value = i->second;
      std::string classname = i->first;
      int _classno = classno ( classname );
      vector<string> submatches;
      if ( StringTools::regexMatch ( value, "^ *([[:digit:]]+) *: *([[:digit:]]+) *: *([[:digit:]]+) *$", submatches )
           && ( submatches.size() == 4 ) )
      {
        int r = StringTools::convert<int> ( submatches[1] );
        int g = StringTools::convert<int> ( submatches[2] );
        int b = StringTools::convert<int> ( submatches[3] );
        long index = 256 * ( 256 * r + g ) + b;
        tbl_color_classno[index] = _classno;
        tbl_classno_color[_classno] = index;
      } else {
        fprintf ( stderr, "LabeledFileList: parse error colors >%s<\n", value.c_str() );
        exit ( -1 );
      }
    }
  }

  return true;
}

int ClassNames::getMaxClassno () const
{
  return maxClassNo;
}

void ClassNames::getRGBColor ( int classno, int & r, int & g, int & b ) const
{
  map<int, long>::const_iterator i = tbl_classno_color.find ( classno );

  if ( i == tbl_classno_color.end() )
  {
    fprintf ( stderr, "ClassNames: no color setting found for class %d\n", classno );
    getchar();
    double x = classno / ( double ) numClasses();
    double rd, gd, bd;
    convertToPseudoColor ( x, rd, gd, bd );
    r = ( int ) ( 255 * rd );
    g = ( int ) ( 255 * gd );
    b = ( int ) ( 255 * bd );
  } else {
    long color = i->second;
    b = color % 256;
    color /= 256;
    g = color % 256;
    color /= 256;
    r = color % 256;
  }
}

void ClassNames::getClassnoFromColor ( int & classno, int r, int g, int b ) const
{
  long color = 256 * ( 256 * r + g ) + b;
//  __gnu_cxx::hash_map<long, int>::const_iterator i = tbl_color_classno.find ( color );
  std::tr1::unordered_map<long, int>::const_iterator i = tbl_color_classno.find ( color );

  if ( i == tbl_color_classno.end() )
  {
    classno = -1;
  } else {
    classno = i->second;
  }
}

void ClassNames::labelToRGB ( const NICE::Image & img, NICE::ColorImage & rgb ) const
{
  int red, green, blue;

  rgb.resize ( img.width(), img.height() );

  for ( int y = 0 ; y < img.height(); y++ )
    for ( int x = 0 ; x < img.width(); x++ )
    {
      int label = img.getPixel ( x, y );
      getRGBColor ( label, red, green, blue );
      rgb.setPixel ( x, y, 0, red );
      rgb.setPixel ( x, y, 1, green );
      rgb.setPixel ( x, y, 2, blue );
    }

}

int ClassNames::getBackgroundClass () const
{
  if ( existsClassCode ( "various" ) )
    return classno ( "various" );
  else if ( existsClassCode ( "background" ) )
    return classno ( "background" );
  else if ( existsClassCode ( "clutter" ) )
    return classno ( "clutter" );
  else
    return 0;
}

void ClassNames::restore ( istream & is, int format )
{
  maxClassNo = -1;
  while ( ! is.eof() )
  {
    std::string mytext;
    std::string mycode;
    int myclassno;

    if ( ! ( is >> mytext ) ) break;
    if ( mytext == "end" ) break;
    if ( ! ( is >> mycode ) ) break;
    if ( ! ( is >> myclassno ) ) break;

    tbl_code_text.insert ( pair<string, string> ( mycode, mytext ) );
    tbl_text_code.insert ( pair<string, string> ( mytext, mycode ) );
    tbl_classno_code.insert ( pair<int, string> ( myclassno, mycode ) );
    tbl_code_classno.insert ( pair<string, int> ( mycode, myclassno ) );

    if ( myclassno > maxClassNo ) maxClassNo = myclassno;
  }
}

void ClassNames::store ( ostream & os, int format ) const
{
  assert ( tbl_classno_code.size() == tbl_code_classno.size() );
  for ( map<int, string>::const_iterator i  = tbl_classno_code.begin() ;
        i != tbl_classno_code.end();
        i++ )
  {
    int myclassno = i->first;
    std::string mycode = i->second;
    std::string mytext = text ( myclassno );

    os << mytext << "\t" << mycode << "\t" << myclassno << endl;
  }
  os << "end" << endl;
}

void ClassNames::clear ()
{
  tbl_code_text.clear();
  tbl_text_code.clear();
  tbl_classno_code.clear();
  tbl_code_classno.clear();
  tbl_color_classno.clear();
  tbl_classno_color.clear();
}