Explorar el Código

added saving of meta information

Dimitri Korsch hace 4 años
padre
commit
6968bc3bc3
Se han modificado 1 ficheros con 17 adiciones y 10 borrados
  1. 17 10
      cvfinetune/finetuner/base.py

+ 17 - 10
cvfinetune/finetuner/base.py

@@ -4,6 +4,7 @@ import numpy as np
 
 import abc
 import logging
+import pyaml
 
 from chainer.backends import cuda
 from chainer.optimizer_hooks import Lasso
@@ -19,14 +20,13 @@ from chainer_addons.training import optimizer
 from chainer_addons.training import optimizer_hooks
 
 from cvdatasets import AnnotationType
+from cvdatasets.dataset.image import Size
 from cvdatasets.utils import new_iterator
 from cvdatasets.utils import pretty_print_dict
-from cvdatasets.dataset.image import Size
-
-from functools import partial
-from os.path import join
 
 from bdb import BdbQuit
+from functools import partial
+from pathlib import Path
 
 
 def check_param_for_decay(param):
@@ -134,12 +134,12 @@ class _ModelMixin(abc.ABC):
 
 				else:
 					msg = "Loading default pre-trained weights \"{}\""
-					self.weights = join(
+					self.weights = str(Path(
 						self.data_info.BASE_DIR,
 						self.data_info.MODEL_DIR,
 						self.model_info.folder,
 						self.model_info.weights
-					)
+					))
 
 				loader_func = self.model.load_for_finetune
 
@@ -315,16 +315,16 @@ class _TrainerMixin(abc.ABC):
 			*args, **kwargs
 		)
 
+		self.save_meta_info(opts, folder=Path(trainer.out, "meta"))
+
 		logging.info("Snapshotting is {}abled".format("dis" if opts.no_snapshot else "en"))
 
 		def dump(suffix):
 			if opts.only_eval or opts.no_snapshot:
 				return
 
-			save_npz(join(trainer.out,
-				"clf_{}.npz".format(suffix)), self.clf)
-			save_npz(join(trainer.out,
-				"model_{}.npz".format(suffix)), self.model)
+			save_npz(Path(trainer.out, f"clf_{suffix}.npz"), self.clf)
+			save_npz(Path(trainer.out, f"model_{suffix}.npz"), self.model)
 
 		try:
 			trainer.run(opts.init_eval or opts.only_eval)
@@ -336,6 +336,13 @@ class _TrainerMixin(abc.ABC):
 		else:
 			dump("final")
 
+	def save_meta_info(self, opts, folder: Path):
+		folder.mkdir(parents=True, exist_ok=True)
+
+		with open(folder / "args.yml", "w") as f:
+			pyaml.dump(args.__dict__, f, sort_keys=True)
+
+
 
 class DefaultFinetuner(_ModelMixin, _DatasetMixin, _TrainerMixin):
 	""" The default Finetuner gathers together the creations of all needed