/** 
* @file VCCrossGeneralization.cpp
* @brief Combination of Classifiers
* @author Erik Rodner
* @date 12/05/2007

*/
#include <iostream>

#include "vislearning/classifier/vclassifier/VCCrossGeneralization.h"

#ifdef NICE_USELIB_ICE

using namespace OBJREC;

using namespace std;
using namespace NICE;

VCCrossGeneralization::VCCrossGeneralization( const Config *conf )
    : VCLearnFromSC(conf),
      gauss(conf), nnclassifier( conf, NULL )
{
    simpleGaussianFinished = false;
    useVotingNormalization = conf->gB("CrossGeneralization", "use_voting_normalization", true );
}

VCCrossGeneralization::~VCCrossGeneralization()
{
}

void VCCrossGeneralization::normalizeVotings ( NICE::Vector & v ) const
{
    double sum = 0.0;

    for ( size_t i = 0 ; i < v.size() ; i++ )
    {
	if ( fabs(v[i]) < 10e-7 )
	    v[i] = 10e7;
	else
	    v[i] = 1.0 / v[i];
	
	sum += v[i];
    }

    if ( fabs(sum) > 10e-5 ) 
	for ( size_t i = 0 ; i < v.size() ; i++ )
	    v[i] /= sum;
}

/** classify using simple vector */
ClassificationResult VCCrossGeneralization::classify ( const NICE::Vector & x ) const
{
    NICE::Vector votings_vec;
    std::map<int, double> votings;
    gauss.getVotings(x, votings);

    for ( map<int, double>::const_iterator j  = votings.begin();
					   j != votings.end();
					   j++ )
    {
	votings_vec.append(j->second);
    }

    if ( useVotingNormalization )
	normalizeVotings( votings_vec );
    
    return nnclassifier.classify(votings_vec);
}

void VCCrossGeneralization::preTeach ( const LabeledSetVector & teachSet )
{
    gauss.teach(teachSet);
}

void VCCrossGeneralization::teach ( const LabeledSetVector & teachSet )
{
    maxClassNo = teachSet.getMaxClassno();
    if ( ! simpleGaussianFinished ) 
    {
	gauss.finishTeaching();
	simpleGaussianFinished = true;
    }

    LOOP_ALL(teachSet)
    {
	EACH(classno,x);
	
	NICE::Vector votings_vec;
	std::map<int, double> votings;
	gauss.getVotings(x, votings);

	for ( map<int, double>::const_iterator j  = votings.begin();
					       j != votings.end();
					       j++ )
	{
	    votings_vec.append(j->second);
	}
	normalizeVotings( votings_vec );
	
	//cerr << votings_vec << endl;

	nnclassifier.teach( classno, votings_vec );
    }
}

void VCCrossGeneralization::finishTeaching()
{
    gauss.finishTeaching();
}

void VCCrossGeneralization::restore ( std::istream & is, int format )
{
    fprintf (stderr, "NOT YET IMPLEMENTED !!\n");
    exit(-1);
}

void VCCrossGeneralization::store ( std::ostream & is, int format ) const
{
    fprintf (stderr, "NOT YET IMPLEMENTED !!\n");
    exit(-1);
}

void VCCrossGeneralization::clear ()
{
    fprintf (stderr, "NOT YET IMPLEMENTED !!\n");
    exit(-1);
}

#endif