PSSLocalizationPrior.cpp 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. /**
  2. * @file PSSLocalizationPrior.cpp
  3. * @brief incorporate prior from localization results
  4. * @author Erik Rodner
  5. * @date 03/19/2009
  6. */
  7. #include <iostream>
  8. #include <limits>
  9. #include "PSSLocalizationPrior.h"
  10. #include "objrec/baselib/StringTools.h"
  11. #include "objrec/baselib/Globals.h"
  12. #include "objrec/baselib/FileMgt.h"
  13. #include "objrec/cbaselib/PascalResults.h"
  14. using namespace OBJREC;
  15. using namespace std;
  16. using namespace NICE;
  17. PSSLocalizationPrior::PSSLocalizationPrior( const std::string & detectiondir,
  18. const ClassNames *classNames,
  19. double alphaDetectionPrior,
  20. int subsamplex, int subsampley )
  21. {
  22. this->subsamplex = subsampley;
  23. this->subsampley = subsamplex;
  24. this->alphaDetectionPrior = alphaDetectionPrior;
  25. loadDetectionResults ( detectiondir, detresults, classNames );
  26. }
  27. PSSLocalizationPrior::~PSSLocalizationPrior()
  28. {
  29. }
  30. void PSSLocalizationPrior::loadDetectionResults ( const std::string & dir,
  31. map<string, LocalizationResult *> & results,
  32. const ClassNames *classNames )
  33. {
  34. vector<string> files;
  35. FileMgt::DirectoryRecursive ( files, dir );
  36. int backgroundClassNo = classNames->getBackgroundClass();
  37. for ( vector<string>::const_iterator i = files.begin();
  38. i != files.end(); i++ )
  39. {
  40. std::string file = *i;
  41. std::string classtext = StringTools::baseName ( file, false );
  42. int classno = classNames->classno(classtext);
  43. if ( classno < 0 ) {
  44. fprintf (stderr, "Unable to find class %s\n", classtext.c_str() );
  45. fprintf (stderr, "dir %s file %s classtext %s\n", dir.c_str(),
  46. file.c_str(), classtext.c_str() );
  47. }
  48. PascalResults::read ( results, file, classno, backgroundClassNo, true /*calibrate*/ );
  49. }
  50. }
  51. void PSSLocalizationPrior::postprocess ( NICE::Image & result, NICE::MultiChannelImageT<double> & probabilities )
  52. {
  53. std::string currentFilename = Globals::getCurrentImgFN();
  54. std::string base = StringTools::baseName ( currentFilename, false );
  55. map<string, LocalizationResult *>::const_iterator i = detresults.find ( base );
  56. if ( i == detresults.end() )
  57. {
  58. fprintf (stderr, "NO detection results found for %s !\n", base.c_str());
  59. return;
  60. }
  61. fprintf (stderr, "Infering detection prior\n");
  62. LocalizationResult *ldet = i->second;
  63. int maxClassNo = probabilities.numChannels - 1;
  64. int xsize = probabilities.xsize;
  65. int ysize = probabilities.ysize;
  66. FullVector *priormap = new FullVector [ xsize * ysize ];
  67. for ( long k = 0 ; k < xsize * ysize ; k++ )
  68. priormap[k].reinit(maxClassNo);
  69. for ( LocalizationResult::const_iterator j = ldet->begin();
  70. j != ldet->end();
  71. j++ )
  72. {
  73. const SingleLocalizationResult *slr = *j;
  74. int xi, yi, xa, ya;
  75. const NICE::Region & r = slr->getRegion();
  76. int classno = slr->r->classno;
  77. double confidence = slr->r->confidence();
  78. r.getRect ( xi, yi, xa, ya );
  79. for ( int y = yi; y <= ya; y++ )
  80. for ( int x = xi; x <= xa; x++ )
  81. {
  82. if ( (y<0) || (x<0) || (x>xsize-1) || (y>ysize-1) )
  83. continue;
  84. if ( r.inside ( x*subsamplex, y*subsampley ) )
  85. priormap[x + y*xsize][classno] += confidence;
  86. }
  87. long k = 0;
  88. for ( int y = 0 ; y < ysize ; y++ )
  89. for ( int x = 0 ; x < xsize ; x++,k++ )
  90. {
  91. FullVector & prior = priormap[k];
  92. if ( prior.sum() < 10e-6 )
  93. continue;
  94. prior.normalize();
  95. double sum = 0.0;
  96. for ( int i = 0 ; i < (int)probabilities.numChannels; i++ )
  97. {
  98. probabilities.data[i][k] *= pow ( prior[i], alphaDetectionPrior );
  99. sum += probabilities.data[i][k];
  100. }
  101. if ( sum < 10e-6 )
  102. continue;
  103. int maxindex = 0;
  104. double maxvalue = - numeric_limits<double>::max();
  105. for ( int i = 0 ; i < (int)probabilities.numChannels; i++ )
  106. {
  107. probabilities.data[i][k] /= sum;
  108. if ( probabilities.data[i][k] > maxvalue )
  109. {
  110. maxindex = i;
  111. maxvalue = probabilities.data[i][k];
  112. }
  113. }
  114. result.setPixel(x,y,maxindex);
  115. }
  116. }
  117. delete [] priormap;
  118. }