test_generic_cl_scenario.py 19 KB


  1. import unittest
  2. import weakref
  3. import gc
  4. import torch
  5. from avalanche.benchmarks import dataset_benchmark, \
  6. GenericExperience, GenericCLScenario
  7. from avalanche.benchmarks.utils import AvalancheTensorDataset, \
  8. AvalancheDatasetType
  9. class GenericCLScenarioTests(unittest.TestCase):
  10. def test_classes_in_exp(self):
  11. train_exps = []
  12. tensor_x = torch.rand(200, 3, 28, 28)
  13. tensor_y = torch.randint(0, 70, (200,))
  14. tensor_t = torch.randint(0, 5, (200,))
  15. train_exps.append(AvalancheTensorDataset(
  16. tensor_x, tensor_y, task_labels=tensor_t))
  17. tensor_x = torch.rand(200, 3, 28, 28)
  18. tensor_y = torch.randint(0, 100, (200,))
  19. tensor_t = torch.randint(0, 5, (200,))
  20. train_exps.append(AvalancheTensorDataset(
  21. tensor_x, tensor_y, task_labels=tensor_t))
  22. test_exps = []
  23. test_x = torch.rand(200, 3, 28, 28)
  24. test_y = torch.randint(100, 200, (200,))
  25. test_t = torch.randint(0, 5, (200,))
  26. test_exps.append(AvalancheTensorDataset(
  27. test_x, test_y, task_labels=test_t))
  28. other_stream_exps = []
  29. other_x = torch.rand(200, 3, 28, 28)
  30. other_y = torch.randint(400, 600, (200,))
  31. other_t = torch.randint(0, 5, (200,))
  32. other_stream_exps.append(AvalancheTensorDataset(
  33. other_x, other_y, task_labels=other_t))
  34. benchmark_instance = dataset_benchmark(
  35. train_datasets=train_exps,
  36. test_datasets=test_exps,
  37. other_streams_datasets={'other': other_stream_exps})
  38. train_0_classes = benchmark_instance.classes_in_experience['train'][0]
  39. train_1_classes = benchmark_instance.classes_in_experience['train'][1]
  40. train_0_classes_min = min(train_0_classes)
  41. train_1_classes_min = min(train_1_classes)
  42. train_0_classes_max = max(train_0_classes)
  43. train_1_classes_max = max(train_1_classes)
  44. self.assertGreaterEqual(train_0_classes_min, 0)
  45. self.assertLess(train_0_classes_max, 70)
  46. self.assertGreaterEqual(train_1_classes_min, 0)
  47. self.assertLess(train_1_classes_max, 100)
  48. # Test deprecated behavior
  49. train_0_classes = benchmark_instance.classes_in_experience[0]
  50. train_1_classes = benchmark_instance.classes_in_experience[1]
  51. train_0_classes_min = min(train_0_classes)
  52. train_1_classes_min = min(train_1_classes)
  53. train_0_classes_max = max(train_0_classes)
  54. train_1_classes_max = max(train_1_classes)
  55. self.assertGreaterEqual(train_0_classes_min, 0)
  56. self.assertLess(train_0_classes_max, 70)
  57. self.assertGreaterEqual(train_1_classes_min, 0)
  58. self.assertLess(train_1_classes_max, 100)
  59. # End test deprecated behavior
  60. test_0_classes = benchmark_instance.classes_in_experience['test'][0]
  61. test_0_classes_min = min(test_0_classes)
  62. test_0_classes_max = max(test_0_classes)
  63. self.assertGreaterEqual(test_0_classes_min, 100)
  64. self.assertLess(test_0_classes_max, 200)
  65. other_0_classes = benchmark_instance.classes_in_experience['other'][0]
  66. other_0_classes_min = min(other_0_classes)
  67. other_0_classes_max = max(other_0_classes)
  68. self.assertGreaterEqual(other_0_classes_min, 400)
  69. self.assertLess(other_0_classes_max, 600)
  70. def test_classes_in_this_experience(self):
  71. train_exps = []
  72. tensor_x = torch.rand(200, 3, 28, 28)
  73. tensor_y = torch.randint(0, 70, (200,))
  74. tensor_t = torch.randint(0, 5, (200,))
  75. train_exps.append(AvalancheTensorDataset(
  76. tensor_x, tensor_y, task_labels=tensor_t))
  77. tensor_x = torch.rand(200, 3, 28, 28)
  78. tensor_y = torch.randint(0, 100, (200,))
  79. tensor_t = torch.randint(0, 5, (200,))
  80. train_exps.append(AvalancheTensorDataset(
  81. tensor_x, tensor_y, task_labels=tensor_t))
  82. test_exps = []
  83. test_x = torch.rand(200, 3, 28, 28)
  84. test_y = torch.randint(100, 200, (200,))
  85. test_t = torch.randint(0, 5, (200,))
  86. test_exps.append(AvalancheTensorDataset(
  87. test_x, test_y, task_labels=test_t))
  88. other_stream_exps = []
  89. other_x = torch.rand(200, 3, 28, 28)
  90. other_y = torch.randint(400, 600, (200,))
  91. other_t = torch.randint(0, 5, (200,))
  92. other_stream_exps.append(AvalancheTensorDataset(
  93. other_x, other_y, task_labels=other_t))
  94. benchmark_instance = dataset_benchmark(
  95. train_datasets=train_exps,
  96. test_datasets=test_exps,
  97. other_streams_datasets={'other': other_stream_exps})
  98. train_exp_0: GenericExperience = benchmark_instance.train_stream[0]
  99. train_exp_1: GenericExperience = benchmark_instance.train_stream[1]
  100. train_0_classes = train_exp_0.classes_in_this_experience
  101. train_1_classes = train_exp_1.classes_in_this_experience
  102. train_0_classes_min = min(train_0_classes)
  103. train_1_classes_min = min(train_1_classes)
  104. train_0_classes_max = max(train_0_classes)
  105. train_1_classes_max = max(train_1_classes)
  106. self.assertGreaterEqual(train_0_classes_min, 0)
  107. self.assertLess(train_0_classes_max, 70)
  108. self.assertGreaterEqual(train_1_classes_min, 0)
  109. self.assertLess(train_1_classes_max, 100)
  110. test_exp_0: GenericExperience = benchmark_instance.test_stream[0]
  111. test_0_classes = test_exp_0.classes_in_this_experience
  112. test_0_classes_min = min(test_0_classes)
  113. test_0_classes_max = max(test_0_classes)
  114. self.assertGreaterEqual(test_0_classes_min, 100)
  115. self.assertLess(test_0_classes_max, 200)
  116. other_exp_0: GenericExperience = benchmark_instance.other_stream[0]
  117. other_0_classes = other_exp_0.classes_in_this_experience
  118. other_0_classes_min = min(other_0_classes)
  119. other_0_classes_max = max(other_0_classes)
  120. self.assertGreaterEqual(other_0_classes_min, 400)
  121. self.assertLess(other_0_classes_max, 600)
  122. def test_lazy_benchmark(self):
  123. train_exps, test_exps, other_stream_exps = self._make_tensor_datasets()
  124. def train_gen():
  125. # Lazy generator of the training stream
  126. for dataset in train_exps:
  127. yield dataset
  128. def test_gen():
  129. # Lazy generator of the test stream
  130. for dataset in test_exps:
  131. yield dataset
  132. def other_gen():
  133. # Lazy generator of the "other" stream
  134. for dataset in other_stream_exps:
  135. yield dataset
  136. benchmark_instance = GenericCLScenario(
  137. stream_definitions=dict(
  138. train=((train_gen(), len(train_exps)), [
  139. train_exps[0].targets_task_labels,
  140. train_exps[1].targets_task_labels
  141. ]),
  142. test=((test_gen(), len(test_exps)), [
  143. test_exps[0].targets_task_labels
  144. ]),
  145. other=((other_gen(), len(other_stream_exps)), [
  146. other_stream_exps[0].targets_task_labels
  147. ])))
  148. # --- START: Test classes timeline before first experience ---
  149. current_classes, prev_classes, cumulative_classes, future_classes = \
  150. benchmark_instance.get_classes_timeline(0)
  151. self.assertIsNone(current_classes)
  152. self.assertSetEqual(set(), set(prev_classes))
  153. self.assertIsNone(cumulative_classes)
  154. self.assertIsNone(future_classes)
  155. # --- END: Test classes timeline before first experience ---
  156. train_exp_0: GenericExperience = benchmark_instance.train_stream[0]
  157. # --- START: Test classes timeline at first experience ---
  158. current_classes, prev_classes, cumulative_classes, future_classes = \
  159. benchmark_instance.get_classes_timeline(0)
  160. self.assertSetEqual(set(train_exps[0].targets), set(current_classes))
  161. self.assertSetEqual(set(), set(prev_classes))
  162. self.assertSetEqual(set(train_exps[0].targets), set(cumulative_classes))
  163. self.assertIsNone(future_classes)
  164. current_classes, prev_classes, cumulative_classes, future_classes = \
  165. benchmark_instance.get_classes_timeline(1)
  166. self.assertIsNone(current_classes)
  167. self.assertSetEqual(set(train_exps[0].targets), set(prev_classes))
  168. # None because we didn't load exp 0 yet
  169. self.assertIsNone(cumulative_classes)
  170. self.assertSetEqual(set(), set(future_classes))
  171. # --- END: Test classes timeline at first experience ---
  172. train_exp_1: GenericExperience = benchmark_instance.train_stream[1]
  173. # --- START: Test classes timeline at second experience ---
  174. # Check if get_classes_timeline(0) is consistent
  175. current_classes, prev_classes, cumulative_classes, future_classes = \
  176. benchmark_instance.get_classes_timeline(0)
  177. self.assertSetEqual(set(train_exps[0].targets), set(current_classes))
  178. self.assertSetEqual(set(), set(prev_classes))
  179. self.assertSetEqual(set(train_exps[0].targets), set(cumulative_classes))
  180. # We now have access to future classes!
  181. self.assertSetEqual(set(train_exps[1].targets), set(future_classes))
  182. current_classes, prev_classes, cumulative_classes, future_classes = \
  183. benchmark_instance.get_classes_timeline(1)
  184. self.assertSetEqual(set(train_exps[1].targets), set(current_classes))
  185. self.assertSetEqual(set(train_exps[0].targets), set(prev_classes))
  186. self.assertSetEqual(
  187. set(train_exps[0].targets).union(set(train_exps[1].targets)),
  188. set(cumulative_classes))
  189. self.assertSetEqual(set(), set(future_classes))
  190. # --- END: Test classes timeline at second experience ---
  191. train_0_classes = train_exp_0.classes_in_this_experience
  192. train_1_classes = train_exp_1.classes_in_this_experience
  193. train_0_classes_min = min(train_0_classes)
  194. train_1_classes_min = min(train_1_classes)
  195. train_0_classes_max = max(train_0_classes)
  196. train_1_classes_max = max(train_1_classes)
  197. self.assertGreaterEqual(train_0_classes_min, 0)
  198. self.assertLess(train_0_classes_max, 70)
  199. self.assertGreaterEqual(train_1_classes_min, 0)
  200. self.assertLess(train_1_classes_max, 100)
  201. with self.assertRaises(IndexError):
  202. train_exp_2: GenericExperience = benchmark_instance.train_stream[2]
  203. test_exp_0: GenericExperience = benchmark_instance.test_stream[0]
  204. test_0_classes = test_exp_0.classes_in_this_experience
  205. test_0_classes_min = min(test_0_classes)
  206. test_0_classes_max = max(test_0_classes)
  207. self.assertGreaterEqual(test_0_classes_min, 100)
  208. self.assertLess(test_0_classes_max, 200)
  209. with self.assertRaises(IndexError):
  210. test_exp_1: GenericExperience = benchmark_instance.test_stream[1]
  211. other_exp_0: GenericExperience = benchmark_instance.other_stream[0]
  212. other_0_classes = other_exp_0.classes_in_this_experience
  213. other_0_classes_min = min(other_0_classes)
  214. other_0_classes_max = max(other_0_classes)
  215. self.assertGreaterEqual(other_0_classes_min, 400)
  216. self.assertLess(other_0_classes_max, 600)
  217. with self.assertRaises(IndexError):
  218. other_exp_1: GenericExperience = benchmark_instance.other_stream[1]
  219. def test_lazy_benchmark_drop_old_ones(self):
  220. train_exps, test_exps, other_stream_exps = self._make_tensor_datasets()
  221. train_dataset_exp_0_weak_ref = weakref.ref(train_exps[0])
  222. train_dataset_exp_1_weak_ref = weakref.ref(train_exps[1])
  223. train_gen = GenericCLScenarioTests._generate_stream(train_exps)
  224. test_gen = GenericCLScenarioTests._generate_stream(test_exps)
  225. other_gen = GenericCLScenarioTests._generate_stream(other_stream_exps)
  226. benchmark_instance = GenericCLScenario(
  227. stream_definitions=dict(
  228. train=((train_gen, len(train_exps)), [
  229. train_exps[0].targets_task_labels,
  230. train_exps[1].targets_task_labels
  231. ]),
  232. test=((test_gen, len(test_exps)), [
  233. test_exps[0].targets_task_labels
  234. ]),
  235. other=((other_gen, len(other_stream_exps)), [
  236. other_stream_exps[0].targets_task_labels
  237. ])))
  238. # --- START: Test classes timeline before first experience ---
  239. current_classes, prev_classes, cumulative_classes, future_classes = \
  240. benchmark_instance.get_classes_timeline(0)
  241. self.assertIsNone(current_classes)
  242. self.assertSetEqual(set(), set(prev_classes))
  243. self.assertIsNone(cumulative_classes)
  244. self.assertIsNone(future_classes)
  245. # --- END: Test classes timeline before first experience ---
  246. train_exp_0: GenericExperience = benchmark_instance.train_stream[0]
  247. # --- START: Test classes timeline at first experience ---
  248. current_classes, prev_classes, cumulative_classes, future_classes = \
  249. benchmark_instance.get_classes_timeline(0)
  250. self.assertSetEqual(set(train_exps[0].targets), set(current_classes))
  251. self.assertSetEqual(set(), set(prev_classes))
  252. self.assertSetEqual(set(train_exps[0].targets), set(cumulative_classes))
  253. self.assertIsNone(future_classes)
  254. current_classes, prev_classes, cumulative_classes, future_classes = \
  255. benchmark_instance.get_classes_timeline(1)
  256. self.assertIsNone(current_classes)
  257. self.assertSetEqual(set(train_exps[0].targets), set(prev_classes))
  258. # None because we didn't load exp 0 yet
  259. self.assertIsNone(cumulative_classes)
  260. self.assertSetEqual(set(), set(future_classes))
  261. # --- END: Test classes timeline at first experience ---
  262. # Check if it works when the previous experience is dropped
  263. benchmark_instance.train_stream.drop_previous_experiences(0)
  264. train_exp_1: GenericExperience = benchmark_instance.train_stream[1]
  265. # --- START: Test classes timeline at second experience ---
  266. # Check if get_classes_timeline(0) is consistent
  267. current_classes, prev_classes, cumulative_classes, future_classes = \
  268. benchmark_instance.get_classes_timeline(0)
  269. self.assertSetEqual(set(train_exps[0].targets), set(current_classes))
  270. self.assertSetEqual(set(), set(prev_classes))
  271. self.assertSetEqual(set(train_exps[0].targets), set(cumulative_classes))
  272. # We now have access to future classes!
  273. self.assertSetEqual(set(train_exps[1].targets), set(future_classes))
  274. current_classes, prev_classes, cumulative_classes, future_classes = \
  275. benchmark_instance.get_classes_timeline(1)
  276. self.assertSetEqual(set(train_exps[1].targets), set(current_classes))
  277. self.assertSetEqual(set(train_exps[0].targets), set(prev_classes))
  278. self.assertSetEqual(
  279. set(train_exps[0].targets).union(set(train_exps[1].targets)),
  280. set(cumulative_classes))
  281. self.assertSetEqual(set(), set(future_classes))
  282. # --- END: Test classes timeline at second experience ---
  283. train_0_classes = train_exp_0.classes_in_this_experience
  284. train_1_classes = train_exp_1.classes_in_this_experience
  285. train_0_classes_min = min(train_0_classes)
  286. train_1_classes_min = min(train_1_classes)
  287. train_0_classes_max = max(train_0_classes)
  288. train_1_classes_max = max(train_1_classes)
  289. self.assertGreaterEqual(train_0_classes_min, 0)
  290. self.assertLess(train_0_classes_max, 70)
  291. self.assertGreaterEqual(train_1_classes_min, 0)
  292. self.assertLess(train_1_classes_max, 100)
  293. with self.assertRaises(IndexError):
  294. train_exp_2: GenericExperience = benchmark_instance.train_stream[2]
  295. test_exp_0: GenericExperience = benchmark_instance.test_stream[0]
  296. test_0_classes = test_exp_0.classes_in_this_experience
  297. test_0_classes_min = min(test_0_classes)
  298. test_0_classes_max = max(test_0_classes)
  299. self.assertGreaterEqual(test_0_classes_min, 100)
  300. self.assertLess(test_0_classes_max, 200)
  301. with self.assertRaises(IndexError):
  302. test_exp_1: GenericExperience = benchmark_instance.test_stream[1]
  303. other_exp_0: GenericExperience = benchmark_instance.other_stream[0]
  304. other_0_classes = other_exp_0.classes_in_this_experience
  305. other_0_classes_min = min(other_0_classes)
  306. other_0_classes_max = max(other_0_classes)
  307. self.assertGreaterEqual(other_0_classes_min, 400)
  308. self.assertLess(other_0_classes_max, 600)
  309. with self.assertRaises(IndexError):
  310. other_exp_1: GenericExperience = benchmark_instance.other_stream[1]
  311. train_exps = None
  312. train_exp_0 = None
  313. train_exp_1 = None
  314. train_0_classes = None
  315. train_1_classes = None
  316. train_gen = None
  317. # The generational GC is needed, ref-count is not enough here
  318. gc.collect()
  319. # This will check that the train dataset of exp0 has been garbage
  320. # collected correctly
  321. self.assertIsNone(train_dataset_exp_0_weak_ref())
  322. self.assertIsNotNone(train_dataset_exp_1_weak_ref())
  323. benchmark_instance.train_stream.drop_previous_experiences(1)
  324. gc.collect()
  325. # This will check that exp1 has been garbage collected correctly
  326. self.assertIsNone(train_dataset_exp_0_weak_ref())
  327. self.assertIsNone(train_dataset_exp_1_weak_ref())
  328. with self.assertRaises(Exception):
  329. exp_0 = benchmark_instance.train_stream[0]
  330. with self.assertRaises(Exception):
  331. exp_1 = benchmark_instance.train_stream[1]
  332. def _make_tensor_datasets(self):
  333. train_exps = []
  334. tensor_x = torch.rand(200, 3, 28, 28)
  335. tensor_y = torch.randint(0, 70, (200,))
  336. tensor_t = torch.randint(0, 5, (200,))
  337. train_exps.append(AvalancheTensorDataset(
  338. tensor_x, tensor_y, task_labels=tensor_t,
  339. dataset_type=AvalancheDatasetType.CLASSIFICATION))
  340. tensor_x = torch.rand(200, 3, 28, 28)
  341. tensor_y = torch.randint(0, 100, (200,))
  342. tensor_t = torch.randint(0, 5, (200,))
  343. train_exps.append(AvalancheTensorDataset(
  344. tensor_x, tensor_y, task_labels=tensor_t,
  345. dataset_type=AvalancheDatasetType.CLASSIFICATION))
  346. test_exps = []
  347. test_x = torch.rand(200, 3, 28, 28)
  348. test_y = torch.randint(100, 200, (200,))
  349. test_t = torch.randint(0, 5, (200,))
  350. test_exps.append(AvalancheTensorDataset(
  351. test_x, test_y, task_labels=test_t,
  352. dataset_type=AvalancheDatasetType.CLASSIFICATION))
  353. other_stream_exps = []
  354. other_x = torch.rand(200, 3, 28, 28)
  355. other_y = torch.randint(400, 600, (200,))
  356. other_t = torch.randint(0, 5, (200,))
  357. other_stream_exps.append(AvalancheTensorDataset(
  358. other_x, other_y, task_labels=other_t,
  359. dataset_type=AvalancheDatasetType.CLASSIFICATION))
  360. return train_exps, test_exps, other_stream_exps
  361. @staticmethod
  362. def _generate_stream(dataset_list):
  363. # Lazy generator of a stream
  364. for dataset in dataset_list:
  365. yield dataset
  366. if __name__ == '__main__':
  367. unittest.main()