splitLabeledSetVector.cpp 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. /**
  2. * @file splitLabeledSetVector.cpp
  3. * @brief split train.vec/test.vec files
  4. * @author Erik Rodner
  5. * @date 03/23/2010
  6. */
  7. #include "core/basics/Config.h"
  8. #include "core/basics/StringTools.h"
  9. #include "vislearning/cbaselib/LabeledSet.h"
  10. #include "vislearning/cbaselib/LabeledSetSelection.h"
  11. #include "core/basics/numerictools.h"
  12. using namespace std;
  13. using namespace OBJREC;
  14. using namespace NICE;
  15. //#stupid test for git
  16. void normalizeLabeledSetVector(const LabeledSetVector &teachSet,
  17. LabeledSetVector &transformedSet)
  18. {
  19. transformedSet.clear();
  20. Vector vector_max, vector_min, vector_span;
  21. int maxClassNo = teachSet.getMaxClassno();
  22. int n = teachSet.count();
  23. int d = teachSet.dimension();
  24. vector_max.resize(d);
  25. vector_min.resize(d);
  26. vector_span.resize(d);
  27. //get input data
  28. uint featurecount = 0;
  29. LOOP_ALL(teachSet)
  30. {
  31. EACH(classno,x);
  32. for (uint k = 0; k < x.size(); ++k)
  33. {
  34. double value = x[k];
  35. if (featurecount == 0)
  36. {
  37. vector_max[k] = value;
  38. vector_min[k] = value;
  39. }
  40. else
  41. {
  42. if (value > vector_max[k])
  43. {
  44. vector_max[k] = value;
  45. }
  46. if (value < vector_min[k])
  47. {
  48. vector_min[k] = value;
  49. }
  50. }
  51. }
  52. ++featurecount;
  53. }
  54. vector_span = vector_max - vector_min;
  55. //save transformed Vectors
  56. LOOP_ALL(teachSet)
  57. {
  58. EACH(classno,x);
  59. NICE::Vector transformed_vector(x.size());
  60. for (uint k = 0; k < vector_min.size(); ++k)
  61. {
  62. if (vector_span[k] > 1e-10)
  63. {
  64. transformed_vector[k] = (x[k] - vector_min[k])
  65. / vector_span[k];
  66. }
  67. else
  68. {
  69. transformed_vector[k] = 1.0;
  70. }
  71. }
  72. transformedSet.add(classno, transformed_vector);
  73. }
  74. }
  75. /**
  76. split train.vec/test.vec files
  77. */
  78. int main(int argc, char **argv)
  79. {
  80. std::set_terminate(__gnu_cxx::__verbose_terminate_handler);
  81. Config conf(argc, argv);
  82. int format = conf.gI("main", "format", 2);
  83. LabeledSetVector all;
  84. string setfn = conf.gS("main", "set");
  85. all.read(setfn, format);
  86. bool normalize = conf.gB("main", "normalize", false);
  87. if (normalize)
  88. {
  89. LabeledSetVector all_tmp;
  90. normalizeLabeledSetVector(all, all_tmp);
  91. all = all_tmp;
  92. }
  93. bool random = conf.gB("main", "random", false);
  94. if (random)
  95. initRand();
  96. map<int, int> fixedPositiveExamples;
  97. string exampleList = conf.gS("main", "examples");
  98. vector<string> list;
  99. StringTools::split(exampleList, ';', list);
  100. for (vector<string>::const_iterator i = list.begin(); i != list.end(); i++)
  101. {
  102. string e = *i;
  103. vector<string> f;
  104. StringTools::split(e, ':', f);
  105. if (f.size() != 2)
  106. fthrow(Exception, "Specify -examples classno:number;classno:number;...\n");
  107. int classno;
  108. int examples;
  109. StringTools::convert<int>(f[0], classno);
  110. StringTools::convert<int>(f[1], examples);
  111. fixedPositiveExamples.insert(pair<int, int> (classno, examples));
  112. }
  113. LabeledSetVector train;
  114. LabeledSetVector test;
  115. LabeledSetSelection<LabeledSetVector>::selectRandom(fixedPositiveExamples,
  116. all, train, test);
  117. string trainfn = conf.gS("main", "train", "train.vec");
  118. string testfn = conf.gS("main", "test", "test.vec");
  119. train.save(trainfn, format);
  120. test.save(testfn, format);
  121. return 0;
  122. }