/** 
* @file VCNearestNeighbour.cpp
* @brief Simple Nearest Neighbour Implementation
* @author Erik Rodner
* @date 10/25/2007

*/
#include <iostream>
#include <queue>

#include "vislearning/classifier/vclassifier/VCNearestNeighbour.h"

using namespace OBJREC;

using namespace std;
using namespace NICE;


#undef DEBUG_VCN


VCNearestNeighbour::VCNearestNeighbour ( const Config *_conf, NICE::VectorDistance<double> *_distancefunc ) 
    : VecClassifier ( _conf ), distancefunc (_distancefunc)
{
    K = _conf->gI("VCNearestNeighbour", "K", 1 );
    if ( _distancefunc == NULL )
		distancefunc = new EuclidianDistance<double>();
}

VCNearestNeighbour::VCNearestNeighbour ( const VCNearestNeighbour & src ) : VecClassifier()
{
	if ( src.teachSet.size() )
		fthrow(Exception, "It is not yet possible to clone an already trained nearest neighbour classifier.");

	distancefunc = src.distancefunc;
	K = src.K;	
	maxClassNo = src.maxClassNo;
}

VCNearestNeighbour::~VCNearestNeighbour()
{
}

/** classify using simple vector */
ClassificationResult VCNearestNeighbour::classify ( const NICE::Vector & x ) const
{
    double mindist = std::numeric_limits<double>::max();
    int    minclass = 0;
    FullVector mindists ( maxClassNo + 1 );
    mindists.set ( mindist );

    if ( teachSet.count() <= 0 ) {
		fprintf (stderr, "VCNearestNeighbour: please train this classifier before classifying\n");
		exit(-1);
    }

    priority_queue< pair<double, int> > examples;
    LOOP_ALL(teachSet) 
    {
      EACH(classno,y)

      double distance = distancefunc->calculate ( x, y );

	if ( NICE::isNaN(distance) )
	{
	    fprintf (stderr, "VCNearestNeighbour::classify: NAN value found !!\n");
	    cerr << x << endl;
	    cerr << y << endl;
	}

      if ( mindists[classno] > distance )
          mindists[classno] = distance;

      if ( mindist > distance )
      {
          minclass = classno;
          mindist  = distance;
      }
      if ( K > 1 ) 
        examples.push ( pair<double, int> ( -distance, classno ) );
    }

    if ( mindist == 0.0 )
	fprintf (stderr, "VCNearestNeighbour::classify WARNING distance is zero, reclassification?\n");

#ifdef DEBUG_VCN
    for ( int i = 0 ; i < mindists.size() ; i++ )
		fprintf (stderr, "class %d : %f\n", i, mindists.get(i) );
#endif

    if ( K > 1 )
    {
      FullVector votes ( maxClassNo + 1 );
      votes.set(0.0);
      for ( int k = 0 ; k < K ; k++ )
      {
        const pair<double, int> & t = examples.top();
        votes[ t.second ]++;
        examples.pop();
      }
      votes.normalize();
      return ClassificationResult ( votes.maxElement(), votes );
    }
    else
    {
      //do we really want to do this? Only useful, if we want to obtain probability like scores      
//       for ( int i = 0 ; i < mindists.size() ; i++ )
//       {
//         mindists[i] = 1.0 / (mindists[i] + 1.0);
//       }
      //mindists.normalize();
      return ClassificationResult ( minclass, mindists );
    }
}

/** classify using a simple vector */
void VCNearestNeighbour::teach ( const LabeledSetVector & _teachSet )
{
    fprintf (stderr, "teach using all !\n");
    maxClassNo = _teachSet.getMaxClassno();
    //NOTE this is crucial if we clear _teachSet afterwards!
    //therefore, take care NOT to call _techSet.clear() somewhere out of this method
    this->teachSet = _teachSet;
    
    std::cerr << "number of known training samples: " << this->teachSet.begin()->second.size() << std::endl;
    
//     //just for testing - remove everything but the first element
//     map< int, vector<NICE::Vector *> >::iterator it = this->teachSet.begin();
//     it++;
//     this->teachSet.erase(it, this->teachSet.end());
//     std::cerr << "keep " << this->teachSet.size() << " elements" << std::endl;
    
    
}


void VCNearestNeighbour::teach ( int classno, const NICE::Vector & x )
{
    std::cerr << "VCNearestNeighbour::teach one new example" << std::endl;
    
    for ( size_t i = 0 ; i < x.size() ; i++ )
	if ( NICE::isNaN(x[i]) ) 
	{
	    fprintf (stderr, "There is a NAN value in within this vector: x[%d] = %f\n", (int)i, x[i]);
	    cerr << x << endl;
	    exit(-1);
	}
    
    if ( classno > maxClassNo ) maxClassNo = classno;

    teachSet.add ( classno, x );
    
    std::cerr << "adden class " << classno << " with feature " << x << std::endl;
    int tmpCnt(0);
    for (LabeledSetVector::const_iterator it = this->teachSet.begin(); it != this->teachSet.end(); it++)
    {
      tmpCnt += it->second.size();
    }
    std::cerr << "number of known training samples: " << tmpCnt << std::endl;
}

void VCNearestNeighbour::finishTeaching()
{
}

VCNearestNeighbour *VCNearestNeighbour::clone() const
{
	VCNearestNeighbour *myclone = new VCNearestNeighbour ( *this );
	return myclone;
}

void VCNearestNeighbour::clear ()
{
    teachSet.clear();
}

void VCNearestNeighbour::store ( std::ostream & os, int format ) const
{
    teachSet.store ( os, format );
}

void VCNearestNeighbour::restore ( std::istream & is, int format )
{
    teachSet.restore ( is, format );
    maxClassNo = teachSet.getMaxClassno();
}