Browse Source

fixed some printing stuff and changed code for newer cvdatasets package

Dimitri Korsch 5 years ago
parent
commit
1df61dc33e
2 changed files with 9 additions and 12 deletions
  1. 8 11
      cvfinetune/finetuner/base.py
  2. 1 1
      cvfinetune/parser/dataset_args.py

+ 8 - 11
cvfinetune/finetuner/base.py

@@ -18,8 +18,9 @@ from chainer_addons.models import PrepareType
 from chainer_addons.training import optimizer
 from chainer_addons.training import optimizer_hooks
 
-from cvdatasets.annotations import AnnotationType
+from cvdatasets import AnnotationType
 from cvdatasets.utils import new_iterator
+from cvdatasets.utils import pretty_print_dict
 
 from functools import partial
 from os.path import join
@@ -30,10 +31,6 @@ from bdb import BdbQuit
 def check_param_for_decay(param):
 	return param.name != "alpha"
 
-
-def _format_kwargs(kwargs):
-	return " ".join([f"{key}={value}" for key, value in kwargs.items()])
-
 class _ModelMixin(abc.ABC):
 	"""This mixin is responsible for optimizer creation, model creation,
 	model wrapping around a classifier and model weights loading.
@@ -56,7 +53,7 @@ class _ModelMixin(abc.ABC):
 
 		logging.info(" ".join([
 			f"Wrapped the model around {clf_class.__name__}",
-			f"with kwargs: {_format_kwargs(kwargs)}",
+			f"with kwargs: {pretty_print_dict(kwargs)}",
 		]))
 
 	def _loss_func(self, opts):
@@ -184,7 +181,7 @@ class _DatasetMixin(abc.ABC):
 
 	@property
 	def n_classes(self):
-		return self.part_info.n_classes + self.dataset_cls.label_shift
+		return self.ds_info.n_classes + self.dataset_cls.label_shift
 
 	def new_dataset(self, opts, size, subset, augment):
 		"""Creates a dataset for a specific subset and certain options"""
@@ -221,12 +218,12 @@ class _DatasetMixin(abc.ABC):
 	def init_annotations(self, opts):
 		"""Reads annotations and creates annotation instance, which holds important infos about the dataset"""
 
-		annot_cls = AnnotationType.get(opts.dataset).value
-		self.annot = annot_cls(root_or_infofile=opts.data, parts=opts.parts, load_strict=False)
+		self.annot = AnnotationType.new_annotation(opts, load_strict=False)
 
 		self.data_info = self.annot.info
 		self.model_info = self.data_info.MODELS[opts.model_type]
-		self.part_info = self.data_info.PARTS[opts.parts]
+		self.ds_info = self.data_info.DATASETS[opts.dataset]
+		# self.part_info = self.data_info.PART_TYPES[opts.parts]
 
 		if getattr(opts, "only_head", False):
 			self.annot.feature_model = opts.model_type
@@ -298,7 +295,7 @@ class _TrainerMixin(abc.ABC):
 		logging.info(" ".join([
 			f"Using single GPU: {self.device}.",
 			f"{self.updater_cls.__name__} is initialized",
-			f"with following kwargs: {_format_kwargs(self.updater_kwargs)}"
+			f"with following kwargs: {pretty_print_dict(self.updater_kwargs)}"
 			])
 		)
 

+ 1 - 1
cvfinetune/parser/dataset_args.py

@@ -19,7 +19,7 @@ def add_dataset_args(parser):
 		_args = [
 			Arg("data", default=DEFAULT_INFO_FILE),
 			Arg("dataset", choices=info_file.DATASETS.keys()),
-			Arg("parts", choices=info_file.PARTS.keys()),
+			Arg("parts", choices=info_file.PART_TYPES.keys()),
 		]
 
 	_args.extend([