/** * @file MultiDataset.cpp * @brief multiple datasets * @author Erik Rodner * @date 02/08/2008 */ #include #include #include #include "vislearning/cbaselib/ClassNames.h" #include "core/basics/StringTools.h" #include "core/basics/FileMgt.h" #include "vislearning/cbaselib/MultiDataset.h" using namespace OBJREC; using namespace std; using namespace NICE; #undef DEBUG_MultiDataset void MultiDataset::selectExamples ( const std::string & examples_command, const LabeledSet & base, LabeledSet & positives, LabeledSet & negatives, const ClassNames & cn ) const { vector examples; StringTools::split ( examples_command, ';', examples ); set processed_classes; for ( vector::const_iterator i = examples.begin(); i != examples.end(); i++ ) { const std::string & cmd = *i; vector parts; StringTools::split ( cmd, ' ', parts ); if ( (parts.size() != 3) && ((parts.size() != 2) || (parts[0] != "all")) ) fthrow( Exception, "Syntax error " << examples_command ); const std::string & mode = parts[0]; const std::string & csel = parts[1]; double parameter = (parts.size() == 3 ) ? atof(parts[2].c_str()) : 0.0; map fpe; set selection; cn.getSelection ( csel, selection ); for ( set::const_iterator j = selection.begin(); j != selection.end(); j++ ) { int classno = *j; if ( processed_classes.find(classno) == processed_classes.end() ) { #ifdef DEBUG_MultiDataset fprintf (stderr, "class %s: %s %d\n", cn.text(classno).c_str(), mode.c_str(), (int)parameter ); #endif fpe[*j] = (int)parameter; processed_classes.insert(classno); } else { if ( csel != "*" ) { fthrow ( Exception, "Example selection method for class %s has multiple specifications" << cn.text(classno) ); } } } if ( mode == "seq" ) { LabeledSetSelection::selectSequential ( fpe, base, positives, negatives ); #ifdef DEBUG_MultiDataset fprintf (stderr, "MultiDataset: after special seq selection: %d\n", positives.count() ); #endif } else if ( mode == "step" ) { LabeledSetSelection::selectSequentialStep ( fpe, base, positives, negatives ); #ifdef DEBUG_MultiDataset fprintf (stderr, "MultiDataset: after special step selection: %d\n", positives.count() ); #endif } else if ( mode == "random" ) { LabeledSetSelection::selectRandom ( fpe, base, positives, negatives ); #ifdef DEBUG_MultiDataset fprintf (stderr, "MultiDataset: after special random selection: %d\n", positives.count() ); #endif } else if ( mode == "all" ) { if ( (int)selection.size() == cn.numClasses() ) { // preserve permutation LabeledSet::Permutation permutation; base.getPermutation ( permutation ); for ( LabeledSet::Permutation::iterator i = permutation.begin(); i != permutation.end(); i++ ) { int classno = i->first; ImageInfo *element = const_cast< ImageInfo * > ( i->second ); positives.add_reference ( classno, element ); } } else { LabeledSetSelection::selectClasses ( selection, base, positives, negatives ); } #ifdef DEBUG_MultiDataset fprintf (stderr, "MultiDataset: after special class selection: %d\n", positives.count() ); #endif } else { fthrow ( Exception, "Wrong value for parameter example\n"); } } #ifdef DEBUG_MultiDataset fprintf (stderr, "MultiDataset: after special selection operations: %d\n", positives.count() ); #endif set allclasses; cn.getSelection ( "*", allclasses ); set allnegative_classes; // add all examples from allclasses \setminus processed_classes set_difference(allclasses.begin(), allclasses.end(), processed_classes.begin(), processed_classes.end(), inserter(allnegative_classes, allnegative_classes.end())); LabeledSet dummy; LabeledSetSelection::selectClasses ( allnegative_classes, base, negatives, dummy ); } /** MultiDataset ------- constructor */ MultiDataset::MultiDataset( const Config *conf ) { std::set blocks; conf->getAllBlocks ( blocks ); map dsconfs; map dirs; for ( set::iterator i = blocks.begin(); i != blocks.end(); ) { if ( conf->gB(*i, "disable", false) ) { i++; continue; } std::string dataset = conf->gS( *i, "dataset", "unknown" ); if ( dataset == "unknown" ) blocks.erase(i++); else { #ifdef DEBUG_MultiDataset fprintf (stderr, "Reading dataset config for block [%s]\n", i->c_str() ); #endif Config dsconf ( (dataset + "/dataset.conf").c_str() ); dirs[*i] = dataset; dsconfs[*i] = dsconf; i++; } } if ( blocks.find("traintest") != blocks.end() ) { LabeledSet ls_base; LabeledSet ls_train (true); LabeledSet ls_nontrain (true); LabeledSet ls_test (true); LabeledSet dummy (true); LabeledSet temp (true); bool localizationInfoDisabled = conf->gB("traintest", "disable_localization_info", false ); std::string classselection_train = conf->gS("traintest", "classselection_train", "*"); std::string classselection_test = conf->gS("traintest", "classselection_test", "*"); classnames["traintest"] = ClassNames(); std::string classNamesTxt = dirs["traintest"] + "/classnames.txt"; if ( FileMgt::fileExists ( classNamesTxt ) ) { classnames["traintest"].read ( classNamesTxt ); } else { classnames["traintest"].readFromConfig ( dsconfs["traintest"], classselection_train ); } lfl.get ( dirs["traintest"], dsconfs["traintest"], classnames["traintest"], ls_base, localizationInfoDisabled, conf->gB("traintest", "debug_dataset", false ) ); std::string examples_train = conf->gS("traintest", "examples_train" ); selectExamples ( examples_train, ls_base, ls_train, ls_nontrain, classnames["traintest"] ); set selection_test; classnames["traintest"].getSelection ( classselection_test, selection_test ); std::string examples_test = conf->gS("traintest", "examples_test" ); if ( examples_test == "reclassification" ) { LabeledSetSelection::selectClasses ( selection_test, ls_train, ls_test, dummy ); } else { selectExamples ( examples_test, ls_nontrain, temp, dummy, classnames["traintest"] ); LabeledSetSelection::selectClasses ( selection_test, temp, ls_test, dummy ); } classnames["train"] = classnames["traintest"]; classnames["test"] = ClassNames ( classnames["traintest"], classselection_test ); datasets["test"] = ls_test; datasets["train"] = ls_train; } for ( set::const_iterator i = blocks.begin(); i != blocks.end(); i++ ) { std::string name = *i; if ( classnames.find(name) != classnames.end() ) continue; if ( conf->gB(name, "disable", false) == true ) continue; if ( dsconfs.find(name) == dsconfs.end() ) continue; LabeledSet ls_base; LabeledSet ls (true); LabeledSet dummy (true); LabeledSet temp (true); bool localizationInfoDisabled = conf->gB(name, "disable_localization_info", false ); std::string classselection = conf->gS(name, "classselection", "*"); classnames[name] = ClassNames(); std::string classNamesTxt = dirs[name] + "/classnames.txt"; if ( FileMgt::fileExists ( classNamesTxt ) ) { #ifdef DEBUG_MultiDataset fprintf (stderr, "MultiDataset: reading class names from %s\n", classNamesTxt.c_str() ); #endif classnames[name].read ( classNamesTxt ); } else { #ifdef DEBUG_MultiDataset fprintf (stderr, "MultiDataset: reading class names from dataset config file\n" ); #endif classnames[name].readFromConfig ( dsconfs[name], classselection ); } lfl.get ( dirs[name], dsconfs[name], classnames[name], ls_base, localizationInfoDisabled, conf->gB(name, "debug_dataset", false ) ); #ifdef DEBUG_MultiDataset fprintf (stderr, "MultiDataset: class names -->\n" ); classnames[name].store ( cerr ); fprintf (stderr, "MultiDataset: all information about %s set obtained ! (size %d)\n", name.c_str(), ls_base.count() ); #endif std::string examples = conf->gS(name, "examples", "all *" ); selectExamples ( examples, ls_base, ls, dummy, classnames[name] ); #ifdef DEBUG_MultiDataset fprintf (stderr, "MultiDataset: size after selection %d\n", ls.count() ); #endif datasets[name] = ls; } bool dumpSelections = conf->gB("datasets", "dump_selection", false); if ( dumpSelections ) { for ( map::const_iterator i = datasets.begin(); i != datasets.end(); i++ ) { const std::string & name = i->first; const LabeledSet & ls = i->second; const ClassNames & classNames = classnames[name]; mkdir ( name.c_str(), 0755 ); std::string filelist = name + "/files.txt"; ofstream olist ( filelist.c_str(), ios::out ); if ( !olist.good() ) fthrow (IOException, "Unable to dump selections to " << filelist ); LOOP_ALL_S(ls) { EACH_S(classno, file); std::string classtext = classNames.code(classno); olist << classtext << " " << file << endl; } olist.close(); std::string datasetconf = name + "/dataset.conf"; ofstream oconf ( datasetconf.c_str(), ios::out ); if ( !oconf.good() ) fthrow (IOException, "Unable to dump selections to " << datasetconf ); set classnos; classNames.getSelection ( "*", classnos); oconf << "[main]" << endl; oconf << "filelist = \"files.txt\"" << endl << endl; oconf << "[classnames]" << endl; for ( set::const_iterator i = classnos.begin(); i != classnos.end(); i++ ) { const std::string & code = classNames.code(*i); const std::string & text = classNames.text(*i); oconf << code << " = \"" << text << "\"" << endl; } oconf.close(); classNames.save ( name + "/classnames.txt" ); } } } MultiDataset::~MultiDataset() { } const ClassNames & MultiDataset::getClassNames ( const std::string & key ) const { map::const_iterator i = classnames.find(key); if ( i == classnames.end() ) { fprintf (stderr, "MultiDataSet::getClassNames() FATAL ERROR: dataset <%s> not found !\n", key.c_str() ); exit(-1); } return (i->second); } const LabeledSet *MultiDataset::operator[] ( const std::string & key ) const { map::const_iterator i = datasets.find(key); if ( i == datasets.end() ) { fprintf (stderr, "MultiDataSet: FATAL ERROR: dataset <%s> not found !\n", key.c_str() ); exit(-1); } return &(i->second); } const LabeledSet *MultiDataset::at ( const std::string & key ) const { map::const_iterator i = datasets.find(key); if ( i == datasets.end() ) return NULL; else return &(i->second); }