icarl.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. from os.path import expanduser
  2. import torch
  3. from avalanche.benchmarks.datasets import CIFAR100
  4. from avalanche.benchmarks.utils import AvalancheDataset
  5. from avalanche.models import IcarlNet, make_icarl_net, initialize_icarl_net
  6. from avalanche.training.plugins.lr_scheduling import LRSchedulerPlugin
  7. from torch.optim import SGD
  8. from torchvision import transforms
  9. from avalanche.benchmarks.generators import nc_benchmark
  10. from avalanche.training.plugins import EvaluationPlugin
  11. from avalanche.evaluation.metrics import ExperienceAccuracy, StreamAccuracy, \
  12. EpochAccuracy
  13. from avalanche.logging.interactive_logging import InteractiveLogger
  14. import random
  15. import numpy as np
  16. from torch.optim.lr_scheduler import MultiStepLR
  17. from avalanche.training.strategies.icarl import ICaRL
  18. def get_dataset_per_pixel_mean(dataset):
  19. result = None
  20. patterns_count = 0
  21. for img_pattern, _ in dataset:
  22. if result is None:
  23. result = torch.zeros_like(img_pattern, dtype=torch.float)
  24. result += img_pattern
  25. patterns_count += 1
  26. if result is None:
  27. result = torch.empty(0, dtype=torch.float)
  28. else:
  29. result = result / patterns_count
  30. return result
  31. def icarl_cifar100_augment_data(img):
  32. img = img.numpy()
  33. padded = np.pad(img, ((0, 0), (4, 4), (4, 4)), mode='constant')
  34. random_cropped = np.zeros(img.shape, dtype=np.float32)
  35. crop = np.random.randint(0, high=8 + 1, size=(2,))
  36. # Cropping and possible flipping
  37. if np.random.randint(2) > 0:
  38. random_cropped[:, :, :] = \
  39. padded[:, crop[0]:(crop[0]+32), crop[1]:(crop[1]+32)]
  40. else:
  41. random_cropped[:, :, :] = \
  42. padded[:, crop[0]:(crop[0]+32), crop[1]:(crop[1]+32)][:, :, ::-1]
  43. t = torch.tensor(random_cropped)
  44. return t
  45. def run_experiment(config):
  46. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  47. torch.manual_seed(config.seed)
  48. torch.cuda.manual_seed(config.seed)
  49. np.random.seed(config.seed)
  50. random.seed(config.seed)
  51. torch.backends.cudnn.enabled = False
  52. torch.backends.cudnn.deterministic = True
  53. per_pixel_mean = get_dataset_per_pixel_mean(
  54. CIFAR100(expanduser("~") + "/.avalanche/data/cifar100/",
  55. train=True, download=True,
  56. transform=transforms.Compose([transforms.ToTensor()])))
  57. transforms_group = dict(
  58. eval=(transforms.Compose([
  59. transforms.ToTensor(),
  60. lambda img_pattern: img_pattern - per_pixel_mean]), None),
  61. train=(transforms.Compose([
  62. transforms.ToTensor(),
  63. lambda img_pattern: img_pattern - per_pixel_mean,
  64. icarl_cifar100_augment_data]), None),
  65. )
  66. train_set = CIFAR100(expanduser("~") + "/.avalanche/data/cifar100/",
  67. train=True, download=True, )
  68. test_set = CIFAR100(expanduser("~") + "/.avalanche/data/cifar100/",
  69. train=False, download=True, )
  70. train_set = AvalancheDataset(train_set,
  71. transform_groups=transforms_group,
  72. initial_transform_group='train')
  73. test_set = AvalancheDataset(test_set,
  74. transform_groups=transforms_group,
  75. initial_transform_group='eval')
  76. scenario = nc_benchmark(
  77. train_dataset=train_set,
  78. test_dataset=test_set,
  79. n_experiences=config.nb_exp,
  80. task_labels=False, seed=config.seed,
  81. shuffle=False,
  82. fixed_class_order=config.fixed_class_order)
  83. evaluator = EvaluationPlugin(EpochAccuracy(), ExperienceAccuracy(),
  84. StreamAccuracy(),
  85. loggers=[InteractiveLogger()])
  86. model: IcarlNet = make_icarl_net(num_classes=100)
  87. model.apply(initialize_icarl_net)
  88. optim = SGD(model.parameters(), lr=config.lr_base,
  89. weight_decay=config.wght_decay, momentum=0.9)
  90. sched = LRSchedulerPlugin(
  91. MultiStepLR(optim, config.lr_milestones, gamma=1.0 / config.lr_factor))
  92. strategy = ICaRL(
  93. model.feature_extractor, model.classifier, optim,
  94. config.memory_size,
  95. buffer_transform=transforms.Compose([icarl_cifar100_augment_data]),
  96. fixed_memory=True, train_mb_size=config.batch_size,
  97. train_epochs=config.epochs, eval_mb_size=config.batch_size,
  98. plugins=[sched], device=device, evaluator=evaluator
  99. )
  100. for i, exp in enumerate(scenario.train_stream):
  101. eval_exps = [e for e in scenario.test_stream][:i + 1]
  102. strategy.train(exp, num_workers=4)
  103. strategy.eval(eval_exps, num_workers=4)
  104. class Config(dict):
  105. def __getattribute__(self, key):
  106. try:
  107. return self[key]
  108. except KeyError:
  109. raise AttributeError(key)
  110. def __setattr__(self, key, value):
  111. self[key] = value
  112. if __name__ == "__main__":
  113. config = Config()
  114. config.batch_size = 128
  115. config.nb_exp = 10
  116. config.memory_size = 2000
  117. config.epochs = 70
  118. config.lr_base = 2.
  119. config.lr_milestones = [49, 63]
  120. config.lr_factor = 5.
  121. config.wght_decay = 0.00001
  122. config.fixed_class_order = [87, 0, 52, 58, 44, 91, 68, 97, 51, 15,
  123. 94, 92, 10, 72, 49, 78, 61, 14, 8, 86,
  124. 84, 96, 18, 24, 32, 45, 88, 11, 4, 67,
  125. 69, 66, 77, 47, 79, 93, 29, 50, 57, 83,
  126. 17, 81, 41, 12, 37, 59, 25, 20, 80, 73,
  127. 1, 28, 6, 46, 62, 82, 53, 9, 31, 75,
  128. 38, 63, 33, 74, 27, 22, 36, 3, 16, 21,
  129. 60, 19, 70, 90, 89, 43, 5, 42, 65, 76,
  130. 40, 30, 23, 85, 2, 95, 56, 48, 71, 64,
  131. 98, 13, 99, 7, 34, 55, 54, 26, 35, 39]
  132. config.seed = 2222
  133. run_experiment(config)