MultiDataset.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. /**
  2. * @file MultiDataset.cpp
  3. * @brief multiple datasets
  4. * @author Erik Rodner
  5. * @date 02/08/2008
  6. */
  7. #include <iostream>
  8. #include <sys/stat.h>
  9. #include <sys/types.h>
  10. #ifdef WIN32
  11. #ifdef NICE_USELIB_BOOST
  12. #include "boost/filesystem.hpp"
  13. #endif
  14. #endif
  15. #include "vislearning/cbaselib/ClassNames.h"
  16. #include "core/basics/StringTools.h"
  17. #include "core/basics/FileMgt.h"
  18. #include "core/basics/FileName.h"
  19. #include "vislearning/cbaselib/MultiDataset.h"
  20. using namespace OBJREC;
  21. using namespace std;
  22. using namespace NICE;
  23. #undef DEBUG_MultiDataset
  24. void MultiDataset::selectExamples ( const std::string & examples_command,
  25. const LabeledSet & base,
  26. LabeledSet & positives,
  27. LabeledSet & negatives,
  28. const ClassNames & cn ) const
  29. {
  30. vector<string> examples;
  31. StringTools::split ( examples_command, ';', examples );
  32. set<int> processed_classes;
  33. for ( vector<string>::const_iterator i = examples.begin();
  34. i != examples.end();
  35. i++ )
  36. {
  37. const std::string & cmd = *i;
  38. vector<string> parts;
  39. StringTools::split ( cmd, ' ', parts );
  40. if ( (parts.size() != 3) && ((parts.size() != 2) || (parts[0].compare("all") != 0)) )
  41. {
  42. std::cerr << "parts.size(): "<<parts.size()<<std::endl;
  43. std::cerr << " parts[0]: " << parts[0] << std::endl;
  44. fthrow( Exception, "Syntax error " << examples_command );
  45. }
  46. const std::string & mode = parts[0];
  47. const std::string & csel = parts[1];
  48. double parameter = (parts.size() == 3 ) ? atof(parts[2].c_str()) : 0.0;
  49. map<int, int> fpe;
  50. set<int> selection;
  51. cn.getSelection ( csel, selection );
  52. for ( set<int>::const_iterator j = selection.begin();
  53. j != selection.end();
  54. j++ )
  55. {
  56. int classno = *j;
  57. if ( processed_classes.find(classno) == processed_classes.end() )
  58. {
  59. #ifdef DEBUG_MultiDataset
  60. fprintf (stderr, "class %s: %s %d\n", cn.text(classno).c_str(),
  61. mode.c_str(), (int)parameter );
  62. #endif
  63. fpe[*j] = (int)parameter;
  64. processed_classes.insert(classno);
  65. } else {
  66. if ( csel != "*" ) {
  67. fthrow ( Exception, "Example selection method for class %s has multiple specifications" << cn.text(classno) );
  68. }
  69. }
  70. }
  71. if ( mode == "seq" ) {
  72. LabeledSetSelection<LabeledSet>::selectSequential (
  73. fpe, base, positives, negatives );
  74. #ifdef DEBUG_MultiDataset
  75. fprintf (stderr, "MultiDataset: after special seq selection: %d\n", positives.count() );
  76. #endif
  77. } else if ( mode == "step" ) {
  78. LabeledSetSelection<LabeledSet>::selectSequentialStep (
  79. fpe, base, positives, negatives );
  80. #ifdef DEBUG_MultiDataset
  81. fprintf (stderr, "MultiDataset: after special step selection: %d\n", positives.count() );
  82. #endif
  83. } else if ( mode == "random" ) {
  84. LabeledSetSelection<LabeledSet>::selectRandom (
  85. fpe, base, positives, negatives );
  86. #ifdef DEBUG_MultiDataset
  87. fprintf (stderr, "MultiDataset: after special random selection: %d\n", positives.count() );
  88. #endif
  89. } else if ( mode == "all" ) {
  90. if ( (int)selection.size() == cn.numClasses() )
  91. {
  92. // preserve permutation
  93. LabeledSet::Permutation permutation;
  94. base.getPermutation ( permutation );
  95. for ( LabeledSet::Permutation::iterator i = permutation.begin(); i != permutation.end(); i++ )
  96. {
  97. int classno = i->first;
  98. ImageInfo *element = const_cast< ImageInfo * > ( i->second );
  99. positives.add_reference ( classno, element );
  100. }
  101. } else {
  102. LabeledSetSelection<LabeledSet>::selectClasses ( selection, base, positives, negatives );
  103. }
  104. #ifdef DEBUG_MultiDataset
  105. fprintf (stderr, "MultiDataset: after special class selection: %d\n", positives.count() );
  106. #endif
  107. } else {
  108. fthrow ( Exception, "Wrong value for parameter example\n");
  109. }
  110. }
  111. #ifdef DEBUG_MultiDataset
  112. fprintf (stderr, "MultiDataset: after special selection operations: %d\n", positives.count() );
  113. #endif
  114. set<int> allclasses;
  115. cn.getSelection ( "*", allclasses );
  116. set<int> allnegative_classes;
  117. // add all examples from allclasses \setminus processed_classes
  118. set_difference(allclasses.begin(), allclasses.end(), processed_classes.begin(), processed_classes.end(),
  119. inserter(allnegative_classes, allnegative_classes.end()));
  120. LabeledSet dummy;
  121. LabeledSetSelection<LabeledSet>::selectClasses ( allnegative_classes,
  122. base, negatives, dummy );
  123. }
  124. /** MultiDataset ------- constructor */
  125. MultiDataset::MultiDataset( const Config *conf , LabeledSetFactory *pSetFactory)
  126. {
  127. //read all blocks from our config file
  128. std::set<string> blocks;
  129. conf->getAllBlocks ( blocks );
  130. #ifdef DEBUG_MultiDataset
  131. std::cerr << "found the following config blocks: " << std::endl;
  132. for ( std::set<string>::const_iterator blockIt = blocks.begin(); blockIt != blocks.end(); blockIt++)
  133. {
  134. std::cerr << *blockIt << " ";
  135. }
  136. std::cerr << std::endl;
  137. #endif
  138. lfl.setFactory( pSetFactory );
  139. //for every dataset (e.g., train and test), we store a single confog file
  140. map<string, Config> dsconfs;
  141. //for every dataset (e.g., train and test), we store the position of the file directory
  142. map<string, string> dirs;
  143. //first of all, remove all blocks which do correspond to specified datasets, i.e., that do not contain a "dataset" entry
  144. for ( set<string>::iterator i = blocks.begin();
  145. i != blocks.end(); )
  146. {
  147. if ( conf->gB(*i, "disable", false) )
  148. {
  149. i++;
  150. continue;
  151. }
  152. std::string dataset = conf->gS( *i, "dataset", "unknown" );
  153. if ( dataset == "unknown" )
  154. blocks.erase(i++);
  155. else {
  156. #ifdef DEBUG_MultiDataset
  157. fprintf (stderr, "Reading dataset config for block [%s]\n", i->c_str() );
  158. #endif
  159. NICE::FileName t_ConfigFilename( conf->getFilename() );
  160. NICE::FileName t_DatasetFilename( dataset );
  161. // check for relative path correction:
  162. // if dataset is a relative path, then make it absolute using the
  163. // given config's directory
  164. if( t_DatasetFilename.isRelative() )
  165. {
  166. t_DatasetFilename.set( t_ConfigFilename.extractPath().str() + dataset );
  167. }
  168. t_DatasetFilename.convertToRealPath();
  169. std::string sDatasetConfFilename = t_DatasetFilename.str() + "/dataset.conf";
  170. Config dsconf ( sDatasetConfFilename.c_str() );
  171. dirs[*i] = dataset;
  172. dsconfs[*i] = dsconf;
  173. i++;
  174. }
  175. }
  176. #ifdef DEBUG_MultiDataset
  177. std::cerr << "found the following datasets within all config blocks: " << std::endl;
  178. for ( std::set<string>::const_iterator blockIt = blocks.begin(); blockIt != blocks.end(); blockIt++)
  179. {
  180. std::cerr << *blockIt << " ";
  181. }
  182. std::cerr << std::endl;
  183. #endif
  184. //is there a dataset specified that contains images for both, training and testing?
  185. if ( blocks.find("traintest") != blocks.end() )
  186. {
  187. LabeledSet ls_base;
  188. LabeledSet ls_train (true);
  189. LabeledSet ls_nontrain (true);
  190. LabeledSet ls_test (true);
  191. LabeledSet dummy (true);
  192. LabeledSet temp (true);
  193. bool localizationInfoDisabled = conf->gB("traintest", "disable_localization_info", false );
  194. std::string classselection_train = conf->gS("traintest", "classselection_train", "*");
  195. std::string classselection_test = conf->gS("traintest", "classselection_test", "*");
  196. classnames["traintest"] = ClassNames();
  197. std::string classNamesTxt = dirs["traintest"] + "/classnames.txt";
  198. if ( FileMgt::fileExists ( classNamesTxt ) )
  199. {
  200. classnames["traintest"].read ( classNamesTxt );
  201. } else {
  202. classnames["traintest"].readFromConfig ( dsconfs["traintest"], classselection_train );
  203. }
  204. lfl.get ( dirs["traintest"], dsconfs["traintest"], classnames["traintest"], ls_base,
  205. localizationInfoDisabled, conf->gB("traintest", "debug_dataset", false ) );
  206. std::string examples_train = conf->gS("traintest", "examples_train" );
  207. selectExamples ( examples_train, ls_base, ls_train, ls_nontrain, classnames["traintest"] );
  208. set<int> selection_test;
  209. classnames["traintest"].getSelection ( classselection_test, selection_test );
  210. std::string examples_test = conf->gS("traintest", "examples_test" );
  211. if ( examples_test == "reclassification" )
  212. {
  213. LabeledSetSelection<LabeledSet>::selectClasses
  214. ( selection_test, ls_train, ls_test, dummy );
  215. } else {
  216. selectExamples ( examples_test, ls_nontrain, temp, dummy, classnames["traintest"] );
  217. LabeledSetSelection<LabeledSet>::selectClasses
  218. ( selection_test, temp, ls_test, dummy );
  219. }
  220. classnames["train"] = classnames["traintest"];
  221. classnames["test"] = ClassNames ( classnames["traintest"], classselection_test );
  222. datasets["test"] = ls_test;
  223. datasets["train"] = ls_train;
  224. }
  225. //now read files for every specified dataset (e.g., train and test)
  226. for ( set<string>::const_iterator i = blocks.begin();
  227. i != blocks.end();
  228. i++ )
  229. {
  230. std::string name = *i;
  231. std::cerr << "read: " << name << std::endl;
  232. if ( classnames.find(name) != classnames.end() )
  233. continue;
  234. if ( conf->gB(name, "disable", false) == true )
  235. continue;
  236. if ( dsconfs.find(name) == dsconfs.end() )
  237. continue;
  238. LabeledSet ls_base;
  239. LabeledSet ls (true);
  240. LabeledSet dummy (true);
  241. LabeledSet temp (true);
  242. bool localizationInfoDisabled = conf->gB(name, "disable_localization_info", false );
  243. std::string classselection = conf->gS(name, "classselection", "*");
  244. classnames[name] = ClassNames();
  245. std::string classNamesTxt = dirs[name] + "/classnames.txt";
  246. if ( FileMgt::fileExists ( classNamesTxt ) )
  247. {
  248. #ifdef DEBUG_MultiDataset
  249. fprintf (stderr, "MultiDataset: reading class names from %s\n", classNamesTxt.c_str() );
  250. #endif
  251. classnames[name].read ( classNamesTxt );
  252. } else {
  253. #ifdef DEBUG_MultiDataset
  254. fprintf (stderr, "MultiDataset: reading class names from dataset config file\n" );
  255. #endif
  256. classnames[name].readFromConfig ( dsconfs[name], classselection );
  257. }
  258. #ifdef DEBUG_MultiDataset
  259. std::cerr << "we set up everything to read this dataset - so now call lfl.get" << std::endl;
  260. #endif
  261. lfl.get ( dirs[name],
  262. dsconfs[name],
  263. classnames[name],
  264. ls_base,
  265. localizationInfoDisabled,
  266. conf->gB(name, "debug_dataset", false ) );
  267. #ifdef DEBUG_MultiDataset
  268. fprintf (stderr, "MultiDataset: class names -->\n" );
  269. classnames[name].store ( cerr );
  270. fprintf (stderr, "MultiDataset: all information about %s set obtained ! (size %d)\n", name.c_str(), ls_base.count() );
  271. #endif
  272. #ifdef DEBUG_MultiDataset
  273. std::cerr << "we now call selectExamples to pick only a subset if desired" << std::endl;
  274. #endif
  275. std::string examples = conf->gS(name, "examples", "all *" );
  276. selectExamples ( examples, ls_base, ls, dummy, classnames[name] );
  277. #ifdef DEBUG_MultiDataset
  278. fprintf (stderr, "MultiDataset: size after selection %d\n", ls.count() );
  279. #endif
  280. datasets[name] = ls;
  281. }
  282. bool dumpSelections = conf->gB("datasets", "dump_selection", false);
  283. if ( dumpSelections )
  284. {
  285. for ( map<string, LabeledSet>::const_iterator i = datasets.begin();
  286. i != datasets.end(); i++ )
  287. {
  288. const std::string & name = i->first;
  289. const LabeledSet & ls = i->second;
  290. const ClassNames & classNames = classnames[name];
  291. #ifdef WIN32
  292. #ifdef NICE_USELIB_BOOST
  293. boost::filesystem::path t_path(name.c_str());
  294. boost::filesystem::create_directory(t_path);
  295. #else
  296. fthrow(Exception,"mkdir function not defined on system. try using boost lib and rebuild library");
  297. #endif
  298. #else
  299. mkdir ( name.c_str(), 0755 );
  300. #endif
  301. std::string filelist = name + "/files.txt";
  302. ofstream olist ( filelist.c_str(), ios::out );
  303. if ( !olist.good() )
  304. fthrow (IOException, "Unable to dump selections to " << filelist );
  305. LOOP_ALL_S(ls)
  306. {
  307. EACH_S(classno, file);
  308. std::string classtext = classNames.code(classno);
  309. olist << classtext << " " << file << endl;
  310. }
  311. olist.close();
  312. std::string datasetconf = name + "/dataset.conf";
  313. ofstream oconf ( datasetconf.c_str(), ios::out );
  314. if ( !oconf.good() )
  315. fthrow (IOException, "Unable to dump selections to " << datasetconf );
  316. set<int> classnos;
  317. classNames.getSelection ( "*", classnos);
  318. oconf << "[main]" << endl;
  319. oconf << "filelist = \"files.txt\"" << endl << endl;
  320. oconf << "[classnames]" << endl;
  321. for ( set<int>::const_iterator i = classnos.begin();
  322. i != classnos.end(); i++ )
  323. {
  324. const std::string & code = classNames.code(*i);
  325. const std::string & text = classNames.text(*i);
  326. oconf << code << " = \"" << text << "\"" << endl;
  327. }
  328. oconf.close();
  329. classNames.save ( name + "/classnames.txt" );
  330. }
  331. }
  332. }
  333. MultiDataset::~MultiDataset()
  334. {
  335. }
  336. const ClassNames & MultiDataset::getClassNames ( const std::string & key ) const
  337. {
  338. map<string, ClassNames>::const_iterator i = classnames.find(key);
  339. if ( i == classnames.end() )
  340. {
  341. fprintf (stderr, "MultiDataSet::getClassNames() FATAL ERROR: dataset <%s> not found !\n", key.c_str() );
  342. exit(-1);
  343. }
  344. return (i->second);
  345. }
  346. ClassNames & MultiDataset::getClassNamesNonConst ( const std::string & key )
  347. {
  348. std::map<string, ClassNames>::iterator i = classnames.find(key);
  349. if ( i == classnames.end() )
  350. {
  351. fprintf (stderr, "MultiDataSet::getClassNamesNonConst() FATAL ERROR: dataset <%s> not found !\n", key.c_str() );
  352. exit(-1);
  353. }
  354. return (i->second);
  355. }
  356. const LabeledSet *MultiDataset::operator[] ( const std::string & key ) const
  357. {
  358. map<string, LabeledSet>::const_iterator i = datasets.find(key);
  359. if ( i == datasets.end() )
  360. {
  361. fprintf (stderr, "MultiDataSet: FATAL ERROR: dataset <%s> not found !\n", key.c_str() );
  362. exit(-1);
  363. }
  364. return &(i->second);
  365. }
  366. const LabeledSet *MultiDataset::at ( const std::string & key ) const
  367. {
  368. map<string, LabeledSet>::const_iterator i = datasets.find(key);
  369. if ( i == datasets.end() )
  370. return NULL;
  371. else
  372. return &(i->second);
  373. }