standalone_metric.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. ################################################################################
  2. # Copyright (c) 2021 ContinualAI. #
  3. # Copyrights licensed under the MIT License. #
  4. # See the accompanying LICENSE file for terms. #
  5. # #
  6. # Date: 01-12-2020 #
  7. # Author(s): Andrea Cossu #
  8. # E-mail: contact@continualai.org #
  9. # Website: avalanche.continualai.org #
  10. ################################################################################
  11. """
  12. This example shows how to use Standalone metrics like `Accuracy`, `Loss`,
  13. `ConfusionMatrix` and others. Please, refer to `eval_plugin` example to
  14. dig deeper into the use of Plugin metrics, which are already integrated
  15. in the Avalanche training and evaluation loops.
  16. """
  17. # import a standalone metric
  18. import torch
  19. from avalanche.evaluation.metrics import Accuracy
  20. # all standalone metrics
  21. from avalanche.evaluation.metrics import Accuracy
  22. # , Loss, # Loss
  23. # Forgetting, # Forgetting
  24. # ConfusionMatrix, # Confusion Matrix
  25. # CPUUsage, # CPU Usage
  26. # DiskUsage, # Disk Usage
  27. # MaxGPU, # Max GPU Usage
  28. # MAC, # Multiply and Accumulate
  29. # MaxRAM, # Max RAM Usage
  30. # ElapsedTime # Timing metrics
  31. # create an instance of the standalone Accuracy metric
  32. # initial accuracy is 0
  33. acc_metric = Accuracy()
  34. print("Initial Accuracy: ", acc_metric.result()) # output 0
  35. # update method allows to keep the running average accuracy
  36. # result method returns the current average accuracy
  37. real_y = torch.tensor([1, 2]).long()
  38. predicted_y = torch.tensor([1, 0]).float()
  39. acc_metric.update(real_y, predicted_y)
  40. acc = acc_metric.result()
  41. print("Average Accuracy: ", acc) # output 0.5
  42. # you can continue to update the metric with new values
  43. predicted_y = torch.tensor([1, 2]).float()
  44. acc_metric.update(real_y, predicted_y)
  45. acc = acc_metric.result()
  46. print("Average Accuracy: ", acc) # output 0.75
  47. # reset accuracy to 0
  48. acc_metric.reset()
  49. print("After reset: ", acc_metric.result()) # output 0