Selaa lähdekoodia

added loading of default chainercv2 weights

Dimitri Korsch 4 vuotta sitten
vanhempi
commit
a7f5c1e19b
1 muutettua tiedostoa jossa 16 lisäystä ja 8 poistoa
  1. 16 8
      cvfinetune/finetuner/mixins/model.py

+ 16 - 8
cvfinetune/finetuner/mixins/model.py

@@ -9,6 +9,7 @@ from chainer_addons.models import ModelType
 from chainer_addons.models import PrepareType
 from chainer_addons.training import optimizer
 from chainer_addons.training import optimizer_hooks
+from chainercv2.models import model_store
 from cvdatasets.dataset.image import Size
 from cvdatasets.utils import pretty_print_dict
 from cvmodelz.models import ModelFactory
@@ -146,17 +147,24 @@ class _ModelMixin(abc.ABC):
 			return True, weights
 
 	def _default_weights(self, opts):
-		ds_info = self.data_info
-		model_info = self.model_info
+		if self.model_type.startswith("chainercv2"):
+			model_name = self.model_type.split(".")[-1]
+			return model_store.get_model_file(
+				model_name=model_name,
+				local_model_store_dir_path=str(Path.home() / ".chainer" / "models"))
+
+		else:
+			ds_info = self.data_info
+			model_info = self.model_info
 
-		base_dir = Path(ds_info.BASE_DIR)
-		weights_dir = base_dir / ds_info.MODEL_DIR / model_info.folder
+			base_dir = Path(ds_info.BASE_DIR)
+			weights_dir = base_dir / ds_info.MODEL_DIR / model_info.folder
 
-		weights = model_info.weights
-		assert opts.pre_training in weights, \
-			f"Weights for \"{opts.pre_training}\" pre-training were not found!"
+			weights = model_info.weights
+			assert opts.pre_training in weights, \
+				f"Weights for \"{opts.pre_training}\" pre-training were not found!"
 
-		return str(weights_dir / weights[opts.pre_training])
+			return str(weights_dir / weights[opts.pre_training])
 
 
 	def load_weights(self, opts) -> None: