|
@@ -0,0 +1,110 @@
|
|
|
+import abc
|
|
|
+import pyaml
|
|
|
+
|
|
|
+
|
|
|
+from chainer import links as L
|
|
|
+from chainercv2.models import inceptionv3 as cv2inceptionv3
|
|
|
+from chainercv2.models import resnet as cv2resnet
|
|
|
+from collections import OrderedDict
|
|
|
+
|
|
|
+from cvmodelz.models import pretrained
|
|
|
+from cvmodelz.models.wrapper import ModelWrapper
|
|
|
+
|
|
|
+class ModelFactory(abc.ABC):
|
|
|
+
|
|
|
+ @abc.abstractmethod
|
|
|
+ def __init__(self):
|
|
|
+ raise NotImplementedError("instance creation is not supported!")
|
|
|
+
|
|
|
+ supported = OrderedDict(
|
|
|
+ chainer=(
|
|
|
+ L.ResNet50Layers,
|
|
|
+ L.ResNet101Layers,
|
|
|
+ L.ResNet152Layers,
|
|
|
+ L.VGG16Layers,
|
|
|
+ L.VGG19Layers,
|
|
|
+ ),
|
|
|
+
|
|
|
+ chainercv=(
|
|
|
+ # todo: chainercv.links.models.ssd
|
|
|
+ ),
|
|
|
+
|
|
|
+ chainercv2=(
|
|
|
+ cv2resnet.resnet50,
|
|
|
+ cv2resnet.resnet50b,
|
|
|
+
|
|
|
+ cv2inceptionv3.inceptionv3,
|
|
|
+ ),
|
|
|
+
|
|
|
+ cvmodelz=(
|
|
|
+ pretrained.VGG16,
|
|
|
+ pretrained.VGG19,
|
|
|
+
|
|
|
+ pretrained.ResNet35,
|
|
|
+ pretrained.ResNet50,
|
|
|
+ pretrained.ResNet101,
|
|
|
+ pretrained.ResNet152,
|
|
|
+
|
|
|
+ pretrained.InceptionV3,
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def new(cls, model_type, *args, **kwargs):
|
|
|
+
|
|
|
+ key, cls_name = model_type.split(".")
|
|
|
+
|
|
|
+
|
|
|
+ for model_cls in cls.supported[key]:
|
|
|
+ if model_cls.__name__ == cls_name:
|
|
|
+ break
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Could not find {model_type}!")
|
|
|
+
|
|
|
+ if model_cls in cls.supported["chainer"]:
|
|
|
+ if "pretrained_model" not in kwargs:
|
|
|
+ kwargs["pretrained_model"] = None
|
|
|
+
|
|
|
+ elif model_cls in cls.supported["chainercv2"]:
|
|
|
+ if "pretrained" not in kwargs:
|
|
|
+ kwargs["pretrained"] = False
|
|
|
+ return ModelWrapper(model_cls(*args, **kwargs))
|
|
|
+
|
|
|
+ return model_cls(*args, **kwargs)
|
|
|
+
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def _check(cls, model, key):
|
|
|
+ return isinstance(model, cls.supported[key])
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def is_chainer_model(cls, model):
|
|
|
+ return cls._check(model, "chainer")
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def is_cv_model(cls, model):
|
|
|
+ return cls._check(model, "chainercv")
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def is_cv2_model(cls, model):
|
|
|
+ return cls._check(model, "chainercv2")
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def is_cvmodelz_model(cls, model):
|
|
|
+ return cls._check(model, "cvmodelz")
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def get_all_models(cls, key=None):
|
|
|
+ if key is not None:
|
|
|
+ return [f"{key}.{model_cls.__name__}" for model_cls in cls.supported[key]]
|
|
|
+
|
|
|
+ res = []
|
|
|
+ for key in cls.supported:
|
|
|
+ res += cls.get_all_models(key)
|
|
|
+
|
|
|
+ return res
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ print(pyaml.dump(dict(Models=ModelFactory.get_all_models()), indent=2))
|