|
@@ -1,11 +1,15 @@
|
|
import chainer
|
|
import chainer
|
|
|
|
|
|
from chainer import functions as F
|
|
from chainer import functions as F
|
|
|
|
+from chainer_addons.links.pooling import PoolingType # TODO: replace this!
|
|
|
|
+from collections import OrderedDict
|
|
from typing import Callable
|
|
from typing import Callable
|
|
|
|
|
|
from cvmodelz.models.base import BaseModel
|
|
from cvmodelz.models.base import BaseModel
|
|
from cvmodelz.models.meta_info import ModelInfo
|
|
from cvmodelz.models.meta_info import ModelInfo
|
|
|
|
|
|
|
|
+
|
|
|
|
+
|
|
class ModelWrapper(BaseModel, chainer.Chain):
|
|
class ModelWrapper(BaseModel, chainer.Chain):
|
|
"""
|
|
"""
|
|
This class is designed to wrap around chainercv2 models
|
|
This class is designed to wrap around chainercv2 models
|
|
@@ -13,7 +17,7 @@ class ModelWrapper(BaseModel, chainer.Chain):
|
|
The wrapped model is stored under self.wrapped
|
|
The wrapped model is stored under self.wrapped
|
|
"""
|
|
"""
|
|
|
|
|
|
- def __init__(self, model: chainer.Chain, pooling: Callable = F.identity):
|
|
|
|
|
|
+ def __init__(self, model: chainer.Chain, pooling: Callable = PoolingType.G_AVG.value()):
|
|
super(ModelWrapper, self).__init__()
|
|
super(ModelWrapper, self).__init__()
|
|
|
|
|
|
name = model.__class__.__name__
|
|
name = model.__class__.__name__
|
|
@@ -26,7 +30,7 @@ class ModelWrapper(BaseModel, chainer.Chain):
|
|
self.meta = ModelInfo(
|
|
self.meta = ModelInfo(
|
|
name=name,
|
|
name=name,
|
|
classifier_layers=("output/fc",),
|
|
classifier_layers=("output/fc",),
|
|
- conv_map_layer="stage4",
|
|
|
|
|
|
+ conv_map_layer="features",
|
|
feature_layer="pool",
|
|
feature_layer="pool",
|
|
)
|
|
)
|
|
|
|
|
|
@@ -38,9 +42,20 @@ class ModelWrapper(BaseModel, chainer.Chain):
|
|
self.meta.feature_size = self.clf_layer.W.shape[-1]
|
|
self.meta.feature_size = self.clf_layer.W.shape[-1]
|
|
|
|
|
|
@property
|
|
@property
|
|
- def model_instance(self):
|
|
|
|
|
|
+ def model_instance(self) -> chainer.Chain:
|
|
return self.wrapped
|
|
return self.wrapped
|
|
|
|
|
|
|
|
+ @property
|
|
|
|
+ def functions(self) -> OrderedDict:
|
|
|
|
+
|
|
|
|
+ links = [
|
|
|
|
+ ("features", [self.wrapped.features]),
|
|
|
|
+ ("pool", [self.pool]),
|
|
|
|
+ ("output/fc", [self.wrapped.output.fc]),
|
|
|
|
+ ]
|
|
|
|
+
|
|
|
|
+ return OrderedDict(links)
|
|
|
|
+
|
|
def load_for_inference(self, *args, path="", **kwargs):
|
|
def load_for_inference(self, *args, path="", **kwargs):
|
|
return super(ModelWrapper, self).load_for_inference(*args, path=f"{path}wrapped/", **kwargs)
|
|
return super(ModelWrapper, self).load_for_inference(*args, path=f"{path}wrapped/", **kwargs)
|
|
|
|
|