|
|
@@ -132,13 +132,19 @@ class _ModelMixin(abc.ABC):
|
|
|
msg = "Loading already fine-tuned weights from \"{}\""
|
|
|
loader_func = self.model.load_for_inference
|
|
|
else:
|
|
|
- self.weights = join(
|
|
|
- self.data_info.BASE_DIR,
|
|
|
- self.data_info.MODEL_DIR,
|
|
|
- self.model_info.folder,
|
|
|
- self.model_info.weights
|
|
|
- )
|
|
|
- msg = "Loading pre-trained weights \"{}\""
|
|
|
+ if args.weights:
|
|
|
+ msg = "Loading custom pre-trained weights \"{}\""
|
|
|
+ self.weights = args.weights
|
|
|
+
|
|
|
+ else:
|
|
|
+ msg = "Loading default pre-trained weights \"{}\""
|
|
|
+ self.weights = join(
|
|
|
+ self.data_info.BASE_DIR,
|
|
|
+ self.data_info.MODEL_DIR,
|
|
|
+ self.model_info.folder,
|
|
|
+ self.model_info.weights
|
|
|
+ )
|
|
|
+
|
|
|
loader_func = self.model.load_for_finetune
|
|
|
|
|
|
logging.info(msg.format(self.weights))
|