MultiDataset.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. /**
  2. * @file MultiDataset.cpp
  3. * @brief multiple datasets
  4. * @author Erik Rodner
  5. * @date 02/08/2008
  6. */
  7. #include <iostream>
  8. #include <sys/stat.h>
  9. #include <sys/types.h>
  10. #include "vislearning/cbaselib/ClassNames.h"
  11. #include "core/basics/StringTools.h"
  12. #include "core/basics/FileMgt.h"
  13. #include "vislearning/cbaselib/MultiDataset.h"
  14. using namespace OBJREC;
  15. using namespace std;
  16. using namespace NICE;
  17. #undef DEBUG_MultiDataset
  18. void MultiDataset::selectExamples ( const std::string & examples_command,
  19. const LabeledSet & base,
  20. LabeledSet & positives,
  21. LabeledSet & negatives,
  22. const ClassNames & cn ) const
  23. {
  24. vector<string> examples;
  25. StringTools::split ( examples_command, ';', examples );
  26. set<int> processed_classes;
  27. for ( vector<string>::const_iterator i = examples.begin();
  28. i != examples.end();
  29. i++ )
  30. {
  31. const std::string & cmd = *i;
  32. vector<string> parts;
  33. StringTools::split ( cmd, ' ', parts );
  34. if ( (parts.size() != 3) && ((parts.size() != 2) || (parts[0] != "all")) )
  35. fthrow( Exception, "Syntax error " << examples_command );
  36. const std::string & mode = parts[0];
  37. const std::string & csel = parts[1];
  38. double parameter = (parts.size() == 3 ) ? atof(parts[2].c_str()) : 0.0;
  39. map<int, int> fpe;
  40. set<int> selection;
  41. cn.getSelection ( csel, selection );
  42. for ( set<int>::const_iterator j = selection.begin();
  43. j != selection.end();
  44. j++ )
  45. {
  46. int classno = *j;
  47. if ( processed_classes.find(classno) == processed_classes.end() )
  48. {
  49. #ifdef DEBUG_MultiDataset
  50. fprintf (stderr, "class %s: %s %d\n", cn.text(classno).c_str(),
  51. mode.c_str(), (int)parameter );
  52. #endif
  53. fpe[*j] = (int)parameter;
  54. processed_classes.insert(classno);
  55. } else {
  56. if ( csel != "*" ) {
  57. fthrow ( Exception, "Example selection method for class %s has multiple specifications" << cn.text(classno) );
  58. }
  59. }
  60. }
  61. if ( mode == "seq" ) {
  62. LabeledSetSelection<LabeledSet>::selectSequential (
  63. fpe, base, positives, negatives );
  64. #ifdef DEBUG_MultiDataset
  65. fprintf (stderr, "MultiDataset: after special seq selection: %d\n", positives.count() );
  66. #endif
  67. } else if ( mode == "step" ) {
  68. LabeledSetSelection<LabeledSet>::selectSequentialStep (
  69. fpe, base, positives, negatives );
  70. #ifdef DEBUG_MultiDataset
  71. fprintf (stderr, "MultiDataset: after special step selection: %d\n", positives.count() );
  72. #endif
  73. } else if ( mode == "random" ) {
  74. LabeledSetSelection<LabeledSet>::selectRandom (
  75. fpe, base, positives, negatives );
  76. #ifdef DEBUG_MultiDataset
  77. fprintf (stderr, "MultiDataset: after special random selection: %d\n", positives.count() );
  78. #endif
  79. } else if ( mode == "all" ) {
  80. if ( (int)selection.size() == cn.numClasses() )
  81. {
  82. // preserve permutation
  83. LabeledSet::Permutation permutation;
  84. base.getPermutation ( permutation );
  85. for ( LabeledSet::Permutation::iterator i = permutation.begin(); i != permutation.end(); i++ )
  86. {
  87. int classno = i->first;
  88. ImageInfo *element = const_cast< ImageInfo * > ( i->second );
  89. positives.add_reference ( classno, element );
  90. }
  91. } else {
  92. LabeledSetSelection<LabeledSet>::selectClasses ( selection, base, positives, negatives );
  93. }
  94. #ifdef DEBUG_MultiDataset
  95. fprintf (stderr, "MultiDataset: after special class selection: %d\n", positives.count() );
  96. #endif
  97. } else {
  98. fthrow ( Exception, "Wrong value for parameter example\n");
  99. }
  100. }
  101. #ifdef DEBUG_MultiDataset
  102. fprintf (stderr, "MultiDataset: after special selection operations: %d\n", positives.count() );
  103. #endif
  104. set<int> allclasses;
  105. cn.getSelection ( "*", allclasses );
  106. set<int> allnegative_classes;
  107. // add all examples from allclasses \setminus processed_classes
  108. set_difference(allclasses.begin(), allclasses.end(), processed_classes.begin(), processed_classes.end(),
  109. inserter(allnegative_classes, allnegative_classes.end()));
  110. LabeledSet dummy;
  111. LabeledSetSelection<LabeledSet>::selectClasses ( allnegative_classes,
  112. base, negatives, dummy );
  113. }
  114. /** MultiDataset ------- constructor */
  115. MultiDataset::MultiDataset( const Config *conf , LabeledSetFactory *pSetFactory)
  116. {
  117. std::set<string> blocks;
  118. conf->getAllBlocks ( blocks );
  119. lfl.setFactory( pSetFactory );
  120. map<string, Config> dsconfs;
  121. map<string, string> dirs;
  122. for ( set<string>::iterator i = blocks.begin();
  123. i != blocks.end(); )
  124. {
  125. if ( conf->gB(*i, "disable", false) )
  126. {
  127. i++;
  128. continue;
  129. }
  130. std::string dataset = conf->gS( *i, "dataset", "unknown" );
  131. if ( dataset == "unknown" )
  132. blocks.erase(i++);
  133. else {
  134. #ifdef DEBUG_MultiDataset
  135. fprintf (stderr, "Reading dataset config for block [%s]\n", i->c_str() );
  136. #endif
  137. Config dsconf ( (dataset + "/dataset.conf").c_str() );
  138. dirs[*i] = dataset;
  139. dsconfs[*i] = dsconf;
  140. i++;
  141. }
  142. }
  143. if ( blocks.find("traintest") != blocks.end() )
  144. {
  145. LabeledSet ls_base;
  146. LabeledSet ls_train (true);
  147. LabeledSet ls_nontrain (true);
  148. LabeledSet ls_test (true);
  149. LabeledSet dummy (true);
  150. LabeledSet temp (true);
  151. bool localizationInfoDisabled = conf->gB("traintest", "disable_localization_info", false );
  152. std::string classselection_train = conf->gS("traintest", "classselection_train", "*");
  153. std::string classselection_test = conf->gS("traintest", "classselection_test", "*");
  154. classnames["traintest"] = ClassNames();
  155. std::string classNamesTxt = dirs["traintest"] + "/classnames.txt";
  156. if ( FileMgt::fileExists ( classNamesTxt ) )
  157. {
  158. classnames["traintest"].read ( classNamesTxt );
  159. } else {
  160. classnames["traintest"].readFromConfig ( dsconfs["traintest"], classselection_train );
  161. }
  162. lfl.get ( dirs["traintest"], dsconfs["traintest"], classnames["traintest"], ls_base,
  163. localizationInfoDisabled, conf->gB("traintest", "debug_dataset", false ) );
  164. std::string examples_train = conf->gS("traintest", "examples_train" );
  165. selectExamples ( examples_train, ls_base, ls_train, ls_nontrain, classnames["traintest"] );
  166. set<int> selection_test;
  167. classnames["traintest"].getSelection ( classselection_test, selection_test );
  168. std::string examples_test = conf->gS("traintest", "examples_test" );
  169. if ( examples_test == "reclassification" )
  170. {
  171. LabeledSetSelection<LabeledSet>::selectClasses
  172. ( selection_test, ls_train, ls_test, dummy );
  173. } else {
  174. selectExamples ( examples_test, ls_nontrain, temp, dummy, classnames["traintest"] );
  175. LabeledSetSelection<LabeledSet>::selectClasses
  176. ( selection_test, temp, ls_test, dummy );
  177. }
  178. classnames["train"] = classnames["traintest"];
  179. classnames["test"] = ClassNames ( classnames["traintest"], classselection_test );
  180. datasets["test"] = ls_test;
  181. datasets["train"] = ls_train;
  182. }
  183. for ( set<string>::const_iterator i = blocks.begin();
  184. i != blocks.end();
  185. i++ )
  186. {
  187. std::string name = *i;
  188. if ( classnames.find(name) != classnames.end() )
  189. continue;
  190. if ( conf->gB(name, "disable", false) == true )
  191. continue;
  192. if ( dsconfs.find(name) == dsconfs.end() )
  193. continue;
  194. LabeledSet ls_base;
  195. LabeledSet ls (true);
  196. LabeledSet dummy (true);
  197. LabeledSet temp (true);
  198. bool localizationInfoDisabled = conf->gB(name, "disable_localization_info", false );
  199. std::string classselection = conf->gS(name, "classselection", "*");
  200. classnames[name] = ClassNames();
  201. std::string classNamesTxt = dirs[name] + "/classnames.txt";
  202. if ( FileMgt::fileExists ( classNamesTxt ) )
  203. {
  204. #ifdef DEBUG_MultiDataset
  205. fprintf (stderr, "MultiDataset: reading class names from %s\n", classNamesTxt.c_str() );
  206. #endif
  207. classnames[name].read ( classNamesTxt );
  208. } else {
  209. #ifdef DEBUG_MultiDataset
  210. fprintf (stderr, "MultiDataset: reading class names from dataset config file\n" );
  211. #endif
  212. classnames[name].readFromConfig ( dsconfs[name], classselection );
  213. }
  214. lfl.get ( dirs[name],
  215. dsconfs[name],
  216. classnames[name],
  217. ls_base,
  218. localizationInfoDisabled,
  219. conf->gB(name, "debug_dataset", false ) );
  220. #ifdef DEBUG_MultiDataset
  221. fprintf (stderr, "MultiDataset: class names -->\n" );
  222. classnames[name].store ( cerr );
  223. fprintf (stderr, "MultiDataset: all information about %s set obtained ! (size %d)\n", name.c_str(), ls_base.count() );
  224. #endif
  225. std::string examples = conf->gS(name, "examples", "all *" );
  226. selectExamples ( examples, ls_base, ls, dummy, classnames[name] );
  227. #ifdef DEBUG_MultiDataset
  228. fprintf (stderr, "MultiDataset: size after selection %d\n", ls.count() );
  229. #endif
  230. datasets[name] = ls;
  231. }
  232. bool dumpSelections = conf->gB("datasets", "dump_selection", false);
  233. if ( dumpSelections )
  234. {
  235. for ( map<string, LabeledSet>::const_iterator i = datasets.begin();
  236. i != datasets.end(); i++ )
  237. {
  238. const std::string & name = i->first;
  239. const LabeledSet & ls = i->second;
  240. const ClassNames & classNames = classnames[name];
  241. mkdir ( name.c_str(), 0755 );
  242. std::string filelist = name + "/files.txt";
  243. ofstream olist ( filelist.c_str(), ios::out );
  244. if ( !olist.good() )
  245. fthrow (IOException, "Unable to dump selections to " << filelist );
  246. LOOP_ALL_S(ls)
  247. {
  248. EACH_S(classno, file);
  249. std::string classtext = classNames.code(classno);
  250. olist << classtext << " " << file << endl;
  251. }
  252. olist.close();
  253. std::string datasetconf = name + "/dataset.conf";
  254. ofstream oconf ( datasetconf.c_str(), ios::out );
  255. if ( !oconf.good() )
  256. fthrow (IOException, "Unable to dump selections to " << datasetconf );
  257. set<int> classnos;
  258. classNames.getSelection ( "*", classnos);
  259. oconf << "[main]" << endl;
  260. oconf << "filelist = \"files.txt\"" << endl << endl;
  261. oconf << "[classnames]" << endl;
  262. for ( set<int>::const_iterator i = classnos.begin();
  263. i != classnos.end(); i++ )
  264. {
  265. const std::string & code = classNames.code(*i);
  266. const std::string & text = classNames.text(*i);
  267. oconf << code << " = \"" << text << "\"" << endl;
  268. }
  269. oconf.close();
  270. classNames.save ( name + "/classnames.txt" );
  271. }
  272. }
  273. }
  274. MultiDataset::~MultiDataset()
  275. {
  276. }
  277. const ClassNames & MultiDataset::getClassNames ( const std::string & key ) const
  278. {
  279. map<string, ClassNames>::const_iterator i = classnames.find(key);
  280. if ( i == classnames.end() )
  281. {
  282. fprintf (stderr, "MultiDataSet::getClassNames() FATAL ERROR: dataset <%s> not found !\n", key.c_str() );
  283. exit(-1);
  284. }
  285. return (i->second);
  286. }
  287. const LabeledSet *MultiDataset::operator[] ( const std::string & key ) const
  288. {
  289. map<string, LabeledSet>::const_iterator i = datasets.find(key);
  290. if ( i == datasets.end() )
  291. {
  292. fprintf (stderr, "MultiDataSet: FATAL ERROR: dataset <%s> not found !\n", key.c_str() );
  293. exit(-1);
  294. }
  295. return &(i->second);
  296. }
  297. const LabeledSet *MultiDataset::at ( const std::string & key ) const
  298. {
  299. map<string, LabeledSet>::const_iterator i = datasets.find(key);
  300. if ( i == datasets.end() )
  301. return NULL;
  302. else
  303. return &(i->second);
  304. }