KernelUtils.cpp 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. /**
  2. * @file KernelUtils.cpp
  3. * @brief some utilities to select kernel sub matrices etc.
  4. * @author Erik Rodner
  5. * @date 03/01/2010
  6. */
  7. #include <iostream>
  8. #include <set>
  9. #include "KernelUtils.h"
  10. #include "core/basics/StringTools.h"
  11. using namespace std;
  12. using namespace NICE;
  13. using namespace OBJREC;
  14. void KernelUtils::selectExamples ( const Config *conf, const Vector & labels,
  15. vector<int> & trainSelection, vector<int> & testSelection )
  16. {
  17. string selectionType = conf->gS("main", "selection_type");
  18. map<int, int> trainExamplesCount;
  19. map<int, int> testExamplesCount;
  20. if ( selectionType == "seq" ) {
  21. int trainExamples = conf->gI("main", "selection_examples" );
  22. if ( ((int)labels.size() < trainExamples) || (trainExamples <= 0) )
  23. fthrow(Exception, "Unable to select " << trainExamples << " from " << labels.size() << ".");
  24. for ( uint i = 0 ; i < (uint)trainExamples; i++ )
  25. {
  26. int classno = (int)labels[i];
  27. trainSelection.push_back( i );
  28. trainExamplesCount[classno] ++;
  29. }
  30. for ( uint i = (uint)trainExamples ; i < labels.size(); i++ )
  31. {
  32. int classno = (int)labels[i];
  33. testSelection.push_back( i );
  34. testExamplesCount[classno] ++;
  35. }
  36. } else if ( selectionType == "seq_class" )
  37. {
  38. int trainExamplesForEachClassSingle = conf->gI("main", "selection_examples_class", -1 );
  39. Vector trainExamplesForEachClass;
  40. if ( trainExamplesForEachClassSingle <=0 ) {
  41. string trainExamplesForEachClass_s = conf->gS("main", "selection_examples_class" );
  42. StringTools::splitVector ( trainExamplesForEachClass_s, ',', trainExamplesForEachClass );
  43. }
  44. for ( uint i = 0 ; i < labels.size() ; i++ )
  45. {
  46. int classno = (int)labels[i];
  47. if ( (trainExamplesForEachClassSingle <= 0) && (classno >= (int)trainExamplesForEachClass.size()) )
  48. fthrow(Exception, "-selection_examples_class <n0>,<n1>,..." << endl << "Missing data in selection_examples_class!" );
  49. int limit = trainExamplesForEachClassSingle;
  50. if ( limit <= 0 )
  51. limit = trainExamplesForEachClass[classno];
  52. if ( trainExamplesCount[classno] < limit )
  53. {
  54. trainSelection.push_back ( i );
  55. trainExamplesCount[classno] ++;
  56. } else {
  57. testSelection.push_back ( i );
  58. testExamplesCount[classno] ++;
  59. }
  60. }
  61. } else if ( selectionType == "random_class_doaa" )
  62. {
  63. if ( labels.size() != 224*5 )
  64. fthrow(Exception, "This selection only works with the Jena-Range-02 database!\n");
  65. int k = conf->gI("main", "selection_instances", 3 );
  66. map<int, set<int> > trainInstances;
  67. // loop through all classes
  68. for ( int i = 0 ; i < 5 ; i++ )
  69. {
  70. trainInstances.insert ( pair<int, set<int> > ( i, set<int> () ) );
  71. for ( int j = 0 ; j < k ; j++ )
  72. {
  73. int inst;
  74. do {
  75. inst = randInt ( 6 ) + 1;
  76. } while ( trainInstances[i].find(inst) != trainInstances[i].end() );
  77. trainInstances[i].insert ( inst );
  78. }
  79. }
  80. for ( uint i = 0 ; i < labels.size() ; i++ )
  81. {
  82. int classno = (int)labels[i];
  83. int instance = (i % 224) / 32;
  84. if ( instance == 0 )
  85. continue;
  86. // cerr << i << " " << "inst " << instance << " " << trainInstances[classno].size() << endl;
  87. if ( trainInstances[classno].find(instance) != trainInstances[classno].end() ) {
  88. trainSelection.push_back ( i );
  89. trainExamplesCount[classno] ++;
  90. } else {
  91. testSelection.push_back ( i );
  92. testExamplesCount[classno] ++;
  93. }
  94. }
  95. for ( int i = 0 ; i < 5 ; i++ )
  96. if ( trainExamplesCount[i] != k*32 ) {
  97. fthrow(Exception, "Something is wrong here: training examples of class " << i << " = " << trainExamplesCount[i] << " != " << k*32 );
  98. }
  99. } else if ( selectionType == "seq_class_doaa" )
  100. {
  101. int trainExamplesForEachClassSingle = 100;
  102. int testExamplesForEachClassSingle = 60;
  103. for ( uint i = 0 ; i < labels.size() ; i++ )
  104. {
  105. int classno = (int)labels[i];
  106. if ( trainExamplesCount[classno] < trainExamplesForEachClassSingle )
  107. {
  108. trainSelection.push_back ( i );
  109. trainExamplesCount[classno] ++;
  110. } else if ( testExamplesCount[classno] < testExamplesForEachClassSingle ) {
  111. testSelection.push_back ( i );
  112. testExamplesCount[classno] ++;
  113. }
  114. }
  115. } else if ( selectionType == "random_class" )
  116. {
  117. int trainExamplesForEachClassSingle = conf->gI("main", "selection_examples_class", -1 );
  118. Vector trainExamplesForEachClass;
  119. if ( trainExamplesForEachClassSingle <=0 ) {
  120. string trainExamplesForEachClass_s = conf->gS("main", "selection_examples_class" );
  121. StringTools::splitVector ( trainExamplesForEachClass_s, ',', trainExamplesForEachClass );
  122. }
  123. map<uint, uint> counts;
  124. for ( uint i = 0 ; i < labels.size(); i++ )
  125. {
  126. uint classno = (uint)labels[i];
  127. map<uint, uint>::iterator i = counts.find( classno );
  128. if ( i == counts.end() )
  129. counts.insert ( pair<uint, uint> ( classno, 1 ) );
  130. else
  131. i->second += 1;
  132. }
  133. set<int> memory;
  134. for ( map<uint, uint>::const_iterator k = counts.begin(); k != counts.end(); k++ )
  135. {
  136. uint count = k->second;
  137. uint classno = k->first;
  138. if ( (trainExamplesForEachClassSingle <= 0) && (classno >= trainExamplesForEachClass.size()) )
  139. fthrow(Exception, "-selection_examples_class <n0>,<n1>,..." << endl << "Missing data in selection_examples_class!" );
  140. int limit = trainExamplesForEachClassSingle;
  141. if ( limit <= 0 )
  142. limit = trainExamplesForEachClass[classno];
  143. if ( limit > (int)count )
  144. {
  145. cerr << "Class " << classno << " has not enough examples, we will use all of them (" << count << ") !" << endl;
  146. limit = count;
  147. }
  148. for ( int j = 0 ; j < limit ; j++ )
  149. {
  150. int k;
  151. // inefficient random selection
  152. do {
  153. k = rand() % labels.size();
  154. } while ( (memory.find(k) != memory.end()) || ((uint)labels[k] != classno) );
  155. memory.insert(k);
  156. trainSelection.push_back ( k );
  157. }
  158. cerr << classno << " -> " << limit << endl;
  159. }
  160. // put the remainder to the test sets
  161. for ( uint i = 0 ; i < labels.size(); i++ )
  162. {
  163. if ( memory.find(i) == memory.end() ) {
  164. testSelection.push_back ( i );
  165. }
  166. }
  167. } else {
  168. fthrow(Exception, "Selection type " << selectionType << " is unknown." );
  169. }
  170. cerr << "Learning" << endl;
  171. for ( map<int, int>::const_iterator j = trainExamplesCount.begin();
  172. j != trainExamplesCount.end(); j++ )
  173. cerr << "class " << j->first << ": " << j->second << endl;
  174. cerr << "Testing" << endl;
  175. for ( map<int, int>::const_iterator j = testExamplesCount.begin();
  176. j != testExamplesCount.end(); j++ )
  177. cerr << "class " << j->first << ": " << j->second << endl;
  178. }
  179. void KernelUtils::getKernelMatrix ( const vector<int> & trainSelection,
  180. const Matrix & kernelMatrix, const Vector & labels,
  181. Matrix & kernelMatrixTrain, Vector & labelsTrain )
  182. {
  183. kernelMatrixTrain.resize ( trainSelection.size(), trainSelection.size() );
  184. labelsTrain.resize ( trainSelection.size() );
  185. int ik = 0;
  186. for ( vector<int>::const_iterator i = trainSelection.begin();
  187. i != trainSelection.end(); i++,ik++ )
  188. {
  189. int index_i = *i;
  190. labelsTrain[ik] = labels[index_i];
  191. int jk = 0;
  192. for ( vector<int>::const_iterator j = trainSelection.begin();
  193. j != trainSelection.end(); j++,jk++ )
  194. {
  195. int index_j = *j;
  196. kernelMatrixTrain(ik,jk) = kernelMatrix(index_i,index_j);
  197. }
  198. }
  199. }
  200. void KernelUtils::getKernelVector ( const vector<int> & trainSelection,
  201. const Matrix & kernelMatrix, uint index, Vector & kernelVector )
  202. {
  203. kernelVector.resize ( trainSelection.size() );
  204. int ik = 0;
  205. for ( vector<int>::const_iterator i = trainSelection.begin();
  206. i != trainSelection.end(); i++,ik++ )
  207. {
  208. int index_i = *i;
  209. kernelVector(ik) = kernelMatrix(index_i,index);
  210. }
  211. }