CascadeOptimization.cpp 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. /**
  2. * @file CascadeOptimization.cpp
  3. * @brief optimization of a previously built cascade
  4. * @author Erik Rodner
  5. * @date 11/13/2008
  6. */
  7. #ifdef NOVISUAL
  8. #include <vislearning/nice_nonvis.h>
  9. #else
  10. #include <vislearning/nice.h>
  11. #endif
  12. #include <iostream>
  13. #include <algorithm>
  14. #include <assert.h>
  15. #include "CascadeOptimization.h"
  16. using namespace OBJREC;
  17. using namespace std;
  18. // refactor-nice.pl: check this substitution
  19. // old: using namespace ice;
  20. using namespace NICE;
  21. CascadeOptimization::CascadeOptimization()
  22. {
  23. }
  24. CascadeOptimization::~CascadeOptimization()
  25. {
  26. }
  27. bool CascadeOptimization::calcOptimalCascade ( const vector<vector< triplet<double, double, double> > > & matrix,
  28. list<int> & path,
  29. double & besttprate,
  30. double tprate,
  31. double fprate,
  32. double minimumFPRate,
  33. uint round )
  34. {
  35. if ( round < matrix.size() )
  36. {
  37. int index = 0;
  38. bool solutionFoundLoop = false;
  39. const vector<triplet<double, double, double> > & statistics = matrix[round];
  40. for ( vector<triplet<double, double, double> >::const_iterator i = statistics.begin();
  41. i != statistics.end();
  42. i++, index++ )
  43. {
  44. double tp = i->first;
  45. double fp = i->second;
  46. if ( tp*tprate > besttprate ) {
  47. bool solutionFound = false;
  48. if ( fp*fprate < minimumFPRate) {
  49. path.clear();
  50. solutionFound = true;
  51. } else if ( calcOptimalCascade ( matrix, path, besttprate,
  52. tprate*tp, fprate*fp, minimumFPRate, round+1 ) )
  53. {
  54. solutionFound = true;
  55. }
  56. if ( solutionFound )
  57. {
  58. besttprate = tprate*tp;
  59. path.push_front ( index );
  60. solutionFoundLoop = true;
  61. }
  62. }
  63. }
  64. return solutionFoundLoop;
  65. } else {
  66. assert ( tprate > besttprate );
  67. if ( fprate < minimumFPRate )
  68. {
  69. path.clear();
  70. return true;
  71. } else {
  72. return false;
  73. }
  74. }
  75. }
  76. void CascadeOptimization::calcOptimalCascade ( const vector<vector< triplet<double, double, double> > > & matrix,
  77. double minimumFPRate,
  78. vector<double> & thresholds )
  79. {
  80. list<int> path;
  81. double besttprate = 0.0;
  82. calcOptimalCascade ( matrix, path, besttprate, 1.0, 1.0, minimumFPRate, 0 );
  83. int index = 0;
  84. for ( list<int>::const_iterator i = path.begin(); i != path.end();
  85. i++, index++ )
  86. {
  87. int entry = *i;
  88. const triplet<double, double, double> & vals = matrix[index][entry];
  89. double tprate = vals.first;
  90. double fprate = vals.second;
  91. double threshold = vals.third;
  92. fprintf (stderr, "cascade (%d): tp %f fp %f threshold %f\n",
  93. index+1, tprate, fprate, threshold);
  94. thresholds.push_back(threshold);
  95. }
  96. }
  97. bool CascadeOptimization::evaluateCascade ( vector<pair<double, int> > & results,
  98. long N, long P,
  99. int negativeClassDST,
  100. double requiredDetectionRate,
  101. double & bestthreshold,
  102. double & besttprate,
  103. double & bestfprate,
  104. vector< triplet<double, double, double> > & statistics )
  105. {
  106. sort ( results.begin(), results.end() );
  107. long positives_count = 0;
  108. long count = 1;
  109. bool solutionFound = false;
  110. int bestEntry = 0;
  111. int secondBestEntry = 0;
  112. for ( vector<pair<double, int> >::const_iterator j = results.begin();
  113. j+1 != results.end();
  114. j++, count++ )
  115. {
  116. int classno = j->second;
  117. double threshold = j->first;
  118. //fprintf (stderr, "CascadeOptimization: classno %d, threshold %f\n", classno, threshold );
  119. if ( classno != negativeClassDST )
  120. positives_count++;
  121. double tprate = positives_count / (double)P;
  122. double fprate = ( count - positives_count ) / (double) N;
  123. if ( (classno != negativeClassDST) && ((j+1)->second == negativeClassDST)
  124. && ((j+1)->first != threshold) )
  125. {
  126. statistics.push_back ( triplet<double, double, double> ( tprate, fprate, threshold ) );
  127. fprintf (stderr, "CascadeOptimization: tprate %f fprate %f threshold %f (required tprate %f)\n", tprate, fprate, j->first, requiredDetectionRate );
  128. if ( (!solutionFound) && (tprate >= requiredDetectionRate) )
  129. {
  130. bestEntry = statistics.size() - 1;
  131. fprintf (stderr, "CascadeOptimization: suitable entry found !\n");
  132. solutionFound = true;
  133. } else {
  134. secondBestEntry = statistics.size() - 1;
  135. }
  136. }
  137. }
  138. if ( ! solutionFound ) {
  139. fprintf (stderr, "CascadeOptimization: Using second best solution !!\n");
  140. besttprate = statistics[secondBestEntry].first;
  141. bestfprate = statistics[secondBestEntry].second;
  142. bestthreshold = statistics[secondBestEntry].third;
  143. fprintf (stderr, "CascadeOptimization: threshold %f detection rate %f fp rate %f\n",
  144. bestthreshold, besttprate, bestfprate );
  145. } else {
  146. besttprate = statistics[bestEntry].first;
  147. bestfprate = statistics[bestEntry].second;
  148. bestthreshold = statistics[bestEntry].third;
  149. }
  150. if ( besttprate == 0.0 ) {
  151. fprintf (stderr, "!!!! WORST CLASSIFIER I'VE EVER SEEN !!!!\n");
  152. }
  153. return solutionFound;
  154. }