|
@@ -1,7 +1,8 @@
|
|
|
import abc
|
|
|
+import gc
|
|
|
import logging
|
|
|
import pyaml
|
|
|
-import gc
|
|
|
+import typing as T
|
|
|
|
|
|
from bdb import BdbQuit
|
|
|
from chainer.serializers import save_npz
|
|
@@ -9,9 +10,13 @@ from chainer.training import extension
|
|
|
from chainer.training import extensions
|
|
|
from chainer.training import updaters
|
|
|
from cvdatasets.utils import pretty_print_dict
|
|
|
+from functools import partial
|
|
|
from pathlib import Path
|
|
|
|
|
|
+
|
|
|
from cvfinetune.finetuner.mixins.base import BaseMixin
|
|
|
+from cvfinetune.training.extensions import SacredReport
|
|
|
+from cvfinetune.utils.sacred import Experiment
|
|
|
|
|
|
@extension.make_extension(default_name="ManualGC", trigger=(1, "iteration"))
|
|
|
def gc_collect(trainer):
|
|
@@ -25,9 +30,13 @@ class _TrainerMixin(BaseMixin):
|
|
|
def __init__(self, *args,
|
|
|
updater_cls=updaters.StandardUpdater,
|
|
|
updater_kwargs: dict = {},
|
|
|
+
|
|
|
only_eval: bool = False,
|
|
|
init_eval: bool = False,
|
|
|
+
|
|
|
+ experiment_name: T.Optional[str] = None,
|
|
|
no_snapshot: bool = False,
|
|
|
+ no_sacred: bool = False,
|
|
|
|
|
|
manual_gc: bool = True,
|
|
|
**kwargs):
|
|
@@ -38,8 +47,37 @@ class _TrainerMixin(BaseMixin):
|
|
|
self.only_eval = only_eval
|
|
|
self.init_eval = init_eval
|
|
|
self.no_snapshot = no_snapshot
|
|
|
+ self.no_sacred = no_sacred
|
|
|
+ self.experiment_name = experiment_name
|
|
|
self.manual_gc = manual_gc
|
|
|
|
|
|
+ self.ex = None
|
|
|
+
|
|
|
+ @property
|
|
|
+ def no_observe(self):
|
|
|
+ return self.no_sacred
|
|
|
+
|
|
|
+ def init_experiment(self, *, config: dict):
|
|
|
+ """ creates a sacred experiment that is later used by the trainer's sacred extension """
|
|
|
+
|
|
|
+ self.config = config
|
|
|
+
|
|
|
+ if self.no_sacred:
|
|
|
+ logging.warning("Default sacred workflow is disabled by the --no_sacred option!")
|
|
|
+ return
|
|
|
+
|
|
|
+ self.ex = Experiment(
|
|
|
+ name=self.experiment_name,
|
|
|
+ config=self.config,
|
|
|
+ no_observe=self.no_observe)
|
|
|
+
|
|
|
+ # self.trainer will initialized later
|
|
|
+ def run(*args, **kwargs):
|
|
|
+ self._check_attr("trainer")
|
|
|
+ return self.trainer.run(*args, **kwargs)
|
|
|
+
|
|
|
+ self.ex.main(run)
|
|
|
+
|
|
|
|
|
|
def init_updater(self):
|
|
|
"""Creates an updater from training iterator and the optimizer."""
|
|
@@ -90,35 +128,52 @@ class _TrainerMixin(BaseMixin):
|
|
|
|
|
|
def run(self, trainer_cls, opts, *args, **kwargs):
|
|
|
|
|
|
- trainer = self._new_trainer(trainer_cls, opts, *args, **kwargs)
|
|
|
+ self.trainer = self._new_trainer(trainer_cls, opts, *args, **kwargs)
|
|
|
|
|
|
if self.manual_gc:
|
|
|
- trainer.extend(gc_collect)
|
|
|
+ self.trainer.extend(gc_collect)
|
|
|
|
|
|
- self.save_meta_info(opts, folder=Path(trainer.out, "meta"))
|
|
|
+ self.save_meta_info()
|
|
|
|
|
|
logging.info("Snapshotting is {}abled".format("dis" if self.no_snapshot else "en"))
|
|
|
|
|
|
- def dump(suffix):
|
|
|
- if self.only_eval or self.no_snapshot:
|
|
|
- return
|
|
|
-
|
|
|
- save_npz(Path(trainer.out, f"clf_{suffix}.npz"), self.clf)
|
|
|
- save_npz(Path(trainer.out, f"model_{suffix}.npz"), self.model)
|
|
|
-
|
|
|
try:
|
|
|
- trainer.run(self.init_eval or self.only_eval)
|
|
|
+ self.run_experiment(self.init_eval or self.only_eval)
|
|
|
except (KeyboardInterrupt, BdbQuit) as e:
|
|
|
raise e
|
|
|
except Exception as e:
|
|
|
- dump("exception")
|
|
|
+ self.dump("exception")
|
|
|
raise e
|
|
|
else:
|
|
|
- dump("final")
|
|
|
+ self.dump("final")
|
|
|
+
|
|
|
+
|
|
|
+ def run_experiment(self, *args, **kwargs):
|
|
|
+
|
|
|
+ if self.ex is None:
|
|
|
+ return self.trainer.run(*args, **kwargs)
|
|
|
+
|
|
|
+ sacred_reporter = SacredReport(ex=self.ex, trigger=(1, "epoch"))
|
|
|
+ self.trainer.extend(sacred_reporter)
|
|
|
+ return self.ex(*args, **kwargs)
|
|
|
+
|
|
|
+ def _trainer_output(self, name: str = ""):
|
|
|
+ return Path(self.trainer.out, name)
|
|
|
+
|
|
|
+ def dump(self, suffix):
|
|
|
+ self._check_attr("config")
|
|
|
+ if self.only_eval or self.no_snapshot:
|
|
|
+ return
|
|
|
+
|
|
|
+ save_npz(self._trainer_output(f"clf_{suffix}.npz"), self.clf)
|
|
|
+ save_npz(self._trainer_output(f"model_{suffix}.npz"), self.model)
|
|
|
+
|
|
|
+ def save_meta_info(self, meta_folder: str = "meta"):
|
|
|
+ self._check_attr("config")
|
|
|
|
|
|
- def save_meta_info(self, opts, folder: Path):
|
|
|
+ folder = self._trainer_output(meta_folder)
|
|
|
folder.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
with open(folder / "args.yml", "w") as f:
|
|
|
- pyaml.dump(opts.__dict__, f, sort_keys=True)
|
|
|
+ pyaml.dump(self.config, f, sort_keys=True)
|
|
|
|