Welcome to the "Training" tutorial of the "From Zero to Hero" series. In this part we will present the functionalities offered by the training
module.
!pip install git+https://github.com/ContinualAI/avalanche.git
The training
module in Avalanche is designed with modularity in mind. Its main goals are to:
At the moment, the training
module includes two main components:
Keep in mind that Avalanche's components are mostly independent from each other. If you already have your own strategy which does not use Avalanche, you can use benchmarks and metrics without ever looking at Avalanche's strategies.
If you want to compare your strategy with other classic continual learning algorithm or baselines, in Avalanche you can instantiate a strategy with a couple lines of code.
Most strategies require only 3 mandatory arguments:
torch.nn.Module
.torch.optim.Optimizer
already initialized on your model
.torch.nn.functional
.Additional arguments are optional and allow you to customize training (batch size, epochs, ...) or strategy-specific parameters (buffer size, regularization strenght, ...).
from torch.optim import SGD
from torch.nn import CrossEntropyLoss
from avalanche.models import SimpleMLP
from avalanche.training.strategies import Naive, CWRStar, Replay, GDumb, Cumulative, LwF, GEM, AGEM, EWC
model = SimpleMLP(num_classes=10)
optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = CrossEntropyLoss()
cl_strategy = Naive(
model, optimizer, criterion,
train_mb_size=100, train_epochs=4, eval_mb_size=100
)
Each strategy object offers two main methods: train
and eval
. Both of them, accept either a single experience(Experience
) or a list of them, for maximum flexibility.
We can train the model continually by iterating over the train_stream
provided by the scenario.
from avalanche.benchmarks.classic import SplitMNIST
# scenario
scenario = SplitMNIST(n_experiences=5, seed=1)
# TRAINING LOOP
print('Starting experiment...')
results = []
for experience in scenario.train_stream:
print("Start of experience: ", experience.current_experience)
print("Current Classes: ", experience.classes_in_this_experience)
cl_strategy.train(experience)
print('Training completed')
print('Computing accuracy on the whole test set')
results.append(cl_strategy.eval(scenario.test_stream))
We noticed that many continual learning strategies follow roughly the same training/evaluation loops and always implement the same boilerplate code. So, it seems natural to define most strategies by specializing the few methods that need to be changed. Most strategies only augment the naive strategy with additional behavior, without changing the basic training and evlaution loops. These strategies can be easily implemented with a couple of methods.
Avalanche's plugins allow to augment a strategy with additional behavior. Currently, most continual learning strategies are also implemented as plugins, which makes them easy to combine together. For example, it is extremely easy to create a hybrid strategy that combines replay and EWC together by passing the appropriate plugins
list to the BaseStrategy
:
from avalanche.training.strategies import BaseStrategy
from avalanche.training.plugins import ReplayPlugin, EWCPlugin
replay = ReplayPlugin(mem_size=100)
ewc = EWCPlugin(ewc_lambda=0.001)
strategy = BaseStrategy(
model, optimizer, criterion,
plugins=[replay, ewc])
In Avalanche you can build your own strategy in 2 main ways:
BaseStrategy
, which provides generic training and evaluation loops. You can safely override most methods to customize your strategy. However, there are some caveats to discuss (see later) and in general this approach is more difficult than plugins.Keep in mind that if you already have a continual learning strategy that does not use Avalanche, you can always use benchmarks
and evaluation
without using Avalanche's strategies!
As we already mentioned, Avalanche strategies inherit from BaseStrategy
. This strategy provides:
The training loop has the following structure:
train
before_training
before_training_exp
adapt_train_dataset
make_train_dataloader
before_training_epoch
before_training_iteration
before_forward
after_forward
before_backward
after_backward
after_training_iteration
before_update
after_update
after_training_epoch
after_training_exp
after_training
The evaluation loop is similar:
eval
before_eval
adapt_eval_dataset
make_eval_dataloader
before_eval_exp
eval_epoch
before_eval_iteration
before_eval_forward
after_eval_forward
after_eval_iteration
after_eval_exp
after_eval
Plugins provide a simple solution to define a new strategy by augmenting the behavior of another strategy (typically a naive strategy). This approach reduces the overhead and code duplication, improving code readability and prototyping speed.
Creating a plugin is straightforward. You create a class which inherits from StrategyPlugin
and implements the callbacks that you need. The exact callback to use depend on your strategy. For example, the following replay plugin uses after_training_exp
to update the buffer after each training experience, and the adapt_training_dataset
to concatenate the buffer's data with the current experience:
from avalanche.training.plugins import StrategyPlugin
class ReplayPlugin(StrategyPlugin):
"""
Experience replay plugin.
Handles an external memory filled with randomly selected
patterns and implements the "adapt_train_dataset" callback to add them to
the training set.
The :mem_size: attribute controls the number of patterns to be stored in
the external memory. In multitask scenarios, mem_size is the memory size
for each task. We assume the training set contains at least :mem_size: data
points.
"""
def __init__(self, mem_size=200):
super().__init__()
self.mem_size = mem_size
self.ext_mem = {} # a Dict<task_id, Dataset>
self.rm_add = None
def adapt_train_dataset(self, strategy, **kwargs):
"""
Expands the current training set with datapoint from
the external memory before training.
"""
curr_data = strategy.experience.dataset
# Additional set of the current batch to be concatenated to the ext.
# memory at the end of the training
self.rm_add = None
# how many patterns to save for next iter
h = min(self.mem_size // (strategy.training_exp_counter + 1),
len(curr_data))
# We recover it using the random_split method and getting rid of the
# second split.
self.rm_add, _ = random_split(
curr_data, [h, len(curr_data) - h]
)
if strategy.training_exp_counter > 0:
# We update the train_dataset concatenating the external memory.
# We assume the user will shuffle the data when creating the data
# loader.
for mem_task_id in self.ext_mem.keys():
mem_data = self.ext_mem[mem_task_id]
if mem_task_id in strategy.adapted_dataset:
cat_data = ConcatDataset([curr_data, mem_data])
strategy.adapted_dataset[mem_task_id] = cat_data
else:
strategy.adapted_dataset[mem_task_id] = mem_data
def after_training_exp(self, strategy, **kwargs):
""" After training we update the external memory with the patterns of
the current training batch/task. """
curr_task_id = strategy.experience.task_label
# replace patterns in random memory
ext_mem = self.ext_mem
if curr_task_id not in ext_mem:
ext_mem[curr_task_id] = self.rm_add
else:
rem_len = len(ext_mem[curr_task_id]) - len(self.rm_add)
_, saved_part = random_split(ext_mem[curr_task_id],
[len(self.rm_add), rem_len]
)
ext_mem[curr_task_id] = ConcatDataset([saved_part, self.rm_add])
self.ext_mem = ext_mem
Check StrategyPlugin
's documentation for a complete list of the available callbacks.
You can always define a custom strategy by overriding BaseStrategy
methods.
However, There is an important caveat to keep in mind. If you override a method, you must remember to call all the callback's handlers at the appropriate points. For example, train
calls before_training
and after_training
before and after the training loops, respectively. If your strategy strategy does not call them, plugins may not work as expected. The easiest way to avoid mistakes is to start from the BaseStrategy
method that you want to override and modify it to your own needs without removing the callbacks handling.
There is only a single plugin that is always used by default, the EvaluationPlugin
(see evaluation
tutorial). This means that if you break callbacks you must log metrics by yourself. This is totally possible but requires some manual work to update, log, and reset each metric, which is done automatically for you by the BaseStrategy
.
BaseStrategy
provides the global state of the loop in the strategy's attributes, which you can safely use when you override a method. As an example, the Cumulative
strategy trains a model continually on the union of all the experiences encountered so far. To achieve this, the cumulative strategy overrides adapt_train_dataset
and updates `self.adapted_dataset' by concatenating all the previous experiences with the current one.
class Cumulative(BaseStrategy):
def __init__(*args, **kwargs):
super().__init__(*args, **kwargs)
self.dataset = {} # cumulative dataset
def adapt_train_dataset(self, **kwargs):
super().adapt_train_dataset(**kwargs)
curr_task_id = self.experience.task_label
curr_data = self.experience.dataset
if curr_task_id in self.dataset:
cat_data = ConcatDataset([self.dataset[curr_task_id],
curr_data])
self.dataset[curr_task_id] = cat_data
else:
self.dataset[curr_task_id] = curr_data
self.adapted_dataset = self.dataset
Easy, isn't it? :-)
In general, we recommend to implement a Strategy via plugins, if possible. This approach is the easiest to use and requires a minimal knowledge of the BaseStrategy
. It also allows other people to use your plugin and facilitates interoperability among different strategies.
For example, replay strategies can be implemented as a custom strategy of the BaseStrategy
or as plugins. However, creating a plugin is better because it allows to use our replay strategy in conjunction with other strategies.
This completes the "Training" chapter for the "From Zero to Hero" series. We hope you enjoyed it!
You can run this chapter and play with it on Google Colaboratory: