Ver código fonte

reworked sacred experiment creation

Dimitri Korsch 3 anos atrás
pai
commit
cddcc6dd6c

+ 3 - 1
cvfinetune/finetuner/base.py

@@ -15,9 +15,11 @@ class DefaultFinetuner(
 
 	"""
 
-	def __init__(self, *args, gpu = [-1], **kwargs):
+	def __init__(self, *args, config: dict = {},  gpu = [-1], **kwargs):
 		super().__init__(*args, **kwargs)
 
+		self.init_experiment(config=config)
+
 		self.gpu_config(gpu)
 		self.read_annotations()
 

+ 1 - 1
cvfinetune/finetuner/factory.py

@@ -42,7 +42,7 @@ class FinetunerFactory(object):
     def __call__(self, opts, **kwargs):
         opt_kwargs = self.tuner_cls.extract_kwargs(opts)
         _kwargs = dict(self.kwargs, **kwargs, **opt_kwargs)
-        return self.tuner_cls(**_kwargs)
+        return self.tuner_cls(config=opt_kwargs, **_kwargs)
 
     def get(self, key, default=None):
         return self.kwargs.get(key, default)

+ 71 - 16
cvfinetune/finetuner/mixins/trainer.py

@@ -1,7 +1,8 @@
 import abc
+import gc
 import logging
 import pyaml
-import gc
+import typing as T
 
 from bdb import BdbQuit
 from chainer.serializers import save_npz
@@ -9,9 +10,13 @@ from chainer.training import extension
 from chainer.training import extensions
 from chainer.training import updaters
 from cvdatasets.utils import pretty_print_dict
+from functools import partial
 from pathlib import Path
 
+
 from cvfinetune.finetuner.mixins.base import BaseMixin
+from cvfinetune.training.extensions import SacredReport
+from cvfinetune.utils.sacred import Experiment
 
 @extension.make_extension(default_name="ManualGC", trigger=(1, "iteration"))
 def gc_collect(trainer):
@@ -25,9 +30,13 @@ class _TrainerMixin(BaseMixin):
     def __init__(self, *args,
                  updater_cls=updaters.StandardUpdater,
                  updater_kwargs: dict = {},
+
                  only_eval: bool = False,
                  init_eval: bool = False,
+
+                 experiment_name: T.Optional[str] = None,
                  no_snapshot: bool = False,
+                 no_sacred: bool = False,
 
                  manual_gc: bool = True,
                  **kwargs):
@@ -38,8 +47,37 @@ class _TrainerMixin(BaseMixin):
         self.only_eval = only_eval
         self.init_eval = init_eval
         self.no_snapshot = no_snapshot
+        self.no_sacred = no_sacred
+        self.experiment_name = experiment_name
         self.manual_gc = manual_gc
 
+        self.ex = None
+
+    @property
+    def no_observe(self):
+        return self.no_sacred
+
+    def init_experiment(self, *, config: dict):
+        """ creates a sacred experiment that is later used by the trainer's sacred extension """
+
+        self.config = config
+
+        if self.no_sacred:
+            logging.warning("Default sacred workflow is disabled by the --no_sacred option!")
+            return
+
+        self.ex = Experiment(
+            name=self.experiment_name,
+            config=self.config,
+            no_observe=self.no_observe)
+
+        # self.trainer will initialized later
+        def run(*args, **kwargs):
+            self._check_attr("trainer")
+            return self.trainer.run(*args, **kwargs)
+
+        self.ex.main(run)
+
 
     def init_updater(self):
         """Creates an updater from training iterator and the optimizer."""
@@ -90,35 +128,52 @@ class _TrainerMixin(BaseMixin):
 
     def run(self, trainer_cls, opts, *args, **kwargs):
 
-        trainer = self._new_trainer(trainer_cls, opts, *args, **kwargs)
+        self.trainer = self._new_trainer(trainer_cls, opts, *args, **kwargs)
 
         if self.manual_gc:
-            trainer.extend(gc_collect)
+            self.trainer.extend(gc_collect)
 
-        self.save_meta_info(opts, folder=Path(trainer.out, "meta"))
+        self.save_meta_info()
 
         logging.info("Snapshotting is {}abled".format("dis" if self.no_snapshot else "en"))
 
-        def dump(suffix):
-            if self.only_eval or self.no_snapshot:
-                return
-
-            save_npz(Path(trainer.out, f"clf_{suffix}.npz"), self.clf)
-            save_npz(Path(trainer.out, f"model_{suffix}.npz"), self.model)
-
         try:
-            trainer.run(self.init_eval or self.only_eval)
+            self.run_experiment(self.init_eval or self.only_eval)
         except (KeyboardInterrupt, BdbQuit) as e:
             raise e
         except Exception as e:
-            dump("exception")
+            self.dump("exception")
             raise e
         else:
-            dump("final")
+            self.dump("final")
+
+
+    def run_experiment(self, *args, **kwargs):
+
+        if self.ex is None:
+            return self.trainer.run(*args, **kwargs)
+
+        sacred_reporter = SacredReport(ex=self.ex, trigger=(1, "epoch"))
+        self.trainer.extend(sacred_reporter)
+        return self.ex(*args, **kwargs)
+
+    def _trainer_output(self, name: str = ""):
+        return Path(self.trainer.out, name)
+
+    def dump(self, suffix):
+        self._check_attr("config")
+        if self.only_eval or self.no_snapshot:
+            return
+
+        save_npz(self._trainer_output(f"clf_{suffix}.npz"), self.clf)
+        save_npz(self._trainer_output(f"model_{suffix}.npz"), self.model)
+
+    def save_meta_info(self, meta_folder: str = "meta"):
+        self._check_attr("config")
 
-    def save_meta_info(self, opts, folder: Path):
+        folder = self._trainer_output(meta_folder)
         folder.mkdir(parents=True, exist_ok=True)
 
         with open(folder / "args.yml", "w") as f:
-            pyaml.dump(opts.__dict__, f, sort_keys=True)
+            pyaml.dump(self.config, f, sort_keys=True)
 

+ 5 - 1
cvfinetune/finetuner/mpi.py

@@ -20,6 +20,11 @@ class MPIFinetuner(DefaultFinetuner):
 	def mpi_main_process(self):
 		return not (self.comm is not None and self.comm.rank != 0)
 
+
+	@property
+	def no_observe(self):
+		return self.no_sacred or not self.mpi_main_process
+
 	def gpu_config(self, devices):
 
 		if not self.mpi:
@@ -71,7 +76,6 @@ class MPIFinetuner(DefaultFinetuner):
 
 	def run(self, trainer_cls, opts, *args, **kwargs):
 		if not self.mpi_main_process:
-			kwargs["no_observe"] = True
 			opts.no_snapshot = True
 			opts.no_progress = True
 			self.evaluator._progress_bar = False

+ 1 - 1
cvfinetune/parser/model_args.py

@@ -17,7 +17,7 @@ def add_model_args(parser: BaseParser) -> None:
 			choices=choices,
 			help="type of the model"),
 
-		Arg("--pre_training", "-pt",
+		Arg("--pretrained_on", "-pt",
 			default="imagenet",
 			choices=["imagenet", "inat"],
 			help="type of model pre-training"),

+ 1 - 0
cvfinetune/training/extensions/__init__.py

@@ -0,0 +1 @@
+from cvfinetune.training.extensions.sacred import SacredReport

+ 37 - 0
cvfinetune/training/extensions/sacred.py

@@ -0,0 +1,37 @@
+from chainer import reporter
+from chainer.training import trigger as trigger_module
+from chainer.training.extension import Extension
+
+class SacredReport(Extension):
+	def __init__(self, *, ex, keys=None, trigger=(1, "epoch")):
+		super(SacredReport, self).__init__()
+		self.ex = ex
+		self._keys = keys
+		self._trigger = trigger_module.get_trigger(trigger)
+
+		self._init_summary()
+
+	def __call__(self, trainer):
+		if self.ex is None or self.ex.current_run is None:
+			return
+
+		obs = trainer.observation
+		keys = self._keys
+
+		if keys is None:
+			self._summary.add(obs)
+		else:
+			self._summary.add({k: obs[k] for k in keys if k in obs})
+
+		if not self._trigger(trainer):
+			return
+
+		stats = self._summary.compute_mean()
+		epoch = trainer.updater.epoch
+		for name in stats:
+			self.ex.log_scalar(name, float(stats[name]), step=epoch)
+
+		self._init_summary()
+
+	def _init_summary(self):
+		self._summary = reporter.DictSummary()

+ 3 - 3
cvfinetune/training/trainer/__init__.py

@@ -1,3 +1,3 @@
-from .base import default_intervals, Trainer
-from .sacred import SacredTrainer
-from .alpha_pooling import AlphaPoolingTrainer
+from cvfinetune.training.trainer.alpha_pooling import AlphaPoolingTrainer
+from cvfinetune.training.trainer.base import Trainer
+from cvfinetune.training.trainer.base import default_intervals

+ 24 - 11
cvfinetune/training/trainer/base.py

@@ -1,15 +1,21 @@
 import logging
-from os.path import join, basename
+
 from datetime import datetime
+from os.path import basename
+from os.path import join
 from typing import Tuple
 
 import chainer
-from chainer.training import extensions, Trainer as T
+
+from chainer.training import Trainer as T
+from chainer.training import extensions
 from chainer.training import trigger as trigger_module
 from chainer_addons.training import lr_shift
-from chainer_addons.training.optimizer import OptimizerType
+from chainer_addons.training.extensions import AlternateTrainable
+from chainer_addons.training.extensions import SwitchTrainables
+from chainer_addons.training.extensions import WarmUp
 from chainer_addons.training.extensions.learning_rate import CosineAnnealingLearningRate
-from chainer_addons.training.extensions import AlternateTrainable, SwitchTrainables, WarmUp
+from chainer_addons.training.optimizer import OptimizerType
 
 from cvdatasets.utils import attr_dict
 
@@ -30,7 +36,8 @@ def _is_adam(opts):
 
 class Trainer(T):
 
-	def __init__(self, opts, updater,
+	def __init__(self, opts,
+		updater,
 		evaluator: extensions.Evaluator = None,
 		intervals: attr_dict = default_intervals,
 		no_observe: bool = False):
@@ -67,21 +74,27 @@ class Trainer(T):
 		### Code below is only for "main" Trainers ###
 		if no_observe: return
 
-		self.extend(extensions.observe_lr(), trigger=intervals.log)
-		self.extend(extensions.LogReport(trigger=intervals.log))
-
 		### Snapshotting ###
 		self.setup_snapshots(opts, self.model, intervals.snapshot)
 
+		self.setup_reporter(opts, intervals.log, intervals.print)
+		self.setup_progress_bar(opts)
+
+	def setup_reporter(self, opts, log_trigger, print_trigger):
+
+		self.extend(extensions.observe_lr(), trigger=log_trigger)
+		self.extend(extensions.LogReport(trigger=log_trigger))
+
 		### Reports and Plots ###
 		print_values, plot_values = self.reportables(opts)
-		self.extend(extensions.PrintReport(print_values),
-			trigger=intervals.print)
+
+		self.extend(extensions.PrintReport(print_values), trigger=print_trigger)
+
 		for name, values in plot_values.items():
 			ext = extensions.PlotReport(values, 'epoch', file_name='{}.png'.format(name))
 			self.extend(ext)
 
-		### Progress bar ###
+	def setup_progress_bar(self, opts):
 		if not opts.no_progress:
 			self.extend(extensions.ProgressBar(update_interval=1))
 

+ 0 - 8
cvfinetune/training/trainer/sacred.py

@@ -1,8 +0,0 @@
-from cvfinetune.training.trainer.base import Trainer, default_intervals
-from chainer_addons.training.sacred import SacredTrainerMixin
-
-class SacredTrainer(SacredTrainerMixin, Trainer):
-
-	def __init__(self, intervals=default_intervals, *args, **kwargs):
-		super(SacredTrainer, self).__init__(
-			intervals=intervals, sacred_trigger=intervals.log, *args, **kwargs)

+ 6 - 0
cvfinetune/utils/sacred/__init__.py

@@ -0,0 +1,6 @@
+from sacred import SETTINGS
+
+SETTINGS.DISCOVER_SOURCES = "dir"
+
+from cvfinetune.utils.sacred.experiment import Experiment
+from cvfinetune.utils.sacred.plotter import SacredPlotter

+ 112 - 0
cvfinetune/utils/sacred/experiment.py

@@ -0,0 +1,112 @@
+import logging
+import munch
+import os
+import re
+import sys
+import typing as T
+
+from sacred import Experiment as BaseExperiment
+from sacred.observers import MongoObserver
+from sacred.utils import apply_backspaces_and_linefeeds
+
+from pathlib import Path
+from urllib.parse import quote_plus
+
+def progress_bar_filter(text,
+	escape=re.compile(r"\x1B\[([0-?]*[ -/]*[@-~])"),
+	line_contents=re.compile(r".*(total|validation|this epoch|Estimated time|\d+ iter, \d+ epoch|\d+ \/ \d+ iteration).+\n"),
+	tqdm_progress = re.compile(r"\n? *\d+\%\|.+\n?")):
+
+	""" Filters out the progress bar of chainer """
+
+	_text = apply_backspaces_and_linefeeds(text)
+
+	_text = escape.sub("", _text)
+	_text = line_contents.sub("", _text)
+	_text = tqdm_progress.sub("", _text)
+	_text = re.sub(r"\n *\n*", "\n", _text)
+
+	return _text
+
+
+class Experiment(BaseExperiment):
+
+	ENV_KEYS: munch.Munch = munch.munchify(dict(
+		USER_NAME="MONGODB_USER_NAME",
+		PASSWORD="MONGODB_PASSWORD",
+		DB_NAME="MONGODB_DB_NAME",
+
+		HOST="MONGODB_HOST",
+		PORT="MONGODB_PORT",
+	))
+
+	def __init__(self, *args,
+		config: dict = {},
+		host: T.Optional[str] = None,
+		port: T.Optional[int] = None,
+		no_observe: bool = False,
+		output_filter: T.Callable = progress_bar_filter,
+		**kwargs):
+
+		if kwargs.get("base_dir") is None:
+			base_dir = Path(sys.argv[0]).resolve().parent
+			logging.info(f"Base experiment directory: {base_dir}")
+			kwargs["base_dir"] = str(base_dir)
+
+		super(Experiment, self).__init__(*args, **kwargs)
+
+		if no_observe:
+			return
+
+		self.logger = logging.getLogger()
+		self.captured_out_filter = output_filter
+
+		creds = Experiment.get_creds()
+		_mongo_observer = MongoObserver.create(
+			url=Experiment.auth_url(creds, host=host, port=port),
+			db_name=creds["db_name"],
+		)
+
+		self.observers.append(_mongo_observer)
+
+		self.add_config(**config)
+
+	def __call__(self, *args, **kwargs):
+		return self._create_run()(*args, **kwargs)
+
+	@classmethod
+	def get_creds(cls):
+		return dict(
+			user=cls._get_env_key(cls.ENV_KEYS.USER_NAME),
+			password=cls._get_env_key(cls.ENV_KEYS.PASSWORD),
+			db_name=cls._get_env_key(cls.ENV_KEYS.DB_NAME),
+		)
+
+	@classmethod
+	def auth_url(cls, creds, host="localhost", port=27017):
+		host = host or cls.get_host()
+		port = port or cls.get_port()
+		logging.info(f"MongoDB host: {host}:{port}")
+
+		url = "mongodb://{user}:{password}@{host}:{port}/{db_name}?authSource=admin".format(
+			host=host, port=port, **creds)
+		return url
+
+	@classmethod
+	def get_host(cls):
+		return cls._get_env_key(cls.ENV_KEYS.HOST, default="localhost")
+
+	@classmethod
+	def get_port(cls):
+		return cls._get_env_key(cls.ENV_KEYS.PORT, default=27017)
+
+	@classmethod
+	def _get_env_key(cls, key, default=None):
+		return quote_plus(str(os.environ.get(key, default)))
+
+
+
+__all__ = [
+	"Experiment",
+	"progress_bar_filter"
+]

+ 0 - 0
cvfinetune/utils/sacred_plotter.py → cvfinetune/utils/sacred/plotter.py