cont_ava.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520
  1. import pandas as pd
  2. from skimage import io
  3. import numpy as np
  4. from PIL import Image
  5. import torch
  6. import torch.nn as nn # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions
  7. import torch.optim as optim # For all Optimization algorithms, SGD, Adam, etc.
  8. import torchvision.transforms as transforms # Transformations we can perform on our dataset
  9. from torchvision.io import read_image
  10. import os
  11. from torch.utils.data import (
  12. Dataset,
  13. DataLoader,
  14. )
  15. from typing import Optional, Sequence, Union, List
  16. from avalanche.benchmarks.generators import nc_benchmark, ni_benchmark
  17. from avalanche.models import SimpleCNN
  18. from avalanche.models import pytorchcv_wrapper
  19. from avalanche.benchmarks.utils.datasets_from_filelists import SeqPathsDataset, PathsDataset
  20. from avalanche.training.strategies import Naive, Cumulative, JointTraining
  21. #os.environ['CUDA_LAUNCH_BLOCKING'] = "0"
  22. from avalanche.benchmarks.generators import paths_benchmark
  23. #from avalanche.training.plugins.sequence_data import _get_random_indicies, _get_cls_acc_based_indicies
  24. from avalanche.training.plugins import (SeqDataPlugin,
  25. ReplayPlugin,
  26. EvaluationPlugin,
  27. MaximallyInterferedRetrievalRehersalPlugin,
  28. ClassErrorRehersalPlugin,
  29. ClassErrorRehersalTemperaturePlugin,
  30. ClassFrequencyRehearsalPlugin,
  31. RandomRehersal,
  32. ClassBalancingReservoirMemoryRehersalPlugin,
  33. ReservoirMemoryRehearsalPlugin,
  34. WeightedeMovingClassErrorAverageRehersalPlugin,
  35. ClassErrorFrequencyAvgRehearsalPlugin,
  36. FillExpBasedRehearsalPlugin,
  37. FillExpOversampleBasedRehearsalPlugin)
  38. from avalanche.evaluation.metrics import (ClassTop_nAvgAcc,
  39. SeasonClassTop_nAcc,
  40. Top1AccTransfer,
  41. ClasswiseTop_nAcc,
  42. Top_nAcc,
  43. SeasonTop_nAcc,
  44. SeqMaxAcc,
  45. SeqClassTop_nAvgAcc,
  46. SeqClassAvgMaxAcc,
  47. SeasonSeqClassTop_nAcc,
  48. SeqClasswiseTop_nAcc,
  49. SeqAnyAcc,
  50. SeasonSeqTop_nAcc,
  51. SeqTop_nAcc,
  52. BinaryAnimalAcc,
  53. QuatrupleAnimalAcc,
  54. SeqBinaryAnimalAcc,
  55. IgnoreNonAnimalOutputsTop_nAcc,
  56. SeqIgnoreNonAnimalOutputsAcc,
  57. ClassIgnoreNonAnimalOutputsTop_nAvgAcc
  58. )
  59. from avalanche.logging import StrategyLogger, InteractiveLogger, TextLogger, TensorboardLogger, CSVLogger, GenericCSVLogger
  60. from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics
  61. import matplotlib.pyplot as plt
  62. import pickle as pkl
  63. import time
  64. import argparse
  65. import threading
  66. from avalanche.models.utils import LabelSmoothingCrossEntropy
  67. from pathlib import Path
  68. import datetime
  69. import sys
  70. import random
  71. parser =argparse.ArgumentParser()
  72. parser.add_argument('--file', '-f', type=str, help='file name where train/test/val_stream.pkl should be loaded')
  73. parser.add_argument('--data_root', '-roots', type=str, default='/home/boehlke/AMMOD/continual_learning/data/boehlke_ssd_Species/', help='file name where train/test_stream.pkl should be loaded')
  74. parser.add_argument('--log_root', '-log', type=str, default='/home/boehlke/AMMOD/continual_learning/results_ever_eval_test/', help='file results shall be logged')
  75. parser.add_argument('--strategy', '-s', type=str, default='naive', help='strategy name')
  76. parser.add_argument('--epochs', '-e', type=int, default=3, help='number of times the finetuning set (rehearsal+experience) are itereted over')
  77. parser.add_argument('--batch_size', '-bs', type=int, default=48)
  78. parser.add_argument('--eval_after_n_exp', '-ne', type=int, default=1, help='interval of evaluation on validation set in experiences, in case of joint training, interval in epochs')
  79. parser.add_argument('--test_eval_after_n_exp', '-tne', type=int, default=100, help='interval of evaluation on validation set in experiences, in case of joint training, interval in epochs')
  80. parser.add_argument('--exp_size', '-es', type=int, default=128, help='exp size, also used as eval batch exp_size' )
  81. parser.add_argument('--memory_filling', '-memf', type=str, default='inf', help='how the memory shall be filled, default is inf, other options(cbrs_mem, stdrs_mem)')
  82. parser.add_argument('--memory_size', '-mems', type=float, default=0.3, help='either ratio of entire train stream saved in memory, or if larger than 1, number of instances possible in memory ')
  83. parser.add_argument('--temperature', '-temp', type=float, default=None, help='class error based reheasal with sharpening/softening temperature value')
  84. parser.add_argument('--rehearsal_method', '-rm', type=str, default=None, help='abreviation corresponding to one of the implemented rehersal methods such as "ce, cf, rr, mir, wce, cefa, fe, feo", see line 396')
  85. parser.add_argument('--nr_of_steps_to_avg', '-nexp_avg', type=int, default=None, help='needs to be set for WeightedeMovingClassErrorAverageRehersalPlugin and defines how many past exp are used in weighted average')
  86. parser.add_argument('--sigma', '-sig', type=float, default=None, help='needs to be set for WeightedeMovingClassErrorAverageRehersalPlugin, sets the rate of decay for weights in moving weighted average. The bigger sigma, the more past error rates are weitghed')
  87. parser.add_argument('--buffer_data_ratio', '-dr', type=float, default=8, help='how many images from memory should be used in the rehersal set for each image in the experience')
  88. parser.add_argument('--shuffle_stream', '-ss', action='store_true', default=False, help='if train_stream should be shuffled on an image level to obtain mode iid distribution in stream for comparison')
  89. parser.add_argument('--num_classes', '-nc', type=int, default=16, help='number of classes in dataset')
  90. parser.add_argument('--non_animal_cls', '-nac', default=[13,14,15], nargs='+', type=int, help='classes, that do not contain species, which is relevant for "binary" accuracies' )
  91. parser.add_argument('--images_root', '-ir', type=str, default='/home/AMMOD_data/camera_traps/BayerWald/G-Fallen/original/', help='directory where image files are stored')
  92. parser.add_argument('--label_dict', '-ld', type=str, default="/home/boehlke/AMMOD/cam_trap_classification/data/csv_files/BIRDS.pkl" , help='used to create confusion matrix at c-n intervals')
  93. args = parser.parse_args()
  94. print("PyTorch Version: ",torch.__version__)
  95. use_cuda = torch.cuda.is_available()
  96. global detected_device
  97. args.detected_device = torch.device("cuda:0" if use_cuda else "cpu")
  98. print(args.detected_device)
  99. number_workers = 8
  100. # Train and test data
  101. with open(args.data_root+args.file.replace('384','128')+'_train_stream.pkl', 'rb') as handle: # even for training with stream of exp size384 128 needs to be loaded first to match the winter stream files
  102. args.train_stream = pkl.load(handle)
  103. with open(args.data_root+args.file.replace('384','128')+'_test_stream.pkl', 'rb') as handle:
  104. args.test_stream = pkl.load(handle)#[:5000]
  105. with open(args.data_root+args.file.replace('384','128')+'_val_stream.pkl', 'rb') as handle:
  106. args.val_stream = pkl.load(handle)#[:5000]
  107. with open(args.data_root+args.file.replace('384','128')+'_winter_val_stream.pkl', 'rb') as handle:
  108. args.validation_stream_winter = pkl.load(handle)
  109. with open(args.data_root+args.file.replace('384','128')+'_winter_test_stream.pkl', 'rb') as handle:
  110. args.test_stream_winter = pkl.load(handle)
  111. with open(args.data_root+args.file.replace('_crop','').replace('384','128')+'_exp_season_split_dict.pkl', 'rb') as handle:
  112. args.exp_season_split_dict = pkl.load(handle)
  113. # Backbone model settings
  114. model_depth=18
  115. args.model = pytorchcv_wrapper.resnet('imagenet', depth=model_depth, pretrained=True)
  116. args.model.num_classes = args.num_classes
  117. num_input_ftrs = args.model.output.in_features
  118. args.model.output = nn.Linear(num_input_ftrs , args.model.num_classes)
  119. args.validation_stream_winter_files = np.array(args.validation_stream_winter)[:,0]
  120. args.validation_stream_winter_files = [args.images_root+s[0:] for s in args.validation_stream_winter_files ]
  121. args.test_stream_winter_files = np.array(args.test_stream_winter)[:,0]
  122. args.test_stream_winter_files = [args.images_root+s[0:] for s in args.test_stream_winter_files ]
  123. def flattened(list_w_sublists):
  124. flattened = []
  125. for item in list_w_sublists:
  126. if isinstance(item, list):
  127. for val in item:
  128. flattened.append(val)
  129. else:
  130. flattened.append(item)
  131. return flattened
  132. # creating the winter list from the loaded data
  133. winter_exp = args.exp_season_split_dict['winter']
  134. summer_exp = args.exp_season_split_dict['summer']
  135. winter_train_exp = [args.train_stream[i] for i in winter_exp]
  136. all_winter_train_data = [i for sublist in winter_train_exp for i in sublist]
  137. all_winter_train_files = np.array(all_winter_train_data)[:,0]
  138. all_winter_train_files = [args.images_root+s[0:] for s in all_winter_train_files ]
  139. all_test_files = np.array(args.test_stream)[:,0]
  140. all_test_files = [args.images_root+s[0:] for s in all_test_files ]
  141. all_val_files = np.array(args.validation_stream_winter)[:,0]
  142. all_val_files = [args.images_root+s[0:] for s in all_val_files ]
  143. args.all_winter_files = all_winter_train_files +args.test_stream_winter_files+args.validation_stream_winter_files
  144. # Optimizer Settings
  145. eps=1e-4
  146. weight_decay=75e-3
  147. args.optimizer = optim.AdamW(args.model.parameters(), eps=eps, weight_decay=weight_decay)
  148. args.criterion = LabelSmoothingCrossEntropy()
  149. with open(args.data_root+args.file+'_train_stream.pkl', 'rb') as handle:
  150. args.train_stream = pkl.load(handle)
  151. if args.shuffle_stream:
  152. train_exp_list_flatten = flattened(args.train_stream)
  153. random.shuffle(train_exp_list_flatten)
  154. shuffled_train_stream = []
  155. exp_size = len(args.train_stream[0])
  156. for i in range(len(args.train_stream)):
  157. shuffled_train_stream.append(train_exp_list_flatten[i*exp_size:(i+1)*exp_size])
  158. args.train_stream = shuffled_train_stream
  159. strategy = args.strategy
  160. # using parameters to assemble a name for the logging floder
  161. data_ratio =''
  162. rehersal=''
  163. if args.rehearsal_method is not None:
  164. strategy = 'rehersal'
  165. data_ratio = '_dataratio'+str(args.buffer_data_ratio)
  166. rehersal = '_'+args.rehearsal_method
  167. if ( args.rehearsal_method == 'ce' or args.rehearsal_method=='cf' ) and args.temperature is not None:
  168. rehersal = rehersal+'_temp'+str(args.temperature)
  169. if args.rehearsal_method == 'wce':
  170. if args.nr_of_steps_to_avg is None or args.sigma is None:
  171. print('ATTENTION: weighted moving average variables not set')
  172. print(args.nr_of_steps_to_avg)
  173. print(args.sigma)
  174. else:
  175. rehersal = rehersal+'_navg'+str(args.nr_of_steps_to_avg)+'_sigma'+str(args.sigma)
  176. if args.memory_filling !='inf':
  177. total_train_data = len(args.train_stream[0])*len(args.train_stream)
  178. print(args.memory_size)
  179. print(total_train_data)
  180. if args.memory_size <1:
  181. memory_size = int(args.memory_size*total_train_data)
  182. print('r', memory_size)
  183. else:
  184. memory_size = int(args.memory_size)
  185. print('i', memory_size)
  186. strategy = 'rehersal_' +args.memory_filling+'_memsize_'+str(args.memory_size)
  187. print('total memory size ', args.memory_size)
  188. else:
  189. print('inf_mem')
  190. print(args.memory_filling)
  191. if args.shuffle_stream:
  192. strategy = strategy+'_shuffled'
  193. # Logging rpaht
  194. log_root = Path('/home/boehlke/AMMOD/continual_learning/results/')
  195. if args.log_root is not None:
  196. log_root = Path(args.log_root)
  197. model_settings = 'resnet'+str(model_depth)+'_smoothloss_adamW_eps'+str(eps)+'wd'+str(weight_decay)+'_bs'+str(args.batch_size)
  198. strategy_settings = strategy+rehersal+data_ratio+'_ep'+str(args.epochs)+'_ne'+str(args.eval_after_n_exp)
  199. args.log_dir_name = log_root / model_settings/ args.file / strategy_settings
  200. print(args.log_dir_name)
  201. Path.mkdir(args.log_dir_name, parents=True, exist_ok=True)
  202. cls_in_teststream =np.unique(np.array(args.test_stream)[:,1])
  203. if cls_in_teststream.shape[0] != args.num_classes:
  204. print('ATTENTION: args.num_classes doesnt mathc number of classes in test_stream')
  205. print(args.num_classes)
  206. #exit(0)
  207. cls_in_valstream =np.unique(np.array(args.val_stream)[:,1])
  208. if cls_in_valstream.shape[0] != args.num_classes:
  209. print('ATTENTION: args.num_classes doesnt mathc number of classes in val_stream')
  210. print(args.num_classes)
  211. #exit(0)
  212. # Special case for
  213. with open(args.label_dict, 'rb') as p:
  214. label_dict = pkl.load(p)
  215. print(label_dict.items())
  216. if not 0 in label_dict.keys(): # when the class names arte the keys an the integer values are the values
  217. print('label_dict wrong way around')
  218. flip_lable_dict = {}
  219. for key,value in label_dict.items():
  220. flip_lable_dict[value] = key
  221. label_dict = flip_lable_dict
  222. # augmentations for train and val/test data
  223. resize = int(224/0.875)
  224. args.data_transforms = {
  225. 'train': transforms.Compose([
  226. transforms.Resize((resize,resize)),
  227. transforms.RandomCrop(size=(224,224)),
  228. transforms.RandomHorizontalFlip(),
  229. transforms.ToTensor(),
  230. #transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
  231. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  232. ]),
  233. 'val': transforms.Compose([
  234. transforms.Resize((224,224)),
  235. transforms.ToTensor(),
  236. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  237. ]),
  238. }
  239. def get_scenario(args):
  240. '''Returns a benchmark instance given a sequence of lists of files. A separate
  241. dataset will be created for each list. Each of those datasets
  242. will be considered a separate experience.
  243. function paths_benchmark() is defined in /avalanche/scenarios/generic_benchmark_creation.py
  244. under "create_generic_benchmark_from_paths()" which has been augmented from
  245. the original avalanche version with a path_dataset_class variable which allows for handling of data
  246. with more than path and label infomation. In this case, the sequence code is passed as well to allow
  247. for sequencewise evaluation.
  248. '''
  249. if args.strategy == 'joint':
  250. joint_train_stream = [item for sublist in args.train_stream for item in sublist ]
  251. args.train_stream = [joint_train_stream]
  252. task_labels = np.zeros((1,len(args.train_stream))).astype(int).tolist()[0]
  253. nac =torch.tensor(args.non_animal_cls, device=args.detected_device)
  254. nc = args.num_classes
  255. metrics = [
  256. #TrainingTime(reset_at='epoch', emit_at='epoch', mode='train'),
  257. #accuracy_metrics(minibatch=False, epoch=True, experience=False, stream=True),
  258. Top1AccTransfer(nr_train_exp=len(args.train_stream), reset_at='stream', emit_at='stream', mode='eval'),
  259. ClassTop_nAvgAcc(nr_classes=nc, n=1, reset_at='stream', emit_at='stream', mode='eval'),
  260. #ClassTop_nAvgAcc(nr_classes=nc, n=1,reset_at='epoch', emit_at='epoch', mode='train'),
  261. ClassTop_nAvgAcc(nr_classes=nc, n=2, reset_at='stream', emit_at='stream', mode='eval'),
  262. SeasonClassTop_nAcc(nr_classes=nc, n=1, winter_files_list=args.all_winter_files, season='summer', reset_at='stream', emit_at='stream', mode='eval'),
  263. SeasonClassTop_nAcc(nr_classes=nc, n=1, winter_files_list=args.all_winter_files, season='winter', reset_at='stream', emit_at='stream', mode='eval'),
  264. #ClassTop_nAvgAcc(nr_classes=nc, n=2,reset_at='epoch', emit_at='epoch', mode='train'),
  265. ClasswiseTop_nAcc(nr_classes=nc, n=1, reset_at='stream', emit_at='stream', mode='eval'),
  266. Top_nAcc(n=1, reset_at='stream', emit_at='stream', mode='eval'),
  267. SeasonTop_nAcc(n=1, winter_files_list=args.all_winter_files, season='summer', reset_at='stream', emit_at='stream', mode='eval'),
  268. SeasonTop_nAcc(n=1, winter_files_list=args.all_winter_files, season='winter', reset_at='stream', emit_at='stream', mode='eval'),
  269. #Top_nAcc(n=2, reset_at='stream', emit_at='stream', mode='eval'),
  270. #Top_nAcc(n=1, reset_at='epoch', emit_at='epoch', mode='train'),
  271. #Top_nAcc(n=2, reset_at='epoch', emit_at='epoch', mode='train'),
  272. SeqTop_nAcc(n=1, reset_at='stream', emit_at='stream', mode='eval'),
  273. SeqMaxAcc( reset_at='stream', emit_at='stream', mode='eval'),
  274. SeasonSeqTop_nAcc(num=1, winter_files_list=args.all_winter_files, season='summer', reset_at='stream', emit_at='stream', mode='eval'),
  275. SeasonSeqTop_nAcc(num=1, winter_files_list=args.all_winter_files, season='winter', reset_at='stream', emit_at='stream', mode='eval'),
  276. #SeqTop_nAcc(n=2, reset_at='stream', emit_at='stream', moall_winter_filesdall_winter_filese='eval'),
  277. #SeqTop_nAcc(n=3, reset_at='stream', emit_at='stream', mode='eval'),
  278. #SeqTop_nAcc(n=1, reset_at='epoch', emit_at='epoch', mode='train'),
  279. #SeqTop_nAcc(n=2, reset_at='epoch', emit_at='epoch', mode='train'),
  280. #SeqTop_nAcc(n=3, reset_at='epoch', emit_at='epoch', mode='train'),
  281. SeqClasswiseTop_nAcc(nr_classes=nc, n=1, reset_at='stream', emit_at='stream', mode='eval'),
  282. SeqClassTop_nAvgAcc(nr_classes=nc, n=1, reset_at='stream', emit_at='stream', mode='eval'),
  283. SeqClassAvgMaxAcc(nr_classes=nc, reset_at='stream', emit_at='stream', mode='eval'),
  284. SeasonSeqClassTop_nAcc(nr_classes=nc, n=1, winter_files_list=args.all_winter_files, season='summer', reset_at='stream', emit_at='stream', mode='eval'),
  285. SeasonSeqClassTop_nAcc(nr_classes=nc, n=1, winter_files_list=args.all_winter_files, season='winter', reset_at='stream', emit_at='stream', mode='eval'),
  286. #SeqClassTop_nAvgAcc(nr_classes=nc, n=1,reset_at='epoch', emit_at='epoch', mode='train'),
  287. #SeqClassTop_nAvgAcc(nr_classes=nc, n=2, reset_at='stream', emit_at='stream', mode='eval'),
  288. #SeqClassTop_nAvgAcc(nr_classes=nc, n=2,reset_at='epoch', emit_at='epoch', mode='train'),
  289. SeqAnyAcc(reset_at='stream', emit_at='stream', mode='eval'),
  290. #SeqAnyAcc(reset_at='epoch', emit_at='epoch', mode='train'),
  291. #BinaryAnimalAcc(non_animal_class =nac, reset_at='epoch', emit_at='epoch', mode='train'),
  292. #BinaryAnimalAcc(non_animal_class =nac,reset_at='stream', emit_at='stream', mode='eval'),
  293. #QuatrupleAnimalAcc(non_animal_class=nac, reset_at='epoch', emit_at='epoch', mode='train'),
  294. #QuatrupleAnimalAcc(non_animal_class=nac,reset_at='stream', emit_at='stream', mode='eval'),
  295. #SeqBinaryAnimalAcc(non_animal_class=nac,reset_at='epoch', emit_at='epoch', mode='train'),
  296. #SeqBinaryAnimalAcc(non_animal_class=nac,reset_at='stream', emit_at='stream', mode='eval'),
  297. #SeqIgnoreNonAnimalOutputsAcc(non_animal_class=nac,nr_classes=nc, reset_at='epoch', emit_at='epoch', mode='train'),
  298. #SeqIgnoreNonAnimalOutputsAcc(non_animal_class=nac,nr_classes=nc,reset_at='stream', emit_at='stream', mode='eval'),
  299. #ClassIgnoreNonAnimalOutputsTop_nAvgAcc(nr_classes=nc,n=1,non_animal_class=nac,reset_at='epoch', emit_at='epoch', mode='train'),
  300. #ClassIgnoreNonAnimalOutputsTop_nAvgAcc(nr_classes=nc,n=1, non_animal_class=nac,reset_at='stream', emit_at='stream', mode='eval'),
  301. #IgnoreNonAnimalOutputsTop_nAcc(non_animal_class=nac, nr_classes=nc, n=1, reset_at='epoch', emit_at='epoch', mode='train'),
  302. #IgnoreNonAnimalOutputsTop_nAcc(non_animal_class=nac,nr_classes=nc, n=1, reset_at='stream', emit_at='stream', mode='eval'),
  303. #NrStepsBatchwise(reset_at='epoch', emit_at='epoch', mode='train'),
  304. #NrStepsImagewise(reset_at='epoch', emit_at='epoch', mode='train')
  305. ]
  306. metrics = flattened(metrics)
  307. datetime_stamp = str(datetime.datetime.now()).replace(' ','_').replace(':','-').replace('.','-')[:-4]
  308. tb_logger = TensorboardLogger(tb_log_dir=str(args.log_dir_name)+"/tb_data_"+datetime_stamp, filename_suffix='test_run')
  309. interactive_logger = InteractiveLogger()
  310. gen_csv_logger = GenericCSVLogger(log_folder=str(args.log_dir_name)+'/gen_csvlogs_'+datetime_stamp)
  311. args.combined_logger = EvaluationPlugin(
  312. metrics,
  313. loggers=[tb_logger, gen_csv_logger, interactive_logger],
  314. suppress_warnings=False)
  315. scenario = paths_benchmark(
  316. args.train_stream,
  317. [args.test_stream],
  318. other_streams_lists_of_files={'validation': [args.val_stream]},
  319. task_labels=task_labels,
  320. complete_test_set_only=True,
  321. train_transform=args.data_transforms['train'],
  322. eval_transform=args.data_transforms['val'],
  323. other_streams_transforms={'validation_stream': args.data_transforms['val']},
  324. path_dataset_class=SeqPathsDataset,
  325. common_root = args.images_root
  326. )
  327. return scenario
  328. def get_strategy(args):
  329. ''' Returns the avalanche training strategy with the necessary plugins '''
  330. plugin = SeqDataPlugin()
  331. plugins = [plugin]
  332. if args.memory_filling=='inf':
  333. if args.rehearsal_method == 'ce':
  334. if args.temperature is None:
  335. plugin_rehearsal = ClassErrorRehersalPlugin(buffer_data_ratio= args.buffer_data_ratio)
  336. plugins.append(plugin_rehearsal)
  337. else:
  338. plugin_rehearsal = ClassErrorRehersalTemperaturePlugin(buffer_data_ratio= args.buffer_data_ratio, temperature=args.temperature)
  339. plugins.append(plugin_rehearsal)
  340. elif args.rehearsal_method == 'cf':
  341. plugin_rehearsal = ClassFrequencyRehearsalPlugin(buffer_data_ratio = args.buffer_data_ratio)
  342. plugins.append(plugin_rehearsal)
  343. elif args.rehearsal_method == 'mir':
  344. plugin_rehearsal = MaximallyInterferedRetrievalRehersalPlugin(buffer_data_ratio=args.buffer_data_ratio)
  345. plugins.append(plugin_rehearsal)
  346. elif args.rehearsal_method == 'rr':
  347. plugin_rehearsal = RandomRehersal(buffer_data_ratio=args.buffer_data_ratio)
  348. plugins.append(plugin_rehearsal)
  349. elif args.rehearsal_method == 'wce':
  350. plugin_rehearsal = WeightedeMovingClassErrorAverageRehersalPlugin( buffer_data_ratio=args.buffer_data_ratio, nr_of_steps_to_avg=args.nr_of_steps_to_avg, nr_classes=args.num_classes, sigma=args.sigma)
  351. plugins.append(plugin_rehearsal)
  352. elif args.rehearsal_method == 'cefa':
  353. plugin_rehearsal = ClassErrorFrequencyAvgRehearsalPlugin(buffer_data_ratio=args.buffer_data_ratio, nr_classes=args.num_classes )
  354. plugins.append(plugin_rehearsal)
  355. elif args.rehearsal_method == 'fe':
  356. plugin_rehearsal = FillExpBasedRehearsalPlugin(buffer_data_ratio=args.buffer_data_ratio, nr_classes=args.num_classes )
  357. plugins.append(plugin_rehearsal)
  358. elif args.rehearsal_method == 'feo':
  359. plugin_rehearsal = FillExpOversampleBasedRehearsalPlugin(buffer_data_ratio=args.buffer_data_ratio, nr_classes=args.num_classes )
  360. plugins.append(plugin_rehearsal)
  361. elif args.memory_filling == 'cbrs':
  362. plugin_rehearsal = ClassBalancingReservoirMemoryRehersalPlugin(buffer_data_ratio=args.buffer_data_ratio, memory_size=memory_size, nr_classes=args.num_classes, rehearsal_selection_strategy=args.rehearsal_method, temperature=args.temperature)
  363. plugins.append(plugin_rehearsal)
  364. elif args.memory_filling == 'stdrs':
  365. plugin_rehearsal = ReservoirMemoryRehearsalPlugin(buffer_data_ratio=args.buffer_data_ratio, memory_size=memory_size, nr_classes=args.num_classes, rehearsal_selection_strategy=args.rehearsal_method, temperature=args.temperature)
  366. plugins.append(plugin_rehearsal)
  367. print('no rehearsal method selected')
  368. if args.strategy == 'joint':
  369. ''' In the joint baseline scenario the naive base strategy defined in /avalanche/training/strategies/base_strategies.py
  370. is used
  371. '''
  372. print('getting joint strategy')
  373. strategy = Naive(
  374. args.model,
  375. args.optimizer,
  376. args.criterion,
  377. train_mb_size=args.batch_size,
  378. train_epochs=1,
  379. eval_mb_size=args.exp_size,
  380. device=args.detected_device,
  381. plugins=[plugin],
  382. evaluator=args.combined_logger,
  383. label_dict=label_dict
  384. )
  385. elif args.strategy == 'cumulative':
  386. print('getting cumulative strategy')
  387. strategy = Cumulative(
  388. args.model,
  389. args.optimizer,
  390. args.criterion,
  391. train_mb_size=args.batch_size,
  392. train_epochs=args.epochs,
  393. eval_mb_size=args.exp_size,
  394. device=args.detected_device,
  395. plugins=[plugin],
  396. evaluator=args.combined_logger,
  397. label_dict=label_dict
  398. )
  399. else:
  400. print('getting naive strategy')
  401. strategy = Naive(
  402. args.model,
  403. args.optimizer,
  404. args.criterion,
  405. train_mb_size=args.batch_size,
  406. train_epochs=args.epochs,
  407. eval_mb_size=args.exp_size,
  408. device=args.detected_device,
  409. plugins=plugins,
  410. evaluator=args.combined_logger,
  411. label_dict=label_dict
  412. )
  413. return strategy
  414. scenario = get_scenario(args)
  415. cl_strategy = get_strategy(args)
  416. # different loops are necessary for different strategies.
  417. # when using class error based method, the evaluation on the validation data should occure more frequently, meaning the variable eval_after_n_exp (-ne) should be less than 10
  418. if args.strategy == 'joint':
  419. for i in range(args.epochs):
  420. #print('joint ', dir(cl_strategy) )
  421. cl_strategy.train(scenario.train_stream[0], num_workers=number_workers)
  422. print('Training completed')
  423. last_eval = i ==args.epochs-1
  424. if i%args.test_eval_after_n_exp==0 or last_eval:
  425. print('Computing accuracy on the whole test set')
  426. cl_strategy.eval(scenario.test_stream, num_workers=number_workers, last_eval=True)
  427. if i%args.eval_after_n_exp==0 or last_eval:
  428. print('Computing accuracy on the whole validation set')
  429. cl_strategy.eval(scenario.validation_stream, num_workers=number_workers, last_eval=True)
  430. if last_eval:
  431. cl_strategy.eval(scenario.train_stream, num_workers=number_workers, last_eval=False)
  432. else:
  433. train_stream_len = len(scenario.train_stream)
  434. for i, experience in enumerate(scenario.train_stream):
  435. cl_strategy.train(experience, num_workers=number_workers)
  436. print('Training completed')
  437. last_eval = i ==train_stream_len-1
  438. if (i%args.eval_after_n_exp==0 or last_eval):
  439. print('Evaluaton after experience: ', i )
  440. #cl_strategy.eval(scenario.train_stream, num_workers=number_workers, last_eval=True)
  441. print('Computing accuracy on the whole validation set')
  442. cl_strategy.eval(scenario.validation_stream, num_workers=number_workers, last_eval=True)
  443. if (i%args.test_eval_after_n_exp==0 or last_eval):
  444. print('Evaluaton after experience: ', i )
  445. print('Computing accuracy on the whole test set')
  446. cl_strategy.eval(scenario.test_stream, num_workers=number_workers, last_eval=True)
  447. if last_eval:
  448. cl_strategy.eval(scenario.train_stream, num_workers=number_workers, last_eval=True)