1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798 |
- import numpy as np
- from chainer import iterators
- from chainer.dataset import DatasetMixin
- from chainer.datasets import TransformDataset
- from chainercv import transforms as tr
- from imageio import imread
- from pathlib import Path
- from typing import Callable
- class Dataset(DatasetMixin):
- def __init__(self, root: str, split_id: int, is_train: bool = True):
- super().__init__()
- root = Path(root)
- self._root = root
- self.class_names = np.loadtxt(root / "class_names.txt", dtype="U255")
- # read annoations from the root folder
- _images = np.loadtxt(root / "images.txt", dtype=[("id", np.int32), ("fname", "U255")])
- _labels = np.loadtxt(root / "labels.txt", dtype=np.int32)
- _split_ids = np.loadtxt(root / "tr_ID.txt", dtype=np.int32)
- if is_train:
- # select all other splits
- split_mask = _split_ids != split_id
- else:
- # select only images for a given split ID
- split_mask = _split_ids == split_id
- self.images = _images["fname"][split_mask]
- self.labels = _labels[split_mask]
- def __len__(self):
- return len(self.images)
- def get_example(self, i):
- """ Here the images are loaded """
- im_path = self._root / "images" / self.images[i]
- label = self.labels[i]
- return imread(im_path, pilmode="RGB"), label
- class DataTransformer(object):
- def __init__(self, prepare: Callable, size: int):
- super().__init__()
- self.prepare = prepare
- self.size = size
- def __call__(self, data):
- """
- Before passing the data to the CNN, it needs
- to be transformed:
- - resize with the "prepare" function of the model
- - center crop to the size of the CNN input
- - rescale the pixel range from [0..1] tp [-1..1]
- (the CNN was trained with pixel range)
- """
- image, label = data
- new_image = self.prepare(image, self.size)
- new_image = tr.center_crop(new_image, size=(self.size, self.size))
- # transform the pixel range from 0..1 to -1..1
- new_image = new_image * 2 - 1
- return new_image, label
- def load_datasets(root: Path, model_input_size: int, prepare: Callable, split_id: int):
- """
- load the two dataset splits (training and evaluation)
- and return these as Dataset instances
- """
- train_ds = Dataset(root, split_id=split_id, is_train=True)
- val_ds = Dataset(root, split_id=split_id, is_train=False)
- transformer = DataTransformer(prepare, model_input_size)
- train_ds = TransformDataset(train_ds, transformer)
- val_ds = TransformDataset(val_ds, transformer)
- return train_ds, val_ds
- def new_iterator(dataset, n_jobs: int = -1, **kwargs):
- """
- Depending on the n_jobs argument create either a single-thread
- serial iterator (n_jobs < 1) or a multi-thread iterator.
- Iterators are responsible to gather the images from the dataset
- and group it to a batch.
- """
- it_cls = iterators.SerialIterator
- if n_jobs >= 1:
- kwargs["n_threads"] = n_jobs
- it_cls = iterators.MultithreadIterator
- return it_cls(dataset, **kwargs)
|