csv_logger.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438
  1. ################################################################################
  2. # Copyright (c) 2021 ContinualAI. #
  3. # Copyrights licensed under the MIT License. #
  4. # See the accompanying LICENSE file for terms. #
  5. # #
  6. # Date: 2020-01-25 #
  7. # Author(s): Andrea Cossu #
  8. # E-mail: contact@continualai.org #
  9. # Website: avalanche.continualai.org #
  10. ################################################################################
  11. from typing import List, TYPE_CHECKING
  12. import torch
  13. import os
  14. import datetime
  15. from avalanche.evaluation.metric_results import MetricValue
  16. from avalanche.logging import StrategyLogger
  17. if TYPE_CHECKING:
  18. from avalanche.training import BaseStrategy
  19. class CSVLogger(StrategyLogger):
  20. """
  21. The `CSVLogger` logs accuracy and loss metrics into a csv file.
  22. Metrics are logged separately for training and evaluation in files
  23. training_results.csv and eval_results.csv, respectively.
  24. This Logger assumes that the user is evaluating on only one experience
  25. during training (see below for an example of a `train` call).
  26. Trough the `EvaluationPlugin`, the user should monitor at least
  27. EpochAccuracy/Loss and ExperienceAccuracy/Loss.
  28. If monitored, the logger will also record Experience Forgetting.
  29. In order to monitor the performance on held-out experience
  30. associated to the current training experience, set
  31. `eval_every=1` (or larger value) in the strategy constructor
  32. and pass the eval experience to the `train` method:
  33. `for i, exp in enumerate(benchmark.train_stream):`
  34. `strategy.train(exp, eval_streams=[benchmark.test_stream[i]])`
  35. When not provided, validation loss and validation accuracy
  36. will be logged as zero.
  37. The training file header is composed of:
  38. training_exp_id, epoch, training_accuracy, val_accuracy,
  39. training_loss, val_loss.
  40. The evaluation file header is composed of:
  41. eval_exp, training_exp, eval_accuracy, eval_loss, forgetting
  42. """
  43. def __init__(self, log_folder=None):
  44. """
  45. Creates an instance of `CSVLogger` class.
  46. :param log_folder: folder in which to create log files.
  47. If None, `csvlogs` folder in the default current directory
  48. will be used.
  49. """
  50. super().__init__()
  51. self.log_folder = log_folder if log_folder is not None else "csvlogs"
  52. os.makedirs(self.log_folder, exist_ok=True)
  53. self.training_file = open(os.path.join(self.log_folder,
  54. 'training_results.csv'), 'w')
  55. self.eval_file = open(os.path.join(self.log_folder,
  56. 'eval_results.csv'), 'w')
  57. os.makedirs(self.log_folder, exist_ok=True)
  58. # current training experience id
  59. self.training_exp_id = None
  60. # if we are currently training or evaluating
  61. # evaluation within training will not change this flag
  62. self.in_train_phase = None
  63. # validation metrics computed during training
  64. self.val_acc, self.val_loss = 0, 0
  65. # print csv headers
  66. print('training_exp', 'epoch', 'training_accuracy', 'val_accuracy',
  67. 'training_loss', 'val_loss', sep=',', file=self.training_file,
  68. flush=True)
  69. print('eval_exp', 'training_exp', 'eval_accuracy', 'eval_loss',
  70. 'forgetting', sep=',', file=self.eval_file, flush=True)
  71. def log_metric(self, metric_value: 'MetricValue', callback: str) -> None:
  72. pass
  73. def _val_to_str(self, m_val):
  74. if isinstance(m_val, torch.Tensor):
  75. return '\n' + str(m_val)
  76. elif isinstance(m_val, float):
  77. return f'{m_val:.4f}'
  78. else:
  79. return str(m_val)
  80. def print_train_metrics(self, training_exp, epoch, train_acc,
  81. val_acc, train_loss, val_loss):
  82. print(training_exp, epoch, self._val_to_str(train_acc),
  83. self._val_to_str(val_acc), self._val_to_str(train_loss),
  84. self._val_to_str(val_loss), sep=',',
  85. file=self.training_file, flush=True)
  86. def print_eval_metrics(self, eval_exp, training_exp, eval_acc,
  87. eval_loss, forgetting):
  88. print(eval_exp, training_exp, self._val_to_str(eval_acc),
  89. self._val_to_str(eval_loss), self._val_to_str(forgetting),
  90. sep=',', file=self.eval_file, flush=True)
  91. def after_training_epoch(self, strategy: 'BaseStrategy',
  92. metric_values: List['MetricValue'], **kwargs):
  93. super().after_training_epoch(strategy, metric_values, **kwargs)
  94. train_acc, val_acc, train_loss, val_loss = 0, 0, 0, 0
  95. for val in metric_values:
  96. if 'train_stream' in val.name:
  97. if val.name.startswith('Top1_Acc_Epoch'):
  98. train_acc = val.value
  99. elif val.name.startswith('Loss_Epoch'):
  100. train_loss = val.value
  101. self.print_train_metrics(self.training_exp_id, strategy.epoch,
  102. train_acc, self.val_acc, train_loss,
  103. self.val_loss)
  104. def after_eval_exp(self, strategy: 'BaseStrategy',
  105. metric_values: List['MetricValue'], **kwargs):
  106. super().after_eval_exp(strategy, metric_values, **kwargs)
  107. acc, loss, forgetting = 0, 0, 0
  108. for val in metric_values:
  109. if self.in_train_phase: # validation within training
  110. if val.name.startswith('Top1_Acc_Exp'):
  111. self.val_acc = val.value
  112. elif val.name.startswith('Loss_Exp'):
  113. self.val_loss = val.value
  114. else:
  115. if val.name.startswith('Top1_Acc_Exp'):
  116. acc = val.value
  117. elif val.name.startswith('Loss_Exp'):
  118. loss = val.value
  119. elif val.name.startswith('ExperienceForgetting'):
  120. forgetting = val.value
  121. if not self.in_train_phase:
  122. self.print_eval_metrics(strategy.experience.current_experience,
  123. self.training_exp_id, acc, loss,
  124. forgetting)
  125. def before_training_exp(self, strategy: 'BaseStrategy',
  126. metric_values: List['MetricValue'], **kwargs):
  127. super().before_training(strategy, metric_values, **kwargs)
  128. self.training_exp_id = strategy.experience.current_experience
  129. def before_eval(self, strategy: 'BaseStrategy',
  130. metric_values: List['MetricValue'], **kwargs):
  131. """
  132. Manage the case in which `eval` is first called before `train`
  133. """
  134. if self.in_train_phase is None:
  135. self.in_train_phase = False
  136. def before_training(self, strategy: 'BaseStrategy',
  137. metric_values: List['MetricValue'], **kwargs):
  138. self.in_train_phase = True
  139. def after_training(self, strategy: 'BaseStrategy',
  140. metric_values: List['MetricValue'], **kwargs):
  141. self.in_train_phase = False
  142. def close(self):
  143. self.training_file.close()
  144. self.eval_file.close()
  145. class GenericCSVLogger(StrategyLogger):
  146. # A more comprehensive CSV Logger capable of logging new metrics
  147. def __init__(self, log_folder=None):
  148. super().__init__()
  149. self.log_folder = log_folder if log_folder is not None else "csvlogs"
  150. os.makedirs(self.log_folder, exist_ok=True)
  151. datetime_stap = str(datetime.datetime.now()).replace(' ','_').replace(':','-').replace('.','-')[:-4]
  152. self.training_epoch_file = open(os.path.join(self.log_folder, '_training_epochs.csv'), 'w')
  153. self.test_stream_file = open(os.path.join(self.log_folder,'_test_stream.csv'), 'w')
  154. self.validation_stream_file = open(os.path.join(self.log_folder,'_validation_stream.csv'), 'w')
  155. self.training_stream_file = open(os.path.join(self.log_folder, '_training_stream.csv'), 'w')
  156. self.transfer_file = open(os.path.join(self.log_folder,'_training_transfer.csv'), 'w')
  157. os.makedirs(self.log_folder, exist_ok=True)
  158. # current training experience id
  159. self.training_exp_id = None
  160. # if we are currently training or evaluating
  161. # evaluation within training will not change this flag
  162. self.in_train_phase = None
  163. # validation metrics computed during training
  164. self.val_acc, self.val_loss = 0, 0
  165. def log_metric(self, metric_value: 'MetricValue', callback: str) -> None:
  166. pass
  167. def _val_to_str(self, m_val):
  168. if isinstance(m_val, torch.Tensor):
  169. return '\n' + str(m_val)
  170. elif isinstance(m_val, float):
  171. return f'{m_val:.4f}'
  172. else:
  173. return str(m_val)
  174. def print_vals_to_file(self, val_list, mode, exp, epoch, strategy):
  175. add_classes = False
  176. if mode =='test':
  177. classes_in_last_exp = [str(i)+' ' for i in self.classes_in_last_exp]
  178. classes_in_last_exp = ''.join(classes_in_last_exp)
  179. classes_in_last_exp = '['+classes_in_last_exp+']'
  180. add_classes = True
  181. log_file = self.test_stream_file
  182. elif mode =='validation':
  183. classes_in_last_exp = [str(i)+' ' for i in self.classes_in_last_exp]
  184. classes_in_last_exp = ''.join(classes_in_last_exp)
  185. classes_in_last_exp = '['+classes_in_last_exp+']'
  186. add_classes = True
  187. log_file = self.validation_stream_file
  188. elif mode =='train_epoch':
  189. log_file = self.training_epoch_file
  190. elif mode== 'train_stream':
  191. classes_in_last_exp = [str(i)+' ' for i in self.classes_in_last_exp]
  192. classes_in_last_exp = ''.join(classes_in_last_exp)
  193. classes_in_last_exp = '['+classes_in_last_exp+']'
  194. add_classes = True
  195. log_file = self.training_stream_file
  196. elif mode == 'transfer':
  197. log_file = self.transfer_file
  198. if exp == 0 and (epoch+1)==strategy.train_epochs:
  199. nr_exp = len(val_list)
  200. exp_list = torch.arange(nr_exp)
  201. exp_list = ['Exp '+ str(i.item()) for i in exp_list]
  202. exp_list = ['train_exp', 'epoch']+ exp_list
  203. print(*exp_list, sep=',', file=log_file, flush=True)
  204. val_list = [self._val_to_str(i) for i in val_list]
  205. if add_classes:
  206. log_list = [exp, epoch, classes_in_last_exp ]+ val_list
  207. else:
  208. log_list = [exp, epoch ]+ val_list
  209. print(*log_list, sep=',', file=log_file, flush=True)
  210. def after_training_epoch(self, strategy: 'BaseStrategy',
  211. metric_values: List['MetricValue'], **kwargs):
  212. super().after_training_epoch(strategy, metric_values, **kwargs)
  213. train_vals = [None] * len(self.train_metrics)
  214. train_exp = False
  215. metrics_found = False
  216. for val in metric_values:
  217. if 'train_stream' in val.name:
  218. train_exp = True
  219. metrics_found = True
  220. val_index = self.train_metrics.index(val.name.split('/')[0])
  221. train_vals[val_index] = val.value
  222. if train_exp:
  223. self.print_vals_to_file(train_vals, mode='train_epoch', exp = self.training_exp_id,epoch= strategy.epoch, strategy=strategy)
  224. def after_eval_exp(self, strategy: 'BaseStrategy',
  225. metric_values: List['MetricValue'], **kwargs):
  226. super().after_eval_exp(strategy, metric_values, **kwargs)
  227. metrics_found = False
  228. if metrics_found:
  229. if not train_eval_exp :
  230. self.print_vals_to_file(eval_vals, mode ='eval', exp = self.training_exp_id,epoch= strategy.epoch, strategy=strategy)
  231. else:
  232. self.print_vals_to_file(train_vals, mode ='train', exp = self.training_exp_id,epoch= strategy.epoch, strategy=strategy)
  233. def after_eval(self, strategy: 'BaseStrategy',
  234. metric_values: List['MetricValue'], **kwargs):
  235. super().after_eval_exp(strategy, metric_values, **kwargs)
  236. train_metrics_found = False
  237. validation_metrics_found = False
  238. test_metrics_found = False
  239. train_stream_accs = []
  240. train_stream_vals = [None]*len(self.eval_metrics)
  241. test_stream_vals = [None]*len(self.eval_metrics)
  242. validation_stream_vals = [None]*len(self.eval_metrics)
  243. classes_acc = []
  244. seq_classes_acc = []
  245. for val in metric_values:
  246. if 'eval_phase' in val.name: # validation within training
  247. if 'Transfer' in val.name :
  248. if 'train_stream' in val.name:
  249. train_metrics_found = True
  250. train_stream_accs.append(val.value)
  251. else:
  252. if 'train_stream' in val.name:
  253. train_metrics_found = True
  254. if 'SeqClasswise' in val.name:
  255. seq_classes_acc.append(val.value)
  256. elif 'Classwise' in val.name:
  257. classes_acc.append(val.value)
  258. else:
  259. val_index = self.eval_metrics.index(val.name.split('/')[0])
  260. train_stream_vals[val_index] = val.value
  261. if 'test_stream' in val.name:
  262. test_metrics_found = True
  263. if 'SeqClasswise' in val.name:
  264. seq_classes_acc.append(val.value)
  265. elif 'Classwise' in val.name:
  266. classes_acc.append(val.value)
  267. else:
  268. val_index = self.eval_metrics.index(val.name.split('/')[0])
  269. test_stream_vals[val_index] = val.value
  270. if 'validation_stream' in val.name:
  271. validation_metrics_found = True
  272. if 'SeqClasswise' in val.name:
  273. seq_classes_acc.append(val.value)
  274. elif 'Classwise' in val.name:
  275. classes_acc.append(val.value)
  276. else:
  277. val_index = self.eval_metrics.index(val.name.split('/')[0])
  278. validation_stream_vals[val_index] = val.value
  279. if test_metrics_found:
  280. if classes_acc != []:
  281. test_stream_vals = test_stream_vals+classes_acc
  282. if seq_classes_acc != []:
  283. test_stream_vals = test_stream_vals+seq_classes_acc
  284. self.print_vals_to_file(test_stream_vals, mode ='test', exp = self.training_exp_id,epoch= strategy.epoch, strategy=strategy)
  285. if validation_metrics_found:
  286. if classes_acc != []:
  287. validation_stream_vals = validation_stream_vals+classes_acc
  288. if seq_classes_acc != []:
  289. validation_stream_vals = validation_stream_vals+seq_classes_acc
  290. self.print_vals_to_file(validation_stream_vals, mode ='validation', exp = self.training_exp_id,epoch= strategy.epoch, strategy=strategy)
  291. if train_metrics_found:
  292. if train_stream_accs != []:
  293. self.print_vals_to_file(train_stream_accs, mode ='transfer', exp=self.training_exp_id, epoch=strategy.epoch, strategy=strategy)
  294. if not None in train_stream_vals:
  295. if classes_acc != []:
  296. train_stream_vals = train_stream_vals+classes_acc
  297. if seq_classes_acc != []:
  298. train_stream_vals = train_stream_vals+seq_classes_acc
  299. self.print_vals_to_file(train_stream_vals, mode='train_stream', exp=self.training_exp_id, epoch=strategy.epoch, strategy=strategy)
  300. def before_training_exp(self, strategy: 'BaseStrategy',
  301. metric_values: List['MetricValue'], **kwargs):
  302. super().before_training(strategy, metric_values, **kwargs)
  303. self.training_exp_id = strategy.experience.current_experience
  304. strategy.current_train_exp_seen = strategy.experience.current_experience
  305. self.classes_in_last_exp = strategy.experience.classes_in_this_experience
  306. def before_eval(self, strategy: 'BaseStrategy',
  307. metric_values: List['MetricValue'], **kwargs):
  308. """
  309. Manage the case in which `eval` is first called before `train`
  310. """
  311. if self.in_train_phase is None:
  312. self.in_train_phase = False
  313. def before_training(self, strategy: 'BaseStrategy',
  314. metric_values: List['MetricValue'], **kwargs):
  315. if strategy.experience == None:
  316. self.train_metrics =[]
  317. self.eval_metrics = []
  318. logging_classes_accs = False
  319. seq_logging_classes_accs = False
  320. for metric in self.all_metrics:
  321. if metric._mode == 'train':
  322. self.train_metrics.append(str(metric))
  323. else:
  324. if 'SeqClasswise' in str(metric):
  325. seq_logging_classes_accs =True
  326. elif 'Classwise' in str(metric):
  327. logging_classes_accs =True
  328. elif not 'Transfer' in str(metric) :
  329. self.eval_metrics.append(str(metric))
  330. train_head = ['train_exp', 'epoch']+ self.train_metrics
  331. eval_head = ['train_exp', 'epoch', 'train exp classes']+ self.eval_metrics
  332. if logging_classes_accs:
  333. classes_head = ['Cls'+str(i.item()) for i in torch.arange(strategy.model.num_classes)]
  334. eval_head = eval_head +classes_head
  335. if seq_logging_classes_accs:
  336. classes_head = ['SeqCls'+str(i.item()) for i in torch.arange(strategy.model.num_classes)]
  337. eval_head = eval_head +classes_head
  338. print(*train_head, sep=',', file=self.training_epoch_file, flush=True)
  339. print(*eval_head, sep=',', file=self.validation_stream_file, flush=True)
  340. print(*eval_head, sep=',', file=self.test_stream_file, flush=True)
  341. print(*eval_head, sep=',', file=self.training_stream_file, flush=True)
  342. self.in_train_phase = True
  343. def after_training(self, strategy: 'BaseStrategy',
  344. metric_values: List['MetricValue'], **kwargs):
  345. self.in_train_phase = False
  346. def close(self):
  347. self.training_epoch_file.close()
  348. self.validation_stream_file.close()
  349. self.test_stream_file.close()
  350. self.training_stream_file.close()
  351. self.transfer_file.close()