calcCurves.cpp 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. /**
  2. * @file analyseLocalization.cpp
  3. * @brief print recall/precision curves etc
  4. * @author Erik Rodner
  5. * @date 09/01/2008
  6. */
  7. #include <vislearning/nice_nonvis.h>
  8. #include <core/basics/Config.h>
  9. #include <vislearning/baselib/Gnuplot.h>
  10. #include <core/basics/StringTools.h>
  11. #include <vislearning/cbaselib/LocalizationAnalysis.h>
  12. using namespace OBJREC;
  13. using namespace NICE;
  14. using namespace std;
  15. void readResults ( const string & resultsfn, vector<pair<double, int> > & results )
  16. {
  17. ifstream ifs ( resultsfn.c_str(), ios::in );
  18. if ( !ifs.good() )
  19. fthrow(IOException, "Unable to open " << resultsfn << "." );
  20. while ( !ifs.eof() )
  21. {
  22. char buf [1024];
  23. ifs.getline( buf, 1024 );
  24. if ( !ifs.good() ) break;
  25. Vector entry;
  26. StringTools::splitVector ( string(buf), ',', entry );
  27. if ( entry.size() != 4 )
  28. fthrow(IOException, "Parse error in " << resultsfn << "." );
  29. results.push_back ( pair<double, int> ( entry[2], (int)entry[1] ) );
  30. }
  31. ifs.close();
  32. }
  33. /**
  34. print recall/precision curves etc
  35. */
  36. int main (int argc, char **argv)
  37. {
  38. std::set_terminate(__gnu_cxx::__verbose_terminate_handler);
  39. Config conf ( argc, argv );
  40. string resultssetting = conf.gS("main", "results", "results.txt" );
  41. int graphtype = conf.gI("main", "graphtype", 0 );
  42. string outfn = conf.gS("main", "out", "");
  43. bool displayGraph = conf.gB("main", "display", false );
  44. bool plainMode = conf.gB("main", "plain", false );
  45. vector<string> submatches;
  46. vector<string> resultsfiles;
  47. if ( StringTools::regexMatch ( resultssetting, "^list:(.+)", submatches ) )
  48. {
  49. if ( !plainMode )
  50. cerr << "reading file list" << endl;
  51. string listfn = submatches[1];
  52. ifstream ifs ( listfn.c_str(), ios::in );
  53. if ( ifs.bad() )
  54. fthrow(IOException, "Unable to open " << listfn << "." );
  55. char buf[1024];
  56. while ( !ifs.eof() )
  57. {
  58. ifs.getline(buf, 1024);
  59. if ( strlen(buf) >= 1 )
  60. resultsfiles.push_back ( StringTools::chomp(string(buf)) );
  61. }
  62. ifs.close();
  63. } else {
  64. StringTools::split ( resultssetting, ',', resultsfiles );
  65. }
  66. for ( vector<string>::const_iterator i = resultsfiles.begin();
  67. i != resultsfiles.end(); i++ )
  68. {
  69. vector<pair<double, int> > results;
  70. string resultsfn = *i;
  71. if ( !plainMode )
  72. fprintf (stderr, "file: %s\n", resultsfn.c_str() );
  73. readResults ( resultsfn, results );
  74. if ( !plainMode )
  75. fprintf (stderr, "results.size() = %d\n", (int)results.size() );
  76. int count_positives = 0;
  77. for ( vector<pair<double, int> >::const_iterator i = results.begin();
  78. i != results.end(); i++ )
  79. if ( i->second == 1 ) count_positives++;
  80. vector<double> thresholds;
  81. vector<double> x;
  82. vector<double> y;
  83. LocalizationAnalysis la;
  84. double areaMeasure = 0.0;
  85. if ( (graphtype == 0) || (graphtype == 2) )
  86. {
  87. la.calcRecallPrecisionCurve ( results, count_positives, thresholds, x, y );
  88. fprintf (stderr, "average precision (11-point): %f\n", la.calcAveragePrecision( x, y ) );
  89. areaMeasure = la.calcAveragePrecisionPrecise( x, y );
  90. if ( !plainMode )
  91. fprintf (stderr, "average precision (precise): %f\n", areaMeasure );
  92. } else {
  93. la.calcROCCurve ( results, count_positives, results.size() - count_positives, thresholds, x, y );
  94. areaMeasure = la.calcAreaUnderROC( x, y );
  95. if ( !plainMode )
  96. fprintf (stderr, "area under the ROC curve: %f\n", areaMeasure );
  97. }
  98. cerr << resultsfn << " " << areaMeasure << endl;
  99. if ( displayGraph ) {
  100. Gnuplot gp;
  101. gp.set_xrange ( 0, 1 );
  102. gp.set_yrange ( 0, 1 );
  103. gp.set_style ( "lines" );
  104. if ( graphtype == 0 )
  105. gp.plot_xy ( x, y, "Recall/Precision");
  106. else if ( graphtype == 2 ) {
  107. for ( uint i = 0 ; i < x.size() ; i++ )
  108. {
  109. double recall = x[i];
  110. x[i] = 1.0 - y[i];
  111. y[i] = recall;
  112. }
  113. gp.plot_xy ( x, y, "1-Precision/Recall");
  114. } else
  115. gp.plot_xy ( x, y, "TP/FP (ROC)");
  116. getchar();
  117. }
  118. if ( outfn.size() > 0 ) {
  119. ofstream ofs ( outfn.c_str(), ios::out );
  120. if ( !ofs.good() )
  121. fthrow(IOException, "Unable to write graph results to " << outfn << "." );
  122. for ( unsigned int i = 0 ; i < x.size(); i++ )
  123. ofs << x[i] << " " << y[i] << " " << thresholds[i] << endl;
  124. ofs.close();
  125. }
  126. }
  127. return 0;
  128. }