|
@@ -15,22 +15,22 @@ class PretrainedModelMixin(BaseModel):
|
|
...
|
|
...
|
|
"""
|
|
"""
|
|
|
|
|
|
- def __init__(self, n_classes: int = 1000, *args, **kwargs):
|
|
+ def __init__(self, *args, n_classes: int = 1000, pretrained_model: str = None, **kwargs):
|
|
from cvmodelz.models import ModelFactory
|
|
from cvmodelz.models import ModelFactory
|
|
|
|
|
|
if ModelFactory.is_chainer_model(self):
|
|
if ModelFactory.is_chainer_model(self):
|
|
- kwargs["pretrained_model"] = kwargs.get("pretrained_model", None)
|
|
+ kwargs["pretrained_model"] = pretrained_model
|
|
|
|
|
|
super(PretrainedModelMixin, self).__init__(*args, **kwargs)
|
|
super(PretrainedModelMixin, self).__init__(*args, **kwargs)
|
|
|
|
|
|
with self.init_scope():
|
|
with self.init_scope():
|
|
self.init_extra_layers(n_classes)
|
|
self.init_extra_layers(n_classes)
|
|
|
|
|
|
- def __call__(self, X, layer_name=None):
|
|
+ def forward(self, X, layer_name=None):
|
|
assert hasattr(self, "meta"), "Did you forgot to initialize the meta attribute?"
|
|
assert hasattr(self, "meta"), "Did you forgot to initialize the meta attribute?"
|
|
|
|
|
|
layer_name = layer_name or self.meta.classifier_layers[-1]
|
|
layer_name = layer_name or self.meta.classifier_layers[-1]
|
|
- caller = super(PretrainedModelMixin, self).__call__
|
|
+ caller = super(PretrainedModelMixin, self).forward
|
|
activations = caller(X, layers=[layer_name])
|
|
activations = caller(X, layers=[layer_name])
|
|
|
|
|
|
if isinstance(activations, dict):
|
|
if isinstance(activations, dict):
|