|
|
@@ -21,6 +21,7 @@ from chainer_addons.training import optimizer_hooks
|
|
|
from cvdatasets import AnnotationType
|
|
|
from cvdatasets.utils import new_iterator
|
|
|
from cvdatasets.utils import pretty_print_dict
|
|
|
+from cvdatasets.dataset.image import Size
|
|
|
|
|
|
from functools import partial
|
|
|
from os.path import join
|
|
|
@@ -112,14 +113,8 @@ class _ModelMixin(abc.ABC):
|
|
|
|
|
|
self.model = ModelType.new(
|
|
|
model_type=self.model_info.class_key,
|
|
|
- input_size=opts.input_size,
|
|
|
+ input_size=Size(opts.input_size),
|
|
|
**self.model_kwargs,
|
|
|
- # pooling=opts.pooling,
|
|
|
- # pooling_params=dict(
|
|
|
- # init_alpha=opts.init_alpha,
|
|
|
- # output_dim=8192,
|
|
|
- # normalize=opts.normalize),
|
|
|
- # aux_logits=False
|
|
|
)
|
|
|
|
|
|
def load_model_weights(self, args):
|
|
|
@@ -183,10 +178,10 @@ class _DatasetMixin(abc.ABC):
|
|
|
def n_classes(self):
|
|
|
return self.ds_info.n_classes + self.dataset_cls.label_shift
|
|
|
|
|
|
- def new_dataset(self, opts, size, subset, augment):
|
|
|
+ def new_dataset(self, opts, size, subset):
|
|
|
"""Creates a dataset for a specific subset and certain options"""
|
|
|
if self.dataset_kwargs_factory is not None and callable(self.dataset_kwargs_factory):
|
|
|
- kwargs = self.dataset_kwargs_factory(opts, subset, augment)
|
|
|
+ kwargs = self.dataset_kwargs_factory(opts, subset)
|
|
|
else:
|
|
|
kwargs = dict()
|
|
|
|
|
|
@@ -195,25 +190,18 @@ class _DatasetMixin(abc.ABC):
|
|
|
dataset_cls=self.dataset_cls,
|
|
|
))
|
|
|
|
|
|
- # if opts.use_parts:
|
|
|
- # kwargs.update(dict(
|
|
|
- # no_glob=opts.no_global,
|
|
|
- # ))
|
|
|
|
|
|
if not getattr(opts, "only_head", False):
|
|
|
kwargs.update(dict(
|
|
|
- preprocess=self.prepare,
|
|
|
- augment=augment,
|
|
|
+ prepare=self.prepare,
|
|
|
size=size,
|
|
|
- center_crop_on_val=not getattr(opts, "no_center_crop_on_val", False),
|
|
|
+ center_crop_on_val=getattr(opts, "center_crop_on_val", False),
|
|
|
|
|
|
))
|
|
|
|
|
|
- d = self.annot.new_dataset(**kwargs)
|
|
|
- logging.info("Loaded {} images".format(len(d)))
|
|
|
- logging.info("Data augmentation is {}abled".format("en" if augment else "dis"))
|
|
|
- # logging.info("Global feature is {}used".format("not " if opts.no_global else ""))
|
|
|
- return d
|
|
|
+ ds = self.annot.new_dataset(**kwargs)
|
|
|
+ logging.info("Loaded {} images".format(len(ds)))
|
|
|
+ return ds
|
|
|
|
|
|
def init_annotations(self, opts):
|
|
|
"""Reads annotations and creates annotation instance, which holds important infos about the dataset"""
|
|
|
@@ -237,7 +225,7 @@ class _DatasetMixin(abc.ABC):
|
|
|
|
|
|
self.prepare = partial(PrepareType[opts.prepare_type](self.model),
|
|
|
swap_channels=opts.swap_channels,
|
|
|
- keep_ratio=not getattr(opts, "no_center_crop_on_val", False),
|
|
|
+ keep_ratio=getattr(opts, "center_crop_on_val", False),
|
|
|
)
|
|
|
|
|
|
logging.info(" ".join([
|
|
|
@@ -246,8 +234,8 @@ class _DatasetMixin(abc.ABC):
|
|
|
f"Image input size: {size}",
|
|
|
]))
|
|
|
|
|
|
- self.train_data = self.new_dataset(opts, size, "train", True)
|
|
|
- self.val_data = self.new_dataset(opts, size, "test", False)
|
|
|
+ self.train_data = self.new_dataset(opts, size, "train")
|
|
|
+ self.val_data = self.new_dataset(opts, size, "test")
|
|
|
|
|
|
def init_iterators(self, opts):
|
|
|
"""Creates training and validation iterators from training and validation datasets"""
|