ClassificationResults.cpp 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. /**
  2. * @file ClassificationResults.cpp
  3. // refactor-nice.pl: check this substitution
  4. // old: * @brief vector of ClassificationResult
  5. * @brief std::vector of ClassificationResult
  6. * @author Erik Rodner
  7. * @date 02/13/2008
  8. */
  9. #include "core/image/ImageT.h"
  10. #include "core/vector/VectorT.h"
  11. #include "core/vector/MatrixT.h"
  12. #include <iostream>
  13. #include <fstream>
  14. #include <iomanip>
  15. #include "vislearning/cbaselib/ClassificationResults.h"
  16. #include "vislearning/cbaselib/LocalizationAnalysis.h"
  17. using namespace OBJREC;
  18. using namespace std;
  19. using namespace NICE;
  20. ClassificationResults::ClassificationResults()
  21. {
  22. }
  23. ClassificationResults::~ClassificationResults()
  24. {
  25. }
  26. void ClassificationResults::writeWEKA ( const std::string & filename, int classno ) const
  27. {
  28. ofstream ofs ( filename.c_str(), ios::out );
  29. int instno = 0;
  30. for ( const_iterator i = begin(); i != end() ; i++, instno++ )
  31. {
  32. const ClassificationResult & r = *i;
  33. double confidence = r.scores.get(classno);
  34. ofs << instno << ", " << r.classno_groundtruth << ", " <<setiosflags(ios::fixed)<< setprecision(20)<<confidence << ", " << r.classno << endl;
  35. }
  36. ofs.close();
  37. }
  38. double ClassificationResults::getBinaryClassPerformance ( int type ) const
  39. {
  40. LocalizationAnalysis la;
  41. vector< pair<double, int> > resultsFlat;
  42. uint countPositives = 0;
  43. uint countNegatives = 0;
  44. for ( const_iterator i = begin(); i != end(); i++ )
  45. {
  46. const ClassificationResult & r = *i;
  47. double confidence = r.scores.get(1);
  48. uint classno_groundtruth = r.classno_groundtruth;
  49. resultsFlat.push_back ( pair<double, int> ( confidence, classno_groundtruth ) );
  50. if ( classno_groundtruth == 1 )
  51. countPositives++;
  52. if ( classno_groundtruth == 0 )
  53. countNegatives++;
  54. }
  55. if ( countPositives <= 0 )
  56. fthrow(Exception, "No positive ground truth examples");
  57. if ( countNegatives <= 0 )
  58. fthrow(Exception, "No negative ground truth examples");
  59. vector<double> thresholds, x, y;
  60. if ( type == PERF_AUC )
  61. la.calcROCCurve ( resultsFlat, countPositives, countNegatives, thresholds, x, y );
  62. else
  63. la.calcRecallPrecisionCurve ( resultsFlat, countPositives, thresholds, x, y );
  64. if ( type == PERF_AUC )
  65. return la.calcAreaUnderROC ( x, y );
  66. else if ( type == PERF_AVG_PRECISION_11_POINT )
  67. return la.calcAveragePrecision ( x, y );
  68. else
  69. return la.calcAveragePrecisionPrecise ( x, y );
  70. }