resnet.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. import chainer
  2. from chainer import functions as F
  3. from chainer import links as L
  4. from chainer.links.model.vision.resnet import BuildingBlock
  5. from chainer.links.model.vision.resnet import _global_average_pooling_2d
  6. from collections import OrderedDict
  7. from functools import partial
  8. from cvmodelz.models.meta_info import ModelInfo
  9. from cvmodelz.models.pretrained.base import PretrainedModelMixin
  10. class BaseResNet(PretrainedModelMixin):
  11. n_layers = ""
  12. def __init__(self, *args, **kwargs):
  13. super(BaseResNet, self).__init__(*args, pooling=_global_average_pooling_2d, **kwargs)
  14. self.meta = ModelInfo(
  15. name=f"ResNet{self.n_layers}",
  16. input_size=224,
  17. feature_size=2048,
  18. n_conv_maps=2048,
  19. conv_map_layer="res5",
  20. feature_layer="pool5",
  21. classifier_layers=["fc6"],
  22. )
  23. @property
  24. def functions(self):
  25. return super(BaseResNet, self).functions
  26. class ResNet35(BaseResNet, chainer.Chain):
  27. n_layers = 35
  28. def init_extra_layers(self, n_classes, **kwargs):
  29. self.conv1 = L.Convolution2D(3, 64, 7, 2, 3, **kwargs)
  30. self.bn1 = L.BatchNormalization(64)
  31. self.res2 = BuildingBlock(2, 64, 64, 256, 1, **kwargs)
  32. self.res3 = BuildingBlock(3, 256, 128, 512, 2, **kwargs)
  33. self.res4 = BuildingBlock(3, 512, 256, 1024, 2, **kwargs)
  34. self.res5 = BuildingBlock(3, 1024, 512, 2048, 2, **kwargs)
  35. self.fc6 = L.Linear(2048, n_classes)
  36. @property
  37. def functions(self):
  38. links = [
  39. ("conv1", [self.conv1, self.bn1, F.relu]),
  40. ("pool1", [partial(F.max_pooling_2d, ksize=3, stride=2)]),
  41. ("res2", [self.res2]),
  42. ("res3", [self.res3]),
  43. ("res4", [self.res4]),
  44. ("res5", [self.res5]),
  45. ("pool5", [self.pool]),
  46. ("fc6", [self.fc6]),
  47. ("prob", [F.softmax]),
  48. ]
  49. return OrderedDict(links)
  50. class ResNet50(BaseResNet, L.ResNet50Layers):
  51. n_layers = 50
  52. class ResNet101(BaseResNet, L.ResNet101Layers):
  53. n_layers = 101
  54. class ResNet152(BaseResNet, L.ResNet152Layers):
  55. n_layers = 152