splitLabeledSetVector.cpp 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. /**
  2. * @file splitLabeledSetVector.cpp
  3. * @brief split train.vec/test.vec files
  4. * @author Erik Rodner
  5. * @date 03/23/2010
  6. */
  7. #include "core/basics/Config.h"
  8. #include "core/basics/StringTools.h"
  9. #include "vislearning/cbaselib/LabeledSet.h"
  10. #include "vislearning/cbaselib/LabeledSetSelection.h"
  11. #include "core/basics/numerictools.h"
  12. using namespace std;
  13. using namespace OBJREC;
  14. using namespace NICE;
  15. //#stupid test for git
  16. void normalizeLabeledSetVector(const LabeledSetVector &teachSet,
  17. LabeledSetVector &transformedSet)
  18. {
  19. transformedSet.clear();
  20. Vector vector_max, vector_min, vector_span;
  21. int maxClassNo = teachSet.getMaxClassno();
  22. int n = teachSet.count();
  23. int d = teachSet.dimension();
  24. vector_max.resize(d);
  25. vector_min.resize(d);
  26. vector_span.resize(d);
  27. //get input data
  28. uint featurecount = 0;
  29. LOOP_ALL(teachSet)
  30. {
  31. EACH(classno,x);
  32. for (uint k = 0; k < x.size(); ++k)
  33. {
  34. double value = x[k];
  35. if (featurecount == 0)
  36. {
  37. vector_max[k] = value;
  38. vector_min[k] = value;
  39. }
  40. else
  41. {
  42. if (value > vector_max[k])
  43. {
  44. vector_max[k] = value;
  45. }
  46. if (value < vector_min[k])
  47. {
  48. vector_min[k] = value;
  49. }
  50. }
  51. }
  52. ++featurecount;
  53. }
  54. vector_span = vector_max - vector_min;
  55. //save transformed Vectors
  56. LOOP_ALL(teachSet)
  57. {
  58. EACH(classno,x);
  59. NICE::Vector transformed_vector(x.size());
  60. for (uint k = 0; k < vector_min.size(); ++k)
  61. {
  62. if (vector_span[k] > 1e-10)
  63. {
  64. transformed_vector[k] = (x[k] - vector_min[k])
  65. / vector_span[k];
  66. }
  67. else
  68. {
  69. transformed_vector[k] = 1.0;
  70. }
  71. }
  72. transformedSet.add(classno, transformed_vector);
  73. }
  74. }
  75. /**
  76. split train.vec/test.vec files
  77. */
  78. int main(int argc, char **argv)
  79. {
  80. #ifndef __clang__
  81. #ifndef __llvm__
  82. std::set_terminate(__gnu_cxx::__verbose_terminate_handler);
  83. #endif
  84. #endif
  85. Config conf(argc, argv);
  86. int format = conf.gI("main", "format", 2);
  87. LabeledSetVector all;
  88. string setfn = conf.gS("main", "set");
  89. all.read(setfn, format);
  90. bool normalize = conf.gB("main", "normalize", false);
  91. if (normalize)
  92. {
  93. LabeledSetVector all_tmp;
  94. normalizeLabeledSetVector(all, all_tmp);
  95. all = all_tmp;
  96. }
  97. bool random = conf.gB("main", "random", false);
  98. if (random)
  99. initRand();
  100. map<int, int> fixedPositiveExamples;
  101. string exampleList = conf.gS("main", "examples");
  102. vector<string> list;
  103. StringTools::split(exampleList, ';', list);
  104. for (vector<string>::const_iterator i = list.begin(); i != list.end(); i++)
  105. {
  106. string e = *i;
  107. vector<string> f;
  108. StringTools::split(e, ':', f);
  109. if (f.size() != 2)
  110. fthrow(Exception, "Specify -examples classno:number;classno:number;...\n");
  111. int classno;
  112. int examples;
  113. StringTools::convert<int>(f[0], classno);
  114. StringTools::convert<int>(f[1], examples);
  115. fixedPositiveExamples.insert(pair<int, int> (classno, examples));
  116. }
  117. LabeledSetVector train;
  118. LabeledSetVector test;
  119. LabeledSetSelection<LabeledSetVector>::selectRandom(fixedPositiveExamples,
  120. all, train, test);
  121. string trainfn = conf.gS("main", "train", "train.vec");
  122. string testfn = conf.gS("main", "test", "test.vec");
  123. train.save(trainfn, format);
  124. test.save(testfn, format);
  125. return 0;
  126. }