factory.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import abc
  2. import pyaml
  3. from chainer import links as L
  4. from chainercv2.models import inceptionv3 as cv2inceptionv3
  5. from chainercv2.models import inceptionresnetv1 as cv2inceptionresnetv1
  6. from chainercv2.models import resnext as cv2resnext
  7. from chainercv2.models import resnet as cv2resnet
  8. from chainercv.links.model import ssd
  9. from chainercv.links.model import faster_rcnn
  10. from collections import OrderedDict
  11. from cvmodelz.models import pretrained
  12. from cvmodelz.models.wrapper import ModelWrapper
  13. class ModelFactory(abc.ABC):
  14. supported = OrderedDict(
  15. chainer=(
  16. L.ResNet50Layers,
  17. L.ResNet101Layers,
  18. L.ResNet152Layers,
  19. # L.VGG16Layers,
  20. # L.VGG19Layers,
  21. ),
  22. chainercv=(
  23. ssd.SSD300,
  24. faster_rcnn.FasterRCNNVGG16,
  25. ),
  26. chainercv2=(
  27. cv2resnet.resnet50,
  28. cv2resnet.resnet18,
  29. cv2inceptionv3.inceptionv3,
  30. cv2inceptionresnetv1.inceptionresnetv1,
  31. cv2resnext.resnext50_32x4d,
  32. ),
  33. cvmodelz=(
  34. pretrained.VGG16,
  35. pretrained.VGG19,
  36. pretrained.ResNet35,
  37. pretrained.ResNet35HD,
  38. pretrained.ResNet50,
  39. pretrained.ResNet50HD,
  40. pretrained.ResNet101,
  41. pretrained.ResNet152,
  42. pretrained.InceptionV3,
  43. pretrained.InceptionV3HD,
  44. ),
  45. )
  46. @abc.abstractmethod
  47. def __init__(self):
  48. raise NotImplementedError("instance creation is not supported!")
  49. @classmethod
  50. def new(cls, model_type, **kwargs):
  51. key, cls_name = model_type.split(".")
  52. for model_cls in cls.supported[key]:
  53. if model_cls.__name__ == cls_name:
  54. break
  55. else:
  56. raise ValueError(f"Could not find {model_type}!")
  57. if model_cls in cls.supported["chainercv2"]:
  58. n_classes = kwargs.pop("n_classes", 1000)
  59. pretrained = kwargs.pop("pretrained_model", None) == "auto"
  60. model = model_cls(classes=n_classes, pretrained=pretrained)
  61. kwargs["model"] = model
  62. model_cls = ModelWrapper
  63. return model_cls(**kwargs)
  64. @classmethod
  65. def _check(cls, model, key):
  66. return isinstance(model, cls.supported[key])
  67. @classmethod
  68. def is_chainer_model(cls, model):
  69. return cls._check(model, "chainer")
  70. @classmethod
  71. def is_cv_model(cls, model):
  72. return cls._check(model, "chainercv")
  73. @classmethod
  74. def is_cv2_model(cls, model):
  75. return cls._check(model, "chainercv2")
  76. @classmethod
  77. def is_cvmodelz_model(cls, model):
  78. return cls._check(model, "cvmodelz")
  79. @classmethod
  80. def get_all_models(cls, key=None):
  81. if key is not None:
  82. return [f"{key}.{model_cls.__name__}" for model_cls in cls.supported[key]]
  83. return cls.get_models(cls.supported.keys())
  84. @classmethod
  85. def get_models(cls, keys=None):
  86. if keys is None:
  87. keys = cls.supported.keys()
  88. res = []
  89. for key in keys:
  90. res += cls.get_all_models(key)
  91. return res
  92. if __name__ == '__main__':
  93. print(pyaml.dump(dict(Models=ModelFactory.get_all_models()), indent=2))