calcCurves.cpp 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. /**
  2. * @file analyseLocalization.cpp
  3. * @brief print recall/precision curves etc
  4. * @author Erik Rodner
  5. * @date 09/01/2008
  6. */
  7. #include "core/vector/VectorT.h"
  8. #include "core/vector/MatrixT.h"
  9. #include "core/image/ImageT.h"
  10. #include <core/basics/Config.h>
  11. #include <vislearning/baselib/Gnuplot.h>
  12. #include <core/basics/StringTools.h>
  13. #include <vislearning/cbaselib/LocalizationAnalysis.h>
  14. using namespace OBJREC;
  15. using namespace NICE;
  16. using namespace std;
  17. void readResults ( const string & resultsfn, vector<pair<double, int> > & results )
  18. {
  19. ifstream ifs ( resultsfn.c_str(), ios::in );
  20. if ( !ifs.good() )
  21. fthrow(IOException, "Unable to open " << resultsfn << "." );
  22. while ( !ifs.eof() )
  23. {
  24. char buf [1024];
  25. ifs.getline( buf, 1024 );
  26. if ( !ifs.good() ) break;
  27. Vector entry;
  28. StringTools::splitVector ( string(buf), ',', entry );
  29. if ( entry.size() != 4 )
  30. fthrow(IOException, "Parse error in " << resultsfn << "." );
  31. results.push_back ( pair<double, int> ( entry[2], (int)entry[1] ) );
  32. }
  33. ifs.close();
  34. }
  35. /**
  36. print recall/precision curves etc
  37. */
  38. int main (int argc, char **argv)
  39. {
  40. std::set_terminate(__gnu_cxx::__verbose_terminate_handler);
  41. Config conf ( argc, argv );
  42. string resultssetting = conf.gS("main", "results", "results.txt" );
  43. int graphtype = conf.gI("main", "graphtype", 0 );
  44. string outfn = conf.gS("main", "out", "");
  45. bool displayGraph = conf.gB("main", "display", false );
  46. bool plainMode = conf.gB("main", "plain", false );
  47. vector<string> submatches;
  48. vector<string> resultsfiles;
  49. if ( StringTools::regexMatch ( resultssetting, "^list:(.+)", submatches ) )
  50. {
  51. if ( !plainMode )
  52. cerr << "reading file list" << endl;
  53. string listfn = submatches[1];
  54. ifstream ifs ( listfn.c_str(), ios::in );
  55. if ( ifs.bad() )
  56. fthrow(IOException, "Unable to open " << listfn << "." );
  57. char buf[1024];
  58. while ( !ifs.eof() )
  59. {
  60. ifs.getline(buf, 1024);
  61. if ( strlen(buf) >= 1 )
  62. resultsfiles.push_back ( StringTools::chomp(string(buf)) );
  63. }
  64. ifs.close();
  65. } else {
  66. StringTools::split ( resultssetting, ',', resultsfiles );
  67. }
  68. for ( vector<string>::const_iterator i = resultsfiles.begin();
  69. i != resultsfiles.end(); i++ )
  70. {
  71. vector<pair<double, int> > results;
  72. string resultsfn = *i;
  73. if ( !plainMode )
  74. fprintf (stderr, "file: %s\n", resultsfn.c_str() );
  75. readResults ( resultsfn, results );
  76. if ( !plainMode )
  77. fprintf (stderr, "results.size() = %d\n", (int)results.size() );
  78. int count_positives = 0;
  79. for ( vector<pair<double, int> >::const_iterator i = results.begin();
  80. i != results.end(); i++ )
  81. if ( i->second == 1 ) count_positives++;
  82. vector<double> thresholds;
  83. vector<double> x;
  84. vector<double> y;
  85. LocalizationAnalysis la;
  86. double areaMeasure = 0.0;
  87. if ( (graphtype == 0) || (graphtype == 2) )
  88. {
  89. la.calcRecallPrecisionCurve ( results, count_positives, thresholds, x, y );
  90. fprintf (stderr, "average precision (11-point): %f\n", la.calcAveragePrecision( x, y ) );
  91. areaMeasure = la.calcAveragePrecisionPrecise( x, y );
  92. if ( !plainMode )
  93. fprintf (stderr, "average precision (precise): %f\n", areaMeasure );
  94. } else {
  95. la.calcROCCurve ( results, count_positives, results.size() - count_positives, thresholds, x, y );
  96. areaMeasure = la.calcAreaUnderROC( x, y );
  97. if ( !plainMode )
  98. fprintf (stderr, "area under the ROC curve: %f\n", areaMeasure );
  99. }
  100. cerr << resultsfn << " " << areaMeasure << endl;
  101. if ( displayGraph ) {
  102. Gnuplot gp;
  103. gp.set_xrange ( 0, 1 );
  104. gp.set_yrange ( 0, 1 );
  105. gp.set_style ( "lines" );
  106. if ( graphtype == 0 )
  107. gp.plot_xy ( x, y, "Recall/Precision");
  108. else if ( graphtype == 2 ) {
  109. for ( uint i = 0 ; i < x.size() ; i++ )
  110. {
  111. double recall = x[i];
  112. x[i] = 1.0 - y[i];
  113. y[i] = recall;
  114. }
  115. gp.plot_xy ( x, y, "1-Precision/Recall");
  116. } else
  117. gp.plot_xy ( x, y, "TP/FP (ROC)");
  118. getchar();
  119. }
  120. if ( outfn.size() > 0 ) {
  121. ofstream ofs ( outfn.c_str(), ios::out );
  122. if ( !ofs.good() )
  123. fthrow(IOException, "Unable to write graph results to " << outfn << "." );
  124. for ( unsigned int i = 0 ; i < x.size(); i++ )
  125. ofs << x[i] << " " << y[i] << " " << thresholds[i] << endl;
  126. ofs.close();
  127. }
  128. }
  129. return 0;
  130. }