test_metrics.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603
  1. """ Metrics Tests"""
  2. import unittest
  3. import torch
  4. from torch.utils.data import TensorDataset
  5. import numpy as np
  6. import random
  7. import pickle
  8. import os
  9. from sklearn.model_selection import train_test_split
  10. from sklearn.datasets import make_classification
  11. from copy import deepcopy
  12. from avalanche.evaluation.metrics import Accuracy, Loss, ConfusionMatrix, \
  13. DiskUsage, MAC, CPUUsage, MaxGPU, MaxRAM, Mean, Sum, ElapsedTime, \
  14. Forgetting, ForwardTransfer
  15. from avalanche.training.strategies.base_strategy import BaseStrategy
  16. import pathlib
  17. from torch.nn import CrossEntropyLoss
  18. from torch.optim import SGD
  19. from avalanche.benchmarks.utils import AvalancheTensorDataset, \
  20. AvalancheDatasetType
  21. from avalanche.benchmarks import nc_benchmark, dataset_benchmark
  22. from avalanche.evaluation.metrics import forgetting_metrics, \
  23. accuracy_metrics, loss_metrics, cpu_usage_metrics, timing_metrics, \
  24. ram_usage_metrics, disk_usage_metrics, MAC_metrics, \
  25. bwt_metrics, confusion_matrix_metrics, forward_transfer_metrics
  26. from avalanche.models import SimpleMLP
  27. from avalanche.logging import TextLogger
  28. from avalanche.training.plugins import EvaluationPlugin
  29. #################################
  30. #################################
  31. # STANDALONE METRIC TEST #
  32. #################################
  33. #################################
  34. class GeneralMetricTests(unittest.TestCase):
  35. def setUp(self) -> None:
  36. self.batch_size = 3
  37. self.input_size = 10
  38. self.n_classes = 3
  39. self.n_tasks = 2
  40. self.out = torch.randn(self.batch_size, self.input_size)
  41. self.y = torch.randint(0, self.n_classes, (self.batch_size,))
  42. self.task_labels = torch.randint(0, self.n_tasks, (self.batch_size,))
  43. def test_accuracy(self):
  44. metric = Accuracy()
  45. self.assertEqual(metric.result(), {})
  46. metric.update(self.out, self.y, 0)
  47. self.assertLessEqual(metric.result(0)[0], 1)
  48. self.assertGreaterEqual(metric.result(0)[0], 0)
  49. metric.reset()
  50. self.assertEqual(metric.result(), {})
  51. def test_accuracy_task_per_pattern(self):
  52. metric = Accuracy()
  53. self.assertEqual(metric.result(), {})
  54. metric.update(self.out, self.y, self.task_labels)
  55. out = metric.result()
  56. for k, v in out.items():
  57. self.assertIn(k, self.task_labels.tolist())
  58. self.assertLessEqual(v, 1)
  59. self.assertGreaterEqual(v, 0)
  60. metric.reset()
  61. self.assertEqual(metric.result(), {})
  62. def test_loss(self):
  63. metric = Loss()
  64. self.assertEqual(metric.result(0)[0], 0)
  65. metric.update(torch.tensor(1.), self.batch_size, 0)
  66. self.assertGreaterEqual(metric.result(0)[0], 0)
  67. metric.reset()
  68. self.assertEqual(metric.result(), {})
  69. def test_loss_multi_task(self):
  70. metric = Loss()
  71. self.assertEqual(metric.result(), {})
  72. metric.update(torch.tensor(1.), 1, 0)
  73. metric.update(torch.tensor(2.), 1, 1)
  74. out = metric.result()
  75. for k, v in out.items():
  76. self.assertIn(k, [0, 1])
  77. if k == 0:
  78. self.assertEqual(v, 1)
  79. else:
  80. self.assertEqual(v, 2)
  81. metric.reset()
  82. self.assertEqual(metric.result(), {})
  83. def test_cm(self):
  84. metric = ConfusionMatrix()
  85. cm = metric.result()
  86. self.assertTrue((cm == 0).all().item())
  87. metric.update(self.y, self.out)
  88. cm = metric.result()
  89. self.assertTrue((cm >= 0).all().item())
  90. metric.reset()
  91. cm = metric.result()
  92. self.assertTrue((cm == 0).all().item())
  93. def test_ram(self):
  94. metric = MaxRAM()
  95. self.assertEqual(metric.result(), 0)
  96. metric.start_thread() # start thread
  97. self.assertGreaterEqual(metric.result(), 0)
  98. metric.stop_thread() # stop thread
  99. metric.reset() # stop thread
  100. self.assertEqual(metric.result(), 0)
  101. def test_gpu(self):
  102. if torch.cuda.is_available():
  103. metric = MaxGPU(0)
  104. self.assertEqual(metric.result(), 0)
  105. metric.start_thread() # start thread
  106. self.assertGreaterEqual(metric.result(), 0)
  107. metric.stop_thread() # stop thread
  108. metric.reset() # stop thread
  109. self.assertEqual(metric.result(), 0)
  110. def test_cpu(self):
  111. metric = CPUUsage()
  112. self.assertEqual(metric.result(), 0)
  113. metric.update()
  114. self.assertGreaterEqual(metric.result(), 0)
  115. metric.reset()
  116. self.assertEqual(metric.result(), 0)
  117. def test_disk(self):
  118. metric = DiskUsage()
  119. self.assertEqual(metric.result(), 0)
  120. metric.update()
  121. self.assertGreaterEqual(metric.result(), 0)
  122. metric.reset()
  123. self.assertEqual(metric.result(), 0)
  124. def test_timing(self):
  125. metric = ElapsedTime()
  126. self.assertEqual(metric.result(), 0)
  127. metric.update() # need two update calls
  128. self.assertEqual(metric.result(), 0)
  129. metric.update()
  130. self.assertGreaterEqual(metric.result(), 0)
  131. metric.reset()
  132. self.assertEqual(metric.result(), 0)
  133. def test_mac(self):
  134. model = torch.nn.Linear(self.input_size, 2)
  135. metric = MAC()
  136. self.assertEqual(metric.result(), 0)
  137. metric.update(model, self.out)
  138. self.assertGreaterEqual(metric.result(), 0)
  139. def test_mean(self):
  140. metric = Mean()
  141. self.assertEqual(metric.result(), 0)
  142. metric.update(0.1, 1)
  143. self.assertEqual(metric.result(), 0.1)
  144. metric.reset()
  145. self.assertEqual(metric.result(), 0)
  146. def test_sum(self):
  147. metric = Sum()
  148. self.assertEqual(metric.result(), 0)
  149. metric.update(5)
  150. self.assertEqual(metric.result(), 5)
  151. metric.reset()
  152. self.assertEqual(metric.result(), 0)
  153. def test_forgetting(self):
  154. metric = Forgetting()
  155. f = metric.result()
  156. self.assertEqual(f, {})
  157. f = metric.result(k=0)
  158. self.assertIsNone(f)
  159. metric.update(0, 1, initial=True)
  160. f = metric.result(k=0)
  161. self.assertIsNone(f)
  162. metric.update(0, 0.4)
  163. f = metric.result(k=0)
  164. self.assertEqual(f, 0.6)
  165. metric.reset()
  166. self.assertEqual(metric.result(), {})
  167. def test_forward_transfer(self):
  168. metric = ForwardTransfer()
  169. f = metric.result()
  170. self.assertEqual(f, {})
  171. f = metric.result(k=0)
  172. self.assertIsNone(f)
  173. metric.update(0, 1, initial=True)
  174. f = metric.result(k=0)
  175. self.assertIsNone(f)
  176. metric.update(0, 0.4)
  177. f = metric.result(k=0)
  178. self.assertEqual(f, -0.6)
  179. metric.reset()
  180. self.assertEqual(metric.result(), {})
  181. #################################
  182. #################################
  183. # PLUGIN METRIC TEST #
  184. #################################
  185. #################################
  186. DEVICE = 'cpu'
  187. DELTA = 0.01
  188. def filter_dict(d, name):
  189. out = {}
  190. for k, v in sorted(d.items()):
  191. if name in k:
  192. out[k] = deepcopy(v)
  193. return out
  194. class PluginMetricTests(unittest.TestCase):
  195. @classmethod
  196. def setUpClass(cls) -> None:
  197. torch.manual_seed(0)
  198. np.random.seed(0)
  199. random.seed(0)
  200. n_samples_per_class = 100
  201. dataset = make_classification(
  202. n_samples=6 * n_samples_per_class,
  203. n_classes=6,
  204. n_features=4, n_informative=4, n_redundant=0)
  205. X = torch.from_numpy(dataset[0]).float()
  206. y = torch.from_numpy(dataset[1]).long()
  207. train_X, test_X, train_y, test_y = train_test_split(
  208. X, y, train_size=0.5, shuffle=True, stratify=y)
  209. tr_d = TensorDataset(train_X, train_y)
  210. ts_d = TensorDataset(test_X, test_y)
  211. benchmark = nc_benchmark(train_dataset=tr_d, test_dataset=ts_d,
  212. n_experiences=3,
  213. task_labels=False, shuffle=False, seed=0)
  214. model = SimpleMLP(input_size=4, num_classes=benchmark.n_classes)
  215. f = open('log.txt', 'w')
  216. text_logger = TextLogger(f)
  217. eval_plugin = EvaluationPlugin(
  218. accuracy_metrics(
  219. minibatch=True, epoch=True, epoch_running=True,
  220. experience=True, stream=True, trained_experience=True),
  221. loss_metrics(minibatch=True, epoch=True, epoch_running=True,
  222. experience=True, stream=True),
  223. forgetting_metrics(experience=True, stream=True),
  224. forward_transfer_metrics(experience=True, stream=True),
  225. confusion_matrix_metrics(num_classes=10, save_image=False,
  226. normalize='all', stream=True),
  227. bwt_metrics(experience=True, stream=True),
  228. cpu_usage_metrics(
  229. minibatch=True, epoch=True, epoch_running=True,
  230. experience=True, stream=True),
  231. timing_metrics(
  232. minibatch=True, epoch=True, epoch_running=True,
  233. experience=True, stream=True),
  234. ram_usage_metrics(
  235. every=0.5, minibatch=True, epoch=True,
  236. experience=True, stream=True),
  237. disk_usage_metrics(
  238. minibatch=True, epoch=True, experience=True, stream=True),
  239. MAC_metrics(
  240. minibatch=True, epoch=True, experience=True),
  241. loggers=[text_logger],
  242. collect_all=True) # collect all metrics (set to True by default)
  243. cl_strategy = BaseStrategy(
  244. model, SGD(model.parameters(), lr=0.001, momentum=0.9),
  245. CrossEntropyLoss(), train_mb_size=10, train_epochs=2,
  246. eval_mb_size=10, device=DEVICE, evaluator=eval_plugin,
  247. eval_every=1)
  248. for i, experience in enumerate(benchmark.train_stream):
  249. cl_strategy.train(experience,
  250. eval_streams=[benchmark.test_stream],
  251. shuffle=False)
  252. cl_strategy.eval(benchmark.test_stream)
  253. cls.all_metrics = cl_strategy.evaluator.get_all_metrics()
  254. f.close()
  255. # # Uncomment me to regenerate the reference metrics. Make sure
  256. # # the old tests were passing for all unchanged metrics
  257. # with open(os.path.join(pathlib.Path(__file__).parent.absolute(),
  258. # 'target_metrics',
  259. # 'sit.pickle'), 'wb') as f:
  260. # pickle.dump(dict(cls.all_metrics), f,
  261. # protocol=4)
  262. with open(os.path.join(pathlib.Path(__file__).parent.absolute(),
  263. 'target_metrics',
  264. 'sit.pickle'), 'rb') as f:
  265. cls.ref = pickle.load(f)
  266. def metric_check(self, name):
  267. d = filter_dict(self.all_metrics, name)
  268. d_ref = filter_dict(self.ref, name)
  269. for (k, v), (kref, vref) in zip(d.items(), d_ref.items()):
  270. self.assertEqual(k, kref)
  271. init = -1
  272. for el in v[0]:
  273. self.assertTrue(el > init)
  274. init = el
  275. for el, elref in zip(v[0], vref[0]):
  276. self.assertEqual(el, elref)
  277. for el, elref in zip(v[1], vref[1]):
  278. self.assertAlmostEqual(el, elref, delta=DELTA)
  279. def test_accuracy(self):
  280. self.metric_check('Acc')
  281. def test_loss(self):
  282. self.metric_check('Loss')
  283. def test_mac(self):
  284. self.metric_check('MAC')
  285. def test_forgetting_bwt(self):
  286. df = filter_dict(self.all_metrics, 'Forgetting')
  287. db = filter_dict(self.all_metrics, 'BWT')
  288. self.metric_check('Forgetting')
  289. self.metric_check('BWT')
  290. for (kf, vf), (kb, vb) in zip(df.items(), db.items()):
  291. self.assertTrue(
  292. (kf.startswith('Stream') and kb.startswith('Stream')) or
  293. (kf.startswith('Experience') and kb.startswith('Experience')))
  294. for f, b in zip(vf[1], vb[1]):
  295. self.assertEqual(f, -b)
  296. def test_fwt(self):
  297. self.metric_check('ForwardTransfer')
  298. def test_cm(self):
  299. d = filter_dict(self.all_metrics, 'ConfusionMatrix')
  300. d_ref = filter_dict(self.ref, 'ConfusionMatrix')
  301. for (k, v), (kref, vref) in zip(d.items(), d_ref.items()):
  302. self.assertEqual(k, kref)
  303. for el, elref in zip(v[0], vref[0]):
  304. self.assertEqual(el, elref)
  305. for el, elref in zip(v[1], vref[1]):
  306. self.assertTrue((el == elref).all())
  307. class PluginMetricMultiTaskTests(unittest.TestCase):
  308. @classmethod
  309. def setUpClass(cls) -> None:
  310. torch.manual_seed(0)
  311. np.random.seed(0)
  312. random.seed(0)
  313. n_samples_per_class = 100
  314. dataset = make_classification(
  315. n_samples=6 * n_samples_per_class,
  316. n_classes=6,
  317. n_features=4, n_informative=4, n_redundant=0)
  318. X = torch.from_numpy(dataset[0]).float()
  319. y = torch.from_numpy(dataset[1]).long()
  320. train_X, test_X, train_y, test_y = train_test_split(
  321. X, y, train_size=0.5, shuffle=True, stratify=y)
  322. tr_d = TensorDataset(train_X, train_y)
  323. ts_d = TensorDataset(test_X, test_y)
  324. benchmark = nc_benchmark(train_dataset=tr_d, test_dataset=ts_d,
  325. n_experiences=3,
  326. task_labels=True, shuffle=False, seed=0)
  327. model = SimpleMLP(input_size=4, num_classes=benchmark.n_classes)
  328. f = open('log.txt', 'w')
  329. text_logger = TextLogger(f)
  330. eval_plugin = EvaluationPlugin(
  331. accuracy_metrics(
  332. minibatch=True, epoch=True, epoch_running=True,
  333. experience=True, stream=True, trained_experience=True),
  334. loss_metrics(minibatch=True, epoch=True, epoch_running=True,
  335. experience=True, stream=True),
  336. forgetting_metrics(experience=True, stream=True),
  337. confusion_matrix_metrics(num_classes=6, save_image=False,
  338. normalize='all', stream=True),
  339. bwt_metrics(experience=True, stream=True),
  340. forward_transfer_metrics(experience=True, stream=True),
  341. cpu_usage_metrics(
  342. minibatch=True, epoch=True, epoch_running=True,
  343. experience=True, stream=True),
  344. timing_metrics(
  345. minibatch=True, epoch=True, epoch_running=True,
  346. experience=True, stream=True),
  347. ram_usage_metrics(
  348. every=0.5, minibatch=True, epoch=True,
  349. experience=True, stream=True),
  350. disk_usage_metrics(
  351. minibatch=True, epoch=True, experience=True, stream=True),
  352. MAC_metrics(
  353. minibatch=True, epoch=True, experience=True),
  354. loggers=[text_logger],
  355. collect_all=True) # collect all metrics (set to True by default)
  356. cl_strategy = BaseStrategy(
  357. model, SGD(model.parameters(), lr=0.001, momentum=0.9),
  358. CrossEntropyLoss(), train_mb_size=10, train_epochs=2,
  359. eval_mb_size=10, device=DEVICE, evaluator=eval_plugin,
  360. eval_every=1)
  361. for i, experience in enumerate(benchmark.train_stream):
  362. cl_strategy.train(experience,
  363. eval_streams=[benchmark.test_stream],
  364. shuffle=False)
  365. cl_strategy.eval(benchmark.test_stream)
  366. cls.all_metrics = cl_strategy.evaluator.get_all_metrics()
  367. f.close()
  368. # # Uncomment me to regenerate the reference metrics. Make sure
  369. # # the old tests were passing for all unchanged metrics
  370. # with open(os.path.join(pathlib.Path(__file__).parent.absolute(),
  371. # 'target_metrics',
  372. # 'mt.pickle'), 'wb') as f:
  373. # pickle.dump(dict(cls.all_metrics), f,
  374. # protocol=4)
  375. with open(os.path.join(pathlib.Path(__file__).parent.absolute(),
  376. 'target_metrics',
  377. 'mt.pickle'), 'rb') as f:
  378. cls.ref = pickle.load(f)
  379. def metric_check(self, name):
  380. d = filter_dict(self.all_metrics, name)
  381. d_ref = filter_dict(self.ref, name)
  382. for (k, v), (kref, vref) in zip(d.items(), d_ref.items()):
  383. self.assertEqual(k, kref)
  384. init = -1
  385. for el in v[0]:
  386. self.assertTrue(el > init)
  387. init = el
  388. for el, elref in zip(v[0], vref[0]):
  389. self.assertEqual(el, elref)
  390. for el, elref in zip(v[1], vref[1]):
  391. self.assertAlmostEqual(el, elref, delta=DELTA)
  392. def test_accuracy(self):
  393. self.metric_check('Acc')
  394. def test_loss(self):
  395. self.metric_check('Loss')
  396. def test_mac(self):
  397. self.metric_check('MAC')
  398. def test_fwt(self):
  399. self.metric_check('ForwardTransfer')
  400. def test_forgetting_bwt(self):
  401. df = filter_dict(self.all_metrics, 'Forgetting')
  402. db = filter_dict(self.all_metrics, 'BWT')
  403. self.metric_check('Forgetting')
  404. self.metric_check('BWT')
  405. for (kf, vf), (kb, vb) in zip(df.items(), db.items()):
  406. self.assertTrue(
  407. (kf.startswith('Stream') and kb.startswith('Stream')) or
  408. (kf.startswith('Experience') and kb.startswith('Experience')))
  409. for f, b in zip(vf[1], vb[1]):
  410. self.assertEqual(f, -b)
  411. def test_cm(self):
  412. d = filter_dict(self.all_metrics, 'ConfusionMatrix')
  413. d_ref = filter_dict(self.ref, 'ConfusionMatrix')
  414. for (k, v), (kref, vref) in zip(d.items(), d_ref.items()):
  415. self.assertEqual(k, kref)
  416. for el, elref in zip(v[0], vref[0]):
  417. self.assertEqual(el, elref)
  418. for el, elref in zip(v[1], vref[1]):
  419. self.assertTrue((el == elref).all())
  420. class PluginMetricTaskLabelPerPatternTests(unittest.TestCase):
  421. @classmethod
  422. def setUpClass(cls) -> None:
  423. torch.manual_seed(0)
  424. np.random.seed(0)
  425. random.seed(0)
  426. n_samples_per_class = 100
  427. datasets = []
  428. for i in range(3):
  429. dataset = make_classification(
  430. n_samples=3 * n_samples_per_class,
  431. n_classes=3,
  432. n_features=3, n_informative=3, n_redundant=0)
  433. X = torch.from_numpy(dataset[0]).float()
  434. y = torch.from_numpy(dataset[1]).long()
  435. train_X, test_X, train_y, test_y = train_test_split(
  436. X, y, train_size=0.5, shuffle=True, stratify=y)
  437. datasets.append((train_X, train_y, test_X, test_y))
  438. tr_ds = [AvalancheTensorDataset(
  439. tr_X, tr_y,
  440. dataset_type=AvalancheDatasetType.CLASSIFICATION,
  441. task_labels=torch.randint(0, 3, (150,)).tolist())
  442. for tr_X, tr_y, _, _ in datasets]
  443. ts_ds = [AvalancheTensorDataset(
  444. ts_X, ts_y,
  445. dataset_type=AvalancheDatasetType.CLASSIFICATION,
  446. task_labels=torch.randint(0, 3, (150,)).tolist())
  447. for _, _, ts_X, ts_y in datasets]
  448. benchmark = dataset_benchmark(train_datasets=tr_ds, test_datasets=ts_ds)
  449. model = SimpleMLP(num_classes=3, input_size=3)
  450. f = open('log.txt', 'w')
  451. text_logger = TextLogger(f)
  452. eval_plugin = EvaluationPlugin(
  453. accuracy_metrics(
  454. minibatch=True, epoch=True, epoch_running=True,
  455. experience=True, stream=True, trained_experience=True),
  456. loss_metrics(minibatch=True, epoch=True, epoch_running=True,
  457. experience=True, stream=True),
  458. forgetting_metrics(experience=True, stream=True),
  459. confusion_matrix_metrics(num_classes=3, save_image=False,
  460. normalize='all', stream=True),
  461. bwt_metrics(experience=True, stream=True),
  462. forward_transfer_metrics(experience=True, stream=True),
  463. cpu_usage_metrics(
  464. minibatch=True, epoch=True, epoch_running=True,
  465. experience=True, stream=True),
  466. timing_metrics(
  467. minibatch=True, epoch=True, epoch_running=True,
  468. experience=True, stream=True),
  469. ram_usage_metrics(
  470. every=0.5, minibatch=True, epoch=True,
  471. experience=True, stream=True),
  472. disk_usage_metrics(
  473. minibatch=True, epoch=True, experience=True, stream=True),
  474. MAC_metrics(
  475. minibatch=True, epoch=True, experience=True),
  476. loggers=[text_logger],
  477. collect_all=True) # collect all metrics (set to True by default)
  478. cl_strategy = BaseStrategy(
  479. model, SGD(model.parameters(), lr=0.001, momentum=0.9),
  480. CrossEntropyLoss(), train_mb_size=2, train_epochs=2,
  481. eval_mb_size=2, device=DEVICE,
  482. evaluator=eval_plugin, eval_every=1)
  483. for i, experience in enumerate(benchmark.train_stream):
  484. cl_strategy.train(experience,
  485. eval_streams=[benchmark.test_stream],
  486. shuffle=False)
  487. cl_strategy.eval(benchmark.test_stream)
  488. cls.all_metrics = cl_strategy.evaluator.get_all_metrics()
  489. f.close()
  490. # # Uncomment me to regenerate the reference metrics. Make sure
  491. # # the old tests were passing for all unchanged metrics
  492. # with open(os.path.join(pathlib.Path(__file__).parent.absolute(),
  493. # 'target_metrics',
  494. # 'tpp.pickle'), 'wb') as f:
  495. # pickle.dump(dict(cls.all_metrics), f,
  496. # protocol=4)
  497. with open(os.path.join(pathlib.Path(__file__).parent.absolute(),
  498. 'target_metrics',
  499. 'tpp.pickle'), 'rb') as f:
  500. cls.ref = pickle.load(f)
  501. def metric_check(self, name):
  502. d = filter_dict(self.all_metrics, name)
  503. d_ref = filter_dict(self.ref, name)
  504. for (k, v), (kref, vref) in zip(d.items(), d_ref.items()):
  505. self.assertEqual(k, kref)
  506. init = -1
  507. for el in v[0]:
  508. self.assertTrue(el > init)
  509. init = el
  510. for el, elref in zip(v[0], vref[0]):
  511. self.assertEqual(el, elref)
  512. for el, elref in zip(v[1], vref[1]):
  513. self.assertAlmostEqual(el, elref, delta=DELTA)
  514. def test_accuracy(self):
  515. self.metric_check('Acc')
  516. def test_loss(self):
  517. self.metric_check('Loss')
  518. def test_mac(self):
  519. self.metric_check('MAC')
  520. def test_fwt(self):
  521. self.metric_check('ForwardTransfer')
  522. def test_forgetting_bwt(self):
  523. df = filter_dict(self.all_metrics, 'Forgetting')
  524. db = filter_dict(self.all_metrics, 'BWT')
  525. self.metric_check('Forgetting')
  526. self.metric_check('BWT')
  527. for (kf, vf), (kb, vb) in zip(df.items(), db.items()):
  528. self.assertTrue(
  529. (kf.startswith('Stream') and kb.startswith('Stream')) or
  530. (kf.startswith('Experience') and kb.startswith('Experience')))
  531. for f, b in zip(vf[1], vb[1]):
  532. self.assertEqual(f, -b)
  533. def test_cm(self):
  534. d = filter_dict(self.all_metrics, 'ConfusionMatrix')
  535. d_ref = filter_dict(self.ref, 'ConfusionMatrix')
  536. for (k, v), (kref, vref) in zip(d.items(), d_ref.items()):
  537. self.assertEqual(k, kref)
  538. for el, elref in zip(v[0], vref[0]):
  539. self.assertEqual(el, elref)
  540. for el, elref in zip(v[1], vref[1]):
  541. self.assertTrue((el == elref).all())
  542. if __name__ == '__main__':
  543. unittest.main()