MultiDataset.cpp 12 KB


  1. /**
  2. * @file MultiDataset.cpp
  3. * @brief multiple datasets
  4. * @author Erik Rodner
  5. * @date 02/08/2008
  6. */
  7. #include <iostream>
  8. #include <sys/stat.h>
  9. #include <sys/types.h>
  10. #ifdef WIN32
  11. #ifdef NICE_USELIB_BOOST
  12. #include "boost/filesystem.hpp"
  13. #endif
  14. #endif
  15. #include "vislearning/cbaselib/ClassNames.h"
  16. #include "core/basics/StringTools.h"
  17. #include "core/basics/FileMgt.h"
  18. #include "vislearning/cbaselib/MultiDataset.h"
  19. using namespace OBJREC;
  20. using namespace std;
  21. using namespace NICE;
  22. #undef DEBUG_MultiDataset
  23. void MultiDataset::selectExamples ( const std::string & examples_command,
  24. const LabeledSet & base,
  25. LabeledSet & positives,
  26. LabeledSet & negatives,
  27. const ClassNames & cn ) const
  28. {
  29. vector<string> examples;
  30. StringTools::split ( examples_command, ';', examples );
  31. set<int> processed_classes;
  32. for ( vector<string>::const_iterator i = examples.begin();
  33. i != examples.end();
  34. i++ )
  35. {
  36. const std::string & cmd = *i;
  37. vector<string> parts;
  38. StringTools::split ( cmd, ' ', parts );
  39. if ( (parts.size() != 3) && ((parts.size() != 2) || (parts[0] != "all")) )
  40. fthrow( Exception, "Syntax error " << examples_command );
  41. const std::string & mode = parts[0];
  42. const std::string & csel = parts[1];
  43. double parameter = (parts.size() == 3 ) ? atof(parts[2].c_str()) : 0.0;
  44. map<int, int> fpe;
  45. set<int> selection;
  46. cn.getSelection ( csel, selection );
  47. for ( set<int>::const_iterator j = selection.begin();
  48. j != selection.end();
  49. j++ )
  50. {
  51. int classno = *j;
  52. if ( processed_classes.find(classno) == processed_classes.end() )
  53. {
  54. #ifdef DEBUG_MultiDataset
  55. fprintf (stderr, "class %s: %s %d\n", cn.text(classno).c_str(),
  56. mode.c_str(), (int)parameter );
  57. #endif
  58. fpe[*j] = (int)parameter;
  59. processed_classes.insert(classno);
  60. } else {
  61. if ( csel != "*" ) {
  62. fthrow ( Exception, "Example selection method for class %s has multiple specifications" << cn.text(classno) );
  63. }
  64. }
  65. }
  66. if ( mode == "seq" ) {
  67. LabeledSetSelection<LabeledSet>::selectSequential (
  68. fpe, base, positives, negatives );
  69. #ifdef DEBUG_MultiDataset
  70. fprintf (stderr, "MultiDataset: after special seq selection: %d\n", positives.count() );
  71. #endif
  72. } else if ( mode == "step" ) {
  73. LabeledSetSelection<LabeledSet>::selectSequentialStep (
  74. fpe, base, positives, negatives );
  75. #ifdef DEBUG_MultiDataset
  76. fprintf (stderr, "MultiDataset: after special step selection: %d\n", positives.count() );
  77. #endif
  78. } else if ( mode == "random" ) {
  79. LabeledSetSelection<LabeledSet>::selectRandom (
  80. fpe, base, positives, negatives );
  81. #ifdef DEBUG_MultiDataset
  82. fprintf (stderr, "MultiDataset: after special random selection: %d\n", positives.count() );
  83. #endif
  84. } else if ( mode == "all" ) {
  85. if ( (int)selection.size() == cn.numClasses() )
  86. {
  87. // preserve permutation
  88. LabeledSet::Permutation permutation;
  89. base.getPermutation ( permutation );
  90. for ( LabeledSet::Permutation::iterator i = permutation.begin(); i != permutation.end(); i++ )
  91. {
  92. int classno = i->first;
  93. ImageInfo *element = const_cast< ImageInfo * > ( i->second );
  94. positives.add_reference ( classno, element );
  95. }
  96. } else {
  97. LabeledSetSelection<LabeledSet>::selectClasses ( selection, base, positives, negatives );
  98. }
  99. #ifdef DEBUG_MultiDataset
  100. fprintf (stderr, "MultiDataset: after special class selection: %d\n", positives.count() );
  101. #endif
  102. } else {
  103. fthrow ( Exception, "Wrong value for parameter example\n");
  104. }
  105. }
  106. #ifdef DEBUG_MultiDataset
  107. fprintf (stderr, "MultiDataset: after special selection operations: %d\n", positives.count() );
  108. #endif
  109. set<int> allclasses;
  110. cn.getSelection ( "*", allclasses );
  111. set<int> allnegative_classes;
  112. // add all examples from allclasses \setminus processed_classes
  113. set_difference(allclasses.begin(), allclasses.end(), processed_classes.begin(), processed_classes.end(),
  114. inserter(allnegative_classes, allnegative_classes.end()));
  115. LabeledSet dummy;
  116. LabeledSetSelection<LabeledSet>::selectClasses ( allnegative_classes,
  117. base, negatives, dummy );
  118. }
  119. /** MultiDataset ------- constructor */
  120. MultiDataset::MultiDataset( const Config *conf , LabeledSetFactory *pSetFactory)
  121. {
  122. std::set<string> blocks;
  123. conf->getAllBlocks ( blocks );
  124. lfl.setFactory( pSetFactory );
  125. map<string, Config> dsconfs;
  126. map<string, string> dirs;
  127. for ( set<string>::iterator i = blocks.begin();
  128. i != blocks.end(); )
  129. {
  130. if ( conf->gB(*i, "disable", false) )
  131. {
  132. i++;
  133. continue;
  134. }
  135. std::string dataset = conf->gS( *i, "dataset", "unknown" );
  136. if ( dataset == "unknown" )
  137. blocks.erase(i++);
  138. else {
  139. #ifdef DEBUG_MultiDataset
  140. fprintf (stderr, "Reading dataset config for block [%s]\n", i->c_str() );
  141. #endif
  142. Config dsconf ( (dataset + "/dataset.conf").c_str() );
  143. dirs[*i] = dataset;
  144. dsconfs[*i] = dsconf;
  145. i++;
  146. }
  147. }
  148. if ( blocks.find("traintest") != blocks.end() )
  149. {
  150. LabeledSet ls_base;
  151. LabeledSet ls_train (true);
  152. LabeledSet ls_nontrain (true);
  153. LabeledSet ls_test (true);
  154. LabeledSet dummy (true);
  155. LabeledSet temp (true);
  156. bool localizationInfoDisabled = conf->gB("traintest", "disable_localization_info", false );
  157. std::string classselection_train = conf->gS("traintest", "classselection_train", "*");
  158. std::string classselection_test = conf->gS("traintest", "classselection_test", "*");
  159. classnames["traintest"] = ClassNames();
  160. std::string classNamesTxt = dirs["traintest"] + "/classnames.txt";
  161. if ( FileMgt::fileExists ( classNamesTxt ) )
  162. {
  163. classnames["traintest"].read ( classNamesTxt );
  164. } else {
  165. classnames["traintest"].readFromConfig ( dsconfs["traintest"], classselection_train );
  166. }
  167. lfl.get ( dirs["traintest"], dsconfs["traintest"], classnames["traintest"], ls_base,
  168. localizationInfoDisabled, conf->gB("traintest", "debug_dataset", false ) );
  169. std::string examples_train = conf->gS("traintest", "examples_train" );
  170. selectExamples ( examples_train, ls_base, ls_train, ls_nontrain, classnames["traintest"] );
  171. set<int> selection_test;
  172. classnames["traintest"].getSelection ( classselection_test, selection_test );
  173. std::string examples_test = conf->gS("traintest", "examples_test" );
  174. if ( examples_test == "reclassification" )
  175. {
  176. LabeledSetSelection<LabeledSet>::selectClasses
  177. ( selection_test, ls_train, ls_test, dummy );
  178. } else {
  179. selectExamples ( examples_test, ls_nontrain, temp, dummy, classnames["traintest"] );
  180. LabeledSetSelection<LabeledSet>::selectClasses
  181. ( selection_test, temp, ls_test, dummy );
  182. }
  183. classnames["train"] = classnames["traintest"];
  184. classnames["test"] = ClassNames ( classnames["traintest"], classselection_test );
  185. datasets["test"] = ls_test;
  186. datasets["train"] = ls_train;
  187. }
  188. for ( set<string>::const_iterator i = blocks.begin();
  189. i != blocks.end();
  190. i++ )
  191. {
  192. std::string name = *i;
  193. if ( classnames.find(name) != classnames.end() )
  194. continue;
  195. if ( conf->gB(name, "disable", false) == true )
  196. continue;
  197. if ( dsconfs.find(name) == dsconfs.end() )
  198. continue;
  199. LabeledSet ls_base;
  200. LabeledSet ls (true);
  201. LabeledSet dummy (true);
  202. LabeledSet temp (true);
  203. bool localizationInfoDisabled = conf->gB(name, "disable_localization_info", false );
  204. std::string classselection = conf->gS(name, "classselection", "*");
  205. classnames[name] = ClassNames();
  206. std::string classNamesTxt = dirs[name] + "/classnames.txt";
  207. if ( FileMgt::fileExists ( classNamesTxt ) )
  208. {
  209. #ifdef DEBUG_MultiDataset
  210. fprintf (stderr, "MultiDataset: reading class names from %s\n", classNamesTxt.c_str() );
  211. #endif
  212. classnames[name].read ( classNamesTxt );
  213. } else {
  214. #ifdef DEBUG_MultiDataset
  215. fprintf (stderr, "MultiDataset: reading class names from dataset config file\n" );
  216. #endif
  217. classnames[name].readFromConfig ( dsconfs[name], classselection );
  218. }
  219. lfl.get ( dirs[name],
  220. dsconfs[name],
  221. classnames[name],
  222. ls_base,
  223. localizationInfoDisabled,
  224. conf->gB(name, "debug_dataset", false ) );
  225. #ifdef DEBUG_MultiDataset
  226. fprintf (stderr, "MultiDataset: class names -->\n" );
  227. classnames[name].store ( cerr );
  228. fprintf (stderr, "MultiDataset: all information about %s set obtained ! (size %d)\n", name.c_str(), ls_base.count() );
  229. #endif
  230. std::string examples = conf->gS(name, "examples", "all *" );
  231. selectExamples ( examples, ls_base, ls, dummy, classnames[name] );
  232. #ifdef DEBUG_MultiDataset
  233. fprintf (stderr, "MultiDataset: size after selection %d\n", ls.count() );
  234. #endif
  235. datasets[name] = ls;
  236. }
  237. bool dumpSelections = conf->gB("datasets", "dump_selection", false);
  238. if ( dumpSelections )
  239. {
  240. for ( map<string, LabeledSet>::const_iterator i = datasets.begin();
  241. i != datasets.end(); i++ )
  242. {
  243. const std::string & name = i->first;
  244. const LabeledSet & ls = i->second;
  245. const ClassNames & classNames = classnames[name];
  246. #ifdef WIN32
  247. #ifdef NICE_USELIB_BOOST
  248. boost::filesystem::path t_path(name.c_str());
  249. boost::filesystem::create_directory(t_path);
  250. #else
  251. fthrow(Exception,"mkdir function not defined on system. try using boost lib and rebuild library");
  252. #endif
  253. #else
  254. mkdir ( name.c_str(), 0755 );
  255. #endif
  256. std::string filelist = name + "/files.txt";
  257. ofstream olist ( filelist.c_str(), ios::out );
  258. if ( !olist.good() )
  259. fthrow (IOException, "Unable to dump selections to " << filelist );
  260. LOOP_ALL_S(ls)
  261. {
  262. EACH_S(classno, file);
  263. std::string classtext = classNames.code(classno);
  264. olist << classtext << " " << file << endl;
  265. }
  266. olist.close();
  267. std::string datasetconf = name + "/dataset.conf";
  268. ofstream oconf ( datasetconf.c_str(), ios::out );
  269. if ( !oconf.good() )
  270. fthrow (IOException, "Unable to dump selections to " << datasetconf );
  271. set<int> classnos;
  272. classNames.getSelection ( "*", classnos);
  273. oconf << "[main]" << endl;
  274. oconf << "filelist = \"files.txt\"" << endl << endl;
  275. oconf << "[classnames]" << endl;
  276. for ( set<int>::const_iterator i = classnos.begin();
  277. i != classnos.end(); i++ )
  278. {
  279. const std::string & code = classNames.code(*i);
  280. const std::string & text = classNames.text(*i);
  281. oconf << code << " = \"" << text << "\"" << endl;
  282. }
  283. oconf.close();
  284. classNames.save ( name + "/classnames.txt" );
  285. }
  286. }
  287. }
  288. MultiDataset::~MultiDataset()
  289. {
  290. }
  291. const ClassNames & MultiDataset::getClassNames ( const std::string & key ) const
  292. {
  293. map<string, ClassNames>::const_iterator i = classnames.find(key);
  294. if ( i == classnames.end() )
  295. {
  296. fprintf (stderr, "MultiDataSet::getClassNames() FATAL ERROR: dataset <%s> not found !\n", key.c_str() );
  297. exit(-1);
  298. }
  299. return (i->second);
  300. }
  301. const LabeledSet *MultiDataset::operator[] ( const std::string & key ) const
  302. {
  303. map<string, LabeledSet>::const_iterator i = datasets.find(key);
  304. if ( i == datasets.end() )
  305. {
  306. fprintf (stderr, "MultiDataSet: FATAL ERROR: dataset <%s> not found !\n", key.c_str() );
  307. exit(-1);
  308. }
  309. return &(i->second);
  310. }
  311. const LabeledSet *MultiDataset::at ( const std::string & key ) const
  312. {
  313. map<string, LabeledSet>::const_iterator i = datasets.find(key);
  314. if ( i == datasets.end() )
  315. return NULL;
  316. else
  317. return &(i->second);
  318. }