Ver Fonte

moved prepare function initialization from dataset to model mixin

Dimitri Korsch há 4 anos atrás
pai
commit
8e95d9c768

+ 2 - 2
cvfinetune/finetuner/base.py

@@ -28,9 +28,9 @@ class DefaultFinetuner(mixins._ModelMixin, mixins._DatasetMixin, mixins._Trainer
 		self.gpu_config(opts)
 		cuda.get_device_from_id(self.device).use()
 
-		self.init_annotations(opts)
-		self.init_model(opts)
+		self.read_annotations(opts)
 
+		self.init_model(opts)
 		self.init_datasets(opts)
 		self.init_iterators(opts)
 

+ 1 - 12
cvfinetune/finetuner/mixins/dataset.py

@@ -1,7 +1,6 @@
 import abc
 import logging
 
-from chainer_addons.models import PrepareType
 from cvdatasets import AnnotationType
 from cvdatasets.dataset.image import Size
 from cvdatasets.utils import new_iterator
@@ -56,7 +55,7 @@ class _DatasetMixin(abc.ABC):
 		return ds
 
 
-	def init_annotations(self, opts):
+	def read_annotations(self, opts):
 		"""Reads annotations and creates annotation instance, which holds important infos about the dataset"""
 
 		self.annot = AnnotationType.new_annotation(opts, load_strict=False)
@@ -69,16 +68,6 @@ class _DatasetMixin(abc.ABC):
 		part_size = getattr(opts, "parts_input_size", None)
 		part_size = size if part_size is None else Size(part_size)
 
-		self.prepare = partial(PrepareType[opts.prepare_type](self.model),
-			swap_channels=opts.swap_channels,
-			keep_ratio=getattr(opts, "center_crop_on_val", False),
-		)
-
-		logging.info(" ".join([
-			f"Created {self.model.__class__.__name__} model",
-			f"with \"{opts.prepare_type}\" prepare function."
-		]))
-
 		logging.info(" ".join([
 			f"Image input size: {size}",
 			f"Image parts input size: {part_size}",

+ 16 - 0
cvfinetune/finetuner/mixins/model.py

@@ -6,6 +6,7 @@ from chainer.optimizer_hooks import Lasso
 from chainer.optimizer_hooks import WeightDecay
 from chainer_addons.functions import smoothed_cross_entropy
 from chainer_addons.models import ModelType
+from chainer_addons.models import PrepareType
 from chainer_addons.training import optimizer
 from chainer_addons.training import optimizer_hooks
 from cvdatasets.dataset.image import Size
@@ -43,6 +44,21 @@ class _ModelMixin(abc.ABC):
 			**self.model_kwargs
 		)
 
+
+		if self.model_type.startswith("chainercv2"):
+			opts.prepare_type = "chainercv2"
+
+		self.prepare = partial(PrepareType[opts.prepare_type](self.model),
+			swap_channels=opts.swap_channels,
+			keep_ratio=getattr(opts, "center_crop_on_val", False),
+		)
+
+		logging.info(" ".join([
+			f"Created {self.model.__class__.__name__} model",
+			f"with \"{opts.prepare_type}\" prepare function."
+		]))
+
+
 	def init_classifier(self, opts):
 
 		clf_class, kwargs = self.classifier_cls, self.classifier_kwargs