LabeledSetSelection.h 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. /**
  2. * @file LabeledSetSelection.h
  3. * @brief Select a subset of a LabeledSet
  4. * @author Erik Rodner
  5. * @date 02/08/2008
  6. */
  7. #ifndef LABELEDSETSELECTIONINCLUDE
  8. #define LABELEDSETSELECTIONINCLUDE
  9. #include <vislearning/nice_nonvis.h>
  10. #include <set>
  11. #include <map>
  12. namespace OBJREC {
  13. /** Select a subset of a LabeledSet */
  14. template <class F>
  15. class LabeledSetSelection
  16. {
  17. protected:
  18. public:
  19. static void selectRandom ( const std::map<int,int> & fixedPositiveExamples, const F & base, F & positive, F & negative )
  20. {
  21. F & base_nonconst = const_cast< F & >(base);
  22. //srand(time(NULL));
  23. for ( std::map<int,int>::const_iterator i = fixedPositiveExamples.begin();
  24. i != fixedPositiveExamples.end() ;
  25. i++ )
  26. {
  27. int classno = i->first;
  28. int fixedPositiveExamples = i->second;
  29. std::set<int> memory;
  30. int count = base_nonconst[classno].size();
  31. if ( (count < fixedPositiveExamples) ) {
  32. fthrow ( Exception, "LabeledSetSelection::selectRandom: unable to select " << fixedPositiveExamples
  33. << " of " << count << " examples" );
  34. }
  35. for ( int j = 0 ; j < fixedPositiveExamples ; j++ )
  36. {
  37. int k;
  38. do {
  39. k = rand() % count;
  40. } while ( memory.find(k) != memory.end() );
  41. memory.insert(k);
  42. positive.add_reference ( classno, base_nonconst[classno][k]);
  43. }
  44. for ( int k = 0 ; k < count ; k++ )
  45. if ( memory.find(k) == memory.end() )
  46. negative.add_reference ( classno, base_nonconst[classno][k]);
  47. }
  48. }
  49. static void selectRandomMax ( const std::map<int,int> & fixedPositiveExamples, const F & base, F & positive, F & negative )
  50. {
  51. F & base_nonconst = const_cast< F & >(base);
  52. //srand(time(NULL));
  53. for ( std::map<int,int>::const_iterator i = fixedPositiveExamples.begin();
  54. i != fixedPositiveExamples.end() ;
  55. i++ )
  56. {
  57. int classno = i->first;
  58. int fixedPositiveExamples = i->second;
  59. std::set<int> memory;
  60. int count = base_nonconst[classno].size();
  61. int m = fixedPositiveExamples < count ? fixedPositiveExamples : count;
  62. for ( int j = 0 ; j < m ; j++ )
  63. {
  64. int k;
  65. do {
  66. k = rand() % count;
  67. } while ( memory.find(k) != memory.end() );
  68. memory.insert(k);
  69. positive.add_reference ( classno, base_nonconst[classno][k]);
  70. }
  71. for ( int k = 0 ; k < count ; k++ )
  72. if ( memory.find(k) == memory.end() )
  73. negative.add_reference ( classno, base_nonconst[classno][k]);
  74. }
  75. }
  76. static void selectSequentialStep ( const std::map<int, int> & fixedPositiveExamples, const F & base, F & positive, F & negative )
  77. {
  78. F & base_nonconst = const_cast< F & >(base);
  79. for ( std::map<int,int>::const_iterator i = fixedPositiveExamples.begin();
  80. i != fixedPositiveExamples.end() ;
  81. i++ )
  82. {
  83. int classno = i->first;
  84. int fixedPositiveExamples = i->second;
  85. int count = base_nonconst[classno].size();
  86. int step = fixedPositiveExamples ? (count / fixedPositiveExamples) : 0;
  87. if ( count < fixedPositiveExamples ) {
  88. fthrow ( Exception, "LabeledSetSelection::selectSequentialStep: unable to select " << fixedPositiveExamples
  89. << " of " << count << " examples (classno " << classno << ")" );
  90. }
  91. int k = 0;
  92. for ( int j = 0 ; j < count ; j++ )
  93. {
  94. if ( (step == 0) || (k >= fixedPositiveExamples) || (j % step != 0) )
  95. negative.add_reference ( classno, base_nonconst[classno][j] );
  96. else {
  97. k++;
  98. positive.add_reference ( classno, base_nonconst[classno][j] );
  99. }
  100. }
  101. }
  102. };
  103. static void selectSequential ( const std::map<int, int> & fixedPositiveExamples, const F & base, F & positive, F & negative )
  104. {
  105. F & base_nonconst = const_cast< F & >(base);
  106. for ( std::map<int,int>::const_iterator i = fixedPositiveExamples.begin();
  107. i != fixedPositiveExamples.end() ;
  108. i++ )
  109. {
  110. int classno = i->first;
  111. int fixedPositiveExamples = i->second;
  112. int count = base_nonconst[classno].size();
  113. if ( count < fixedPositiveExamples ) {
  114. fthrow ( Exception, "LabeledSetSelection::selectSequential: unable to select " << fixedPositiveExamples
  115. << " of " << count << " examples" );
  116. }
  117. for ( int j = 0 ; j < fixedPositiveExamples ; j++ )
  118. positive.add_reference ( classno, base_nonconst[classno][j] );
  119. for ( int j = fixedPositiveExamples ; j < count ; j++ )
  120. negative.add_reference ( classno, base_nonconst[classno][j] );
  121. }
  122. };
  123. static void selectClasses ( const std::set<int> & classnos, const F & base, F & positive, F & negative )
  124. {
  125. F & base_nonconst = const_cast< F & >(base);
  126. std::vector<int> classes;
  127. base.getClasses(classes);
  128. for ( std::set<int>::const_iterator i = classnos.begin();
  129. i != classnos.end() ;
  130. i++ )
  131. {
  132. int classno = *i;
  133. int count = base_nonconst[classno].size();
  134. for ( int j = 0 ; j < count ; j++ )
  135. positive.add_reference ( classno, base_nonconst[classno][j]);
  136. }
  137. for ( std::vector<int>::const_iterator i = classes.begin();
  138. i != classes.end() ;
  139. i++ )
  140. {
  141. int classno = *i;
  142. if ( classnos.find(classno) != classnos.end() ) continue;
  143. int count = base_nonconst[classno].size();
  144. for ( int j = 0 ; j < count ; j++ )
  145. negative.add_reference ( classno, base_nonconst[classno][j]);
  146. }
  147. };
  148. };
  149. } // namespace
  150. #endif