浏览代码

added support for Weights-and-Biases

Dimitri Korsch 3 年之前
父节点
当前提交
9cce4a3486

+ 2 - 1
cvfinetune/training/extensions/__init__.py

@@ -1,2 +1,3 @@
 from cvfinetune.training.extensions.gc_collect import ManualGCCollect
-from cvfinetune.training.extensions.sacred import SacredReport
+from cvfinetune.training.extensions.reporters.sacred import SacredReport
+from cvfinetune.training.extensions.reporters.wandb import WandbReport

+ 0 - 0
cvfinetune/training/extensions/reporters/__init__.py


+ 25 - 7
cvfinetune/training/extensions/sacred.py → cvfinetune/training/extensions/reporters/base.py

@@ -1,18 +1,19 @@
+import abc
+
 from chainer import reporter
 from chainer.training import trigger as trigger_module
 from chainer.training.extension import Extension
 
-class SacredReport(Extension):
-	def __init__(self, *, ex, keys=None, trigger=(1, "epoch")):
-		super(SacredReport, self).__init__()
-		self.ex = ex
+class BaseReport(abc.ABC, Extension):
+	def __init__(self, *, keys=None, trigger=(1, "epoch")):
+		super().__init__()
 		self._keys = keys
 		self._trigger = trigger_module.get_trigger(trigger)
 
 		self._init_summary()
 
 	def __call__(self, trainer):
-		if self.ex is None or self.ex.current_run is None:
+		if not self.reporter_enabled():
 			return
 
 		obs = trainer.observation
@@ -27,11 +28,28 @@ class SacredReport(Extension):
 			return
 
 		stats = self._summary.compute_mean()
-		epoch = trainer.updater.epoch
+
+		step = None
+		if self._trigger.unit == "epoch":
+			step = trainer.updater.epoch
+
+		elif self._trigger.unit == "iteration":
+			step = trainer.updater.iteration
+
 		for name in stats:
-			self.ex.log_scalar(name, float(stats[name]), step=epoch)
+			self.log(name, float(stats[name]), step=step)
 
 		self._init_summary()
 
 	def _init_summary(self):
 		self._summary = reporter.DictSummary()
+
+	@abc.abstractmethod
+	def reporter_enabled(self) -> bool:
+		pass
+
+	@abc.abstractmethod
+	def log(self, key: str, value: float, step: int = 0) -> None:
+		pass
+
+

+ 12 - 0
cvfinetune/training/extensions/reporters/sacred.py

@@ -0,0 +1,12 @@
+from cvfinetune.training.extensions.reporters.base import BaseReport
+
+class SacredReport(BaseReport):
+	def __init__(self, *args, ex, **kwargs):
+		self.ex = ex
+		super().__init__(*args, **kwargs)
+
+	def reporter_enabled(self) -> bool:
+		return None not in [self.ex, self.ex.current_run]
+
+	def log(self, key: str, value: float, step: int = 0) -> None:
+		self.ex.log_scalar(key, value, step=step)

+ 17 - 0
cvfinetune/training/extensions/reporters/wandb.py

@@ -0,0 +1,17 @@
+import wandb
+
+from cvfinetune.training.extensions.reporters.base import BaseReport
+
+class WandbReport(BaseReport):
+
+	def reporter_enabled(self) -> bool:
+		return True
+
+	def log(self, key: str, value: float, step: int = 0) -> None:
+		wandb.log({key: value}, step=step, commit=False)
+
+	def __call__(self, trainer):
+		# self.log will be called with commit=False,
+		# so we need to call it once with commit=True
+		super().__call__(trainer)
+		wandb.log({}, commit=True)