|
@@ -1,9 +1,5 @@
|
|
|
import abc
|
|
|
|
|
|
-from chainer import functions as F
|
|
|
-from typing import Callable
|
|
|
-
|
|
|
-from cvmodelz import models
|
|
|
from cvmodelz.models.base import BaseModel
|
|
|
from cvmodelz.models.meta_info import ModelInfo
|
|
|
|
|
@@ -19,16 +15,16 @@ class PretrainedModelMixin(BaseModel):
|
|
|
...
|
|
|
"""
|
|
|
|
|
|
- def __init__(self, n_classes: int = 1000, pooling: Callable = F.identity, *args, **kwargs):
|
|
|
+ def __init__(self, n_classes: int = 1000, *args, **kwargs):
|
|
|
+ from cvmodelz.models import ModelFactory
|
|
|
|
|
|
- if models.is_chainer_model(self):
|
|
|
- kwargs["pretrained_model"] = None
|
|
|
+ if ModelFactory.is_chainer_model(self):
|
|
|
+ kwargs["pretrained_model"] = kwargs.get("pretrained_model", None)
|
|
|
|
|
|
super(PretrainedModelMixin, self).__init__(*args, **kwargs)
|
|
|
|
|
|
with self.init_scope():
|
|
|
self.init_extra_layers(n_classes)
|
|
|
- self.pool = pooling
|
|
|
|
|
|
def __call__(self, X, layer_name=None):
|
|
|
assert hasattr(self, "meta"), "Did you forgot to initialize the meta attribute?"
|
|
@@ -45,13 +41,6 @@ class PretrainedModelMixin(BaseModel):
|
|
|
def init_extra_layers(self, *args, **kwargs):
|
|
|
pass
|
|
|
|
|
|
- # @abc.abstractproperty
|
|
|
- # def _links(self):
|
|
|
- # raise NotImplementedError()
|
|
|
-
|
|
|
- # @property
|
|
|
- # def functions(self):
|
|
|
- # return OrderedDict(self._links)
|
|
|
|
|
|
@property
|
|
|
def model_instance(self):
|