|
|
@@ -9,6 +9,7 @@ from chainer_addons.models import ModelType
|
|
|
from chainer_addons.models import PrepareType
|
|
|
from chainer_addons.training import optimizer
|
|
|
from chainer_addons.training import optimizer_hooks
|
|
|
+from chainercv2.models import model_store
|
|
|
from cvdatasets.dataset.image import Size
|
|
|
from cvdatasets.utils import pretty_print_dict
|
|
|
from cvmodelz.models import ModelFactory
|
|
|
@@ -146,17 +147,24 @@ class _ModelMixin(abc.ABC):
|
|
|
return True, weights
|
|
|
|
|
|
def _default_weights(self, opts):
|
|
|
- ds_info = self.data_info
|
|
|
- model_info = self.model_info
|
|
|
+ if self.model_type.startswith("chainercv2"):
|
|
|
+ model_name = self.model_type.split(".")[-1]
|
|
|
+ return model_store.get_model_file(
|
|
|
+ model_name=model_name,
|
|
|
+ local_model_store_dir_path=str(Path.home() / ".chainer" / "models"))
|
|
|
+
|
|
|
+ else:
|
|
|
+ ds_info = self.data_info
|
|
|
+ model_info = self.model_info
|
|
|
|
|
|
- base_dir = Path(ds_info.BASE_DIR)
|
|
|
- weights_dir = base_dir / ds_info.MODEL_DIR / model_info.folder
|
|
|
+ base_dir = Path(ds_info.BASE_DIR)
|
|
|
+ weights_dir = base_dir / ds_info.MODEL_DIR / model_info.folder
|
|
|
|
|
|
- weights = model_info.weights
|
|
|
- assert opts.pre_training in weights, \
|
|
|
- f"Weights for \"{opts.pre_training}\" pre-training were not found!"
|
|
|
+ weights = model_info.weights
|
|
|
+ assert opts.pre_training in weights, \
|
|
|
+ f"Weights for \"{opts.pre_training}\" pre-training were not found!"
|
|
|
|
|
|
- return str(weights_dir / weights[opts.pre_training])
|
|
|
+ return str(weights_dir / weights[opts.pre_training])
|
|
|
|
|
|
|
|
|
def load_weights(self, opts) -> None:
|