|
|
@@ -90,14 +90,13 @@ class _ModelMixin(abc.ABC):
|
|
|
if getattr(opts, "l1_loss", False):
|
|
|
return F.hinge
|
|
|
|
|
|
- elif getattr(opts, "label_smoothing", 0) >= 0:
|
|
|
- assert getattr(opts, "label_smoothing", 0) < 1, \
|
|
|
- "Label smoothing factor must be less than 1!"
|
|
|
- return partial(smoothed_cross_entropy,
|
|
|
- N=self.n_classes,
|
|
|
- eps=getattr(opts, "label_smoothing", 0))
|
|
|
- else:
|
|
|
- return F.softmax_cross_entropy
|
|
|
+ label_smoothing = getattr(opts, "label_smoothing", 0)
|
|
|
+ if label_smoothing > 0:
|
|
|
+ assert label_smoothing < 1, "Label smoothing factor must be less than 1!"
|
|
|
+
|
|
|
+ return partial(smoothed_cross_entropy, N=self.n_classes, eps=label_smoothing)
|
|
|
+
|
|
|
+ return F.softmax_cross_entropy
|
|
|
|
|
|
def init_optimizer(self, opts):
|
|
|
"""Creates an optimizer for the classifier """
|