test_custom_streams.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. import unittest
  2. from os.path import expanduser
  3. import torch
  4. from torchvision.datasets import MNIST
  5. from avalanche.benchmarks.scenarios.new_classes import NCExperience
  6. from avalanche.benchmarks.utils import AvalancheSubset, AvalancheTensorDataset
  7. from avalanche.benchmarks.scenarios.new_classes.nc_utils import \
  8. make_nc_transformation_subset
  9. from avalanche.benchmarks import nc_benchmark, GenericScenarioStream, \
  10. GenericCLScenario
  11. class CustomStreamsTests(unittest.TestCase):
  12. def test_custom_streams_name_and_length(self):
  13. train_exps = []
  14. test_exps = []
  15. valid_exps = []
  16. for _ in range(5):
  17. tensor_x = torch.rand(200, 3, 28, 28)
  18. tensor_y = torch.randint(0, 100, (200,))
  19. tensor_t = torch.randint(0, 5, (200,))
  20. train_exps.append(AvalancheTensorDataset(tensor_x, tensor_y,
  21. task_labels=tensor_t))
  22. for _ in range(3):
  23. tensor_x = torch.rand(150, 3, 28, 28)
  24. tensor_y = torch.randint(0, 100, (150,))
  25. tensor_t = torch.randint(0, 3, (150,))
  26. test_exps.append(AvalancheTensorDataset(tensor_x, tensor_y,
  27. task_labels=tensor_t))
  28. for _ in range(4):
  29. tensor_x = torch.rand(220, 3, 28, 28)
  30. tensor_y = torch.randint(0, 100, (220,))
  31. tensor_t = torch.randint(0, 5, (220,))
  32. valid_exps.append(AvalancheTensorDataset(tensor_x, tensor_y,
  33. task_labels=tensor_t))
  34. valid_origin_dataset = AvalancheTensorDataset(
  35. torch.ones(10, 3, 32, 32), torch.zeros(10))
  36. valid_t_labels = [{9}, {4, 5}, {7, 8}, {0}, {3}]
  37. with self.assertRaises(Exception):
  38. benchmark_instance = GenericCLScenario(
  39. stream_definitions={
  40. 'train': (train_exps,),
  41. 'test': (test_exps,),
  42. 'valid': (valid_exps, valid_t_labels, valid_origin_dataset)
  43. }
  44. )
  45. valid_t_labels = valid_t_labels[:-1]
  46. benchmark_instance = GenericCLScenario(
  47. stream_definitions={
  48. 'train': (train_exps,),
  49. 'test': (test_exps,),
  50. 'valid': (valid_exps, valid_t_labels, valid_origin_dataset)
  51. }
  52. )
  53. self.assertEqual(5, len(benchmark_instance.train_stream))
  54. self.assertEqual(3, len(benchmark_instance.test_stream))
  55. self.assertEqual(4, len(benchmark_instance.valid_stream))
  56. self.assertEqual(None, benchmark_instance.original_train_dataset)
  57. self.assertEqual(None, benchmark_instance.original_test_dataset)
  58. self.assertEqual(valid_origin_dataset,
  59. benchmark_instance.original_valid_dataset)
  60. for i, exp in enumerate(benchmark_instance.train_stream):
  61. expect_x, expect_y, expect_t = train_exps[i][0]
  62. got_x, got_y, got_t = exp.dataset[0]
  63. self.assertTrue(torch.equal(expect_x, got_x))
  64. self.assertTrue(torch.equal(expect_y, got_y))
  65. self.assertEqual(int(expect_t), got_t)
  66. exp_t_labels = set(exp.task_labels)
  67. self.assertLess(max(exp_t_labels), 5)
  68. self.assertGreaterEqual(min(exp_t_labels), 0)
  69. for i, exp in enumerate(benchmark_instance.test_stream):
  70. expect_x, expect_y, expect_t = test_exps[i][0]
  71. got_x, got_y, got_t = exp.dataset[0]
  72. self.assertTrue(torch.equal(expect_x, got_x))
  73. self.assertTrue(torch.equal(expect_y, got_y))
  74. self.assertEqual(int(expect_t), got_t)
  75. exp_t_labels = set(exp.task_labels)
  76. self.assertLess(max(exp_t_labels), 3)
  77. self.assertGreaterEqual(min(exp_t_labels), 0)
  78. for i, exp in enumerate(benchmark_instance.valid_stream):
  79. expect_x, expect_y, expect_t = valid_exps[i][0]
  80. got_x, got_y, got_t = exp.dataset[0]
  81. self.assertTrue(torch.equal(expect_x, got_x))
  82. self.assertTrue(torch.equal(expect_y, got_y))
  83. self.assertEqual(int(expect_t), got_t)
  84. exp_t_labels = set(exp.task_labels)
  85. self.assertEqual(valid_t_labels[i], exp_t_labels)
  86. def test_complete_test_set_only(self):
  87. train_exps = []
  88. test_exps = []
  89. for _ in range(5):
  90. tensor_x = torch.rand(200, 3, 28, 28)
  91. tensor_y = torch.randint(0, 100, (200,))
  92. tensor_t = torch.randint(0, 5, (200,))
  93. train_exps.append(AvalancheTensorDataset(tensor_x, tensor_y,
  94. task_labels=tensor_t))
  95. for _ in range(3):
  96. tensor_x = torch.rand(150, 3, 28, 28)
  97. tensor_y = torch.randint(0, 100, (150,))
  98. tensor_t = torch.randint(0, 5, (150,))
  99. test_exps.append(AvalancheTensorDataset(tensor_x, tensor_y,
  100. task_labels=tensor_t))
  101. with self.assertRaises(Exception):
  102. benchmark_instance = GenericCLScenario(
  103. stream_definitions={
  104. 'train': (train_exps,),
  105. 'test': (test_exps,),
  106. },
  107. complete_test_set_only=True
  108. )
  109. benchmark_instance = GenericCLScenario(
  110. stream_definitions={
  111. 'train': (train_exps,),
  112. 'test': (test_exps[0],),
  113. },
  114. complete_test_set_only=True
  115. )
  116. self.assertEqual(5, len(benchmark_instance.train_stream))
  117. self.assertEqual(1, len(benchmark_instance.test_stream))
  118. if __name__ == '__main__':
  119. unittest.main()