/** 
* @file VCLogisticRegression.cpp
* @brief Logistics Regression
* @author Erik Rodner
* @date 08/03/2009

*/
#include <iostream>

#include "vislearning/classifier/vclassifier/VCLogisticRegression.h"
#include "vislearning/math/fit/FitSigmoid.h"

using namespace OBJREC;

using namespace std;
using namespace NICE;



VCLogisticRegression::VCLogisticRegression()
{
	mlestimation = false;
	sigmoidA = 1.0;
	sigmoidB = 0.0;
}

VCLogisticRegression::VCLogisticRegression( const Config *conf ) 
{
    mlestimation = conf->gB("VCLogisticRegression", "mlestimation", false );
	sigmoidA = 1.0;
	sigmoidB = 0.0;

}

VCLogisticRegression::~VCLogisticRegression()
{
    clear();
}



/** classify using simple vector */
ClassificationResult VCLogisticRegression::classify ( const NICE::Vector & x ) const
{
	if ( x.size() != 1 )
		fthrow( Exception, "VCLogisticRegression: this classifier is only suitable for one dimensional feature vectors\n" );
    double sigmoidValue = 1.0 / ( 1.0 + exp(sigmoidA * x[0] + sigmoidB ) );

    FullVector scores ( 2 );
    scores.set ( 0.0 );
    scores[1] = sigmoidValue;
    scores[0] = 1.0 - sigmoidValue;

    return ClassificationResult ( scores.maxElement(), scores );
}

/** classify using a simple vector */
void VCLogisticRegression::teach ( const LabeledSetVector & _teachSet )
{
    maxClassNo = _teachSet.getMaxClassno();
    if ( (_teachSet.size() != 2) || (maxClassNo != 1) ) 
		fthrow ( Exception, "VCLogisticRegression: the training set is not correctly labeled with 0/1" );

    std::vector < pair< int, double > > results;
    LOOP_ALL(_teachSet)
    {
		EACH(classno,x);
		if ( x.size() != 1 )
			fthrow( Exception, "VCLogisticRegression: this classifier is only suitable for one dimensional feature vectors\n" );
		int yi = classno;
		results.push_back ( pair<int, double> ( yi, x[0] ) );
    }
    FitSigmoid::fitProbabilities ( results, sigmoidA, sigmoidB, mlestimation );
}

void VCLogisticRegression::clear ()
{
}

void VCLogisticRegression::store ( std::ostream & os, int format ) const
{
	fthrow ( Exception, "Persistent interface not yet implemented !" );
}

void VCLogisticRegression::restore ( std::istream & is, int format )
{
	fthrow ( Exception, "Persistent interface not yet implemented !" );
}

VCLogisticRegression *VCLogisticRegression::clone(void) const
{
	VCLogisticRegression *classifier = new VCLogisticRegression();

	classifier->mlestimation = this->mlestimation;
	classifier->sigmoidA = this->sigmoidA;
	classifier->sigmoidB = this->sigmoidB;

	return classifier;
}