فهرست منبع

added part size handling

Dimitri Korsch 4 سال پیش
والد
کامیت
488da11c09
1فایلهای تغییر یافته به همراه12 افزوده شده و 5 حذف شده
  1. 12 5
      cvfinetune/finetuner/base.py

+ 12 - 5
cvfinetune/finetuner/base.py

@@ -178,7 +178,7 @@ class _DatasetMixin(abc.ABC):
 	def n_classes(self):
 		return self.ds_info.n_classes + self.dataset_cls.label_shift
 
-	def new_dataset(self, opts, size, subset):
+	def new_dataset(self, opts, size, part_size, subset):
 		"""Creates a dataset for a specific subset and certain options"""
 		if self.dataset_kwargs_factory is not None and callable(self.dataset_kwargs_factory):
 			kwargs = self.dataset_kwargs_factory(opts, subset)
@@ -195,6 +195,7 @@ class _DatasetMixin(abc.ABC):
 			kwargs.update(dict(
 				prepare=self.prepare,
 				size=size,
+				part_size=part_size,
 				center_crop_on_val=getattr(opts, "center_crop_on_val", False),
 
 			))
@@ -221,7 +222,9 @@ class _DatasetMixin(abc.ABC):
 
 	def init_datasets(self, opts):
 
-		size = self.model.meta.input_size
+		size = Size(opts.input_size)
+		part_size = getattr(opts, "part_size")
+		part_size = size if part_size is None else Size(part_size)
 
 		self.prepare = partial(PrepareType[opts.prepare_type](self.model),
 			swap_channels=opts.swap_channels,
@@ -230,12 +233,16 @@ class _DatasetMixin(abc.ABC):
 
 		logging.info(" ".join([
 			f"Created {self.model.__class__.__name__} model",
-			f"with \"{opts.prepare_type}\" prepare function.",
+			f"with \"{opts.prepare_type}\" prepare function."
+		]))
+
+		logging.info(" ".join([
 			f"Image input size: {size}",
+			f"Image parts input size: {part_size}",
 		]))
 
-		self.train_data = self.new_dataset(opts, size, "train")
-		self.val_data = self.new_dataset(opts, size, "test")
+		self.train_data = self.new_dataset(opts, size, part_size, "train")
+		self.val_data = self.new_dataset(opts, size, part_size, "test")
 
 	def init_iterators(self, opts):
 		"""Creates training and validation iterators from training and validation datasets"""