#ifdef NICE_USELIB_ICE

#include <iostream>

#include "vislearning/classifier/vclassifier/VCNearestClassMean.h"

using namespace OBJREC;

using namespace std;

using namespace NICE;

VCNearestClassMean::VCNearestClassMean( const Config *_conf, NICE::VectorDistance<double> *_distancefunc  ) 
: VecClassifier ( _conf ), distancefunc (_distancefunc)
{    
    if ( _distancefunc == NULL )
		distancefunc = new EuclidianDistance<double>();
}

VCNearestClassMean::~VCNearestClassMean()
{
    clear();
}

/** classify using simple vector */

ClassificationResult VCNearestClassMean::classify ( const NICE::Vector & x ) const
{
     double min_distance= std::numeric_limits<double>::max();
     int min_class = -1;
     FullVector scores ( classNo.size() );
     
     for(uint i=0;i<this->classNo.size();i++)
     {
          double distance = distancefunc->calculate ( x, means[i] );
          scores[i] = - distance;
          if ( distance < min_distance)
          {
               min_distance = distance;
               min_class = classNo[i];
          }
     }
     
     return ClassificationResult ( min_class, scores );
}


void VCNearestClassMean::teach ( const LabeledSetVector & _teachSet )
{

    _teachSet.getClasses ( this->classNo );

    //initialize means
    NICE::Vector zero( _teachSet.dimension() );
    for(uint d=0;d<zero.size();d++) zero[d]=0.0;
    for(uint c=0;c<this->classNo.size();c++)
    {
	means.push_back(zero);
    }

    //add all class-specific vectors 
    int index=0;
    LOOP_ALL(_teachSet)
    {
	EACH(classno,x);
	//determine index
	for(uint c=0;c<this->classNo.size();c++)
        {
		if(classno==classNo[c]) index=c;
        }
	for(uint d=0;d<zero.size();d++)
        {
	   means[index][d]+=x[d];
        }
    }

    //normalize vectors
    for(uint c=0;c<this->classNo.size();c++)
    {
	for(uint d=0;d<zero.size();d++)
        {
	   means[c][d]/=_teachSet.count(this->classNo[c]);
        }
    }

}

void VCNearestClassMean::finishTeaching()
{
//nothing more to do
}

void VCNearestClassMean::clear ()
{
//nothing to do
}

void VCNearestClassMean::store ( std::ostream & os, int format ) const
{
    fprintf (stderr, "NOT YET IMPLEMENTED\n");
    exit(-1);
}

void VCNearestClassMean::restore ( std::istream & is, int format )
{
    fprintf (stderr, "NOT YET IMPLEMENTED\n");
    exit(-1);
}

#endif