|
@@ -16,6 +16,7 @@ from pathlib import Path
|
|
|
|
|
|
from cvfinetune.finetuner.mixins.base import BaseMixin
|
|
|
from cvfinetune.training.extensions import SacredReport
|
|
|
+from cvfinetune.training.extensions import ManualGCCollect
|
|
|
from cvfinetune.utils.sacred import Experiment
|
|
|
|
|
|
@extension.make_extension(default_name="ManualGC", trigger=(1, "iteration"))
|
|
@@ -131,7 +132,7 @@ class _TrainerMixin(BaseMixin):
|
|
|
self.trainer = self._new_trainer(trainer_cls, opts, *args, **kwargs)
|
|
|
|
|
|
if self.manual_gc:
|
|
|
- self.trainer.extend(gc_collect)
|
|
|
+ self.trainer.extend(ManualGCCollect(trigger=(1, "iteration")))
|
|
|
|
|
|
self.save_meta_info()
|
|
|
|
|
@@ -165,8 +166,10 @@ class _TrainerMixin(BaseMixin):
|
|
|
if self.only_eval or self.no_snapshot:
|
|
|
return
|
|
|
|
|
|
- save_npz(self._trainer_output(f"clf_{suffix}.npz"), self.clf)
|
|
|
- save_npz(self._trainer_output(f"model_{suffix}.npz"), self.model)
|
|
|
+ clf_file = self._trainer_output(f"clf_{suffix}.npz")
|
|
|
+ logging.info(f"Storing classifier weights to {clf_file}")
|
|
|
+ save_npz(clf_file, self.clf)
|
|
|
+ # save_npz(self._trainer_output(f"model_{suffix}.npz"), self.model)
|
|
|
|
|
|
def save_meta_info(self, meta_folder: str = "meta"):
|
|
|
self._check_attr("config")
|