dataset.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import numpy as np
  2. from chainer import iterators
  3. from chainer.dataset import DatasetMixin
  4. from chainer.datasets import TransformDataset
  5. from chainercv import transforms as tr
  6. from imageio import imread
  7. from pathlib import Path
  8. from typing import Callable
  9. class Dataset(DatasetMixin):
  10. def __init__(self, root: str, split_id: int, is_train: bool = True):
  11. super().__init__()
  12. root = Path(root)
  13. self._root = root
  14. self.class_names = np.loadtxt(root / "class_names.txt", dtype="U255")
  15. # read annoations from the root folder
  16. _images = np.loadtxt(root / "images.txt", dtype=[("id", np.int32), ("fname", "U255")])
  17. _labels = np.loadtxt(root / "labels.txt", dtype=np.int32)
  18. _split_ids = np.loadtxt(root / "tr_ID.txt", dtype=np.int32)
  19. if is_train:
  20. # select all other splits
  21. split_mask = _split_ids != split_id
  22. else:
  23. # select only images for a given split ID
  24. split_mask = _split_ids == split_id
  25. self.images = _images["fname"][split_mask]
  26. self.labels = _labels[split_mask]
  27. def __len__(self):
  28. return len(self.images)
  29. def get_example(self, i):
  30. """ Here the images are loaded """
  31. im_path = self._root / "images" / self.images[i]
  32. label = self.labels[i]
  33. return imread(im_path, pilmode="RGB"), label
  34. class DataTransformer(object):
  35. def __init__(self, prepare: Callable, size: int):
  36. super().__init__()
  37. self.prepare = prepare
  38. self.size = size
  39. def __call__(self, data):
  40. """
  41. Before passing the data to the CNN, it needs
  42. to be transformed:
  43. - resize with the "prepare" function of the model
  44. - center crop to the size of the CNN input
  45. - rescale the pixel range from [0..1] tp [-1..1]
  46. (the CNN was trained with pixel range)
  47. """
  48. image, label = data
  49. new_image = self.prepare(image, self.size)
  50. new_image = tr.center_crop(new_image, size=(self.size, self.size))
  51. # transform the pixel range from 0..1 to -1..1
  52. new_image = new_image * 2 - 1
  53. return new_image, label
  54. def load_datasets(root: Path, model_input_size: int, prepare: Callable, split_id: int):
  55. """
  56. load the two dataset splits (training and evaluation)
  57. and return these as Dataset instances
  58. """
  59. train_ds = Dataset(root, split_id=split_id, is_train=True)
  60. val_ds = Dataset(root, split_id=split_id, is_train=False)
  61. transformer = DataTransformer(prepare, model_input_size)
  62. train_ds = TransformDataset(train_ds, transformer)
  63. val_ds = TransformDataset(val_ds, transformer)
  64. return train_ds, val_ds
  65. def new_iterator(dataset, n_jobs: int = -1, **kwargs):
  66. """
  67. Depending on the n_jobs argument create either a single-thread
  68. serial iterator (n_jobs < 1) or a multi-thread iterator.
  69. Iterators are responsible to gather the images from the dataset
  70. and group it to a batch.
  71. """
  72. it_cls = iterators.SerialIterator
  73. if n_jobs >= 1:
  74. kwargs["n_threads"] = n_jobs
  75. it_cls = iterators.MultithreadIterator
  76. return it_cls(dataset, **kwargs)