Ver código fonte

added more comprehensive pooling handling and prepare_func attribute to the ModelMeta class

Dimitri Korsch 4 anos atrás
pai
commit
6abed3cca5

+ 5 - 4
cvmodelz/classifiers/base.py

@@ -43,7 +43,7 @@ class Classifier(chainer.Chain):
 	def save(self, weights_file):
 		npz.save_npz(weights_file, self)
 
-	def load(self, weights_file: str, n_classes: int, *, finetune: bool = False) -> None:
+	def load(self, weights_file: str, n_classes: int, *, finetune: bool = False, **kwargs) -> None:
 		""" Loading a classifier has following use cases:
 
 			(0) No loading.
@@ -68,7 +68,7 @@ class Classifier(chainer.Chain):
 			pass
 
 		# Case (1)
-		self.load_model(weights_file, n_classes=n_classes, finetune=finetune)
+		self.load_model(weights_file, n_classes=n_classes, finetune=finetune, **kwargs)
 
 		# else:
 		# 	# Case (0)
@@ -90,9 +90,10 @@ class Classifier(chainer.Chain):
 			return model.load_for_inference
 
 
-	def load_model(self, weights_file: str, n_classes: int, *, finetune: bool = False):
+	def load_model(self, weights_file: str, n_classes: int, *, finetune: bool = False, **kwargs):
 		model_loader = self.get_model_loader(finetune=finetune, model=self.model)
-		model_loader(weights=weights_file, n_classes=n_classes, strict=True)
+		kwargs["strict"] = kwargs.get("strict", True)
+		model_loader(weights=weights_file, n_classes=n_classes, **kwargs)
 
 	@property
 	def feat_size(self) -> int:

+ 4 - 1
cvmodelz/models/base.py

@@ -16,11 +16,14 @@ from cvmodelz.models.meta_info import ModelInfo
 
 class BaseModel(abc.ABC, chainer.Chain):
 
-	def __init__(self, pooling: Callable = PoolingType.G_AVG.value(),
+	def __init__(self, pooling: Callable = PoolingType.G_AVG,
 		input_size=None, *args, **kwargs):
 		super().__init__(*args, **kwargs)
 		self.init_model_info()
 
+		if isinstance(pooling, (PoolingType, str)):
+			pooling = PoolingType.new(pooling)
+
 		with self.init_scope():
 			self.pool = pooling
 

+ 3 - 0
cvmodelz/models/meta_info.py

@@ -2,6 +2,7 @@ import pyaml
 
 from dataclasses import dataclass
 from typing import Tuple
+from typing import Callable
 
 
 @dataclass
@@ -17,6 +18,8 @@ class ModelInfo(object):
 
 	classifier_layers:          Tuple[str]  = ("fc",)
 
+	prepare_func:               Callable    = None
+
 	def __str__(self):
 		obj = dict(ModelInfo=self.__dict__)
 		return pyaml.dump(obj, sort_dicts=False, )

+ 27 - 2
cvmodelz/models/pretrained/inception/inception_v3.py

@@ -42,8 +42,7 @@ class InceptionV3Layers(chainer.Chain):
 
 	def __init__(self, pretrained_model=None, aux_logits=False, *args, **kwargs):
 		self.aux_logits = aux_logits
-		pooling = PoolingType.G_AVG.value()
-		super().__init__(*args, pooling=pooling, **kwargs)
+		super().__init__(*args, **kwargs)
 
 
 class InceptionV3(PretrainedModelMixin, InceptionV3Layers):
@@ -59,12 +58,38 @@ class InceptionV3(PretrainedModelMixin, InceptionV3Layers):
 			feature_layer="pool",
 
 			classifier_layers=["fc"],
+
+			prepare_func=self.prepare
 		)
 
 	@property
 	def functions(self):
 		super().functions
 
+	def prepare(self, x, size=None, *, swap_channels=False, keep_ratio=True):
+		size = size or self.meta.input_size
+
+		# [0 .. 255] -> [0 .. 1]
+		x = x.astype(np.float32) / 255
+
+		# HWC -> CHW channel order
+		x = x.transpose(2,0,1)
+
+		if keep_ratio:
+			if isinstance(size, Iterable):
+				size = min(size)
+			# scale the smallest side to <size>
+			x = scale(x, size=size)
+		else:
+			if isinstance(size, int):
+				size = (size, size)
+
+			# resize the image to  side to <size>x<size>
+			x = resize(x, size)
+
+		if swap_channels:
+			x = x[::-1]
+		return x
 
 
 	def forward(self, x, layer_name='fc'):

+ 3 - 0
cvmodelz/models/pretrained/resnet.py

@@ -3,6 +3,7 @@ import chainer
 from chainer import functions as F
 from chainer import links as L
 from chainer.links.model.vision.resnet import BuildingBlock
+from chainer.links.model.vision.resnet import prepare
 from collections import OrderedDict
 from functools import partial
 
@@ -24,6 +25,8 @@ class BaseResNet(PretrainedModelMixin):
 			feature_layer="pool5",
 
 			classifier_layers=["fc6"],
+
+			prepare_func=prepare,
 		)
 
 	@property

+ 11 - 15
cvmodelz/models/pretrained/vgg.py

@@ -16,30 +16,26 @@ def _vgg_meta(final_conv_layer):
 		feature_layer="fc7",
 
 		classifier_layers=["fc6", "fc7", "fc8"],
-	)
-
 
-class VGG19(PretrainedModelMixin, L.VGG19Layers):
+		prepare_func=prepare,
+	)
 
+class BaseVGG(PretrainedModelMixin):
 	def __init__(self, *args, **kwargs):
-		super().__init__(*args, pooling=_max_pooling_2d, **kwargs)
-
-	def init_model_info(self):
-		self.meta = _vgg_meta("conv5_3")
+		kwargs["pooling"] = kwargs.get("pooling", _max_pooling_2d)
+		super().__init__(*args, **kwargs)
 
 	@property
 	def functions(self):
 		return super().functions
 
-class VGG16(PretrainedModelMixin, L.VGG16Layers):
+	def init_model_info(self):
+		self.meta = _vgg_meta(self.final_conv_layer)
 
-	def __init__(self, *args, **kwargs):
-		super().__init__(*args, pooling=_max_pooling_2d, **kwargs)
+class VGG19(BaseVGG, L.VGG19Layers):
+	final_conv_layer = "conv5_4"
 
-	def init_model_info(self):
-		self.meta = _vgg_meta("conv5_4")
 
-	@property
-	def functions(self):
-		return super().functions
+class VGG16(BaseVGG, L.VGG16Layers):
+	final_conv_layer = "conv5_3"
 

+ 10 - 0
tests/model_tests/creation.py

@@ -3,11 +3,18 @@ import unittest
 
 from cvmodelz.models import ModelFactory
 from cvmodelz.models.pretrained.base import PretrainedModelMixin
+from chainer_addons.links.pooling import GlobalAveragePooling # TODO: replace this!
 from cvmodelz.models.wrapper import ModelWrapper
 
 
 class ModelCreationsTests(unittest.TestCase):
 
+	def with_pooling_string(self, key):
+		model = ModelFactory.new(key, pooling="g_avg")
+		self.assertIsNotNone(model)
+
+		self.assertIsInstance(model.pool, GlobalAveragePooling)
+
 	def cv2model_creation(self, key):
 
 		model = ModelFactory.new(key)
@@ -27,3 +34,6 @@ test_utils.add_tests(ModelCreationsTests.cv2model_creation,
 test_utils.add_tests(ModelCreationsTests.pretrained_model_creation,
 	model_list=ModelFactory.get_models(["cvmodelz"]))
 
+test_utils.add_tests(ModelCreationsTests.with_pooling_string,
+	model_list=ModelFactory.get_models(["cvmodelz"]))
+