|
@@ -0,0 +1,321 @@
|
|
|
+/**
|
|
|
+* @file testClassifier.cpp
|
|
|
+* @brief main program for classifier evaluation
|
|
|
+* @author Erik Rodner
|
|
|
+* @date 2007-10-12
|
|
|
+*/
|
|
|
+
|
|
|
+#include <objrec/nice_nonvis.h>
|
|
|
+
|
|
|
+#include <fstream>
|
|
|
+#include <iostream>
|
|
|
+
|
|
|
+#include <objrec/cbaselib/MultiDataset.h>
|
|
|
+#include <objrec/iclassifier/icgeneric/CSGeneric.h>
|
|
|
+#include <objrec/cbaselib/ClassificationResults.h>
|
|
|
+#include <objrec/iclassifier/codebook/MutualInformation.h>
|
|
|
+
|
|
|
+#include "objrec/classifier/classifierbase/FeaturePoolClassifier.h"
|
|
|
+#include <objrec/classifier/fpclassifier/randomforest/FPCRandomForestTransfer.h>
|
|
|
+#include <objrec/classifier/classifierinterfaces/VCFeaturePool.h>
|
|
|
+
|
|
|
+#include <objrec/baselib/Config.h>
|
|
|
+#include <objrec/baselib/Preprocess.h>
|
|
|
+#include <objrec/baselib/StringTools.h>
|
|
|
+
|
|
|
+#include "objrec/math/cluster/GMM.h"
|
|
|
+
|
|
|
+#undef DEBUG
|
|
|
+
|
|
|
+using namespace OBJREC;
|
|
|
+
|
|
|
+using namespace NICE;
|
|
|
+using namespace std;
|
|
|
+
|
|
|
+void binarizeVector ( NICE::Vector & xout, const NICE::Vector & x, const NICE::Vector & thresholds )
|
|
|
+{
|
|
|
+ xout.resize(x.size());
|
|
|
+ for ( size_t i = 0 ; i < x.size() ; i++ )
|
|
|
+ if ( fabs(x[i]) > thresholds[i] )
|
|
|
+ xout[i] = 1.0;
|
|
|
+ else
|
|
|
+ xout[i] = 0.0;
|
|
|
+}
|
|
|
+
|
|
|
+void binarizeSet ( LabeledSetVector & dst, const LabeledSetVector & src, const NICE::Vector & thresholds )
|
|
|
+{
|
|
|
+ LOOP_ALL(src)
|
|
|
+ {
|
|
|
+ EACH(classno,x);
|
|
|
+ NICE::Vector dstv;
|
|
|
+ binarizeVector ( dstv, x, thresholds );
|
|
|
+ dst.add ( classno, dstv );
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+int main (int argc, char **argv)
|
|
|
+{
|
|
|
+ fprintf (stderr, "testClassifier: init\n");
|
|
|
+
|
|
|
+ std::set_terminate(__gnu_cxx::__verbose_terminate_handler);
|
|
|
+
|
|
|
+ Config conf ( argc, argv );
|
|
|
+
|
|
|
+ string wekafile = conf.gS("main", "weka", "");
|
|
|
+ string trainfn = conf.gS("main", "train", "train.vec");
|
|
|
+ string testfn = conf.gS("main", "test", "test.vec");
|
|
|
+ int format = conf.gI("main", "format", 0 );
|
|
|
+ bool binarize = conf.gB("main", "binarize", false );
|
|
|
+ int wekaclass = conf.gI("main", "wekaclass", 1 );
|
|
|
+ string classifier_cache = conf.gS("main", "classifiercache", "");
|
|
|
+ string classifier_cache_in = conf.gS("main", "classifierin", "");
|
|
|
+ int numRuns = conf.gI("main", "runs", 1);
|
|
|
+
|
|
|
+ // classno:text,classno:text,...
|
|
|
+ string classes = conf.gS("main", "classes", "");
|
|
|
+ int classesnb = conf.gI("main", "classes", 0);
|
|
|
+ string classesconf = conf.gS("main", "classesconf", "");
|
|
|
+
|
|
|
+ fprintf (stderr, "testClassifier: reading config\n");
|
|
|
+ Preprocess::Init ( &conf );
|
|
|
+
|
|
|
+ fprintf (stderr, "testClassifier: reading multi dataset\n");
|
|
|
+ int testMaxClassNo;
|
|
|
+ int trainMaxClassNo;
|
|
|
+
|
|
|
+
|
|
|
+ ClassNames *classNames;
|
|
|
+ if(classes.size() == 0 && classesnb != 0)
|
|
|
+ {
|
|
|
+ classNames = new ClassNames ();
|
|
|
+ for ( int classno = 0 ; classno < classesnb ; classno++ )
|
|
|
+ {
|
|
|
+ classNames->addClass ( classno, StringTools::convertToString<int> ( classno ), StringTools::convertToString<int> (classno) );
|
|
|
+ }
|
|
|
+ trainMaxClassNo = classNames->getMaxClassno();
|
|
|
+ testMaxClassNo = trainMaxClassNo;
|
|
|
+ }
|
|
|
+ else
|
|
|
+ if ( classes.size() > 0 )
|
|
|
+ {
|
|
|
+ classNames = new ClassNames ();
|
|
|
+
|
|
|
+ vector<string> classes_sub;
|
|
|
+ StringTools::split ( string(classes), ',', classes_sub );
|
|
|
+
|
|
|
+ for ( vector<string>::const_iterator i = classes_sub.begin();
|
|
|
+ i != classes_sub.end(); i++ )
|
|
|
+ {
|
|
|
+ vector<string> desc;
|
|
|
+ StringTools::split ( *i, ':', desc);
|
|
|
+ if ( desc.size() != 2 )
|
|
|
+ break;
|
|
|
+ int classno = StringTools::convert<int> ( desc[0] );
|
|
|
+ classNames->addClass ( classno, desc[1], desc[1] );
|
|
|
+ }
|
|
|
+
|
|
|
+ trainMaxClassNo = classNames->getMaxClassno();
|
|
|
+ testMaxClassNo = trainMaxClassNo;
|
|
|
+
|
|
|
+ classNames->store(cout);
|
|
|
+ }
|
|
|
+ else if ( classesconf.size() > 0 ) {
|
|
|
+ classNames = new ClassNames ();
|
|
|
+ Config cConf ( classesconf );
|
|
|
+ classNames->readFromConfig ( cConf, "*" );
|
|
|
+ trainMaxClassNo = classNames->getMaxClassno();
|
|
|
+ testMaxClassNo = trainMaxClassNo;
|
|
|
+ }
|
|
|
+ else
|
|
|
+ {
|
|
|
+ MultiDataset md ( &conf );
|
|
|
+ classNames = new ClassNames ( md.getClassNames("train"), "*" );
|
|
|
+ testMaxClassNo = md.getClassNames("test").getMaxClassno();
|
|
|
+ trainMaxClassNo = md.getClassNames("train").getMaxClassno();
|
|
|
+ }
|
|
|
+
|
|
|
+ LabeledSetVector train;
|
|
|
+ if ( classifier_cache_in.size() <= 0 )
|
|
|
+ {
|
|
|
+ fprintf (stderr, "testClassifier: Reading training dataset from %s\n", trainfn.c_str() );
|
|
|
+ train.read ( trainfn, format );
|
|
|
+ train.printInformation();
|
|
|
+ } else {
|
|
|
+ fprintf (stderr, "testClassifier: skipping training set %s\n", trainfn.c_str() );
|
|
|
+ }
|
|
|
+
|
|
|
+ LabeledSetVector test;
|
|
|
+ fprintf (stderr, "testClassifier: Reading test dataset from %s\n", testfn.c_str() );
|
|
|
+ test.read ( testfn, format );
|
|
|
+
|
|
|
+ GMM *gmm = NULL;
|
|
|
+ int nbgmm = conf.gI("main", "gmm", 0);
|
|
|
+ if(nbgmm > 0)
|
|
|
+ {
|
|
|
+ gmm = new GMM(&conf, nbgmm);
|
|
|
+ VVector vset;
|
|
|
+ Vector l;
|
|
|
+ train.getFlatRepresentation(vset,l);
|
|
|
+ gmm->computeMixture(vset);
|
|
|
+
|
|
|
+ map<int, vector<NICE::Vector *> >::iterator iter;
|
|
|
+ for( iter = train.begin(); iter != train.end(); ++iter )
|
|
|
+ {
|
|
|
+ for(uint i = 0; i < iter->second.size(); ++i)
|
|
|
+ {
|
|
|
+ gmm->getProbs(*(iter->second[i]),*(iter->second[i]));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ for( iter = test.begin(); iter != test.end(); ++iter )
|
|
|
+ {
|
|
|
+ for(uint i = 0; i < iter->second.size(); ++i)
|
|
|
+ {
|
|
|
+ gmm->getProbs(*(iter->second[i]),*(iter->second[i]));
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ ClassificationResults cresults;
|
|
|
+
|
|
|
+
|
|
|
+ for (int runs = 0 ; runs < numRuns ; runs++ ) {
|
|
|
+ VecClassifier *vec_classifier = NULL;
|
|
|
+
|
|
|
+ if ( conf.gS("main", "classifier") == "random_forest_transfer" )
|
|
|
+ {
|
|
|
+ FeaturePoolClassifier *fpc = new FPCRandomForestTransfer ( &conf, classNames );
|
|
|
+ vec_classifier = new VCFeaturePool ( &conf, fpc );
|
|
|
+ } else {
|
|
|
+ vec_classifier = CSGeneric::selectVecClassifier ( &conf, "main" );
|
|
|
+ }
|
|
|
+
|
|
|
+ NICE::Vector thresholds;
|
|
|
+
|
|
|
+ if ( classifier_cache_in.size() <= 0 )
|
|
|
+ {
|
|
|
+ if ( binarize ) {
|
|
|
+ LabeledSetVector trainbin;
|
|
|
+ NICE::Vector mis;
|
|
|
+ MutualInformation mi;
|
|
|
+ fprintf (stderr, "testClassifier: computing mutual information\n");
|
|
|
+ mi.computeThresholdsOverall ( train, thresholds, mis );
|
|
|
+ fprintf (stderr, "testClassifier: done!\n");
|
|
|
+ binarizeSet ( trainbin, train, thresholds );
|
|
|
+ vec_classifier->teach ( trainbin );
|
|
|
+ } else {
|
|
|
+
|
|
|
+ vec_classifier->teach ( train );
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+ vec_classifier->finishTeaching();
|
|
|
+
|
|
|
+ if ( classifier_cache.size() > 0 )
|
|
|
+ vec_classifier->save ( classifier_cache );
|
|
|
+ } else {
|
|
|
+ vec_classifier->setMaxClassNo ( classNames->getMaxClassno() );
|
|
|
+ vec_classifier->read ( classifier_cache_in );
|
|
|
+ }
|
|
|
+
|
|
|
+ ProgressBar pb ("Classification");
|
|
|
+
|
|
|
+ pb.show();
|
|
|
+
|
|
|
+ std::vector<int> count ( testMaxClassNo+1, 0 );
|
|
|
+
|
|
|
+ std::vector<int> correct ( testMaxClassNo+1, 0 );
|
|
|
+
|
|
|
+ MatrixT<int> confusionMatrix ( testMaxClassNo+1, trainMaxClassNo+1, 0 );
|
|
|
+
|
|
|
+ int n = test.count();
|
|
|
+ LOOP_ALL(test)
|
|
|
+ {
|
|
|
+ EACH(classno,v);
|
|
|
+ pb.update ( n );
|
|
|
+
|
|
|
+ fprintf (stderr, "\tclassification\n" );
|
|
|
+ ClassificationResult r;
|
|
|
+
|
|
|
+ if ( binarize )
|
|
|
+ {
|
|
|
+ NICE::Vector vout;
|
|
|
+ binarizeVector ( vout, v, thresholds );
|
|
|
+ r = vec_classifier->classify ( vout );
|
|
|
+ } else {
|
|
|
+ r = vec_classifier->classify ( v );
|
|
|
+ }
|
|
|
+ r.classno_groundtruth = classno;
|
|
|
+ r.classname = classNames->text( r.classno );
|
|
|
+
|
|
|
+#ifdef DEBUG
|
|
|
+ if ( r.classno == classno )
|
|
|
+ fprintf (stderr, "+ classification %d (\"%s\") <-> %d (\"%s\") score=%f\n", classno,
|
|
|
+ classNames->text(classno).c_str(), r.classno, r.classname.c_str(), r.scores[r.classno]);
|
|
|
+ else
|
|
|
+ fprintf (stderr, "- classification %d (\"%s\") <-> %d (\"%s\") score=%f\n", classno,
|
|
|
+ classNames->text(classno).c_str(), r.classno, r.classname.c_str(), r.scores[r.classno] );
|
|
|
+#endif
|
|
|
+
|
|
|
+ r.scores.store ( cerr );
|
|
|
+ if ( r.classno >= 0 )
|
|
|
+ {
|
|
|
+ if ( classno == r.classno ) correct[classno]++;
|
|
|
+
|
|
|
+ count[classno]++;
|
|
|
+
|
|
|
+ if ( r.ok() ) {
|
|
|
+ confusionMatrix(classno, r.classno)++;
|
|
|
+ }
|
|
|
+ cresults.push_back ( r );
|
|
|
+ }
|
|
|
+ }
|
|
|
+ pb.hide();
|
|
|
+
|
|
|
+ if ( wekafile.size() > 0 )
|
|
|
+ {
|
|
|
+ string wekafile_s = wekafile;
|
|
|
+ if ( numRuns > 1 )
|
|
|
+ wekafile_s = wekafile_s + "." + StringTools::convertToString<int>(runs) + ".txt";
|
|
|
+ cresults.writeWEKA ( wekafile_s, wekaclass );
|
|
|
+ }
|
|
|
+
|
|
|
+ int count_total = 0;
|
|
|
+ int correct_total = 0;
|
|
|
+ int classes_tested = 0;
|
|
|
+ double avg_recognition = 0.0;
|
|
|
+ for ( size_t classno = 0; classno < correct.size(); classno++ )
|
|
|
+ {
|
|
|
+ if ( count[classno] == 0 ) {
|
|
|
+ fprintf (stdout, "class %d not tested !!\n", (int)classno);
|
|
|
+ } else {
|
|
|
+ fprintf (stdout, "classification result class %d (\"%s\") : %5.2f %%\n",
|
|
|
+ (int)classno, classNames->text(classno).c_str(), correct[classno]*100.0/count[classno] );
|
|
|
+ avg_recognition += correct[classno]/(double)count[classno];
|
|
|
+ classes_tested++;
|
|
|
+ }
|
|
|
+
|
|
|
+ count_total += count[classno];
|
|
|
+ correct_total += correct[classno];
|
|
|
+ }
|
|
|
+ avg_recognition /= classes_tested;
|
|
|
+
|
|
|
+
|
|
|
+ fprintf (stdout, "overall recognition rate : %-5.3f %%\n", correct_total*100.0/count_total );
|
|
|
+ fprintf (stdout, "average recognition rate : %-5.3f %%\n", avg_recognition*100 );
|
|
|
+ fprintf (stdout, "total:%d misclassified:%d\n", count_total, count_total - correct_total );
|
|
|
+
|
|
|
+ int max_count = *(max_element( count.begin(), count.end() ));
|
|
|
+ fprintf (stdout, "no of classes : %d\n", classNames->numClasses() );
|
|
|
+ fprintf (stdout, "lower bound 1 : %f\n", 100.0/(classNames->numClasses()));
|
|
|
+ fprintf (stdout, "lower bound 2 : %f\n", max_count * 100.0 / (double) count_total);
|
|
|
+
|
|
|
+ cout << confusionMatrix << endl;
|
|
|
+
|
|
|
+ delete vec_classifier;
|
|
|
+ }
|
|
|
+
|
|
|
+ delete classNames;
|
|
|
+
|
|
|
+ return 0;
|
|
|
+}
|