test_nc_mt_scenario.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. import unittest
  2. from os.path import expanduser
  3. from torchvision.datasets import MNIST
  4. from avalanche.benchmarks.scenarios.new_classes import NCExperience
  5. from avalanche.benchmarks.utils import AvalancheSubset
  6. from avalanche.benchmarks.scenarios.new_classes.nc_utils import \
  7. make_nc_transformation_subset
  8. from avalanche.benchmarks import nc_benchmark, GenericScenarioStream
  9. class MultiTaskTests(unittest.TestCase):
  10. def test_mt_single_dataset(self):
  11. mnist_train = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  12. train=True, download=True)
  13. mnist_test = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  14. train=False, download=True)
  15. my_nc_benchmark = nc_benchmark(
  16. mnist_train, mnist_test, 5, task_labels=True, shuffle=True,
  17. seed=1234, class_ids_from_zero_in_each_exp=True)
  18. self.assertEqual(5, my_nc_benchmark.n_experiences)
  19. self.assertEqual(10, my_nc_benchmark.n_classes)
  20. for task_id in range(5):
  21. self.assertEqual(
  22. 2, len(my_nc_benchmark.classes_in_experience['train'][task_id])
  23. )
  24. all_classes = set()
  25. all_original_classes = set()
  26. for task_id in range(5):
  27. all_classes.update(
  28. my_nc_benchmark.classes_in_experience['train'][task_id])
  29. all_original_classes.update(
  30. my_nc_benchmark.original_classes_in_exp[task_id])
  31. self.assertEqual(2, len(all_classes))
  32. self.assertEqual(10, len(all_original_classes))
  33. def test_mt_single_dataset_without_class_id_remap(self):
  34. mnist_train = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  35. train=True, download=True)
  36. mnist_test = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  37. train=False, download=True)
  38. my_nc_benchmark = nc_benchmark(
  39. mnist_train, mnist_test, 5, task_labels=True, shuffle=True,
  40. seed=1234, class_ids_from_zero_in_each_exp=False)
  41. self.assertEqual(5, my_nc_benchmark.n_experiences)
  42. self.assertEqual(10, my_nc_benchmark.n_classes)
  43. for task_id in range(5):
  44. self.assertEqual(
  45. 2, len(my_nc_benchmark.classes_in_experience['train'][task_id])
  46. )
  47. all_classes = set()
  48. for task_id in range(my_nc_benchmark.n_experiences):
  49. all_classes.update(
  50. my_nc_benchmark.classes_in_experience['train'][task_id])
  51. self.assertEqual(10, len(all_classes))
  52. def test_mt_single_dataset_fixed_order(self):
  53. order = [2, 3, 5, 7, 8, 9, 0, 1, 4, 6]
  54. mnist_train = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  55. train=True, download=True)
  56. mnist_test = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  57. train=False, download=True)
  58. my_nc_benchmark = nc_benchmark(
  59. mnist_train, mnist_test, 5, task_labels=True,
  60. fixed_class_order=order, class_ids_from_zero_in_each_exp=False)
  61. all_classes = []
  62. for task_id in range(5):
  63. all_classes.extend(
  64. my_nc_benchmark.classes_in_experience['train'][task_id])
  65. self.assertEqual(order, all_classes)
  66. def test_sit_single_dataset_fixed_order_subset(self):
  67. order = [2, 5, 7, 8, 9, 0, 1, 4]
  68. mnist_train = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  69. train=True, download=True)
  70. mnist_test = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  71. train=False, download=True)
  72. my_nc_benchmark = nc_benchmark(
  73. mnist_train, mnist_test, 4, task_labels=True,
  74. fixed_class_order=order, class_ids_from_zero_in_each_exp=True)
  75. self.assertEqual(4, len(my_nc_benchmark.classes_in_experience['train']))
  76. all_classes = []
  77. for task_id in range(4):
  78. self.assertEqual(
  79. 2, len(my_nc_benchmark.classes_in_experience['train'][task_id])
  80. )
  81. self.assertEqual(
  82. set(order[task_id*2:(task_id+1)*2]),
  83. my_nc_benchmark.original_classes_in_exp[task_id])
  84. all_classes.extend(
  85. my_nc_benchmark.classes_in_experience['train'][task_id])
  86. self.assertEqual([0, 1] * 4, all_classes)
  87. def test_sit_single_dataset_fixed_subset_no_remap_idx(self):
  88. order = [2, 5, 7, 8, 9, 0, 1, 4]
  89. mnist_train = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  90. train=True, download=True)
  91. mnist_test = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  92. train=False, download=True)
  93. my_nc_benchmark = nc_benchmark(
  94. mnist_train, mnist_test, 2, task_labels=True,
  95. fixed_class_order=order, class_ids_from_zero_in_each_exp=False)
  96. self.assertEqual(2, len(my_nc_benchmark.classes_in_experience['train']))
  97. all_classes = set()
  98. for task_id in range(2):
  99. self.assertEqual(
  100. 4, len(my_nc_benchmark.classes_in_experience['train'][task_id])
  101. )
  102. self.assertEqual(
  103. set(order[task_id*4:(task_id+1)*4]),
  104. my_nc_benchmark.original_classes_in_exp[task_id])
  105. all_classes.update(
  106. my_nc_benchmark.classes_in_experience['train'][task_id])
  107. self.assertEqual(set(order), all_classes)
  108. def test_mt_single_dataset_reproducibility_data(self):
  109. mnist_train = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  110. train=True, download=True)
  111. mnist_test = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  112. train=False, download=True)
  113. nc_benchmark_ref = nc_benchmark(
  114. mnist_train, mnist_test, 5, task_labels=True, shuffle=True,
  115. seed=5678)
  116. my_nc_benchmark = nc_benchmark(
  117. mnist_train, mnist_test, -1, task_labels=True,
  118. reproducibility_data=nc_benchmark_ref.get_reproducibility_data())
  119. self.assertEqual(nc_benchmark_ref.train_exps_patterns_assignment,
  120. my_nc_benchmark.train_exps_patterns_assignment)
  121. self.assertEqual(nc_benchmark_ref.test_exps_patterns_assignment,
  122. my_nc_benchmark.test_exps_patterns_assignment)
  123. def test_mt_single_dataset_task_size(self):
  124. mnist_train = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  125. train=True, download=True)
  126. mnist_test = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  127. train=False, download=True)
  128. my_nc_benchmark = nc_benchmark(
  129. mnist_train, mnist_test, 3, task_labels=True,
  130. per_exp_classes={0: 5, 2: 2},
  131. class_ids_from_zero_in_each_exp=True)
  132. self.assertEqual(3, my_nc_benchmark.n_experiences)
  133. self.assertEqual(10, my_nc_benchmark.n_classes)
  134. all_classes = set()
  135. for task_id in range(3):
  136. all_classes.update(
  137. my_nc_benchmark.classes_in_experience['train'][task_id])
  138. self.assertEqual(5, len(all_classes))
  139. self.assertEqual(
  140. 5, len(my_nc_benchmark.classes_in_experience['train'][0]))
  141. self.assertEqual(
  142. 3, len(my_nc_benchmark.classes_in_experience['train'][1]))
  143. self.assertEqual(
  144. 2, len(my_nc_benchmark.classes_in_experience['train'][2]))
  145. def test_mt_multi_dataset_one_task_per_set(self):
  146. split_mapping = [0, 1, 2, 0, 1, 2, 3, 4, 5, 6]
  147. mnist_train = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  148. train=True, download=True)
  149. mnist_test = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  150. train=False, download=True)
  151. train_part1 = make_nc_transformation_subset(
  152. mnist_train, None, None, range(3))
  153. train_part2 = make_nc_transformation_subset(
  154. mnist_train, None, None, range(3, 10))
  155. train_part2 = AvalancheSubset(
  156. train_part2, class_mapping=split_mapping)
  157. test_part1 = make_nc_transformation_subset(
  158. mnist_test, None, None, range(3))
  159. test_part2 = make_nc_transformation_subset(
  160. mnist_test, None, None, range(3, 10))
  161. test_part2 = AvalancheSubset(test_part2,
  162. class_mapping=split_mapping)
  163. my_nc_benchmark = nc_benchmark(
  164. [train_part1, train_part2], [test_part1, test_part2], 2,
  165. task_labels=True, seed=1234,
  166. class_ids_from_zero_in_each_exp=True, one_dataset_per_exp=True)
  167. self.assertEqual(2, my_nc_benchmark.n_experiences)
  168. self.assertEqual(10, my_nc_benchmark.n_classes)
  169. self.assertEqual(2, len(my_nc_benchmark.train_stream))
  170. self.assertEqual(2, len(my_nc_benchmark.test_stream))
  171. exp_classes_train = []
  172. exp_classes_test = []
  173. all_classes_train = set()
  174. all_classes_test = set()
  175. task_info: NCExperience
  176. for task_id, task_info in enumerate(my_nc_benchmark.train_stream):
  177. self.assertLessEqual(task_id, 1)
  178. all_classes_train.update(
  179. my_nc_benchmark.classes_in_experience['train'][task_id]
  180. )
  181. exp_classes_train.append(task_info.classes_in_this_experience)
  182. self.assertEqual(7, len(all_classes_train))
  183. for task_id, task_info in enumerate(my_nc_benchmark.test_stream):
  184. self.assertLessEqual(task_id, 1)
  185. all_classes_test.update(
  186. my_nc_benchmark.classes_in_experience['test'][task_id]
  187. )
  188. exp_classes_test.append(task_info.classes_in_this_experience)
  189. self.assertEqual(7, len(all_classes_test))
  190. self.assertTrue(
  191. (my_nc_benchmark.classes_in_experience['train'][0] == {0, 1, 2} and
  192. my_nc_benchmark.classes_in_experience['train'][1] ==
  193. set(range(0, 7))) or
  194. (my_nc_benchmark.classes_in_experience['train'][0] ==
  195. set(range(0, 7)) and
  196. my_nc_benchmark.classes_in_experience['train'][1] == {0, 1, 2}))
  197. exp_classes_ref1 = [list(range(3)), list(range(7))]
  198. exp_classes_ref2 = [list(range(7)), list(range(3))]
  199. self.assertTrue(exp_classes_train == exp_classes_ref1 or
  200. exp_classes_train == exp_classes_ref2)
  201. if exp_classes_train == exp_classes_ref1:
  202. self.assertTrue(exp_classes_test == exp_classes_ref1)
  203. else:
  204. self.assertTrue(exp_classes_test == exp_classes_ref2)
  205. def test_nc_mt_slicing(self):
  206. mnist_train = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  207. train=True, download=True)
  208. mnist_test = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/",
  209. train=False, download=True)
  210. my_nc_benchmark = nc_benchmark(
  211. mnist_train, mnist_test, 5, task_labels=True, shuffle=True,
  212. seed=1234)
  213. experience: NCExperience
  214. for batch_id, experience in enumerate(my_nc_benchmark.train_stream):
  215. self.assertEqual(batch_id, experience.current_experience)
  216. self.assertIsInstance(experience, NCExperience)
  217. for batch_id, experience in enumerate(my_nc_benchmark.test_stream):
  218. self.assertEqual(batch_id, experience.current_experience)
  219. self.assertIsInstance(experience, NCExperience)
  220. iterable_slice = [3, 4, 1]
  221. sliced_stream = my_nc_benchmark.train_stream[iterable_slice]
  222. self.assertIsInstance(sliced_stream, GenericScenarioStream)
  223. self.assertEqual(len(iterable_slice), len(sliced_stream))
  224. self.assertEqual('train', sliced_stream.name)
  225. for batch_id, experience in enumerate(sliced_stream):
  226. self.assertEqual(
  227. iterable_slice[batch_id], experience.current_experience
  228. )
  229. self.assertIsInstance(experience, NCExperience)
  230. sliced_stream = my_nc_benchmark.test_stream[iterable_slice]
  231. self.assertIsInstance(sliced_stream, GenericScenarioStream)
  232. self.assertEqual(len(iterable_slice), len(sliced_stream))
  233. self.assertEqual('test', sliced_stream.name)
  234. for batch_id, experience in enumerate(sliced_stream):
  235. self.assertEqual(
  236. iterable_slice[batch_id], experience.current_experience)
  237. self.assertIsInstance(experience, NCExperience)
  238. if __name__ == '__main__':
  239. unittest.main()