소스 검색

manual GC collection is now an extra extension and is executed in a dedicated thread

Dimitri Korsch 3 년 전
부모
커밋
f677e5074b
3개의 변경된 파일65개의 추가작업 그리고 3개의 파일을 삭제
  1. 6 3
      cvfinetune/finetuner/mixins/trainer.py
  2. 1 0
      cvfinetune/training/extensions/__init__.py
  3. 58 0
      cvfinetune/training/extensions/gc_collect.py

+ 6 - 3
cvfinetune/finetuner/mixins/trainer.py

@@ -16,6 +16,7 @@ from pathlib import Path
 
 from cvfinetune.finetuner.mixins.base import BaseMixin
 from cvfinetune.training.extensions import SacredReport
+from cvfinetune.training.extensions import ManualGCCollect
 from cvfinetune.utils.sacred import Experiment
 
 @extension.make_extension(default_name="ManualGC", trigger=(1, "iteration"))
@@ -131,7 +132,7 @@ class _TrainerMixin(BaseMixin):
         self.trainer = self._new_trainer(trainer_cls, opts, *args, **kwargs)
 
         if self.manual_gc:
-            self.trainer.extend(gc_collect)
+            self.trainer.extend(ManualGCCollect(trigger=(1, "iteration")))
 
         self.save_meta_info()
 
@@ -165,8 +166,10 @@ class _TrainerMixin(BaseMixin):
         if self.only_eval or self.no_snapshot:
             return
 
-        save_npz(self._trainer_output(f"clf_{suffix}.npz"), self.clf)
-        save_npz(self._trainer_output(f"model_{suffix}.npz"), self.model)
+        clf_file = self._trainer_output(f"clf_{suffix}.npz")
+        logging.info(f"Storing classifier weights to {clf_file}")
+        save_npz(clf_file, self.clf)
+        # save_npz(self._trainer_output(f"model_{suffix}.npz"), self.model)
 
     def save_meta_info(self, meta_folder: str = "meta"):
         self._check_attr("config")

+ 1 - 0
cvfinetune/training/extensions/__init__.py

@@ -1 +1,2 @@
+from cvfinetune.training.extensions.gc_collect import ManualGCCollect
 from cvfinetune.training.extensions.sacred import SacredReport

+ 58 - 0
cvfinetune/training/extensions/gc_collect.py

@@ -0,0 +1,58 @@
+import gc
+import logging
+import threading
+
+
+from chainer.training import extension
+from chainer.training import trigger as trigger_module
+
+class ManualGCCollect(extension.Extension):
+
+	# should one of the last extensions
+	priority = extension.PRIORITY_WRITER + 1
+	SLEEP = 3
+
+	def __init__(self, trigger=(1, "iteration")):
+		super().__init__()
+
+		self._trigger = trigger_module.get_trigger(trigger)
+
+		self.thread = threading.Thread(target=self.work)
+
+		self.stop = threading.Event()
+		self.trigger_gc = threading.Event()
+
+	def work(self):
+		logging.info("GC Thread working ...")
+
+		while True:
+			if self.stop.is_set():
+				break
+
+			if not self.trigger_gc.wait(self.SLEEP):
+				continue
+
+			self.trigger_gc.clear()
+			gc.collect()
+
+		logging.info("GC Thread finished")
+
+	def initialize(self, trainer):
+		logging.info("Starting GC Thread ...")
+		self.thread.start()
+
+	def __call__(self, trainer):
+		if not self._trigger(trainer):
+			return
+		self.trigger_gc.set()
+
+
+	def finalize(self):
+		logging.info("Finilizing GC Thread")
+		self.stop.set()
+		self.thread.join()
+
+	def on_error(self, trainer, exc, tb):
+		logging.info("Error occured, stopping GC Thread")
+		self.stop.set()
+		self.thread.join()