|
@@ -7,7 +7,6 @@ from cvmodelz.models.base import BaseModel
|
|
|
from cvmodelz.models.meta_info import ModelInfo
|
|
|
|
|
|
|
|
|
-
|
|
|
class ModelWrapper(BaseModel):
|
|
|
"""
|
|
|
This class is designed to wrap around chainercv2 models
|
|
@@ -16,11 +15,11 @@ class ModelWrapper(BaseModel):
|
|
|
"""
|
|
|
|
|
|
def __init__(self, model: chainer.Chain, *args, **kwargs):
|
|
|
- super().__init__(*args, **kwargs)
|
|
|
-
|
|
|
name = model.__class__.__name__
|
|
|
self.__class__.__name__ = name
|
|
|
- self.meta.name = name
|
|
|
+ self.model_name = name
|
|
|
+ super().__init__(*args, **kwargs)
|
|
|
+
|
|
|
|
|
|
if hasattr(model, "meta"):
|
|
|
self.meta = model.meta
|
|
@@ -29,15 +28,40 @@ class ModelWrapper(BaseModel):
|
|
|
self.wrapped = model
|
|
|
delattr(self.wrapped.features, "final_pool")
|
|
|
|
|
|
- self.meta.feature_size = self.clf_layer.W.shape[-1]
|
|
|
|
|
|
def init_model_info(self):
|
|
|
- self.meta = ModelInfo(
|
|
|
- classifier_layers=("output/fc",),
|
|
|
+ info = dict(
|
|
|
+ name=self.model_name,
|
|
|
+ feature_size=2048,
|
|
|
+ n_conv_maps=2048,
|
|
|
+ classifier_layers=["output/fc"],
|
|
|
conv_map_layer="features",
|
|
|
feature_layer="pool",
|
|
|
)
|
|
|
|
|
|
+ if self.model_name == "InceptionResNetV1":
|
|
|
+ info.update(dict(
|
|
|
+ input_size=299,
|
|
|
+ feature_size=1792,
|
|
|
+ n_conv_maps=1792,
|
|
|
+ classifier_layers=[
|
|
|
+ "output/fc1",
|
|
|
+ "output/fc2"
|
|
|
+ ],
|
|
|
+ ))
|
|
|
+
|
|
|
+ elif self.model_name == "InceptionV3":
|
|
|
+ info.update(dict(
|
|
|
+ input_size=299,
|
|
|
+ ))
|
|
|
+
|
|
|
+ elif self.model_name == "ResNet":
|
|
|
+ info.update(dict(
|
|
|
+ input_size=224,
|
|
|
+ ))
|
|
|
+
|
|
|
+ self.meta = ModelInfo(**info)
|
|
|
+
|
|
|
@property
|
|
|
def model_instance(self) -> chainer.Chain:
|
|
|
return self.wrapped
|
|
@@ -46,9 +70,9 @@ class ModelWrapper(BaseModel):
|
|
|
def functions(self) -> OrderedDict:
|
|
|
|
|
|
links = [
|
|
|
- ("features", [self.wrapped.features]),
|
|
|
- ("pool", [self.pool]),
|
|
|
- ("output/fc", [self.wrapped.output.fc]),
|
|
|
+ (self.meta.conv_map_layer, [self.wrapped.features]),
|
|
|
+ (self.meta.feature_layer, [self.pool]),
|
|
|
+ (self.clf_layer_name, [self.wrapped.output]),
|
|
|
]
|
|
|
|
|
|
return OrderedDict(links)
|