123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166 |
- from os.path import expanduser
- import torch
- from avalanche.benchmarks.datasets import CIFAR100
- from avalanche.benchmarks.utils import AvalancheDataset
- from avalanche.models import IcarlNet, make_icarl_net, initialize_icarl_net
- from avalanche.training.plugins.lr_scheduling import LRSchedulerPlugin
- from torch.optim import SGD
- from torchvision import transforms
- from avalanche.benchmarks.generators import nc_benchmark
- from avalanche.training.plugins import EvaluationPlugin
- from avalanche.evaluation.metrics import ExperienceAccuracy, StreamAccuracy, \
- EpochAccuracy
- from avalanche.logging.interactive_logging import InteractiveLogger
- import random
- import numpy as np
- from torch.optim.lr_scheduler import MultiStepLR
- from avalanche.training.strategies.icarl import ICaRL
- def get_dataset_per_pixel_mean(dataset):
- result = None
- patterns_count = 0
- for img_pattern, _ in dataset:
- if result is None:
- result = torch.zeros_like(img_pattern, dtype=torch.float)
- result += img_pattern
- patterns_count += 1
- if result is None:
- result = torch.empty(0, dtype=torch.float)
- else:
- result = result / patterns_count
- return result
- def icarl_cifar100_augment_data(img):
- img = img.numpy()
- padded = np.pad(img, ((0, 0), (4, 4), (4, 4)), mode='constant')
- random_cropped = np.zeros(img.shape, dtype=np.float32)
- crop = np.random.randint(0, high=8 + 1, size=(2,))
- # Cropping and possible flipping
- if np.random.randint(2) > 0:
- random_cropped[:, :, :] = \
- padded[:, crop[0]:(crop[0]+32), crop[1]:(crop[1]+32)]
- else:
- random_cropped[:, :, :] = \
- padded[:, crop[0]:(crop[0]+32), crop[1]:(crop[1]+32)][:, :, ::-1]
- t = torch.tensor(random_cropped)
- return t
- def run_experiment(config):
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
- torch.manual_seed(config.seed)
- torch.cuda.manual_seed(config.seed)
- np.random.seed(config.seed)
- random.seed(config.seed)
- torch.backends.cudnn.enabled = False
- torch.backends.cudnn.deterministic = True
- per_pixel_mean = get_dataset_per_pixel_mean(
- CIFAR100(expanduser("~") + "/.avalanche/data/cifar100/",
- train=True, download=True,
- transform=transforms.Compose([transforms.ToTensor()])))
- transforms_group = dict(
- eval=(transforms.Compose([
- transforms.ToTensor(),
- lambda img_pattern: img_pattern - per_pixel_mean]), None),
- train=(transforms.Compose([
- transforms.ToTensor(),
- lambda img_pattern: img_pattern - per_pixel_mean,
- icarl_cifar100_augment_data]), None),
- )
- train_set = CIFAR100(expanduser("~") + "/.avalanche/data/cifar100/",
- train=True, download=True, )
- test_set = CIFAR100(expanduser("~") + "/.avalanche/data/cifar100/",
- train=False, download=True, )
- train_set = AvalancheDataset(train_set,
- transform_groups=transforms_group,
- initial_transform_group='train')
- test_set = AvalancheDataset(test_set,
- transform_groups=transforms_group,
- initial_transform_group='eval')
- scenario = nc_benchmark(
- train_dataset=train_set,
- test_dataset=test_set,
- n_experiences=config.nb_exp,
- task_labels=False, seed=config.seed,
- shuffle=False,
- fixed_class_order=config.fixed_class_order)
- evaluator = EvaluationPlugin(EpochAccuracy(), ExperienceAccuracy(),
- StreamAccuracy(),
- loggers=[InteractiveLogger()])
- model: IcarlNet = make_icarl_net(num_classes=100)
- model.apply(initialize_icarl_net)
- optim = SGD(model.parameters(), lr=config.lr_base,
- weight_decay=config.wght_decay, momentum=0.9)
- sched = LRSchedulerPlugin(
- MultiStepLR(optim, config.lr_milestones, gamma=1.0 / config.lr_factor))
- strategy = ICaRL(
- model.feature_extractor, model.classifier, optim,
- config.memory_size,
- buffer_transform=transforms.Compose([icarl_cifar100_augment_data]),
- fixed_memory=True, train_mb_size=config.batch_size,
- train_epochs=config.epochs, eval_mb_size=config.batch_size,
- plugins=[sched], device=device, evaluator=evaluator
- )
- for i, exp in enumerate(scenario.train_stream):
- eval_exps = [e for e in scenario.test_stream][:i + 1]
- strategy.train(exp, num_workers=4)
- strategy.eval(eval_exps, num_workers=4)
- class Config(dict):
- def __getattribute__(self, key):
- try:
- return self[key]
- except KeyError:
- raise AttributeError(key)
- def __setattr__(self, key, value):
- self[key] = value
- if __name__ == "__main__":
- config = Config()
- config.batch_size = 128
- config.nb_exp = 10
- config.memory_size = 2000
- config.epochs = 70
- config.lr_base = 2.
- config.lr_milestones = [49, 63]
- config.lr_factor = 5.
- config.wght_decay = 0.00001
- config.fixed_class_order = [87, 0, 52, 58, 44, 91, 68, 97, 51, 15,
- 94, 92, 10, 72, 49, 78, 61, 14, 8, 86,
- 84, 96, 18, 24, 32, 45, 88, 11, 4, 67,
- 69, 66, 77, 47, 79, 93, 29, 50, 57, 83,
- 17, 81, 41, 12, 37, 59, 25, 20, 80, 73,
- 1, 28, 6, 46, 62, 82, 53, 9, 31, 75,
- 38, 63, 33, 74, 27, 22, 36, 3, 16, 21,
- 60, 19, 70, 90, 89, 43, 5, 42, 65, 76,
- 40, 30, 23, 85, 2, 95, 56, 48, 71, 64,
- 98, 13, 99, 7, 34, 55, 54, 26, 35, 39]
- config.seed = 2222
- run_experiment(config)
|