Browse Source

created model factory class

Dimitri Korsch 4 years ago
parent
commit
b6204a0adb
3 changed files with 131 additions and 98 deletions
  1. 3 3
      cvmodelz/model_info.py
  2. 18 95
      cvmodelz/models/__init__.py
  3. 110 0
      cvmodelz/models/factory.py

+ 3 - 3
cvmodelz/model_info.py

@@ -4,18 +4,18 @@ if __name__ != '__main__': raise Exception("Do not import me!")
 from cvargparse import Arg
 from cvargparse import BaseParser
 
-from cvmodelz import models
 from cvmodelz import utils
+from cvmodelz.models import ModelFactory
 
 def main(args):
 
-	model = models.new(args.model_type)
+	model = ModelFactory.new(args.model_type)
 	utils.print_model_info(model, input_size=args.input_size)
 
 parser = BaseParser()
 
 parser.add_args([
-	Arg("model_type", choices=models.get_all_models()),
+	Arg("model_type", choices=ModelFactory.get_all_models()),
 
 	Arg("--input_size", "-size", type=int, default=None)
 ])

+ 18 - 95
cvmodelz/models/__init__.py

@@ -1,101 +1,24 @@
-import pyaml
-
-from chainer import links as L
-from chainercv2.models import resnet as cv2resnet
-from chainercv2.models import inceptionv3 as cv2inceptionv3
-
 from cvmodelz.models.base import BaseModel
-from cvmodelz.models.wrapper import ModelWrapper
-from cvmodelz.models import pretrained
+from cvmodelz.models.factory import ModelFactory
+from cvmodelz.models.pretrained import InceptionV3
+from cvmodelz.models.pretrained import ResNet101
+from cvmodelz.models.pretrained import ResNet152
+from cvmodelz.models.pretrained import ResNet35
+from cvmodelz.models.pretrained import ResNet50
+from cvmodelz.models.pretrained import VGG16
+from cvmodelz.models.pretrained import VGG19
 
 __all__ = [
-	"BaseModel"
-]
-
-
-supported = dict(
-	chainer=(
-		L.ResNet50Layers,
-		L.ResNet101Layers,
-		L.ResNet152Layers,
-		L.VGG16Layers,
-		L.VGG19Layers,
-	),
-
-	chainercv=(),
-
-	chainercv2=(
-		cv2resnet.resnet50,
-		cv2resnet.resnet50b,
-
-		cv2inceptionv3.inceptionv3,
-	),
-
-	cvmodelz=(
-		pretrained.VGG16,
-		pretrained.VGG19,
-
-		pretrained.ResNet35,
-		pretrained.ResNet50,
-		pretrained.ResNet101,
-		pretrained.ResNet152,
-
-		pretrained.InceptionV3,
-	),
-)
-
-def _check(model, key):
-	global supported
-	return isinstance(model, supported[key])
-
-
-def is_chainer_model(model):
-	return _check(model, "chainer")
+	"BaseModel",
+	"ModelFactory",
 
-def is_cv_model(model):
-	return _check(model, "chainercv")
+	"InceptionV3",
 
-def is_cv2_model(model):
-	return _check(model, "chainercv2")
+	"ResNet50",
+	"ResNet35",
+	"ResNet101",
+	"ResNet152",
 
-def is_cvmodelz_model(model):
-	return _check(model, "cvmodelz")
-
-
-def get_all_models(key=None):
-	global supported
-	if key is not None:
-		return [f"{key}.{cls.__name__}" for cls in supported[key]]
-
-	res = []
-	for key in supported:
-		res += get_all_models(key)
-
-	return res
-
-def new(model_type, *args, **kwargs):
-	global supported
-	key, cls_name = model_type.split(".")
-
-
-	for cls in supported[key]:
-		if cls.__name__ == cls_name:
-			break
-	else:
-		raise ValueError(f"Could not find {model_type}!")
-
-	if cls in supported["chainer"]:
-		if "pretrained_model" not in kwargs:
-			kwargs["pretrained_model"] = None
-
-	elif cls in supported["chainercv2"]:
-		if "pretrained" not in kwargs:
-			kwargs["pretrained"] = False
-		return ModelWrapper(cls(*args, **kwargs))
-
-	return cls(*args, **kwargs)
-
-
-
-if __name__ == '__main__':
-	print(pyaml.dump(dict(Models=get_all_models()), indent=2))
+	"VGG16",
+	"VGG19",
+]

+ 110 - 0
cvmodelz/models/factory.py

@@ -0,0 +1,110 @@
+import abc
+import pyaml
+
+
+from chainer import links as L
+from chainercv2.models import inceptionv3 as cv2inceptionv3
+from chainercv2.models import resnet as cv2resnet
+from collections import OrderedDict
+
+from cvmodelz.models import pretrained
+from cvmodelz.models.wrapper import ModelWrapper
+
+class ModelFactory(abc.ABC):
+
+	@abc.abstractmethod
+	def __init__(self):
+		raise NotImplementedError("instance creation is not supported!")
+
+	supported = OrderedDict(
+		chainer=(
+			L.ResNet50Layers,
+			L.ResNet101Layers,
+			L.ResNet152Layers,
+			L.VGG16Layers,
+			L.VGG19Layers,
+		),
+
+		chainercv=(
+			# todo: chainercv.links.models.ssd
+		),
+
+		chainercv2=(
+			cv2resnet.resnet50,
+			cv2resnet.resnet50b,
+
+			cv2inceptionv3.inceptionv3,
+		),
+
+		cvmodelz=(
+			pretrained.VGG16,
+			pretrained.VGG19,
+
+			pretrained.ResNet35,
+			pretrained.ResNet50,
+			pretrained.ResNet101,
+			pretrained.ResNet152,
+
+			pretrained.InceptionV3,
+		),
+	)
+
+	@classmethod
+	def new(cls, model_type, *args, **kwargs):
+
+		key, cls_name = model_type.split(".")
+
+
+		for model_cls in cls.supported[key]:
+			if model_cls.__name__ == cls_name:
+				break
+		else:
+			raise ValueError(f"Could not find {model_type}!")
+
+		if model_cls in cls.supported["chainer"]:
+			if "pretrained_model" not in kwargs:
+				kwargs["pretrained_model"] = None
+
+		elif model_cls in cls.supported["chainercv2"]:
+			if "pretrained" not in kwargs:
+				kwargs["pretrained"] = False
+			return ModelWrapper(model_cls(*args, **kwargs))
+
+		return model_cls(*args, **kwargs)
+
+
+	@classmethod
+	def _check(cls, model, key):
+		return isinstance(model, cls.supported[key])
+
+	@classmethod
+	def is_chainer_model(cls, model):
+		return cls._check(model, "chainer")
+
+	@classmethod
+	def is_cv_model(cls, model):
+		return cls._check(model, "chainercv")
+
+	@classmethod
+	def is_cv2_model(cls, model):
+		return cls._check(model, "chainercv2")
+
+	@classmethod
+	def is_cvmodelz_model(cls, model):
+		return cls._check(model, "cvmodelz")
+
+	@classmethod
+	def get_all_models(cls, key=None):
+		if key is not None:
+			return [f"{key}.{model_cls.__name__}" for model_cls in cls.supported[key]]
+
+		res = []
+		for key in cls.supported:
+			res += cls.get_all_models(key)
+
+		return res
+
+
+
+if __name__ == '__main__':
+	print(pyaml.dump(dict(Models=ModelFactory.get_all_models()), indent=2))