CascadeOptimization.cpp 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. /**
  2. * @file CascadeOptimization.cpp
  3. * @brief optimization of a previously built cascade
  4. * @author Erik Rodner
  5. * @date 11/13/2008
  6. */
  7. #include <iostream>
  8. #include <algorithm>
  9. #include <assert.h>
  10. #include "CascadeOptimization.h"
  11. using namespace OBJREC;
  12. using namespace std;
  13. // refactor-nice.pl: check this substitution
  14. // old: using namespace ice;
  15. using namespace NICE;
  16. CascadeOptimization::CascadeOptimization()
  17. {
  18. }
  19. CascadeOptimization::~CascadeOptimization()
  20. {
  21. }
  22. bool CascadeOptimization::calcOptimalCascade ( const vector<vector< triplet<double, double, double> > > & matrix,
  23. list<int> & path,
  24. double & besttprate,
  25. double tprate,
  26. double fprate,
  27. double minimumFPRate,
  28. uint round )
  29. {
  30. if ( round < matrix.size() )
  31. {
  32. int index = 0;
  33. bool solutionFoundLoop = false;
  34. const vector<triplet<double, double, double> > & statistics = matrix[round];
  35. for ( vector<triplet<double, double, double> >::const_iterator i = statistics.begin();
  36. i != statistics.end();
  37. i++, index++ )
  38. {
  39. double tp = i->first;
  40. double fp = i->second;
  41. if ( tp*tprate > besttprate ) {
  42. bool solutionFound = false;
  43. if ( fp*fprate < minimumFPRate) {
  44. path.clear();
  45. solutionFound = true;
  46. } else if ( calcOptimalCascade ( matrix, path, besttprate,
  47. tprate*tp, fprate*fp, minimumFPRate, round+1 ) )
  48. {
  49. solutionFound = true;
  50. }
  51. if ( solutionFound )
  52. {
  53. besttprate = tprate*tp;
  54. path.push_front ( index );
  55. solutionFoundLoop = true;
  56. }
  57. }
  58. }
  59. return solutionFoundLoop;
  60. } else {
  61. assert ( tprate > besttprate );
  62. if ( fprate < minimumFPRate )
  63. {
  64. path.clear();
  65. return true;
  66. } else {
  67. return false;
  68. }
  69. }
  70. }
  71. void CascadeOptimization::calcOptimalCascade ( const vector<vector< triplet<double, double, double> > > & matrix,
  72. double minimumFPRate,
  73. vector<double> & thresholds )
  74. {
  75. list<int> path;
  76. double besttprate = 0.0;
  77. calcOptimalCascade ( matrix, path, besttprate, 1.0, 1.0, minimumFPRate, 0 );
  78. int index = 0;
  79. for ( list<int>::const_iterator i = path.begin(); i != path.end();
  80. i++, index++ )
  81. {
  82. int entry = *i;
  83. const triplet<double, double, double> & vals = matrix[index][entry];
  84. double tprate = vals.first;
  85. double fprate = vals.second;
  86. double threshold = vals.third;
  87. fprintf (stderr, "cascade (%d): tp %f fp %f threshold %f\n",
  88. index+1, tprate, fprate, threshold);
  89. thresholds.push_back(threshold);
  90. }
  91. }
  92. bool CascadeOptimization::evaluateCascade ( vector<pair<double, int> > & results,
  93. long N, long P,
  94. int negativeClassDST,
  95. double requiredDetectionRate,
  96. double & bestthreshold,
  97. double & besttprate,
  98. double & bestfprate,
  99. vector< triplet<double, double, double> > & statistics )
  100. {
  101. sort ( results.begin(), results.end() );
  102. long positives_count = 0;
  103. long count = 1;
  104. bool solutionFound = false;
  105. int bestEntry = 0;
  106. int secondBestEntry = 0;
  107. for ( vector<pair<double, int> >::const_iterator j = results.begin();
  108. j+1 != results.end();
  109. j++, count++ )
  110. {
  111. int classno = j->second;
  112. double threshold = j->first;
  113. //fprintf (stderr, "CascadeOptimization: classno %d, threshold %f\n", classno, threshold );
  114. if ( classno != negativeClassDST )
  115. positives_count++;
  116. double tprate = positives_count / (double)P;
  117. double fprate = ( count - positives_count ) / (double) N;
  118. if ( (classno != negativeClassDST) && ((j+1)->second == negativeClassDST)
  119. && ((j+1)->first != threshold) )
  120. {
  121. statistics.push_back ( triplet<double, double, double> ( tprate, fprate, threshold ) );
  122. fprintf (stderr, "CascadeOptimization: tprate %f fprate %f threshold %f (required tprate %f)\n", tprate, fprate, j->first, requiredDetectionRate );
  123. if ( (!solutionFound) && (tprate >= requiredDetectionRate) )
  124. {
  125. bestEntry = statistics.size() - 1;
  126. fprintf (stderr, "CascadeOptimization: suitable entry found !\n");
  127. solutionFound = true;
  128. } else {
  129. secondBestEntry = statistics.size() - 1;
  130. }
  131. }
  132. }
  133. if ( ! solutionFound ) {
  134. fprintf (stderr, "CascadeOptimization: Using second best solution !!\n");
  135. besttprate = statistics[secondBestEntry].first;
  136. bestfprate = statistics[secondBestEntry].second;
  137. bestthreshold = statistics[secondBestEntry].third;
  138. fprintf (stderr, "CascadeOptimization: threshold %f detection rate %f fp rate %f\n",
  139. bestthreshold, besttprate, bestfprate );
  140. } else {
  141. besttprate = statistics[bestEntry].first;
  142. bestfprate = statistics[bestEntry].second;
  143. bestthreshold = statistics[bestEntry].third;
  144. }
  145. if ( besttprate == 0.0 ) {
  146. fprintf (stderr, "!!!! WORST CLASSIFIER I'VE EVER SEEN !!!!\n");
  147. }
  148. return solutionFound;
  149. }