Browse Source

added VGG models to supported and fixed some tests related to it

Dimitri Korsch 4 years ago
parent
commit
ca047c5d0e

+ 8 - 6
cvmodelz/models/base.py

@@ -92,13 +92,15 @@ class BaseModel(abc.ABC, chainer.Chain):
 		self.load(weights, path=path, strict=strict, headless=headless)
 
 	def load(self, weights, *, path="", strict=False, headless=False):
-		if weights not in [None, "auto"]:
-			ignore_names = None
-			if headless:
-				ignore_names = lambda name: name.startswith(path + self.clf_layer_name)
+		if weights in [None, "auto"]:
+			return
 
-			npz.load_npz(weights, self.model_instance,
-				path=path, strict=strict, ignore_names=ignore_names)
+		ignore_names = None
+		if headless:
+			ignore_names = lambda name: name.startswith(path + self.clf_layer_name)
+
+		npz.load_npz(weights, self.model_instance,
+			path=path, strict=strict, ignore_names=ignore_names)
 
 	def save(self, path, *args, **kwargs):
 		npz.save_npz(path, self, *args, **kwargs)

+ 2 - 2
cvmodelz/models/factory.py

@@ -33,8 +33,8 @@ class ModelFactory(abc.ABC):
 		),
 
 		cvmodelz=(
-			# pretrained.VGG16,
-			# pretrained.VGG19,
+			pretrained.VGG16,
+			pretrained.VGG19,
 
 			pretrained.ResNet35,
 			pretrained.ResNet50,

+ 10 - 9
cvmodelz/models/pretrained/base.py

@@ -1,4 +1,5 @@
 import abc
+from chainer import links as L
 
 from cvmodelz.models.base import BaseModel
 from cvmodelz.models.meta_info import ModelInfo
@@ -16,16 +17,13 @@ class PretrainedModelMixin(BaseModel):
 	"""
 
 	def __init__(self, *args, n_classes: int = 1000, pretrained_model: str = None, **kwargs):
-		from cvmodelz.models import ModelFactory
-
-		if ModelFactory.is_chainer_model(self):
-			kwargs["pretrained_model"] = pretrained_model
-
-		super(PretrainedModelMixin, self).__init__(*args, **kwargs)
+		super(PretrainedModelMixin, self).__init__(*args, pretrained_model=pretrained_model, **kwargs)
 
 		with self.init_scope():
 			self.init_extra_layers(n_classes)
 
+		self.load(pretrained_model, strict=True)
+
 	def forward(self, X, layer_name=None):
 		assert hasattr(self, "meta"), "Did you forgot to initialize the meta attribute?"
 
@@ -38,12 +36,15 @@ class PretrainedModelMixin(BaseModel):
 
 		return activations
 
-	def init_extra_layers(self, *args, **kwargs):
-		pass
+	def init_extra_layers(self, n_classes, **kwargs) -> None:
+		if hasattr(self, self.clf_layer_name):
+			delattr(self, self.clf_layer_name)
 
+		clf_layer = L.Linear(self.meta.feature_size, n_classes)
+		setattr(self, self.clf_layer_name, clf_layer)
 
 	@property
-	def model_instance(self):
+	def model_instance(self) -> BaseModel:
 		""" since it is a mixin, we are the model """
 
 		return self

+ 26 - 13
cvmodelz/models/pretrained/inception/inception_v3.py

@@ -8,7 +8,6 @@ from chainercv.transforms import resize
 from chainercv.transforms import scale
 from collections import OrderedDict
 from collections.abc import Iterable
-from os.path import isfile
 
 
 from cvmodelz.models.meta_info import ModelInfo
@@ -32,15 +31,22 @@ def _assign_batch_norm(name, link, beta, avg_mean, avg_var):
 	_assign(name, link.avg_var, avg_var)
 
 
-class InceptionV3(PretrainedModelMixin, chainer.Chain):
+
+"""
+We need this to "extract" pretrained_model argument,
+otherwise it would be passed to the constructor of the
+chainer.Chain class, where it raises an error
+"""
+class InceptionV3Layers(chainer.Chain):
+
 
 	def __init__(self, pretrained_model=None, aux_logits=False, *args, **kwargs):
 		self.aux_logits = aux_logits
 		pooling = PoolingType.G_AVG.value()
-		super(InceptionV3, self).__init__(*args, pooling=pooling, **kwargs)
+		super(InceptionV3Layers, self).__init__(*args, pooling=pooling, **kwargs)
+
 
-		if pretrained_model is not None and isfile(pretrained_model):
-			self.load(pretrained_model, strict=True)
+class InceptionV3(PretrainedModelMixin, InceptionV3Layers):
 
 	def init_model_info(self):
 		self.meta = ModelInfo(
@@ -55,6 +61,12 @@ class InceptionV3(PretrainedModelMixin, chainer.Chain):
 			classifier_layers=["fc"],
 		)
 
+	@property
+	def functions(self):
+		super(InceptionV3, self).functions
+
+
+
 	def forward(self, x, layer_name='fc'):
 		aux_logit = None
 		for key, funcs in self.functions.items():
@@ -98,15 +110,15 @@ class InceptionV3(PretrainedModelMixin, chainer.Chain):
 		return self.pool(x)
 
 
-	def load(self, weights, *args, **kwargs):
+	def load(self, weights, *args, **kwargs) -> None:
 		if isinstance(weights, str) and weights.endswith(".h5"):
-			self._load_from_keras(weights)
+			return self._load_from_keras(weights)
 		elif isinstance(weights, str) and weights.endswith(".ckpt.npz"):
-			self._load_from_ckpt_weights(weights)
-		else:
-			return super(InceptionV3, self).load(weights, *args, **kwargs)
+			return self._load_from_ckpt_weights(weights)
 
-	def init_extra_layers(self, n_classes):
+		return super(InceptionV3, self).load(weights, *args, **kwargs)
+
+	def init_extra_layers(self, n_classes) -> None:
 		# input 3 x 299 x 299
 		self.head = blocks.InceptionHead()
 		# output 192 x 35 x 35
@@ -152,7 +164,9 @@ class InceptionV3(PretrainedModelMixin, chainer.Chain):
 		# input 2048 x 8 x 8
 		# global average pooling
 		# output 2048 x 1 x 1
-		self.fc = L.Linear(2048, n_classes)
+
+		# the final fc layer is initilized by PretrainedModelMixin
+		super(InceptionV3, self).init_extra_layers(n_classes)
 
 	def loss(self, pred, gt, loss_func=F.softmax_cross_entropy, alpha=0.4):
 		if isinstance(pred, tuple):
@@ -233,4 +247,3 @@ class InceptionV3(PretrainedModelMixin, chainer.Chain):
 					_assign_batch_norm(name, link, beta, avg_mean, avg_var)
 				else:
 					raise ValueError("Unkown link type: {}!".format(type(link)))
-

+ 15 - 7
cvmodelz/models/pretrained/resnet.py

@@ -26,21 +26,25 @@ class BaseResNet(PretrainedModelMixin):
 			classifier_layers=["fc6"],
 		)
 
-	def init_extra_layers(self, n_classes, **kwargs):
-		if hasattr(self, "fc6"):
-			delattr(self, "fc6")
-		self.fc6 = L.Linear(2048, n_classes)
-
 	@property
 	def functions(self):
 		return super(BaseResNet, self).functions
 
+"""
+We need this to "extract" pretrained_model argument,
+otherwise it would be passed to the constructor of the
+chainer.Chain class, where it raises an error
+"""
+class ResNet35Layers(chainer.Chain):
+
+	def __init__(self, *args, pretrained_model=None, **kwargs):
+		super(ResNet35Layers, self).__init__(*args, **kwargs)
 
-class ResNet35(BaseResNet, chainer.Chain):
+
+class ResNet35(BaseResNet, ResNet35Layers):
 	n_layers = 35
 
 	def init_extra_layers(self, *args, **kwargs):
-		super(ResNet35, self).init_extra_layers(*args, **kwargs)
 		self.conv1 = L.Convolution2D(3, 64, 7, 2, 3, **kwargs)
 		self.bn1 = L.BatchNormalization(64)
 		self.res2 = BuildingBlock(2, 64, 64, 256, 1, **kwargs)
@@ -48,6 +52,9 @@ class ResNet35(BaseResNet, chainer.Chain):
 		self.res4 = BuildingBlock(3, 512, 256, 1024, 2, **kwargs)
 		self.res5 = BuildingBlock(3, 1024, 512, 2048, 2, **kwargs)
 
+		# the final fc layer is initilized by PretrainedModelMixin
+		super(ResNet35, self).init_extra_layers(*args, **kwargs)
+
 	@property
 	def functions(self):
 		links = [
@@ -70,6 +77,7 @@ class ResNet35(BaseResNet, chainer.Chain):
 			if key == layer_name:
 				return x
 
+
 class ResNet50(BaseResNet, L.ResNet50Layers):
 	n_layers = 50
 

+ 4 - 0
cvmodelz/models/pretrained/vgg.py

@@ -23,6 +23,8 @@ class VGG19(PretrainedModelMixin, L.VGG19Layers):
 
 	def __init__(self, *args, **kwargs):
 		super(VGG19, self).__init__(*args, pooling=_max_pooling_2d, **kwargs)
+
+	def init_model_info(self):
 		self.meta = _vgg_meta("conv5_3")
 
 	@property
@@ -33,6 +35,8 @@ class VGG16(PretrainedModelMixin, L.VGG16Layers):
 
 	def __init__(self, *args, **kwargs):
 		super(VGG16, self).__init__(*args, pooling=_max_pooling_2d, **kwargs)
+
+	def init_model_info(self):
 		self.meta = _vgg_meta("conv5_4")
 
 	@property