|
@@ -0,0 +1,98 @@
|
|
|
+import numpy as np
|
|
|
+
|
|
|
+from chainer import iterators
|
|
|
+from chainer.dataset import DatasetMixin
|
|
|
+from chainer.datasets import TransformDataset
|
|
|
+from chainercv import transforms as tr
|
|
|
+from imageio import imread
|
|
|
+from pathlib import Path
|
|
|
+from typing import Callable
|
|
|
+
|
|
|
+class Dataset(DatasetMixin):
|
|
|
+
|
|
|
+ def __init__(self, root: str, split_id: int, is_train: bool = True):
|
|
|
+ super().__init__()
|
|
|
+
|
|
|
+ root = Path(root)
|
|
|
+ self._root = root
|
|
|
+ self.class_names = np.loadtxt(root / "class_names.txt", dtype="U255")
|
|
|
+
|
|
|
+ # read annoations from the root folder
|
|
|
+ _images = np.loadtxt(root / "images.txt", dtype=[("id", np.int32), ("fname", "U255")])
|
|
|
+ _labels = np.loadtxt(root / "labels.txt", dtype=np.int32)
|
|
|
+ _split_ids = np.loadtxt(root / "tr_ID.txt", dtype=np.int32)
|
|
|
+
|
|
|
+ if is_train:
|
|
|
+ # select all other splits
|
|
|
+ split_mask = _split_ids != split_id
|
|
|
+
|
|
|
+ else:
|
|
|
+ # select only images for a given split ID
|
|
|
+ split_mask = _split_ids == split_id
|
|
|
+
|
|
|
+ self.images = _images["fname"][split_mask]
|
|
|
+ self.labels = _labels[split_mask]
|
|
|
+
|
|
|
+
|
|
|
+ def __len__(self):
|
|
|
+ return len(self.images)
|
|
|
+
|
|
|
+ def get_example(self, i):
|
|
|
+ """ Here the images are loaded """
|
|
|
+ im_path = self._root / "images" / self.images[i]
|
|
|
+ label = self.labels[i]
|
|
|
+ return imread(im_path, pilmode="RGB"), label
|
|
|
+
|
|
|
+class DataTransformer(object):
|
|
|
+
|
|
|
+ def __init__(self, prepare: Callable, size: int):
|
|
|
+ super().__init__()
|
|
|
+ self.prepare = prepare
|
|
|
+ self.size = size
|
|
|
+
|
|
|
+ def __call__(self, data):
|
|
|
+ """
|
|
|
+ Before passing the data to the CNN, it needs
|
|
|
+ to be transformed:
|
|
|
+ - resize with the "prepare" function of the model
|
|
|
+ - center crop to the size of the CNN input
|
|
|
+ - rescale the pixel range from [0..1] tp [-1..1]
|
|
|
+ (the CNN was trained with pixel range)
|
|
|
+ """
|
|
|
+ image, label = data
|
|
|
+ new_image = self.prepare(image, self.size)
|
|
|
+
|
|
|
+ new_image = tr.center_crop(new_image, size=(self.size, self.size))
|
|
|
+
|
|
|
+ # transform the pixel range from 0..1 to -1..1
|
|
|
+ new_image = new_image * 2 - 1
|
|
|
+ return new_image, label
|
|
|
+
|
|
|
+def load_datasets(root: Path, model_input_size: int, prepare: Callable, split_id: int):
|
|
|
+ """
|
|
|
+ load the two dataset splits (training and evaluation)
|
|
|
+ and return these as Dataset instances
|
|
|
+ """
|
|
|
+ train_ds = Dataset(root, split_id=split_id, is_train=True)
|
|
|
+ val_ds = Dataset(root, split_id=split_id, is_train=False)
|
|
|
+
|
|
|
+ transformer = DataTransformer(prepare, model_input_size)
|
|
|
+ train_ds = TransformDataset(train_ds, transformer)
|
|
|
+ val_ds = TransformDataset(val_ds, transformer)
|
|
|
+
|
|
|
+ return train_ds, val_ds
|
|
|
+
|
|
|
+def new_iterator(dataset, n_jobs: int = -1, **kwargs):
|
|
|
+ """
|
|
|
+ Depending on the n_jobs argument create either a single-thread
|
|
|
+ serial iterator (n_jobs < 1) or a multi-thread iterator.
|
|
|
+ Iterators are responsible to gather the images from the dataset
|
|
|
+ and group it to a batch.
|
|
|
+ """
|
|
|
+ it_cls = iterators.SerialIterator
|
|
|
+
|
|
|
+ if n_jobs >= 1:
|
|
|
+ kwargs["n_threads"] = n_jobs
|
|
|
+ it_cls = iterators.MultithreadIterator
|
|
|
+
|
|
|
+ return it_cls(dataset, **kwargs)
|