/** * @file KCOneVsAll.cpp * @brief One vs. All interface for kernel classifiers * @author Erik Rodner * @date 12/10/2009 */ #include #include "KCOneVsAll.h" #include "core/vector/Algorithms.h" #include "core/algebra/CholeskyRobust.h" #include "core/algebra/CholeskyRobustAuto.h" #include "vislearning/regression/regressionbase/TeachWithInverseKernelMatrix.h" using namespace std; using namespace NICE; using namespace OBJREC; KCOneVsAll::KCOneVsAll( const Config *conf, const KernelClassifier *prototype, const string & section ) : KernelClassifier ( conf, prototype->getKernelFunction() ) { this->prototype = prototype; this->maxClassNo = 0; this->verbose = conf->gB( section, "verbose", false ); } KCOneVsAll::~KCOneVsAll() { } void KCOneVsAll::teach ( KernelData *kernelData, const NICE::Vector & y ) { maxClassNo = (int)y.Max(); classifiers.clear(); for ( int i = 0 ; i <= maxClassNo ; i++ ) { NICE::Vector ySub ( y ); for ( size_t j = 0 ; j < ySub.size() ; j++ ) ySub[j] = ((int)y[j] == i) ? 1 : 0; KernelClassifier *classifier; classifier = prototype->clone(); fprintf (stderr, "KCOneVsAll: training classifier class %d <-> remainder\n", i ); KernelData *kernelDataCopy = kernelData->clone(); classifier->teach ( kernelDataCopy, ySub ); // FIXME: This might be trickier for kernel classifiers which need the kernel data // explicitly, but otherwise we get a huge! memory leak delete kernelDataCopy; classifiers.push_back ( pair (i, classifier) ); } } ClassificationResult KCOneVsAll::classifyKernel ( const NICE::Vector & kernelVector, double kernelSelf ) const { if ( classifiers.size() <= 0 ) fthrow(Exception, "The classifier was not trained with training data (use teach(...))"); FullVector scores ( maxClassNo+1 ); scores.set(0); for ( vector< pair >::const_iterator i = classifiers.begin(); i != classifiers.end(); i++ ) { int classno = i->first; KernelClassifier *classifier = i->second; ClassificationResult r = classifier->classifyKernel(kernelVector, kernelSelf); scores[classno] += r.scores[1]; } return ClassificationResult( scores.maxElement(), scores ); }