ソースを参照

added support for initializing chainercv2 models

Dimitri Korsch 4 年 前
コミット
28a6bc9aa6

+ 4 - 2
cvmodelz/model_info.py

@@ -9,13 +9,15 @@ from cvmodelz import utils
 
 def main(args):
 
-	model = models.new(args.model_type, pretrained_model=None)
-	utils.print_model_info(model)
+	model = models.new(args.model_type)
+	utils.print_model_info(model, input_size=args.input_size)
 
 parser = BaseParser()
 
 parser.add_args([
 	Arg("model_type", choices=models.get_all_models()),
+
+	Arg("--input_size", "-size", type=int, default=None)
 ])
 
 main(parser.parse_args())

+ 21 - 10
cvmodelz/models/__init__.py

@@ -1,6 +1,11 @@
+import pyaml
+
 from chainer import links as L
+from chainercv2.models import resnet as cv2resnet
+from chainercv2.models import inceptionv3 as cv2inceptionv3
 
 from cvmodelz.models.base import BaseModel
+from cvmodelz.models.wrapper import ModelWrapper
 from cvmodelz.models import pretrained
 
 __all__ = [
@@ -19,7 +24,12 @@ supported = dict(
 
 	chainercv=(),
 
-	chainercv2=(),
+	chainercv2=(
+		cv2resnet.resnet50,
+		cv2resnet.resnet50b,
+
+		cv2inceptionv3.inceptionv3,
+	),
 
 	cvmodelz=(
 		pretrained.VGG16,
@@ -74,17 +84,18 @@ def new(model_type, *args, **kwargs):
 	else:
 		raise ValueError(f"Could not find {model_type}!")
 
-	return cls(*args, **kwargs)
-
+	if cls in supported["chainer"]:
+		if "pretrained_model" not in kwargs:
+			kwargs["pretrained_model"] = None
 
+	elif cls in supported["chainercv2"]:
+		if "pretrained" not in kwargs:
+			kwargs["pretrained"] = False
+		return ModelWrapper(cls(*args, **kwargs))
 
-if __name__ == '__main__':
-	# from cvmodelz import utils
+	return cls(*args, **kwargs)
 
-	print(get_all_models())
 
-	# model = L.VGG19Layers(pretrained_model=None)
-	# model = pretrained.ResNet35()
-	# print(model.pool)
-	# utils.print_model_info(model)
 
+if __name__ == '__main__':
+	print(pyaml.dump(dict(Models=get_all_models()), indent=2))

+ 12 - 5
cvmodelz/models/base.py

@@ -1,6 +1,9 @@
 import abc
-import chainer.functions as F
-import chainer.links as L
+import chainer
+
+from chainer import functions as F
+from chainer import links as L
+from collections import OrderedDict
 
 from cvmodelz import utils
 
@@ -11,15 +14,19 @@ class BaseModel(abc.ABC):
 		pass
 
 	@abc.abstractproperty
-	def model_instance(self):
+	def functions(self) -> OrderedDict:
+		return super(BaseModel, self).functions
+
+	@abc.abstractproperty
+	def model_instance(self) -> chainer.Chain:
 		raise NotImplementedError()
 
 	@property
-	def clf_layer_name(self):
+	def clf_layer_name(self) -> str:
 		return self.meta.classifier_layers[-1]
 
 	@property
-	def clf_layer(self):
+	def clf_layer(self) -> chainer.Link:
 		return utils.get_attr_from_path(self.model_instance, self.clf_layer_name)
 
 	def loss(self, pred, gt, loss_func=F.softmax_cross_entropy):

+ 1 - 2
cvmodelz/models/pretrained/inception/inception_v3.py

@@ -3,13 +3,12 @@ import chainer.functions as F
 import chainer.links as L
 import numpy as np
 
+from chainer_addons.links.pooling import PoolingType # TODO: replace this!
 from chainercv.transforms import resize
 from chainercv.transforms import scale
 from collections import OrderedDict
 from collections.abc import Iterable
 from os.path import isfile
-# TODO: replace this!
-from chainer_addons.links.pooling import PoolingType
 
 
 from cvmodelz.models.meta_info import ModelInfo

+ 17 - 10
cvmodelz/models/pretrained/resnet.py

@@ -28,6 +28,10 @@ class BaseResNet(PretrainedModelMixin):
 			classifier_layers=["fc6"],
 		)
 
+	@property
+	def functions(self):
+		return super(BaseResNet, self).functions
+
 
 class ResNet35(BaseResNet, chainer.Chain):
 	n_layers = 35
@@ -44,23 +48,26 @@ class ResNet35(BaseResNet, chainer.Chain):
 	@property
 	def functions(self):
 		links = [
-				("conv1", [self.conv1, self.bn1, F.relu]),
-				("pool1", [partial(F.max_pooling_2d, ksize=3, stride=2)]),
-				("res2", [self.res2]),
-				("res3", [self.res3]),
-				("res4", [self.res4]),
-				("res5", [self.res5]),
-				("pool5", [self.pool]),
-				("fc6", [self.fc6]),
-				("prob", [F.softmax]),
-			]
+			("conv1", [self.conv1, self.bn1, F.relu]),
+			("pool1", [partial(F.max_pooling_2d, ksize=3, stride=2)]),
+			("res2", [self.res2]),
+			("res3", [self.res3]),
+			("res4", [self.res4]),
+			("res5", [self.res5]),
+			("pool5", [self.pool]),
+			("fc6", [self.fc6]),
+			("prob", [F.softmax]),
+		]
 		return OrderedDict(links)
 
 class ResNet50(BaseResNet, L.ResNet50Layers):
 	n_layers = 50
 
+
 class ResNet101(BaseResNet, L.ResNet101Layers):
 	n_layers = 101
 
+
 class ResNet152(BaseResNet, L.ResNet152Layers):
 	n_layers = 152
+

+ 18 - 3
cvmodelz/models/wrapper.py

@@ -1,11 +1,15 @@
 import chainer
 
 from chainer import functions as F
+from chainer_addons.links.pooling import PoolingType # TODO: replace this!
+from collections import OrderedDict
 from typing import Callable
 
 from cvmodelz.models.base import BaseModel
 from cvmodelz.models.meta_info import ModelInfo
 
+
+
 class ModelWrapper(BaseModel, chainer.Chain):
 	"""
 		This class is designed to wrap around chainercv2 models
@@ -13,7 +17,7 @@ class ModelWrapper(BaseModel, chainer.Chain):
 		The wrapped model is stored under self.wrapped
 	"""
 
-	def __init__(self, model: chainer.Chain, pooling: Callable = F.identity):
+	def __init__(self, model: chainer.Chain, pooling: Callable = PoolingType.G_AVG.value()):
 		super(ModelWrapper, self).__init__()
 
 		name = model.__class__.__name__
@@ -26,7 +30,7 @@ class ModelWrapper(BaseModel, chainer.Chain):
 			self.meta = ModelInfo(
 				name=name,
 				classifier_layers=("output/fc",),
-				conv_map_layer="stage4",
+				conv_map_layer="features",
 				feature_layer="pool",
 			)
 
@@ -38,9 +42,20 @@ class ModelWrapper(BaseModel, chainer.Chain):
 		self.meta.feature_size = self.clf_layer.W.shape[-1]
 
 	@property
-	def model_instance(self):
+	def model_instance(self) -> chainer.Chain:
 		return self.wrapped
 
+	@property
+	def functions(self) -> OrderedDict:
+
+		links = [
+			("features", [self.wrapped.features]),
+			("pool", [self.pool]),
+			("output/fc", [self.wrapped.output.fc]),
+		]
+
+		return OrderedDict(links)
+
 	def load_for_inference(self, *args, path="", **kwargs):
 		return super(ModelWrapper, self).load_for_inference(*args, path=f"{path}wrapped/", **kwargs)
 

+ 3 - 1
cvmodelz/utils/__init__.py

@@ -14,7 +14,9 @@ def get_attr_from_path(obj, path, *, sep="/"):
 	return reduce(getter, path.split(sep), obj)
 
 def _get_activation_shapes(model, input_size, input_var, batch_size=2, n_channels=3):
-	assert hasattr(model, "functions"), "Model should have functions defined!"
+	assert hasattr(model, "functions"), \
+		"Model should have functions defined!"
+
 	if input_var is None:
 		input_shape = (batch_size, n_channels, input_size, input_size)
 		x = model.xp.zeros(input_shape, dtype=model.xp.float32)