|
|
@@ -178,7 +178,7 @@ class _DatasetMixin(abc.ABC):
|
|
|
def n_classes(self):
|
|
|
return self.ds_info.n_classes + self.dataset_cls.label_shift
|
|
|
|
|
|
- def new_dataset(self, opts, size, subset):
|
|
|
+ def new_dataset(self, opts, size, part_size, subset):
|
|
|
"""Creates a dataset for a specific subset and certain options"""
|
|
|
if self.dataset_kwargs_factory is not None and callable(self.dataset_kwargs_factory):
|
|
|
kwargs = self.dataset_kwargs_factory(opts, subset)
|
|
|
@@ -195,6 +195,7 @@ class _DatasetMixin(abc.ABC):
|
|
|
kwargs.update(dict(
|
|
|
prepare=self.prepare,
|
|
|
size=size,
|
|
|
+ part_size=part_size,
|
|
|
center_crop_on_val=getattr(opts, "center_crop_on_val", False),
|
|
|
|
|
|
))
|
|
|
@@ -221,7 +222,9 @@ class _DatasetMixin(abc.ABC):
|
|
|
|
|
|
def init_datasets(self, opts):
|
|
|
|
|
|
- size = self.model.meta.input_size
|
|
|
+ size = Size(opts.input_size)
|
|
|
+ part_size = getattr(opts, "part_size")
|
|
|
+ part_size = size if part_size is None else Size(part_size)
|
|
|
|
|
|
self.prepare = partial(PrepareType[opts.prepare_type](self.model),
|
|
|
swap_channels=opts.swap_channels,
|
|
|
@@ -230,12 +233,16 @@ class _DatasetMixin(abc.ABC):
|
|
|
|
|
|
logging.info(" ".join([
|
|
|
f"Created {self.model.__class__.__name__} model",
|
|
|
- f"with \"{opts.prepare_type}\" prepare function.",
|
|
|
+ f"with \"{opts.prepare_type}\" prepare function."
|
|
|
+ ]))
|
|
|
+
|
|
|
+ logging.info(" ".join([
|
|
|
f"Image input size: {size}",
|
|
|
+ f"Image parts input size: {part_size}",
|
|
|
]))
|
|
|
|
|
|
- self.train_data = self.new_dataset(opts, size, "train")
|
|
|
- self.val_data = self.new_dataset(opts, size, "test")
|
|
|
+ self.train_data = self.new_dataset(opts, size, part_size, "train")
|
|
|
+ self.val_data = self.new_dataset(opts, size, part_size, "test")
|
|
|
|
|
|
def init_iterators(self, opts):
|
|
|
"""Creates training and validation iterators from training and validation datasets"""
|