MAPMultinomialDirichlet.cpp 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. /**
  2. * @file MAPMultinomialDirichlet.h
  3. // refactor-nice.pl: check this substitution
  4. // old: * @brief map estimation of a multinomial using a dirichlet prior
  5. * @brief std::map estimation of a multinomial using a dirichlet prior
  6. * @author Erik Rodner
  7. * @date 10/30/2008
  8. */
  9. #include <assert.h>
  10. #include "MAPMultinomialDirichlet.h"
  11. using namespace OBJREC;
  12. using namespace NICE;
  13. using namespace std;
  14. void MAPMultinomialDirichlet::estimate ( NICE::Vector & mapEstimate,
  15. const VVector & likelihoodDistributionSamples,
  16. const VVector & priorDistributionSamples,
  17. double priorInfluence )
  18. {
  19. assert ( likelihoodDistributionSamples.size() == 1 );
  20. priorInfluence = 1.0 / priorInfluence;
  21. const NICE::Vector & mlEstimate = likelihoodDistributionSamples[0];
  22. double mlEstimateSum = 0.0;
  23. for ( uint k = 0 ; k < (uint)mlEstimate.size(); k++ )
  24. mlEstimateSum += mlEstimate[k];
  25. NICE::Vector mu;
  26. for ( VVector::const_iterator i = priorDistributionSamples.begin();
  27. i != priorDistributionSamples.end(); i++ )
  28. {
  29. const NICE::Vector & x = *i;
  30. if ( mu.size() == 0 ) mu = x;
  31. else mu = mu + x;
  32. }
  33. mu = mu * (priorInfluence/priorDistributionSamples.size());
  34. // ------ CODE SAFETY
  35. double muSum = 0.0;
  36. for ( uint k = 0 ; k < (uint)mu.size(); k++ )
  37. muSum += mu[k];
  38. assert ( fabs(muSum - priorInfluence) < 10e-9 );
  39. // ------ END CODE SAFETY
  40. // mu is a rough estimate of the parameter vector alpha of a dirichlet distribution
  41. mapEstimate.resize(mlEstimate.size());
  42. mapEstimate.set(0);
  43. double scale = mlEstimateSum + priorInfluence - mlEstimate.size();
  44. assert ( fabs(scale) > 10e-8 );
  45. for ( uint k = 0 ; k < (uint)mapEstimate.size() ; k++ )
  46. mapEstimate[k] = (mlEstimate[k] + mu[k] - 1) / scale;
  47. }