|
@@ -42,6 +42,16 @@ class BaseResNet(PretrainedModelMixin):
|
|
def functions(self):
|
|
def functions(self):
|
|
return super().functions
|
|
return super().functions
|
|
|
|
|
|
|
|
+class ResNetHDMixin:
|
|
|
|
+
|
|
|
|
+ def __init__(self, *args, **kwargs):
|
|
|
|
+ super().__init__(*args, **kwargs)
|
|
|
|
+ self.res4.a.conv1.stride = (1, 1)
|
|
|
|
+ self.res4.a.conv4.stride = (1, 1)
|
|
|
|
+
|
|
|
|
+ self.res5.a.conv1.stride = (1, 1)
|
|
|
|
+ self.res5.a.conv4.stride = (1, 1)
|
|
|
|
+
|
|
"""
|
|
"""
|
|
We need this to "extract" pretrained_model argument,
|
|
We need this to "extract" pretrained_model argument,
|
|
otherwise it would be passed to the constructor of the
|
|
otherwise it would be passed to the constructor of the
|
|
@@ -90,9 +100,14 @@ class ResNet35(BaseResNet, ResNet35Layers):
|
|
return x
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
+class ResNet35HD(ResNetHDMixin, ResNet35):
|
|
|
|
+ pass
|
|
|
|
+
|
|
class ResNet50(BaseResNet, L.ResNet50Layers):
|
|
class ResNet50(BaseResNet, L.ResNet50Layers):
|
|
n_layers = 50
|
|
n_layers = 50
|
|
|
|
|
|
|
|
+class ResNet50HD(ResNetHDMixin, ResNet50):
|
|
|
|
+ pass
|
|
|
|
|
|
class ResNet101(BaseResNet, L.ResNet101Layers):
|
|
class ResNet101(BaseResNet, L.ResNet101Layers):
|
|
n_layers = 101
|
|
n_layers = 101
|