test_models.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. import sys
  2. import unittest
  3. import pytorchcv.models.pyramidnet_cifar
  4. import torch
  5. from torch.nn import CrossEntropyLoss
  6. from torch.optim import SGD
  7. from torch.utils.data import DataLoader
  8. from avalanche.logging import TextLogger
  9. from avalanche.models import MTSimpleMLP, SimpleMLP, IncrementalClassifier, \
  10. MultiHeadClassifier, SimpleCNN, NCMClassifier, TrainEvalModel
  11. from avalanche.models.dynamic_optimizers import add_new_params_to_optimizer, \
  12. update_optimizer
  13. from avalanche.training.strategies import Naive
  14. from avalanche.models.pytorchcv_wrapper import vgg, resnet, densenet, \
  15. pyramidnet, get_model
  16. from tests.unit_tests_utils import common_setups, get_fast_benchmark
  17. class PytorchcvWrapperTests(unittest.TestCase):
  18. def setUp(self):
  19. common_setups()
  20. def test_vgg(self):
  21. model = vgg(depth=19, batch_normalization=True,
  22. pretrained=False)
  23. # Batch norm is activated
  24. self.assertIsInstance(model.features.stage1.unit1.bn,
  25. torch.nn.BatchNorm2d)
  26. # Check correct depth is loaded
  27. self.assertEqual(len(model.features.stage5), 5)
  28. def test_resnet(self):
  29. model = resnet("cifar10", depth=20)
  30. # Test input/output sizes
  31. self.assertEqual(model.in_size, (32, 32))
  32. self.assertEqual(model.num_classes, 10)
  33. # Test input/output sizes
  34. model = resnet("imagenet", depth=12)
  35. self.assertEqual(model.in_size, (224, 224))
  36. self.assertEqual(model.num_classes, 1000)
  37. def test_pyramidnet(self):
  38. model = pyramidnet("cifar10", depth=110)
  39. self.assertIsInstance(model,
  40. pytorchcv.models.pyramidnet_cifar.CIFARPyramidNet)
  41. model = pyramidnet("imagenet", depth=101)
  42. self.assertIsInstance(model,
  43. pytorchcv.models.pyramidnet.PyramidNet)
  44. def test_densenet(self):
  45. model = densenet("svhn", depth=40)
  46. self.assertIsInstance(model,
  47. pytorchcv.models.densenet_cifar.CIFARDenseNet)
  48. def test_get_model(self):
  49. # Check general wrapper and whether downloading pretrained model works
  50. model = get_model('simplepose_resnet18_coco', pretrained=True)
  51. self.assertIsInstance(model,
  52. pytorchcv.models.simplepose_coco.SimplePose)
  53. class DynamicOptimizersTests(unittest.TestCase):
  54. def setUp(self):
  55. common_setups()
  56. def _is_param_in_optimizer(self, param, optimizer):
  57. for group in optimizer.param_groups:
  58. for curr_p in group['params']:
  59. if hash(curr_p) == hash(param):
  60. return True
  61. return False
  62. def test_optimizer_update(self):
  63. model = SimpleMLP()
  64. optimizer = SGD(model.parameters(), lr=1e-3)
  65. strategy = Naive(model, optimizer, None)
  66. # check add_param_group
  67. p = torch.nn.Parameter(torch.zeros(10, 10))
  68. add_new_params_to_optimizer(optimizer, p)
  69. assert self._is_param_in_optimizer(p, strategy.optimizer)
  70. # check new_param is in optimizer
  71. # check old_param is NOT in optimizer
  72. p_new = torch.nn.Parameter(torch.zeros(10, 10))
  73. update_optimizer(optimizer, [p], [p_new])
  74. assert self._is_param_in_optimizer(p_new, strategy.optimizer)
  75. assert not self._is_param_in_optimizer(p, strategy.optimizer)
  76. class DynamicModelsTests(unittest.TestCase):
  77. def setUp(self):
  78. common_setups()
  79. self.benchmark = get_fast_benchmark(
  80. use_task_labels=False, shuffle=False)
  81. def test_incremental_classifier(self):
  82. model = SimpleMLP(input_size=6, hidden_size=10)
  83. model.classifier = IncrementalClassifier(in_features=10)
  84. optimizer = SGD(model.parameters(), lr=1e-3)
  85. criterion = CrossEntropyLoss()
  86. benchmark = self.benchmark
  87. strategy = Naive(model, optimizer, criterion,
  88. train_mb_size=100, train_epochs=1,
  89. eval_mb_size=100, device='cpu')
  90. strategy.evaluator.loggers = [TextLogger(sys.stdout)]
  91. print("Current Classes: ",
  92. benchmark.train_stream[0].classes_in_this_experience)
  93. print("Current Classes: ",
  94. benchmark.train_stream[4].classes_in_this_experience)
  95. # train on first task
  96. strategy.train(benchmark.train_stream[0])
  97. w_ptr = model.classifier.classifier.weight.data_ptr()
  98. b_ptr = model.classifier.classifier.bias.data_ptr()
  99. opt_params_ptrs = [w.data_ptr() for group in optimizer.param_groups
  100. for w in group['params']]
  101. # classifier params should be optimized
  102. assert w_ptr in opt_params_ptrs
  103. assert b_ptr in opt_params_ptrs
  104. # train again on the same task.
  105. strategy.train(benchmark.train_stream[0])
  106. # parameters should not change.
  107. assert w_ptr == model.classifier.classifier.weight.data_ptr()
  108. assert b_ptr == model.classifier.classifier.bias.data_ptr()
  109. # the same classifier params should still be optimized
  110. assert w_ptr in opt_params_ptrs
  111. assert b_ptr in opt_params_ptrs
  112. # update classifier with new classes.
  113. old_w_ptr, old_b_ptr = w_ptr, b_ptr
  114. strategy.train(benchmark.train_stream[4])
  115. opt_params_ptrs = [w.data_ptr() for group in optimizer.param_groups
  116. for w in group['params']]
  117. new_w_ptr = model.classifier.classifier.weight.data_ptr()
  118. new_b_ptr = model.classifier.classifier.bias.data_ptr()
  119. # weights should change.
  120. assert old_w_ptr != new_w_ptr
  121. assert old_b_ptr != new_b_ptr
  122. # Old params should not be optimized. New params should be optimized.
  123. assert old_w_ptr not in opt_params_ptrs
  124. assert old_b_ptr not in opt_params_ptrs
  125. assert new_w_ptr in opt_params_ptrs
  126. assert new_b_ptr in opt_params_ptrs
  127. def test_incremental_classifier_weight_update(self):
  128. model = IncrementalClassifier(in_features=10)
  129. optimizer = SGD(model.parameters(), lr=1e-3)
  130. criterion = CrossEntropyLoss()
  131. benchmark = self.benchmark
  132. strategy = Naive(model, optimizer, criterion,
  133. train_mb_size=100, train_epochs=1,
  134. eval_mb_size=100, device='cpu')
  135. strategy.evaluator.loggers = [TextLogger(sys.stdout)]
  136. # train on first task
  137. w_old = model.classifier.weight.clone()
  138. b_old = model.classifier.bias.clone()
  139. # adaptation. Increase number of classes
  140. dataset = benchmark.train_stream[4].dataset
  141. model.adaptation(dataset)
  142. w_new = model.classifier.weight.clone()
  143. b_new = model.classifier.bias.clone()
  144. # old weights should be copied correctly.
  145. assert torch.equal(w_old, w_new[:w_old.shape[0]])
  146. assert torch.equal(b_old, b_new[:w_old.shape[0]])
  147. # shape should be correct.
  148. assert w_new.shape[0] == max(dataset.targets) + 1
  149. assert b_new.shape[0] == max(dataset.targets) + 1
  150. def test_multihead_head_creation(self):
  151. # Check if the optimizer is updated correctly
  152. # when heads are created and updated.
  153. model = MTSimpleMLP(input_size=6, hidden_size=10)
  154. optimizer = SGD(model.parameters(), lr=1e-3)
  155. criterion = CrossEntropyLoss()
  156. benchmark = get_fast_benchmark(use_task_labels=True, shuffle=False)
  157. strategy = Naive(model, optimizer, criterion,
  158. train_mb_size=100, train_epochs=1,
  159. eval_mb_size=100, device='cpu')
  160. strategy.evaluator.loggers = [TextLogger(sys.stdout)]
  161. print("Current Classes: ",
  162. benchmark.train_stream[4].classes_in_this_experience)
  163. print("Current Classes: ",
  164. benchmark.train_stream[0].classes_in_this_experience)
  165. # head creation
  166. strategy.train(benchmark.train_stream[0])
  167. w_ptr = model.classifier.classifiers['0'].classifier.weight.data_ptr()
  168. b_ptr = model.classifier.classifiers['0'].classifier.bias.data_ptr()
  169. opt_params_ptrs = [w.data_ptr() for group in optimizer.param_groups
  170. for w in group['params']]
  171. assert w_ptr in opt_params_ptrs
  172. assert b_ptr in opt_params_ptrs
  173. # head update
  174. strategy.train(benchmark.train_stream[4])
  175. w_ptr_t0 = model.classifier.classifiers[
  176. '0'].classifier.weight.data_ptr()
  177. b_ptr_t0 = model.classifier.classifiers['0'].classifier.bias.data_ptr()
  178. w_ptr_new = model.classifier.classifiers[
  179. '4'].classifier.weight.data_ptr()
  180. b_ptr_new = model.classifier.classifiers['4'].classifier.bias.data_ptr()
  181. opt_params_ptrs = [w.data_ptr() for group in optimizer.param_groups
  182. for w in group['params']]
  183. assert w_ptr not in opt_params_ptrs # head0 has been updated
  184. assert b_ptr not in opt_params_ptrs # head0 has been updated
  185. assert w_ptr_t0 in opt_params_ptrs
  186. assert b_ptr_t0 in opt_params_ptrs
  187. assert w_ptr_new in opt_params_ptrs
  188. assert b_ptr_new in opt_params_ptrs
  189. def test_multihead_head_selection(self):
  190. # Check if the optimizer is updated correctly
  191. # when heads are created and updated.
  192. model = MultiHeadClassifier(in_features=6)
  193. optimizer = SGD(model.parameters(), lr=1e-3)
  194. criterion = CrossEntropyLoss()
  195. benchmark = get_fast_benchmark(use_task_labels=True, shuffle=False)
  196. strategy = Naive(model, optimizer, criterion,
  197. train_mb_size=100, train_epochs=1,
  198. eval_mb_size=100, device='cpu')
  199. strategy.evaluator.loggers = [TextLogger(sys.stdout)]
  200. # initialize head
  201. strategy.train(benchmark.train_stream[0])
  202. strategy.train(benchmark.train_stream[4])
  203. # create models with fixed head
  204. model_t0 = model.classifiers['0']
  205. model_t4 = model.classifiers['4']
  206. # check head task0
  207. for x, y, t in DataLoader(benchmark.train_stream[0].dataset):
  208. y_mh = model(x, t)
  209. y_t = model_t0(x)
  210. assert ((y_mh - y_t) ** 2).sum() < 1.e-7
  211. break
  212. # check head task4
  213. for x, y, t in DataLoader(benchmark.train_stream[4].dataset):
  214. y_mh = model(x, t)
  215. y_t = model_t4(x)
  216. assert ((y_mh - y_t) ** 2).sum() < 1.e-7
  217. break
  218. class TrainEvalModelTests(unittest.TestCase):
  219. def test_classifier_selection(self):
  220. base_model = SimpleCNN()
  221. feature_extractor = base_model.features
  222. classifier1 = base_model.classifier
  223. classifier2 = NCMClassifier()
  224. model = TrainEvalModel(feature_extractor,
  225. train_classifier=classifier1,
  226. eval_classifier=classifier2)
  227. model.eval()
  228. model.adaptation()
  229. assert model.classifier is classifier2
  230. model.train()
  231. model.adaptation()
  232. assert model.classifier is classifier1
  233. model.eval_adaptation()
  234. assert model.classifier is classifier2
  235. model.train_adaptation()
  236. assert model.classifier is classifier1
  237. class NCMClassifierTest(unittest.TestCase):
  238. def test_ncm_classification(self):
  239. class_means = torch.tensor([
  240. [1, 0, 0, 0],
  241. [0, 1, 0, 0],
  242. [0, 0, 1, 0],
  243. [0, 0, 0, 1]
  244. ], dtype=torch.float)
  245. mb_x = torch.tensor([
  246. [4, 3, 2, 1],
  247. [3, 4, 2, 1],
  248. [3, 2, 4, 1],
  249. [3, 2, 1, 4]
  250. ], dtype=torch.float)
  251. mb_y = torch.tensor([0, 1, 2, 3], dtype=torch.float)
  252. classifier = NCMClassifier(class_means)
  253. pred = classifier(mb_x)
  254. assert torch.all(torch.max(pred, 1)[1] == mb_y)