import numpy as np from chainer import iterators from chainer.dataset import DatasetMixin from chainer.datasets import TransformDataset from chainercv import transforms as tr from imageio import imread from pathlib import Path from typing import Callable class Dataset(DatasetMixin): def __init__(self, root: str, split_id: int, is_train: bool = True): super().__init__() root = Path(root) self._root = root self.class_names = np.loadtxt(root / "class_names.txt", dtype="U255") # read annoations from the root folder _images = np.loadtxt(root / "images.txt", dtype=[("id", np.int32), ("fname", "U255")]) _labels = np.loadtxt(root / "labels.txt", dtype=np.int32) _split_ids = np.loadtxt(root / "tr_ID.txt", dtype=np.int32) if is_train: # select all other splits split_mask = _split_ids != split_id else: # select only images for a given split ID split_mask = _split_ids == split_id self.images = _images["fname"][split_mask] self.labels = _labels[split_mask] def __len__(self): return len(self.images) def get_example(self, i): """ Here the images are loaded """ im_path = self._root / "images" / self.images[i] label = self.labels[i] return imread(im_path, pilmode="RGB"), label class DataTransformer(object): def __init__(self, prepare: Callable, size: int): super().__init__() self.prepare = prepare self.size = size def __call__(self, data): """ Before passing the data to the CNN, it needs to be transformed: - resize with the "prepare" function of the model - center crop to the size of the CNN input - rescale the pixel range from [0..1] tp [-1..1] (the CNN was trained with pixel range) """ image, label = data new_image = self.prepare(image, self.size) new_image = tr.center_crop(new_image, size=(self.size, self.size)) # transform the pixel range from 0..1 to -1..1 new_image = new_image * 2 - 1 return new_image, label def load_datasets(root: Path, model_input_size: int, prepare: Callable, split_id: int): """ load the two dataset splits (training and evaluation) and return these as Dataset instances """ train_ds = Dataset(root, split_id=split_id, is_train=True) val_ds = Dataset(root, split_id=split_id, is_train=False) transformer = DataTransformer(prepare, model_input_size) train_ds = TransformDataset(train_ds, transformer) val_ds = TransformDataset(val_ds, transformer) return train_ds, val_ds def new_iterator(dataset, n_jobs: int = -1, **kwargs): """ Depending on the n_jobs argument create either a single-thread serial iterator (n_jobs < 1) or a multi-thread iterator. Iterators are responsible to gather the images from the dataset and group it to a batch. """ it_cls = iterators.SerialIterator if n_jobs >= 1: kwargs["n_threads"] = n_jobs it_cls = iterators.MultithreadIterator return it_cls(dataset, **kwargs)