wrapper.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import chainer
  2. from chainer import functions as F
  3. from collections import OrderedDict
  4. from cvmodelz.models.base import BaseModel
  5. from cvmodelz.models.meta_info import ModelInfo
  6. class ModelWrapper(BaseModel):
  7. """
  8. This class is designed to wrap around chainercv2 models
  9. and provide the loading API of the BaseModel class.
  10. The wrapped model is stored under self.wrapped
  11. """
  12. def __init__(self, model: chainer.Chain, *args, **kwargs):
  13. name = model.__class__.__name__
  14. self.__class__.__name__ = name
  15. self.model_name = name
  16. super().__init__(*args, **kwargs)
  17. if hasattr(model, "meta"):
  18. self.meta = model.meta
  19. with self.init_scope():
  20. self.wrapped = model
  21. delattr(self.wrapped.features, "final_pool")
  22. def init_model_info(self):
  23. info = dict(
  24. name=self.model_name,
  25. feature_size=2048,
  26. n_conv_maps=2048,
  27. classifier_layers=["output/fc"],
  28. conv_map_layer="features",
  29. feature_layer="pool",
  30. )
  31. if self.model_name == "InceptionResNetV1":
  32. info.update(dict(
  33. input_size=299,
  34. feature_size=1792,
  35. n_conv_maps=1792,
  36. classifier_layers=[
  37. "output/fc1",
  38. "output/fc2"
  39. ],
  40. ))
  41. elif self.model_name == "InceptionV3":
  42. info.update(dict(
  43. input_size=299,
  44. ))
  45. elif self.model_name in ["ResNet", "ResNeXt"]:
  46. info.update(dict(
  47. input_size=224,
  48. ))
  49. self.meta = ModelInfo(**info)
  50. @property
  51. def model_instance(self) -> chainer.Chain:
  52. return self.wrapped
  53. @property
  54. def functions(self) -> OrderedDict:
  55. links = [
  56. (self.meta.conv_map_layer, [self.wrapped.features]),
  57. (self.meta.feature_layer, [self.pool]),
  58. (self.clf_layer_name, [self.wrapped.output]),
  59. ]
  60. return OrderedDict(links)
  61. def load(self, *args, path="", **kwargs):
  62. paths = [path, f"{path}wrapped/"]
  63. for _path in paths:
  64. try:
  65. return super().load(*args, path=_path, **kwargs)
  66. except KeyError as e:
  67. pass
  68. raise RuntimeError(f"tried to load weights with paths {paths}, but did not succeeed")
  69. def forward(self, X, layer_name=None):
  70. if layer_name is None:
  71. res = self.wrapped(X)
  72. elif layer_name == self.meta.conv_map_layer:
  73. res = self.wrapped.features(X)
  74. elif layer_name == self.meta.feature_layer:
  75. conv = self.wrapped.features(X)
  76. res = self.pool(conv)
  77. elif layer_name == self.clf_layer_name:
  78. conv = self.wrapped.features(X)
  79. feat = self.pool(conv)
  80. res = self.wrapped.output(feat)
  81. else:
  82. raise ValueError(f"Dont know how to compute \"{layer_name}\"!")
  83. return res