Browse Source

added "ResNet{35,50}HD" versions by reducing strides at 2 stages in the CNN

Dimitri Korsch 3 years ago
parent
commit
34900f98ec

+ 4 - 0
cvmodelz/models/__init__.py

@@ -4,7 +4,9 @@ from cvmodelz.models.pretrained import InceptionV3
 from cvmodelz.models.pretrained import ResNet101
 from cvmodelz.models.pretrained import ResNet152
 from cvmodelz.models.pretrained import ResNet35
+from cvmodelz.models.pretrained import ResNet35HD
 from cvmodelz.models.pretrained import ResNet50
+from cvmodelz.models.pretrained import ResNet50HD
 from cvmodelz.models.pretrained import VGG16
 from cvmodelz.models.pretrained import VGG19
 
@@ -15,7 +17,9 @@ __all__ = [
 	"InceptionV3",
 
 	"ResNet50",
+	"ResNet50HD",
 	"ResNet35",
+	"ResNet35HD",
 	"ResNet101",
 	"ResNet152",
 

+ 2 - 0
cvmodelz/models/factory.py

@@ -40,7 +40,9 @@ class ModelFactory(abc.ABC):
 			pretrained.VGG19,
 
 			pretrained.ResNet35,
+			pretrained.ResNet35HD,
 			pretrained.ResNet50,
+			pretrained.ResNet50HD,
 			pretrained.ResNet101,
 			pretrained.ResNet152,
 

+ 4 - 0
cvmodelz/models/pretrained/__init__.py

@@ -4,7 +4,9 @@ from cvmodelz.models.pretrained.inception import InceptionV3HD
 from cvmodelz.models.pretrained.resnet import ResNet101
 from cvmodelz.models.pretrained.resnet import ResNet152
 from cvmodelz.models.pretrained.resnet import ResNet35
+from cvmodelz.models.pretrained.resnet import ResNet35HD
 from cvmodelz.models.pretrained.resnet import ResNet50
+from cvmodelz.models.pretrained.resnet import ResNet50HD
 from cvmodelz.models.pretrained.vgg import VGG16
 from cvmodelz.models.pretrained.vgg import VGG19
 
@@ -16,7 +18,9 @@ __all__ = [
 	"VGG19",
 
 	"ResNet35",
+	"ResNet35HD",
 	"ResNet50",
+	"ResNet50HD",
 	"ResNet101",
 	"ResNet152",
 

+ 15 - 0
cvmodelz/models/pretrained/resnet.py

@@ -42,6 +42,16 @@ class BaseResNet(PretrainedModelMixin):
 	def functions(self):
 		return super().functions
 
+class ResNetHDMixin:
+
+	def __init__(self, *args, **kwargs):
+		super().__init__(*args, **kwargs)
+		self.res4.a.conv1.stride = (1, 1)
+		self.res4.a.conv4.stride = (1, 1)
+
+		self.res5.a.conv1.stride = (1, 1)
+		self.res5.a.conv4.stride = (1, 1)
+
 """
 We need this to "extract" pretrained_model argument,
 otherwise it would be passed to the constructor of the
@@ -90,9 +100,14 @@ class ResNet35(BaseResNet, ResNet35Layers):
 				return x
 
 
+class ResNet35HD(ResNetHDMixin, ResNet35):
+	pass
+
 class ResNet50(BaseResNet, L.ResNet50Layers):
 	n_layers = 50
 
+class ResNet50HD(ResNetHDMixin, ResNet50):
+	pass
 
 class ResNet101(BaseResNet, L.ResNet101Layers):
 	n_layers = 101