|
|
@@ -77,7 +77,12 @@ class Trainer(T):
|
|
|
if no_observe: return
|
|
|
|
|
|
### Snapshotting ###
|
|
|
- self.setup_snapshots(opts, self.model, intervals.snapshot)
|
|
|
+ self.setup_snapshots(
|
|
|
+ enabled=not opts.no_snapshot,
|
|
|
+ obj=self.clf,
|
|
|
+ trigger=intervals.snapshot,
|
|
|
+ suffix="clf_epoch",
|
|
|
+ )
|
|
|
|
|
|
self.setup_reporter(opts, intervals.log, intervals.print)
|
|
|
self.setup_progress_bar(opts)
|
|
|
@@ -171,14 +176,15 @@ class Trainer(T):
|
|
|
return
|
|
|
self.extend(evaluator, trigger=trigger)
|
|
|
|
|
|
- def setup_snapshots(self, opts, obj, trigger):
|
|
|
+ def setup_snapshots(self, enabled: bool, obj: object, trigger, suffix: str = "ft_model_epoch"):
|
|
|
|
|
|
- if opts.no_snapshot:
|
|
|
+ if not enabled:
|
|
|
logging.warning("Models are not snapshot!")
|
|
|
- else:
|
|
|
- dump_fmt = "ft_model_epoch{0.updater.epoch:03d}.npz"
|
|
|
- self.extend(extensions.snapshot_object(obj, dump_fmt), trigger=trigger)
|
|
|
- logging.info("Snapshot format: \"{}\"".format(dump_fmt))
|
|
|
+ return
|
|
|
+
|
|
|
+ dump_fmt = suffix + "{0.updater.epoch:03d}.npz"
|
|
|
+ self.extend(extensions.snapshot_object(obj, dump_fmt), trigger=trigger)
|
|
|
+ logging.info("Snapshot format: \"{}\"".format(dump_fmt))
|
|
|
|
|
|
def eval_name(self, name):
|
|
|
if self.evaluator is None:
|