|
@@ -10,8 +10,11 @@ from chainer_addons.training import optimizer
|
|
|
from chainer_addons.training import optimizer_hooks
|
|
|
from cvdatasets.dataset.image import Size
|
|
|
from cvdatasets.utils import pretty_print_dict
|
|
|
+from cvmodelz.models import ModelFactory
|
|
|
from functools import partial
|
|
|
from pathlib import Path
|
|
|
+from typing import Callable
|
|
|
+from typing import Tuple
|
|
|
|
|
|
|
|
|
class _ModelMixin(abc.ABC):
|
|
@@ -32,6 +35,14 @@ class _ModelMixin(abc.ABC):
|
|
|
def model_info(self):
|
|
|
return self.data_info.MODELS[self.model_type]
|
|
|
|
|
|
+ def init_model(self, opts):
|
|
|
+ """creates backbone CNN model. This model is wrapped around the classifier later"""
|
|
|
+
|
|
|
+ self.model = ModelFactory.new(self.model_type,
|
|
|
+ input_size=Size(opts.input_size),
|
|
|
+ **self.model_kwargs
|
|
|
+ )
|
|
|
+
|
|
|
def init_classifier(self, opts):
|
|
|
|
|
|
clf_class, kwargs = self.classifier_cls, self.classifier_kwargs
|
|
@@ -97,63 +108,59 @@ class _ModelMixin(abc.ABC):
|
|
|
logging.warning("========= Fine-tuning only classifier layer! =========")
|
|
|
enable_only_head(self.clf)
|
|
|
|
|
|
- def init_model(self, opts):
|
|
|
- """creates backbone CNN model. This model is wrapped around the classifier later"""
|
|
|
|
|
|
- if self.model_type.startswith("cv2_"):
|
|
|
- model_type = args.model_type.split("cv2_")[-1]
|
|
|
+ def _get_loader(self, opts) -> Tuple[bool, str]:
|
|
|
+ if getattr(opts, "from_scratch", False):
|
|
|
+ logging.info("Training a {0.__class__.__name__} model from scratch!".format(self.model))
|
|
|
+ return None, None
|
|
|
+
|
|
|
+ if getattr(opts, "load", None):
|
|
|
+ weights = getattr(opts, "load", None)
|
|
|
+ logging.info(f"Loading already fine-tuned weights from \"{weights}\"")
|
|
|
+ return False, weights
|
|
|
+
|
|
|
+ elif getattr(opts, "weights", None):
|
|
|
+ weights = getattr(opts, "weights", None)
|
|
|
+ logging.info(f"Loading custom fine-tuned weights from \"{weights}\"")
|
|
|
+ return True, weights
|
|
|
+
|
|
|
else:
|
|
|
- model_type = self.model_info.class_key
|
|
|
+ weights = self._default_weights(opts)
|
|
|
+ logging.info(f"Loading custom fine-tuned weights from \"{weights}\"")
|
|
|
+ return True, weights
|
|
|
|
|
|
- self.model = ModelType.new(
|
|
|
- model_type=model_type,
|
|
|
- input_size=Size(opts.input_size),
|
|
|
- **self.model_kwargs,
|
|
|
+ def _default_weights(self, opts):
|
|
|
+ ds_info = self.data_info
|
|
|
+ model_info = self.model_info
|
|
|
+
|
|
|
+ base_dir = Path(ds_info.BASE_DIR)
|
|
|
+ weights_dir = base_dir / ds_info.MODEL_DIR / model_info.folder
|
|
|
+
|
|
|
+ weights = model_info.weights
|
|
|
+ ### TODO: make pre-training command line argument!
|
|
|
+ return str(weights_dir / weights.get("inat", weights.get("imagenet")))
|
|
|
+
|
|
|
+
|
|
|
+ def load_weights(self, opts) -> None:
|
|
|
+
|
|
|
+ finetune, weights = self._get_loader(opts)
|
|
|
+
|
|
|
+ self.clf.load(weights,
|
|
|
+ n_classes=self.n_classes,
|
|
|
+ finetune=finetune,
|
|
|
+
|
|
|
+ path=opts.load_path,
|
|
|
+ strict=opts.load_strict,
|
|
|
+ headless=opts.headless
|
|
|
)
|
|
|
|
|
|
- def load_model_weights(self, args):
|
|
|
- if getattr(args, "from_scratch", False):
|
|
|
- logging.info("Training a {0.__class__.__name__} model from scratch!".format(self.model))
|
|
|
- loader = self.model.reinitialize_clf
|
|
|
- self.weights = None
|
|
|
- else:
|
|
|
- if args.load:
|
|
|
- self.weights = args.load
|
|
|
- msg = "Loading already fine-tuned weights from \"{}\""
|
|
|
- loader_func = self.model.load_for_inference
|
|
|
- else:
|
|
|
- if args.weights:
|
|
|
- msg = "Loading custom pre-trained weights \"{}\""
|
|
|
- self.weights = args.weights
|
|
|
-
|
|
|
- else:
|
|
|
- msg = "Loading default pre-trained weights \"{}\""
|
|
|
- self.weights = str(Path(
|
|
|
- self.data_info.BASE_DIR,
|
|
|
- self.data_info.MODEL_DIR,
|
|
|
- self.model_info.folder,
|
|
|
- self.model_info.weights
|
|
|
- ))
|
|
|
-
|
|
|
- loader_func = self.model.load_for_finetune
|
|
|
-
|
|
|
- logging.info(msg.format(self.weights))
|
|
|
- kwargs = dict(
|
|
|
- weights=self.weights,
|
|
|
- strict=args.load_strict,
|
|
|
- path=args.load_path,
|
|
|
- headless=args.headless,
|
|
|
- )
|
|
|
- loader = partial(loader_func, **kwargs)
|
|
|
+ self.clf.cleargrads()
|
|
|
|
|
|
feat_size = self.model.meta.feature_size
|
|
|
|
|
|
if hasattr(self.clf, "output_size"):
|
|
|
feat_size = self.clf.output_size
|
|
|
|
|
|
- if hasattr(self.clf, "loader"):
|
|
|
- loader = self.clf.loader(loader)
|
|
|
+ ### TODO: handle feature size!
|
|
|
|
|
|
logging.info(f"Part features size after encoding: {feat_size}")
|
|
|
- loader(n_classes=self.n_classes, feat_size=feat_size)
|
|
|
- self.clf.cleargrads()
|