|
|
@@ -16,8 +16,8 @@ class _ClassifierCreator:
|
|
|
self.kwargs = kwargs
|
|
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
|
- kwargs = dict(self.kwargs, **kwargs)
|
|
|
- return self.cls(*args, **kwargs)
|
|
|
+ self.kwargs = dict(self.kwargs, **kwargs)
|
|
|
+ return self.cls(*args, **self.kwargs)
|
|
|
|
|
|
class _ClassifierMixin(BaseMixin):
|
|
|
"""
|
|
|
@@ -39,12 +39,13 @@ class _ClassifierMixin(BaseMixin):
|
|
|
self._label_smoothing = label_smoothing
|
|
|
|
|
|
|
|
|
- def init_classifier(self):
|
|
|
+ def init_classifier(self, **kwargs):
|
|
|
self._check_attr("model")
|
|
|
self._check_attr("n_classes")
|
|
|
|
|
|
- self.clf = self._clf_creator(model=self.model,
|
|
|
- loss_func=self.loss_func)
|
|
|
+ self.clf = self._clf_creator(self.model,
|
|
|
+ loss_func=self.loss_func,
|
|
|
+ **kwargs)
|
|
|
|
|
|
kwargs = self._clf_creator.kwargs
|
|
|
logging.info(
|