/** * @file VCPreRandomForest.cpp * @brief Combination of a classifier with a pre-clustering using a random forest * @author Erik Rodner * @date 06/17/2010 */ #include "VCPreRandomForest.h" #include #include #include "core/image/ImageT.h" //#include "core/imagedisplay/ImageDisplay.h" using namespace OBJREC; using namespace std; using namespace NICE; VCPreRandomForest::VCPreRandomForest( const Config *conf, const std::string & section, VecClassifier *_leafClassifierPrototype ) : leafClassifierPrototype(_leafClassifierPrototype), fp(conf) { string cluster_section = conf->gS(section, "cluster_section", "RandomForest"); mEx = conf->gI("DTBRandom", "min_examples", numeric_limits::max()); mEx = 500; randomforest = new FPCRandomForests ( conf, cluster_section ); } VCPreRandomForest::~VCPreRandomForest() { // delete the random forest if ( randomforest != NULL ) delete randomforest; // delete all classifiers in the leafs for ( map::const_iterator i = leafClassifiers.begin(); i != leafClassifiers.end(); i++ ) { VecClassifier *lc = i->second; delete lc; } } ClassificationResult VCPreRandomForest::classify ( const NICE::Vector & x ) const { NICE::Vector *v = new NICE::Vector(x); Example example(v); vector leafNodes; // traverse the forest and obtain all involved leaf nodes randomforest->getLeafNodes(example, leafNodes); ClassificationResult r ( ClassificationResult::REJECTION_NONE, maxClassNo ); r.scores.set(0.0); for ( vector::const_iterator i = leafNodes.begin(); i != leafNodes.end(); i++ ) { DecisionNode *node = *i; map::const_iterator leafClassifierIt = leafClassifiers.find ( node ); if ( leafClassifierIt == leafClassifiers.end() ) { // this leaf has no associated classifier // -> we will use the random forest "score" :) // double sum = node->distribution.sum(); for (uint k = 0; k < (uint)std::min(node->distribution.size(), r.scores.size());k++) { r.scores[k] += node->distribution[k] / sum; } //fthrow(Exception, "Unable to find this leaf node !! (implementation bug)"); continue; } VecClassifier *leafClassifier = leafClassifierIt->second; ClassificationResult rSingle = leafClassifier->classify ( x ); rSingle.scores.normalize(); for (uint k = 0; k < (uint)std::min(rSingle.scores.size(), r.scores.size());k++) { r.scores[k] += rSingle.scores[k]; } } r.scores.multiply ( 1.0 / (leafNodes.size()) ); r.classno = r.scores.maxElement(); if ( fabs(r.scores.sum() - 1.0) > 1e-2 ) { //fthrow(Exception, "Ups !\n"); r.scores[0] = 1.0; } example.clean(); return r; } void VCPreRandomForest::teach ( const LabeledSetVector & teachSet ) { Examples examples; maxClassNo = teachSet.getMaxClassno(); LOOP_ALL(teachSet) { EACH(classno, x); NICE::Vector *v = new Vector(x); examples.push_back( pair (classno, Example(v))); } uint dimension = teachSet.dimension(); fp.clear(); Feature *f = new VectorFeature(dimension); f->explode(fp); // train the forest randomforest->setMaxClassNo( teachSet.getMaxClassno() ); randomforest->train ( fp, examples ); // free some useless memory, we do not need this // data structure any more examples.clean(); vector leafNodes; randomforest->getAllLeafNodes ( leafNodes ); int lsize = leafNodes.size(); cout << "leafnodes: " << lsize << endl; int leafNo = 0; #pragma omp parallel for for ( int l = 0; l < lsize; l++) { cerr << "Training classifier for leaf " << leafNo << endl; leafNo++; DecisionNode *node = leafNodes[l]; if ( node->distribution.entropy() <= 0.0) continue; if ( ! node->isLeaf() ) continue; vector examplesSet = node->trainExamplesIndices; assert(examplesSet.size() > 0); sort (examplesSet.begin(), examplesSet.end()); LabeledSetVector trainSubSet; vector counter(maxClassNo, 0.0); uint exampleIndex = 0; uint c = 0; LOOP_ALL(teachSet) { EACH(classno, x); if ( examplesSet[c] == exampleIndex ) { c++; trainSubSet.add ( classno, x ); } exampleIndex++; } VecClassifier *lc = leafClassifierPrototype->clone(); lc->teach ( trainSubSet ); leafClassifiers.insert ( pair ( node, lc ) ); } } void VCPreRandomForest::clear() { map::iterator iter; for ( iter = leafClassifiers.begin(); iter != leafClassifiers.end(); ++iter ) { iter->second->clear(); } randomforest->clear(); }