FitSigmoid.cpp 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. /**
  2. * @file FitSigmoid.cpp
  3. * @brief fit a sigmoid function accordings to the paper of platt
  4. * @author Erik Rodner
  5. * @date 03/05/2009
  6. */
  7. #include <iostream>
  8. #include <math.h>
  9. #include <limits>
  10. #include "FitSigmoid.h"
  11. using namespace OBJREC;
  12. using namespace std;
  13. void FitSigmoid::fitSigmoid ( const vector<double> & t,
  14. const vector<double> & out,
  15. double startp,
  16. double & A, double & B )
  17. {
  18. const int maxiterations = 100;
  19. double lambda = 1e-3;
  20. double olderr = numeric_limits<double>::max();
  21. vector<double> pp ( out.size(), startp );
  22. double oldA;
  23. double oldB;
  24. int count = 0; // failure count
  25. for (int it = 1; it < maxiterations ; it++ )
  26. {
  27. double a = 0;
  28. double b = 0;
  29. double c = 0;
  30. double d = 0;
  31. double e = 0;
  32. // First, compute Hessian & gradient of error function
  33. // with respect to A & B
  34. int index = 0;
  35. for (vector<double>::const_iterator j = out.begin();
  36. j != out.end(); j++, index++ )
  37. {
  38. double pi = *j;
  39. double d1 = pp[index]-t[index];
  40. double d2 = pp[index]*(1-pp[index]);
  41. a += pi*pi*d2;
  42. b += d2;
  43. c += pi*d2;
  44. d += pi*d1;
  45. e += d1;
  46. }
  47. // If gradient is really tiny, then stop
  48. if (fabs(d) < 1e-9 && fabs(e) < 1e-9)
  49. break;
  50. oldA = A;
  51. oldB = B;
  52. double err = 0;
  53. // Loop until goodness of fit increases
  54. while (1)
  55. {
  56. double det = (a+lambda)*(b+lambda)-c*c;
  57. if (fabs(det) < 1e-12) {
  58. // if determinant of Hessian is zero,
  59. // increase stabilizer
  60. lambda *= 10;
  61. continue;
  62. }
  63. A = oldA + ((b+lambda)*d-c*e)/det;
  64. B = oldB + ((a+lambda)*e-c*d)/det;
  65. // Now, compute the goodness of fit
  66. err = 0;
  67. int index = 0;
  68. for (vector<double>::const_iterator j = out.begin();
  69. j != out.end(); j++, index++ )
  70. {
  71. double pi = *j;
  72. double p = 1.0/(1.0+exp(pi*A+B));
  73. pp[index] = p;
  74. // At this step, make sure log(0) returns -200
  75. double logp = p < 1e-12 ? -200 : log(p);
  76. double lognp = 1-p < 1e-12 ? -200 : log(1-p);
  77. err -= t[index]*logp+(1-t[index])*lognp;
  78. }
  79. if (err < olderr*(1+1e-7)) {
  80. lambda *= 0.1;
  81. break;
  82. }
  83. // error did not decrease: increase stabilizer by factor of 10
  84. // & try again
  85. lambda *= 10;
  86. if (lambda > 1e6) // something is broken. Give up
  87. break;
  88. }
  89. double diff = err-olderr;
  90. double scale = 0.5*(err+olderr+1);
  91. if ((diff > -1e-3*scale) && (diff < 1e-7*scale))
  92. count++;
  93. else
  94. count = 0;
  95. olderr = err;
  96. if (count == 3)
  97. break;
  98. }
  99. }
  100. void FitSigmoid::fitProbabilities ( const vector<pair<int, double> > & results,
  101. double & A, double & B, double mlestimation )
  102. {
  103. int prior0 = 0;
  104. int prior1 = 0;
  105. for (vector<pair<int, double> >::const_iterator j = results.begin();
  106. j != results.end(); j++ )
  107. if ( j->first ) prior1++;
  108. else prior0++;
  109. // corresponds to MAP estimation instead of ML estimation
  110. // -> paper of platt
  111. double hiTarget;
  112. double loTarget;
  113. if ( mlestimation )
  114. {
  115. hiTarget = 1.0;
  116. loTarget = 0.0;
  117. } else {
  118. hiTarget = (prior1+1)/(double)(prior1+2);
  119. loTarget = 1.0/(prior0+2.0);
  120. }
  121. vector<double> t ( results.size() );
  122. vector<double> out ( results.size() );
  123. int index = 0;
  124. for (vector<pair<int, double> >::const_iterator j = results.begin();
  125. j != results.end(); j++, index++ )
  126. {
  127. int yi = j->first;
  128. if ( yi )
  129. {
  130. t[index] = hiTarget;
  131. } else {
  132. t[index] = loTarget;
  133. }
  134. out[index] = j->second;
  135. }
  136. A = 0;
  137. B = log( (float)(prior0+1))-log( (float)(prior1+1) );
  138. double startp = (prior1+1)/(double)(prior0+prior1+2);
  139. fitSigmoid ( t, out, startp, A, B );
  140. }