/** * @file testClassifier.cpp * @brief main program for classifier evaluation * @author Erik Rodner * @date 2007-10-12 */ #include #include #include #include "vislearning/classifier/genericClassifierSelection.h" #include #include #include "vislearning/classifier/classifierbase/FeaturePoolClassifier.h" #include #include #include "core/basics/Config.h" #include #include #undef DEBUG using namespace OBJREC; using namespace NICE; using namespace std; void binarizeVector( NICE::Vector & xout, const NICE::Vector & x, const NICE::Vector & thresholds ) { xout.resize( x.size() ); for ( size_t i = 0 ; i < x.size() ; i++ ) if ( fabs( x[i] ) > thresholds[i] ) xout[i] = 1.0; else xout[i] = 0.0; } void binarizeSet( LabeledSetVector & dst, const LabeledSetVector & src, const NICE::Vector & thresholds ) { LOOP_ALL( src ) { EACH( classno, x ); NICE::Vector dstv; binarizeVector( dstv, x, thresholds ); dst.add( classno, dstv ); } } int main( int argc, char **argv ) { fprintf( stderr, "testClassifier: init\n" ); std::set_terminate( __gnu_cxx::__verbose_terminate_handler ); Config conf( argc, argv ); string wekafile = conf.gS( "main", "weka", "" ); string trainfn = conf.gS( "main", "train", "train.vec" ); string testfn = conf.gS( "main", "test", "test.vec" ); int format = conf.gI( "main", "format", 0 ); bool binarize = conf.gB( "main", "binarize", false ); int wekaclass = conf.gI( "main", "wekaclass", 1 ); string classifier_cache = conf.gS( "main", "classifiercache", "" ); string classifier_cache_in = conf.gS( "main", "classifierin", "" ); int numRuns = conf.gI( "main", "runs", 1 ); string writeImgNet = conf.gS( "main", "imgnet", "" ); // classno:text,classno:text,... string classes = conf.gS( "main", "classes", "" ); int classesnb = conf.gI( "main", "classes", 0 ); string classesconf = conf.gS( "main", "classesconf", "" ); fprintf( stderr, "testClassifier: reading config\n" ); Preprocess::Init( &conf ); fprintf( stderr, "testClassifier: reading multi dataset\n" ); int testMaxClassNo; int trainMaxClassNo; ClassNames *classNames; if ( classes.size() == 0 && classesnb != 0 ) { classNames = new ClassNames(); for ( int classno = 0 ; classno < classesnb ; classno++ ) { classNames->addClass( classno, StringTools::convertToString ( classno ), StringTools::convertToString ( classno ) ); } trainMaxClassNo = classNames->getMaxClassno(); testMaxClassNo = trainMaxClassNo; } else if ( classes.size() > 0 ) { classNames = new ClassNames(); vector classes_sub; StringTools::split( string( classes ), ',', classes_sub ); for ( vector::const_iterator i = classes_sub.begin(); i != classes_sub.end(); i++ ) { vector desc; StringTools::split( *i, ':', desc ); if ( desc.size() != 2 ) break; int classno = StringTools::convert ( desc[0] ); classNames->addClass( classno, desc[1], desc[1] ); } trainMaxClassNo = classNames->getMaxClassno(); testMaxClassNo = trainMaxClassNo; classNames->store( cout ); } else if ( classesconf.size() > 0 ) { classNames = new ClassNames(); Config cConf( classesconf ); classNames->readFromConfig( cConf, "*" ); trainMaxClassNo = classNames->getMaxClassno(); testMaxClassNo = trainMaxClassNo; } else { MultiDataset md( &conf ); classNames = new ClassNames( md.getClassNames( "train" ), "*" ); testMaxClassNo = md.getClassNames( "test" ).getMaxClassno(); trainMaxClassNo = md.getClassNames( "train" ).getMaxClassno(); } LabeledSetVector train; if ( classifier_cache_in.size() <= 0 ) { fprintf( stderr, "testClassifier: Reading training dataset from %s\n", trainfn.c_str() ); train.read( trainfn, format ); train.printInformation(); } else { fprintf( stderr, "testClassifier: skipping training set %s\n", trainfn.c_str() ); } LabeledSetVector test; fprintf( stderr, "testClassifier: Reading test dataset from %s\n", testfn.c_str() ); test.read( testfn, format ); ClassificationResults cresults; ofstream outinet; if ( writeImgNet.length() > 0 ) { outinet.open( writeImgNet.c_str() ); } for ( int runs = 0 ; runs < numRuns ; runs++ ) { VecClassifier *vec_classifier = NULL; if ( conf.gS( "main", "classifier" ) == "random_forest_transfer" ) { FeaturePoolClassifier *fpc = new FPCRandomForestTransfer( &conf, classNames ); vec_classifier = new VCFeaturePool( &conf, fpc ); } else { string classifierselection = conf.gS("main","classifier"); vec_classifier = GenericClassifierSelection::selectVecClassifier( &conf, classifierselection ); } NICE::Vector thresholds; if ( classifier_cache_in.size() <= 0 ) { if ( binarize ) { LabeledSetVector trainbin; NICE::Vector mis; MutualInformation mi; fprintf( stderr, "testClassifier: computing mutual information\n" ); mi.computeThresholdsOverall( train, thresholds, mis ); fprintf( stderr, "testClassifier: done!\n" ); binarizeSet( trainbin, train, thresholds ); vec_classifier->teach( trainbin ); } else { vec_classifier->teach( train ); } vec_classifier->finishTeaching(); if ( classifier_cache.size() > 0 ) vec_classifier->save( classifier_cache ); } else { vec_classifier->setMaxClassNo( classNames->getMaxClassno() ); vec_classifier->read( classifier_cache_in ); } ProgressBar pb( "Classification" ); pb.show(); std::vector count( testMaxClassNo + 1, 0 ); std::vector correct( testMaxClassNo + 1, 0 ); MatrixT confusionMatrix( testMaxClassNo + 1, trainMaxClassNo + 1, 0 ); int n = test.count(); LOOP_ALL( test ) { EACH( classno, v ); pb.update( n ); #ifdef DEBUG fprintf( stderr, "\tclassification\n" ); #endif ClassificationResult r; if ( binarize ) { NICE::Vector vout; binarizeVector( vout, v, thresholds ); r = vec_classifier->classify( vout ); } else { r = vec_classifier->classify( v ); } r.classno_groundtruth = classno; r.classname = classNames->text( r.classno ); #ifdef DEBUG if ( r.classno == classno ) fprintf( stderr, "+ classification %d (\"%s\") <-> %d (\"%s\") score=%f\n", classno, classNames->text( classno ).c_str(), r.classno, r.classname.c_str(), r.scores[r.classno] ); else fprintf( stderr, "- classification %d (\"%s\") <-> %d (\"%s\") score=%f\n", classno, classNames->text( classno ).c_str(), r.classno, r.classname.c_str(), r.scores[r.classno] ); r.scores.store( cerr ); #endif if ( writeImgNet.length() > 0 ) { for ( int z = 1; z < r.scores.size() - 1; z++ ) { outinet << r.scores[z] << " "; } outinet << r.scores[r.scores.size()-1] << endl; } if ( r.classno >= 0 ) { if ( classno == r.classno ) correct[classno]++; count[classno]++; if ( r.ok() ) { confusionMatrix( classno, r.classno )++; } cresults.push_back( r ); } } pb.hide(); if ( wekafile.size() > 0 ) { string wekafile_s = wekafile; if ( numRuns > 1 ) wekafile_s = wekafile_s + "." + StringTools::convertToString( runs ) + ".txt"; cresults.writeWEKA( wekafile_s, wekaclass ); } int count_total = 0; int correct_total = 0; int classes_tested = 0; double avg_recognition = 0.0; for ( size_t classno = 0; classno < correct.size(); classno++ ) { if ( count[classno] == 0 ) { fprintf( stdout, "class %d not tested !!\n", ( int )classno ); } else { fprintf( stdout, "classification result class %d (\"%s\") : %5.2f %%\n", ( int )classno, classNames->text( classno ).c_str(), correct[classno]*100.0 / count[classno] ); avg_recognition += correct[classno] / ( double )count[classno]; classes_tested++; } count_total += count[classno]; correct_total += correct[classno]; } avg_recognition /= classes_tested; fprintf( stdout, "overall recognition rate : %-5.3f %%\n", correct_total*100.0 / count_total ); fprintf( stdout, "average recognition rate : %-5.3f %%\n", avg_recognition*100 ); fprintf( stdout, "total:%d misclassified:%d\n", count_total, count_total - correct_total ); int max_count = *( max_element( count.begin(), count.end() ) ); fprintf( stdout, "no of classes : %d\n", classNames->numClasses() ); fprintf( stdout, "lower bound 1 : %f\n", 100.0 / ( classNames->numClasses() ) ); fprintf( stdout, "lower bound 2 : %f\n", max_count * 100.0 / ( double ) count_total ); cout << confusionMatrix << endl; delete vec_classifier; } delete classNames; return 0; }