Browse Source

added support for ResNeXt models

Dimitri Korsch 2 years ago
parent
commit
2b35d82311
2 changed files with 3 additions and 1 deletions
  1. 2 0
      cvmodelz/models/factory.py
  2. 1 1
      cvmodelz/models/wrapper.py

+ 2 - 0
cvmodelz/models/factory.py

@@ -5,6 +5,7 @@ import pyaml
 from chainer import links as L
 from chainer import links as L
 from chainercv2.models import inceptionv3 as cv2inceptionv3
 from chainercv2.models import inceptionv3 as cv2inceptionv3
 from chainercv2.models import inceptionresnetv1 as cv2inceptionresnetv1
 from chainercv2.models import inceptionresnetv1 as cv2inceptionresnetv1
+from chainercv2.models import resnext as cv2resnext
 from chainercv2.models import resnet as cv2resnet
 from chainercv2.models import resnet as cv2resnet
 from chainercv.links.model import ssd
 from chainercv.links.model import ssd
 from chainercv.links.model import faster_rcnn
 from chainercv.links.model import faster_rcnn
@@ -35,6 +36,7 @@ class ModelFactory(abc.ABC):
 
 
 			cv2inceptionv3.inceptionv3,
 			cv2inceptionv3.inceptionv3,
 			cv2inceptionresnetv1.inceptionresnetv1,
 			cv2inceptionresnetv1.inceptionresnetv1,
+			cv2resnext.resnext50_32x4d,
 		),
 		),
 
 
 		cvmodelz=(
 		cvmodelz=(

+ 1 - 1
cvmodelz/models/wrapper.py

@@ -55,7 +55,7 @@ class ModelWrapper(BaseModel):
 				input_size=299,
 				input_size=299,
 			))
 			))
 
 
-		elif self.model_name == "ResNet":
+		elif self.model_name in ["ResNet", "ResNeXt"]:
 			info.update(dict(
 			info.update(dict(
 				input_size=224,
 				input_size=224,
 			))
 			))