resnet.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. import chainer
  2. from chainer import functions as F
  3. from chainer import links as L
  4. from chainer.links.model.vision.resnet import BuildingBlock
  5. from chainer.links.model.vision.resnet import _global_average_pooling_2d
  6. from collections import OrderedDict
  7. from functools import partial
  8. from cvmodelz.models.meta_info import ModelInfo
  9. from cvmodelz.models.pretrained.base import PretrainedModelMixin
  10. class BaseResNet(PretrainedModelMixin):
  11. n_layers = ""
  12. def __init__(self, *args, **kwargs):
  13. super(BaseResNet, self).__init__(*args, pooling=_global_average_pooling_2d, **kwargs)
  14. self.meta = ModelInfo(
  15. name=f"ResNet{self.n_layers}",
  16. input_size=224,
  17. feature_size=2048,
  18. n_conv_maps=2048,
  19. conv_map_layer="res5",
  20. feature_layer="pool5",
  21. classifier_layers=["fc6"],
  22. )
  23. class ResNet35(BaseResNet, chainer.Chain):
  24. n_layers = 35
  25. def init_extra_layers(self, n_classes, **kwargs):
  26. self.conv1 = L.Convolution2D(3, 64, 7, 2, 3, **kwargs)
  27. self.bn1 = L.BatchNormalization(64)
  28. self.res2 = BuildingBlock(2, 64, 64, 256, 1, **kwargs)
  29. self.res3 = BuildingBlock(3, 256, 128, 512, 2, **kwargs)
  30. self.res4 = BuildingBlock(3, 512, 256, 1024, 2, **kwargs)
  31. self.res5 = BuildingBlock(3, 1024, 512, 2048, 2, **kwargs)
  32. self.fc6 = L.Linear(2048, n_classes)
  33. @property
  34. def functions(self):
  35. links = [
  36. ("conv1", [self.conv1, self.bn1, F.relu]),
  37. ("pool1", [partial(F.max_pooling_2d, ksize=3, stride=2)]),
  38. ("res2", [self.res2]),
  39. ("res3", [self.res3]),
  40. ("res4", [self.res4]),
  41. ("res5", [self.res5]),
  42. ("pool5", [self.pool]),
  43. ("fc6", [self.fc6]),
  44. ("prob", [F.softmax]),
  45. ]
  46. return OrderedDict(links)
  47. class ResNet50(BaseResNet, L.ResNet50Layers):
  48. n_layers = 50
  49. class ResNet101(BaseResNet, L.ResNet101Layers):
  50. n_layers = 101
  51. class ResNet152(BaseResNet, L.ResNet152Layers):
  52. n_layers = 152