浏览代码

added support for inceptionresnetv1

Dimitri Korsch 2 年之前
父节点
当前提交
2d1f2b9b80
共有 2 个文件被更改,包括 36 次插入10 次删除
  1. 2 0
      cvmodelz/models/factory.py
  2. 34 10
      cvmodelz/models/wrapper.py

+ 2 - 0
cvmodelz/models/factory.py

@@ -4,6 +4,7 @@ import pyaml
 
 from chainer import links as L
 from chainercv2.models import inceptionv3 as cv2inceptionv3
+from chainercv2.models import inceptionresnetv1 as cv2inceptionresnetv1
 from chainercv2.models import resnet as cv2resnet
 from chainercv.links.model import ssd
 from chainercv.links.model import faster_rcnn
@@ -33,6 +34,7 @@ class ModelFactory(abc.ABC):
 			cv2resnet.resnet18,
 
 			cv2inceptionv3.inceptionv3,
+			cv2inceptionresnetv1.inceptionresnetv1,
 		),
 
 		cvmodelz=(

+ 34 - 10
cvmodelz/models/wrapper.py

@@ -7,7 +7,6 @@ from cvmodelz.models.base import BaseModel
 from cvmodelz.models.meta_info import ModelInfo
 
 
-
 class ModelWrapper(BaseModel):
 	"""
 		This class is designed to wrap around chainercv2 models
@@ -16,11 +15,11 @@ class ModelWrapper(BaseModel):
 	"""
 
 	def __init__(self, model: chainer.Chain, *args, **kwargs):
-		super().__init__(*args, **kwargs)
-
 		name = model.__class__.__name__
 		self.__class__.__name__ = name
-		self.meta.name = name
+		self.model_name = name
+		super().__init__(*args, **kwargs)
+
 
 		if hasattr(model, "meta"):
 			self.meta = model.meta
@@ -29,15 +28,40 @@ class ModelWrapper(BaseModel):
 			self.wrapped = model
 			delattr(self.wrapped.features, "final_pool")
 
-		self.meta.feature_size = self.clf_layer.W.shape[-1]
 
 	def init_model_info(self):
-		self.meta = ModelInfo(
-			classifier_layers=("output/fc",),
+		info = dict(
+			name=self.model_name,
+			feature_size=2048,
+			n_conv_maps=2048,
+			classifier_layers=["output/fc"],
 			conv_map_layer="features",
 			feature_layer="pool",
 		)
 
+		if self.model_name == "InceptionResNetV1":
+			info.update(dict(
+				input_size=299,
+				feature_size=1792,
+				n_conv_maps=1792,
+				classifier_layers=[
+					"output/fc1",
+					"output/fc2"
+				],
+			))
+
+		elif self.model_name == "InceptionV3":
+			info.update(dict(
+				input_size=299,
+			))
+
+		elif self.model_name == "ResNet":
+			info.update(dict(
+				input_size=224,
+			))
+
+		self.meta = ModelInfo(**info)
+
 	@property
 	def model_instance(self) -> chainer.Chain:
 		return self.wrapped
@@ -46,9 +70,9 @@ class ModelWrapper(BaseModel):
 	def functions(self) -> OrderedDict:
 
 		links = [
-			("features", [self.wrapped.features]),
-			("pool", [self.pool]),
-			("output/fc", [self.wrapped.output.fc]),
+			(self.meta.conv_map_layer, [self.wrapped.features]),
+			(self.meta.feature_layer, [self.pool]),
+			(self.clf_layer_name, [self.wrapped.output]),
 		]
 
 		return OrderedDict(links)