validation_function.m 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. function ret = validation_function(dec, labels);
  2. labels = (labels >= 0) - (labels < 0);
  3. valid_function = @(dec, labels) auc(dec, labels);
  4. ret = valid_function(dec, labels);
  5. %precision(dec, labels);
  6. %recall(dec, labels);
  7. %fscore(dec, labels);
  8. %bac(dec, labels);
  9. %auc(dec, labels);
  10. %accuracy(dec, labels);
  11. function ret = precision(dec, label)
  12. tp = sum(label == 1 & dec >= 0);
  13. tp_fp = sum(dec >= 0);
  14. if tp_fp == 0;
  15. disp(sprintf('warning: No positive predict label.'));
  16. ret = 0;
  17. else
  18. ret = tp / tp_fp;
  19. end
  20. disp(sprintf('Precision = %g%% (%d/%d)', 100.0 * ret, tp, tp_fp));
  21. function ret = recall(dec, label)
  22. tp = sum(label == 1 & dec >= 0);
  23. tp_fn = sum(label == 1);
  24. if tp_fn == 0;
  25. disp(sprintf('warning: No postive true label.'));
  26. ret = 0;
  27. else
  28. ret = tp / tp_fn;
  29. end
  30. disp(sprintf('Recall = %g%% (%d/%d)', 100.0 * ret, tp, tp_fn));
  31. function ret = fscore(dec, label)
  32. tp = sum(label == 1 & dec >= 0);
  33. tp_fp = sum(dec >= 0);
  34. tp_fn = sum(label == 1);
  35. if tp_fp == 0;
  36. disp(sprintf('warning: No positive predict label.'));
  37. precision = 0;
  38. else
  39. precision = tp / tp_fp;
  40. end
  41. if tp_fn == 0;
  42. disp(sprintf('warning: No postive true label.'));
  43. recall = 0;
  44. else
  45. recall = tp / tp_fn;
  46. end
  47. if precision + recall == 0;
  48. disp(sprintf('warning: precision + recall = 0.'));
  49. ret = 0;
  50. else
  51. ret = 2 * precision * recall / (precision + recall);
  52. end
  53. disp(sprintf('F-score = %g', ret));
  54. function ret = bac(dec, label)
  55. tp = sum(label == 1 & dec >= 0);
  56. tn = sum(label == -1 & dec < 0);
  57. tp_fn = sum(label == 1);
  58. tn_fp = sum(label == -1);
  59. if tp_fn == 0;
  60. disp(sprintf('warning: No positive true label.'));
  61. sensitivity = 0;
  62. else
  63. sensitivity = tp / tp_fn;
  64. end
  65. if tn_fp == 0;
  66. disp(sprintf('warning: No negative true label.'));
  67. specificity = 0;
  68. else
  69. specificity = tn / tn_fp;
  70. end
  71. ret = (sensitivity + specificity) / 2;
  72. disp(sprintf('BAC = %g', ret));
  73. function ret = auc(dec, label)
  74. [dec idx] = sort(dec, 'descend');
  75. label = label(idx);
  76. tp = cumsum(label == 1);
  77. fp = sum(label == -1);
  78. ret = sum(tp(label == -1));
  79. if tp == 0 | fp == 0;
  80. disp(sprintf('warning: Too few postive true labels or negative true labels'));
  81. ret = 0;
  82. else
  83. ret = ret / tp(end) / fp;
  84. end
  85. %disp(sprintf('AUC = %g', ret));
  86. function ret = accuracy(dec, label)
  87. correct = sum(dec .* label >= 0);
  88. total = length(dec);
  89. ret = correct / total;
  90. disp(sprintf('Accuracy = %g%% (%d/%d)', 100.0 * ret, correct, total));