factory.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import abc
  2. import pyaml
  3. from chainer import links as L
  4. from chainercv2.models import inceptionv3 as cv2inceptionv3
  5. from chainercv2.models import resnet as cv2resnet
  6. from collections import OrderedDict
  7. from cvmodelz.models import pretrained
  8. from cvmodelz.models.wrapper import ModelWrapper
  9. class ModelFactory(abc.ABC):
  10. supported = OrderedDict(
  11. chainer=(
  12. L.ResNet50Layers,
  13. L.ResNet101Layers,
  14. L.ResNet152Layers,
  15. L.VGG16Layers,
  16. L.VGG19Layers,
  17. ),
  18. chainercv=(
  19. # todo: chainercv.links.models.ssd
  20. ),
  21. chainercv2=(
  22. cv2resnet.resnet50,
  23. cv2resnet.resnet50b,
  24. cv2inceptionv3.inceptionv3,
  25. ),
  26. cvmodelz=(
  27. pretrained.VGG16,
  28. pretrained.VGG19,
  29. pretrained.ResNet35,
  30. pretrained.ResNet50,
  31. pretrained.ResNet101,
  32. pretrained.ResNet152,
  33. pretrained.InceptionV3,
  34. ),
  35. )
  36. @abc.abstractmethod
  37. def __init__(self):
  38. raise NotImplementedError("instance creation is not supported!")
  39. @classmethod
  40. def new(cls, model_type, *args, **kwargs):
  41. key, cls_name = model_type.split(".")
  42. for model_cls in cls.supported[key]:
  43. if model_cls.__name__ == cls_name:
  44. break
  45. else:
  46. raise ValueError(f"Could not find {model_type}!")
  47. if model_cls in cls.supported["chainer"]:
  48. kwargs["pretrained_model"] = kwargs.get("pretrained_model", None)
  49. kwargs.pop("input_size", None)
  50. elif model_cls in cls.supported["chainercv2"]:
  51. kwargs["pretrained"] = kwargs.get("pretrained", False)
  52. input_size = kwargs.pop("input_size", None)
  53. return ModelWrapper(model_cls(*args, **kwargs), input_size=input_size)
  54. return model_cls(*args, **kwargs)
  55. @classmethod
  56. def _check(cls, model, key):
  57. return isinstance(model, cls.supported[key])
  58. @classmethod
  59. def is_chainer_model(cls, model):
  60. return cls._check(model, "chainer")
  61. @classmethod
  62. def is_cv_model(cls, model):
  63. return cls._check(model, "chainercv")
  64. @classmethod
  65. def is_cv2_model(cls, model):
  66. return cls._check(model, "chainercv2")
  67. @classmethod
  68. def is_cvmodelz_model(cls, model):
  69. return cls._check(model, "cvmodelz")
  70. @classmethod
  71. def get_all_models(cls, key=None):
  72. if key is not None:
  73. return [f"{key}.{model_cls.__name__}" for model_cls in cls.supported[key]]
  74. return cls.get_models(cls.supported.keys())
  75. @classmethod
  76. def get_models(cls, keys=None):
  77. if keys is None:
  78. keys = cls.supported.keys()
  79. res = []
  80. for key in keys:
  81. res += cls.get_all_models(key)
  82. return res
  83. if __name__ == '__main__':
  84. print(pyaml.dump(dict(Models=ModelFactory.get_all_models()), indent=2))