calcCurves.cpp 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. /**
  2. * @file analyseLocalization.cpp
  3. * @brief print recall/precision curves etc
  4. * @author Erik Rodner
  5. * @date 09/01/2008
  6. */
  7. #include "core/vector/VectorT.h"
  8. #include "core/vector/MatrixT.h"
  9. #include "core/image/ImageT.h"
  10. #include <core/basics/Config.h>
  11. #include <vislearning/baselib/Gnuplot.h>
  12. #include <core/basics/StringTools.h>
  13. #include <vislearning/cbaselib/LocalizationAnalysis.h>
  14. using namespace OBJREC;
  15. using namespace NICE;
  16. using namespace std;
  17. void readResults ( const string & resultsfn, vector<pair<double, int> > & results )
  18. {
  19. ifstream ifs ( resultsfn.c_str(), ios::in );
  20. if ( !ifs.good() )
  21. fthrow(IOException, "Unable to open " << resultsfn << "." );
  22. while ( !ifs.eof() )
  23. {
  24. char buf [1024];
  25. ifs.getline( buf, 1024 );
  26. if ( !ifs.good() ) break;
  27. Vector entry;
  28. StringTools::splitVector ( string(buf), ',', entry );
  29. if ( entry.size() != 4 )
  30. fthrow(IOException, "Parse error in " << resultsfn << "." );
  31. results.push_back ( pair<double, int> ( entry[2], (int)entry[1] ) );
  32. }
  33. ifs.close();
  34. }
  35. /**
  36. print recall/precision curves etc
  37. */
  38. int main (int argc, char **argv)
  39. {
  40. #ifndef __clang__
  41. #ifndef __llvm__
  42. std::set_terminate(__gnu_cxx::__verbose_terminate_handler);
  43. #endif
  44. #endif
  45. Config conf ( argc, argv );
  46. string resultssetting = conf.gS("main", "results", "results.txt" );
  47. int graphtype = conf.gI("main", "graphtype", 0 );
  48. string outfn = conf.gS("main", "out", "");
  49. bool displayGraph = conf.gB("main", "display", false );
  50. bool plainMode = conf.gB("main", "plain", false );
  51. vector<string> submatches;
  52. vector<string> resultsfiles;
  53. if ( StringTools::regexMatch ( resultssetting, "^list:(.+)", submatches ) )
  54. {
  55. if ( !plainMode )
  56. cerr << "reading file list" << endl;
  57. string listfn = submatches[1];
  58. ifstream ifs ( listfn.c_str(), ios::in );
  59. if ( ifs.bad() )
  60. fthrow(IOException, "Unable to open " << listfn << "." );
  61. char buf[1024];
  62. while ( !ifs.eof() )
  63. {
  64. ifs.getline(buf, 1024);
  65. if ( strlen(buf) >= 1 )
  66. resultsfiles.push_back ( StringTools::chomp(string(buf)) );
  67. }
  68. ifs.close();
  69. } else {
  70. StringTools::split ( resultssetting, ',', resultsfiles );
  71. }
  72. for ( vector<string>::const_iterator i = resultsfiles.begin();
  73. i != resultsfiles.end(); i++ )
  74. {
  75. vector<pair<double, int> > results;
  76. string resultsfn = *i;
  77. if ( !plainMode )
  78. fprintf (stderr, "file: %s\n", resultsfn.c_str() );
  79. readResults ( resultsfn, results );
  80. if ( !plainMode )
  81. fprintf (stderr, "results.size() = %d\n", (int)results.size() );
  82. int count_positives = 0;
  83. for ( vector<pair<double, int> >::const_iterator i = results.begin();
  84. i != results.end(); i++ )
  85. if ( i->second == 1 ) count_positives++;
  86. vector<double> thresholds;
  87. vector<double> x;
  88. vector<double> y;
  89. LocalizationAnalysis la;
  90. double areaMeasure = 0.0;
  91. if ( (graphtype == 0) || (graphtype == 2) )
  92. {
  93. la.calcRecallPrecisionCurve ( results, count_positives, thresholds, x, y );
  94. fprintf (stderr, "average precision (11-point): %f\n", la.calcAveragePrecision( x, y ) );
  95. areaMeasure = la.calcAveragePrecisionPrecise( x, y );
  96. if ( !plainMode )
  97. fprintf (stderr, "average precision (precise): %f\n", areaMeasure );
  98. } else {
  99. la.calcROCCurve ( results, count_positives, results.size() - count_positives, thresholds, x, y );
  100. areaMeasure = la.calcAreaUnderROC( x, y );
  101. if ( !plainMode )
  102. fprintf (stderr, "area under the ROC curve: %f\n", areaMeasure );
  103. }
  104. cerr << resultsfn << " " << areaMeasure << endl;
  105. if ( displayGraph ) {
  106. Gnuplot gp;
  107. gp.set_xrange ( 0, 1 );
  108. gp.set_yrange ( 0, 1 );
  109. gp.set_style ( "lines" );
  110. if ( graphtype == 0 )
  111. gp.plot_xy ( x, y, "Recall/Precision");
  112. else if ( graphtype == 2 ) {
  113. for ( uint i = 0 ; i < x.size() ; i++ )
  114. {
  115. double recall = x[i];
  116. x[i] = 1.0 - y[i];
  117. y[i] = recall;
  118. }
  119. gp.plot_xy ( x, y, "1-Precision/Recall");
  120. } else
  121. gp.plot_xy ( x, y, "TP/FP (ROC)");
  122. getchar();
  123. }
  124. if ( outfn.size() > 0 ) {
  125. ofstream ofs ( outfn.c_str(), ios::out );
  126. if ( !ofs.good() )
  127. fthrow(IOException, "Unable to write graph results to " << outfn << "." );
  128. for ( unsigned int i = 0 ; i < x.size(); i++ )
  129. ofs << x[i] << " " << y[i] << " " << thresholds[i] << endl;
  130. ofs.close();
  131. }
  132. }
  133. return 0;
  134. }