|
@@ -6,7 +6,6 @@ from chainer import functions as F
|
|
|
from chainer.optimizer_hooks import Lasso
|
|
from chainer.optimizer_hooks import Lasso
|
|
|
from chainer.optimizer_hooks import WeightDecay
|
|
from chainer.optimizer_hooks import WeightDecay
|
|
|
from chainer_addons.functions import smoothed_cross_entropy
|
|
from chainer_addons.functions import smoothed_cross_entropy
|
|
|
-from chainer_addons.models import ModelType
|
|
|
|
|
from chainer_addons.models import PrepareType
|
|
from chainer_addons.models import PrepareType
|
|
|
from chainer_addons.training import optimizer
|
|
from chainer_addons.training import optimizer
|
|
|
from chainer_addons.training import optimizer_hooks
|
|
from chainer_addons.training import optimizer_hooks
|
|
@@ -16,7 +15,6 @@ from cvdatasets.utils import pretty_print_dict
|
|
|
from cvmodelz.models import ModelFactory
|
|
from cvmodelz.models import ModelFactory
|
|
|
from functools import partial
|
|
from functools import partial
|
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
|
-from typing import Callable
|
|
|
|
|
from typing import Tuple
|
|
from typing import Tuple
|
|
|
|
|
|
|
|
def check_param_for_decay(param):
|
|
def check_param_for_decay(param):
|
|
@@ -66,10 +64,10 @@ class _ModelMixin(abc.ABC):
|
|
|
keep_ratio=getattr(opts, "center_crop_on_val", False),
|
|
keep_ratio=getattr(opts, "center_crop_on_val", False),
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- logging.info(" ".join([
|
|
|
|
|
- f"Created {self.model.__class__.__name__} model",
|
|
|
|
|
- f"with \"{opts.prepare_type}\" prepare function."
|
|
|
|
|
- ]))
|
|
|
|
|
|
|
+ logging.info(
|
|
|
|
|
+ f"Created {self.model.__class__.__name__} model "
|
|
|
|
|
+ f" with \"{opts.prepare_type}\" prepare function."
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_classifier(self, opts):
|
|
def init_classifier(self, opts):
|
|
@@ -81,10 +79,10 @@ class _ModelMixin(abc.ABC):
|
|
|
loss_func=self._loss_func(opts),
|
|
loss_func=self._loss_func(opts),
|
|
|
**kwargs)
|
|
**kwargs)
|
|
|
|
|
|
|
|
- logging.info(" ".join([
|
|
|
|
|
- f"Wrapped the model around {clf_class.__name__}",
|
|
|
|
|
- f"with kwargs: {pretty_print_dict(kwargs)}",
|
|
|
|
|
- ]))
|
|
|
|
|
|
|
+ logging.info(
|
|
|
|
|
+ f"Wrapped the model around {clf_class.__name__}"
|
|
|
|
|
+ f" with kwargs: {pretty_print_dict(kwargs)}"
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
def _loss_func(self, opts):
|
|
def _loss_func(self, opts):
|
|
|
if getattr(opts, "l1_loss", False):
|
|
if getattr(opts, "l1_loss", False):
|
|
@@ -114,6 +112,11 @@ class _ModelMixin(abc.ABC):
|
|
|
decay=0, gradient_clipping=False, **opt_kwargs
|
|
decay=0, gradient_clipping=False, **opt_kwargs
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
+ logging.info(
|
|
|
|
|
+ f"Initialized {self.opt.__class__.__name__} optimizer"
|
|
|
|
|
+ f" with initial LR {opts.learning_rate} and kwargs: {pretty_print_dict(opt_kwargs)}"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
if opts.decay > 0:
|
|
if opts.decay > 0:
|
|
|
reg_kwargs = {}
|
|
reg_kwargs = {}
|
|
|
if opts.l1_loss:
|
|
if opts.l1_loss:
|