Эх сурвалжийг харах

changed the snapshot object to the classifier (it just makes more sence)

Dimitri Korsch 3 жил өмнө
parent
commit
54fe82ab47

+ 13 - 7
cvfinetune/training/trainer/base.py

@@ -77,7 +77,12 @@ class Trainer(T):
 		if no_observe: return
 
 		### Snapshotting ###
-		self.setup_snapshots(opts, self.model, intervals.snapshot)
+		self.setup_snapshots(
+			enabled=not opts.no_snapshot,
+			obj=self.clf,
+			trigger=intervals.snapshot,
+			suffix="clf_epoch",
+		)
 
 		self.setup_reporter(opts, intervals.log, intervals.print)
 		self.setup_progress_bar(opts)
@@ -171,14 +176,15 @@ class Trainer(T):
 			return
 		self.extend(evaluator, trigger=trigger)
 
-	def setup_snapshots(self, opts, obj, trigger):
+	def setup_snapshots(self, enabled: bool, obj: object, trigger, suffix: str = "ft_model_epoch"):
 
-		if opts.no_snapshot:
+		if not enabled:
 			logging.warning("Models are not snapshot!")
-		else:
-			dump_fmt = "ft_model_epoch{0.updater.epoch:03d}.npz"
-			self.extend(extensions.snapshot_object(obj, dump_fmt), trigger=trigger)
-			logging.info("Snapshot format: \"{}\"".format(dump_fmt))
+			return
+
+		dump_fmt = suffix + "{0.updater.epoch:03d}.npz"
+		self.extend(extensions.snapshot_object(obj, dump_fmt), trigger=trigger)
+		logging.info("Snapshot format: \"{}\"".format(dump_fmt))
 
 	def eval_name(self, name):
 		if self.evaluator is None: