|
|
@@ -55,7 +55,7 @@ class Trainer(T):
|
|
|
self.setup_lr_schedule(
|
|
|
lr=opts.learning_rate,
|
|
|
lr_target=opts.lr_target,
|
|
|
- lr_shift_trigger=(opts.lr_shift, "epochs"),
|
|
|
+ lr_shift_trigger=(opts.lr_shift, "epoch"),
|
|
|
lr_decrease_rate=opts.lr_decrease_rate,
|
|
|
|
|
|
# needed for cosine annealing
|
|
|
@@ -129,8 +129,9 @@ class Trainer(T):
|
|
|
|
|
|
def setup_warm_up(self, epochs: int, after_warm_up_lr: float, warm_up_lr: float):
|
|
|
|
|
|
- if epochs == 0:
|
|
|
+ if epochs is None or epochs == 0:
|
|
|
return
|
|
|
+
|
|
|
assert epochs > 0, "Warm-up argument must be positive!"
|
|
|
self.offset = epochs
|
|
|
|