compressObjectBankFeatures.cpp 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. /**
  2. * @file compressObjectBankFeatures.cpp
  3. * @brief convert ObjectBank features to a sparse histogram representation
  4. * @author Erik Rodner
  5. * @date 01/23/2012
  6. */
  7. #include <algorithm>
  8. #include "core/basics/Config.h"
  9. #include "vislearning/cbaselib/MultiDataset.h"
  10. #include "vislearning/cbaselib/MutualInformation.h"
  11. #include "vislearning/baselib/Globals.h"
  12. using namespace std;
  13. using namespace NICE;
  14. using namespace OBJREC;
  15. const bool use_standard = false; // experimental setting
  16. Vector transformFeature ( const vector<double> & src )
  17. {
  18. Vector dst;
  19. if ( use_standard ) {
  20. dst = Vector(src);
  21. } else {
  22. if ( src.size() != 44604 )
  23. fthrow(Exception, "This is not a ObjectBank feature! The size is: " << src.size());
  24. dst.resize ( 177 );
  25. dst.set(0.0);
  26. // This was a bad idea: taking the maximum
  27. /*
  28. for ( uint i = 0 ; i < 177; i++ )
  29. dst[i] = *max_element( src.begin() + i*252, src.begin() + (i+1)*252 );
  30. // even a worse idea: summation
  31. for ( uint i = 0 ; i < 177; i++ )
  32. for ( uint j = 0 ; j < 252 ; j++ )
  33. dst[i] += src[j + i*252];
  34. */
  35. }
  36. return dst;
  37. }
  38. void readPlainData ( const Config & conf, const LabeledSet & ls, LabeledSetVector & X, string extension = ".txt" )
  39. {
  40. string cacheroot = conf.gS("cache", "root");
  41. X.clear();
  42. LOOP_ALL_S ( ls )
  43. {
  44. EACH_S(classno, imgfn);
  45. Globals::setCurrentImgFN ( imgfn );
  46. string cachefn = Globals::getCacheFilename ( cacheroot, Globals::SORT_CATEGORIES ) + extension;
  47. cerr << "fn: " << imgfn << " cachefn: " << cachefn << endl;
  48. vector<double> x;
  49. ifstream ifs ( cachefn.c_str(), ios::in );
  50. if ( ! ifs.good() )
  51. fthrow(Exception, "File not found: " << cachefn );
  52. while ( !ifs.eof() )
  53. {
  54. double val = 0.0;
  55. if ( ifs >> val )
  56. x.push_back(val);
  57. }
  58. ifs.close();
  59. X.add ( classno, transformFeature( x ) );
  60. }
  61. }
  62. void saveFeatures ( const Config & conf, const map<double, int> & features, const Vector & thresholds,
  63. const LabeledSet & ls, const string & srcExtension, const string & dstExtension )
  64. {
  65. string cacheroot = conf.gS("cache", "root");
  66. LOOP_ALL_S ( ls )
  67. {
  68. EACH_S(classno, imgfn);
  69. Globals::setCurrentImgFN ( imgfn );
  70. string cachefn = Globals::getCacheFilename ( cacheroot, Globals::SORT_CATEGORIES ) + srcExtension;
  71. cerr << "processing " << cachefn << endl;
  72. vector<double> x;
  73. ifstream ifs ( cachefn.c_str(), ios::in );
  74. if ( ! ifs.good() )
  75. fthrow(Exception, "File not found: " << cachefn );
  76. while ( !ifs.eof() ) {
  77. double val = 0.0;
  78. if ( ifs >> val )
  79. x.push_back(val);
  80. }
  81. ifs.close();
  82. Vector xt = transformFeature(x);
  83. Vector xnew ( features.size() );
  84. int index = 0;
  85. for ( map<double, int>::const_iterator j = features.begin(); j != features.end(); j++, index++ )
  86. {
  87. int srcIndex = j->second;
  88. if ( srcIndex >= xt.size() )
  89. fthrow(Exception, "Bad bug in saveFeatures(...)" );
  90. xnew[index] = (xt[srcIndex] > thresholds[srcIndex]) ? 1.0 : 0.0;
  91. }
  92. // If we do not normalize our features, we pretty much get into
  93. // trouble with the minimum kernel...because the vector with only values of "1" is very
  94. // much similar to every vector
  95. xnew.normalizeL1();
  96. string dst_cachefn = Globals::getCacheFilename ( cacheroot, Globals::SORT_CATEGORIES ) + dstExtension;
  97. ofstream ofs ( dst_cachefn.c_str(), ios::out );
  98. if ( ! ofs.good() )
  99. fthrow(Exception, "Unable to write to " << dst_cachefn );
  100. ofs << xnew << endl;
  101. ofs.close ();
  102. }
  103. }
  104. /**
  105. convert ObjectBank features to a sparse histogram representation
  106. */
  107. int main (int argc, char **argv)
  108. {
  109. std::set_terminate(__gnu_cxx::__verbose_terminate_handler);
  110. Config conf ( argc, argv );
  111. MultiDataset md ( &conf );
  112. Vector y;
  113. const LabeledSet *train = md["train"];
  114. LabeledSetVector trainData;
  115. readPlainData ( conf, *train, trainData, ".jpg.feat" );
  116. // compute optimal thresholds for thresholding
  117. MutualInformation mi ( true /*verbose*/ );
  118. Vector thresholds;
  119. Vector mis;
  120. mi.computeThresholdsOverall ( trainData, thresholds, mis );
  121. cerr << mis << endl;
  122. int numFeatures = conf.gI("main", "d", mis.size() );
  123. cerr << "Retaining " << numFeatures << " features ..." << endl;
  124. map<double, int> features;
  125. for ( uint i = 0 ; i < mis.size(); i++ )
  126. features.insert ( pair<double, int> ( - mis[i], i ) );
  127. // all features should be now sorted with features[0] being the most informative one
  128. // remove boring features
  129. map<double, int>::iterator j = features.begin();
  130. advance( j, numFeatures );
  131. features.erase ( j, features.end() );
  132. const LabeledSet *test = md["test"];
  133. string dstExtention = conf.gS("main", "dstext", ".txt");
  134. saveFeatures ( conf, features, thresholds, *train, ".jpg.feat", dstExtention );
  135. saveFeatures ( conf, features, thresholds, *test, ".jpg.feat", dstExtention );
  136. return 0;
  137. }