Эх сурвалжийг харах

refactoring: moved some initialization logic in the BaseModel class

Dimitri Korsch 4 жил өмнө
parent
commit
a054a66952

+ 2 - 2
cvmodelz/model_info.py

@@ -9,8 +9,8 @@ from cvmodelz.models import ModelFactory
 
 def main(args):
 
-	model = ModelFactory.new(args.model_type)
-	utils.print_model_info(model, input_size=args.input_size)
+	model = ModelFactory.new(args.model_type, input_size=args.input_size)
+	utils.print_model_info(model)
 
 parser = BaseParser()
 

+ 18 - 0
cvmodelz/models/base.py

@@ -3,12 +3,30 @@ import chainer
 
 from chainer import functions as F
 from chainer import links as L
+from chainer_addons.links.pooling import PoolingType # TODO: replace this!
 from collections import OrderedDict
+from typing import Callable
 
 from cvmodelz import utils
+from cvmodelz.models.meta_info import ModelInfo
 
 class BaseModel(abc.ABC):
 
+	def __init__(self, pooling: Callable = PoolingType.G_AVG.value(), input_size=None, *args, **kwargs):
+		super(BaseModel, self).__init__(*args, **kwargs)
+		self.init_model_info()
+
+		with self.init_scope():
+			self.pool = pooling
+
+
+		if input_size is not None:
+			self.meta.input_size = input_size
+
+
+	def init_model_info(self):
+		self.meta = ModelInfo()
+
 	@abc.abstractmethod
 	def __call__(self, X, layer_name=None):
 		pass

+ 8 - 6
cvmodelz/models/factory.py

@@ -12,10 +12,6 @@ from cvmodelz.models.wrapper import ModelWrapper
 
 class ModelFactory(abc.ABC):
 
-	@abc.abstractmethod
-	def __init__(self):
-		raise NotImplementedError("instance creation is not supported!")
-
 	supported = OrderedDict(
 		chainer=(
 			L.ResNet50Layers,
@@ -49,12 +45,15 @@ class ModelFactory(abc.ABC):
 		),
 	)
 
+	@abc.abstractmethod
+	def __init__(self):
+		raise NotImplementedError("instance creation is not supported!")
+
 	@classmethod
 	def new(cls, model_type, *args, **kwargs):
 
 		key, cls_name = model_type.split(".")
 
-
 		for model_cls in cls.supported[key]:
 			if model_cls.__name__ == cls_name:
 				break
@@ -64,11 +63,14 @@ class ModelFactory(abc.ABC):
 		if model_cls in cls.supported["chainer"]:
 			if "pretrained_model" not in kwargs:
 				kwargs["pretrained_model"] = None
+			kwargs.pop("input_size")
 
 		elif model_cls in cls.supported["chainercv2"]:
 			if "pretrained" not in kwargs:
 				kwargs["pretrained"] = False
-			return ModelWrapper(model_cls(*args, **kwargs))
+
+			input_size = kwargs.pop("input_size")
+			return ModelWrapper(model_cls(*args, **kwargs), input_size=input_size)
 
 		return model_cls(*args, **kwargs)
 

+ 4 - 15
cvmodelz/models/pretrained/base.py

@@ -1,9 +1,5 @@
 import abc
 
-from chainer import functions as F
-from typing import Callable
-
-from cvmodelz import models
 from cvmodelz.models.base import BaseModel
 from cvmodelz.models.meta_info import ModelInfo
 
@@ -19,16 +15,16 @@ class PretrainedModelMixin(BaseModel):
 				...
 	"""
 
-	def __init__(self, n_classes: int = 1000, pooling: Callable = F.identity, *args, **kwargs):
+	def __init__(self, n_classes: int = 1000, *args, **kwargs):
+		from cvmodelz.models import ModelFactory
 
-		if models.is_chainer_model(self):
-			kwargs["pretrained_model"] = None
+		if ModelFactory.is_chainer_model(self):
+			kwargs["pretrained_model"] = kwargs.get("pretrained_model", None)
 
 		super(PretrainedModelMixin, self).__init__(*args, **kwargs)
 
 		with self.init_scope():
 			self.init_extra_layers(n_classes)
-			self.pool = pooling
 
 	def __call__(self, X, layer_name=None):
 		assert hasattr(self, "meta"), "Did you forgot to initialize the meta attribute?"
@@ -45,13 +41,6 @@ class PretrainedModelMixin(BaseModel):
 	def init_extra_layers(self, *args, **kwargs):
 		pass
 
-	# @abc.abstractproperty
-	# def _links(self):
-	# 	raise NotImplementedError()
-
-	# @property
-	# def functions(self):
-	# 	return OrderedDict(self._links)
 
 	@property
 	def model_instance(self):

+ 1 - 3
cvmodelz/models/pretrained/resnet.py

@@ -3,7 +3,6 @@ import chainer
 from chainer import functions as F
 from chainer import links as L
 from chainer.links.model.vision.resnet import BuildingBlock
-from chainer.links.model.vision.resnet import _global_average_pooling_2d
 from collections import OrderedDict
 from functools import partial
 
@@ -14,8 +13,7 @@ from cvmodelz.models.pretrained.base import PretrainedModelMixin
 class BaseResNet(PretrainedModelMixin):
 	n_layers = ""
 
-	def __init__(self, *args, **kwargs):
-		super(BaseResNet, self).__init__(*args, pooling=_global_average_pooling_2d, **kwargs)
+	def init_model_info(self):
 		self.meta = ModelInfo(
 			name=f"ResNet{self.n_layers}",
 			input_size=224,

+ 8 - 0
cvmodelz/models/pretrained/vgg.py

@@ -25,9 +25,17 @@ class VGG19(PretrainedModelMixin, L.VGG19Layers):
 		super(VGG19, self).__init__(*args, pooling=_max_pooling_2d, **kwargs)
 		self.meta = _vgg_meta("conv5_3")
 
+	@property
+	def functions(self):
+		return super(VGG19, self).functions
+
 class VGG16(PretrainedModelMixin, L.VGG16Layers):
 
 	def __init__(self, *args, **kwargs):
 		super(VGG16, self).__init__(*args, pooling=_max_pooling_2d, **kwargs)
 		self.meta = _vgg_meta("conv5_4")
 
+	@property
+	def functions(self):
+		return super(VGG16, self).functions
+

+ 10 - 13
cvmodelz/models/wrapper.py

@@ -1,9 +1,7 @@
 import chainer
 
 from chainer import functions as F
-from chainer_addons.links.pooling import PoolingType # TODO: replace this!
 from collections import OrderedDict
-from typing import Callable
 
 from cvmodelz.models.base import BaseModel
 from cvmodelz.models.meta_info import ModelInfo
@@ -17,30 +15,29 @@ class ModelWrapper(BaseModel, chainer.Chain):
 		The wrapped model is stored under self.wrapped
 	"""
 
-	def __init__(self, model: chainer.Chain, pooling: Callable = PoolingType.G_AVG.value()):
-		super(ModelWrapper, self).__init__()
+	def __init__(self, model: chainer.Chain, *args, **kwargs):
+		super(ModelWrapper, self).__init__(*args, **kwargs)
 
 		name = model.__class__.__name__
 		self.__class__.__name__ = name
+		self.meta.name = name
 
 		if hasattr(model, "meta"):
 			self.meta = model.meta
 
-		else:
-			self.meta = ModelInfo(
-				name=name,
-				classifier_layers=("output/fc",),
-				conv_map_layer="features",
-				feature_layer="pool",
-			)
-
 		with self.init_scope():
 			self.wrapped = model
-			self.pool = pooling
 			delattr(self.wrapped.features, "final_pool")
 
 		self.meta.feature_size = self.clf_layer.W.shape[-1]
 
+	def init_model_info(self):
+		self.meta = ModelInfo(
+			classifier_layers=("output/fc",),
+			conv_map_layer="features",
+			feature_layer="pool",
+		)
+
 	@property
 	def model_instance(self) -> chainer.Chain:
 		return self.wrapped