ClassificationResults.cpp 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. /**
  2. * @file ClassificationResults.cpp
  3. // refactor-nice.pl: check this substitution
  4. // old: * @brief vector of ClassificationResult
  5. * @brief std::vector of ClassificationResult
  6. * @author Erik Rodner
  7. * @date 02/13/2008
  8. */
  9. #include <vislearning/nice_nonvis.h>
  10. #include <iostream>
  11. #include <fstream>
  12. #include <iomanip>
  13. #include "vislearning/cbaselib/ClassificationResults.h"
  14. #include "vislearning/cbaselib/LocalizationAnalysis.h"
  15. using namespace OBJREC;
  16. using namespace std;
  17. using namespace NICE;
  18. ClassificationResults::ClassificationResults()
  19. {
  20. }
  21. ClassificationResults::~ClassificationResults()
  22. {
  23. }
  24. void ClassificationResults::writeWEKA ( const std::string & filename, int classno ) const
  25. {
  26. ofstream ofs ( filename.c_str(), ios::out );
  27. int instno = 0;
  28. for ( const_iterator i = begin(); i != end() ; i++, instno++ )
  29. {
  30. const ClassificationResult & r = *i;
  31. double confidence = r.scores.get(classno);
  32. ofs << instno << ", " << r.classno_groundtruth << ", " <<setiosflags(ios::fixed)<< setprecision(20)<<confidence << ", " << r.classno << endl;
  33. }
  34. ofs.close();
  35. }
  36. double ClassificationResults::getBinaryClassPerformance ( int type ) const
  37. {
  38. LocalizationAnalysis la;
  39. vector< pair<double, int> > resultsFlat;
  40. uint countPositives = 0;
  41. uint countNegatives = 0;
  42. for ( const_iterator i = begin(); i != end(); i++ )
  43. {
  44. const ClassificationResult & r = *i;
  45. double confidence = r.scores.get(1);
  46. uint classno_groundtruth = r.classno_groundtruth;
  47. resultsFlat.push_back ( pair<double, int> ( confidence, classno_groundtruth ) );
  48. if ( classno_groundtruth == 1 )
  49. countPositives++;
  50. if ( classno_groundtruth == 0 )
  51. countNegatives++;
  52. }
  53. if ( countPositives <= 0 )
  54. fthrow(Exception, "No positive ground truth examples");
  55. if ( countNegatives <= 0 )
  56. fthrow(Exception, "No negative ground truth examples");
  57. vector<double> thresholds, x, y;
  58. if ( type == PERF_AUC )
  59. la.calcROCCurve ( resultsFlat, countPositives, countNegatives, thresholds, x, y );
  60. else
  61. la.calcRecallPrecisionCurve ( resultsFlat, countPositives, thresholds, x, y );
  62. if ( type == PERF_AUC )
  63. return la.calcAreaUnderROC ( x, y );
  64. else if ( type == PERF_AVG_PRECISION_11_POINT )
  65. return la.calcAveragePrecision ( x, y );
  66. else
  67. return la.calcAveragePrecisionPrecise ( x, y );
  68. }