|
|
@@ -60,20 +60,23 @@ class _ModelMixin(abc.ABC):
|
|
|
]))
|
|
|
|
|
|
def _loss_func(self, opts):
|
|
|
- if opts.l1_loss:
|
|
|
+ if getattr(opts, "l1_loss", False):
|
|
|
return F.hinge
|
|
|
|
|
|
- elif opts.label_smoothing >= 0:
|
|
|
- assert opts.label_smoothing < 1, \
|
|
|
+ elif getattr(opts, "label_smoothing", 0) >= 0:
|
|
|
+ assert getattr(opts, "label_smoothing", 0) < 1, \
|
|
|
"Label smoothing factor must be less than 1!"
|
|
|
return partial(smoothed_cross_entropy,
|
|
|
N=self.n_classes,
|
|
|
- eps=opts.label_smoothing)
|
|
|
+ eps=getattr(opts, "label_smoothing", 0))
|
|
|
else:
|
|
|
return F.softmax_cross_entropy
|
|
|
|
|
|
def init_optimizer(self, opts):
|
|
|
"""Creates an optimizer for the classifier """
|
|
|
+ if not hasattr(opts, "optimizer"):
|
|
|
+ self.opt = None
|
|
|
+ return
|
|
|
|
|
|
opt_kwargs = {}
|
|
|
if opts.optimizer == "rmsprop":
|
|
|
@@ -100,7 +103,7 @@ class _ModelMixin(abc.ABC):
|
|
|
logging.info(f"Adding {reg_cls.__name__} ({opts.decay:e})")
|
|
|
self.opt.add_hook(reg_cls(opts.decay, **reg_kwargs))
|
|
|
|
|
|
- if opts.only_head:
|
|
|
+ if getattr(opts, "only_head", False):
|
|
|
assert not opts.recurrent, "FIX ME! Not supported yet!"
|
|
|
|
|
|
logging.warning("========= Fine-tuning only classifier layer! =========")
|
|
|
@@ -123,7 +126,7 @@ class _ModelMixin(abc.ABC):
|
|
|
)
|
|
|
|
|
|
def load_model_weights(self, args):
|
|
|
- if args.from_scratch:
|
|
|
+ if getattr(args, "from_scratch", False):
|
|
|
logging.info("Training a {0.__class__.__name__} model from scratch!".format(self.model))
|
|
|
loader = self.model.reinitialize_clf
|
|
|
self.weights = None
|
|
|
@@ -200,12 +203,12 @@ class _DatasetMixin(abc.ABC):
|
|
|
# no_glob=opts.no_global,
|
|
|
# ))
|
|
|
|
|
|
- if not opts.only_head:
|
|
|
+ if not getattr(opts, "only_head", False):
|
|
|
kwargs.update(dict(
|
|
|
preprocess=self.prepare,
|
|
|
augment=augment,
|
|
|
size=size,
|
|
|
- center_crop_on_val=not opts.no_center_crop_on_val,
|
|
|
+ center_crop_on_val=not getattr(opts, "no_center_crop_on_val", False),
|
|
|
|
|
|
))
|
|
|
|
|
|
@@ -225,7 +228,7 @@ class _DatasetMixin(abc.ABC):
|
|
|
self.model_info = self.data_info.MODELS[opts.model_type]
|
|
|
self.part_info = self.data_info.PARTS[opts.parts]
|
|
|
|
|
|
- if opts.only_head:
|
|
|
+ if getattr(opts, "only_head", False):
|
|
|
self.annot.feature_model = opts.model_type
|
|
|
|
|
|
self.dataset_cls.label_shift = opts.label_shift
|
|
|
@@ -237,7 +240,7 @@ class _DatasetMixin(abc.ABC):
|
|
|
|
|
|
self.prepare = partial(PrepareType[opts.prepare_type](self.model),
|
|
|
swap_channels=opts.swap_channels,
|
|
|
- keep_ratio=not opts.no_center_crop_on_val,
|
|
|
+ keep_ratio=not getattr(opts, "no_center_crop_on_val", False),
|
|
|
)
|
|
|
|
|
|
logging.info(" ".join([
|
|
|
@@ -281,6 +284,11 @@ class _TrainerMixin(abc.ABC):
|
|
|
|
|
|
def init_updater(self):
|
|
|
"""Creates an updater from training iterator and the optimizer."""
|
|
|
+
|
|
|
+ if self.opt is None:
|
|
|
+ self.updater = None
|
|
|
+ return
|
|
|
+
|
|
|
self.updater = self.updater_cls(
|
|
|
iterator=self.train_iter,
|
|
|
optimizer=self.opt,
|