|
|
@@ -1,18 +1,19 @@
|
|
|
+import abc
|
|
|
+
|
|
|
from chainer import reporter
|
|
|
from chainer.training import trigger as trigger_module
|
|
|
from chainer.training.extension import Extension
|
|
|
|
|
|
-class SacredReport(Extension):
|
|
|
- def __init__(self, *, ex, keys=None, trigger=(1, "epoch")):
|
|
|
- super(SacredReport, self).__init__()
|
|
|
- self.ex = ex
|
|
|
+class BaseReport(abc.ABC, Extension):
|
|
|
+ def __init__(self, *, keys=None, trigger=(1, "epoch")):
|
|
|
+ super().__init__()
|
|
|
self._keys = keys
|
|
|
self._trigger = trigger_module.get_trigger(trigger)
|
|
|
|
|
|
self._init_summary()
|
|
|
|
|
|
def __call__(self, trainer):
|
|
|
- if self.ex is None or self.ex.current_run is None:
|
|
|
+ if not self.reporter_enabled():
|
|
|
return
|
|
|
|
|
|
obs = trainer.observation
|
|
|
@@ -27,11 +28,28 @@ class SacredReport(Extension):
|
|
|
return
|
|
|
|
|
|
stats = self._summary.compute_mean()
|
|
|
- epoch = trainer.updater.epoch
|
|
|
+
|
|
|
+ step = None
|
|
|
+ if self._trigger.unit == "epoch":
|
|
|
+ step = trainer.updater.epoch
|
|
|
+
|
|
|
+ elif self._trigger.unit == "iteration":
|
|
|
+ step = trainer.updater.iteration
|
|
|
+
|
|
|
for name in stats:
|
|
|
- self.ex.log_scalar(name, float(stats[name]), step=epoch)
|
|
|
+ self.log(name, float(stats[name]), step=step)
|
|
|
|
|
|
self._init_summary()
|
|
|
|
|
|
def _init_summary(self):
|
|
|
self._summary = reporter.DictSummary()
|
|
|
+
|
|
|
+ @abc.abstractmethod
|
|
|
+ def reporter_enabled(self) -> bool:
|
|
|
+ pass
|
|
|
+
|
|
|
+ @abc.abstractmethod
|
|
|
+ def log(self, key: str, value: float, step: int = 0) -> None:
|
|
|
+ pass
|
|
|
+
|
|
|
+
|