dataset.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. import logging
  2. import numpy as np
  3. import warnings
  4. def _format_kwargs(kwargs):
  5. return " ".join([f"{key}={value}" for key, value in kwargs.items()])
  6. def _uuid_check(uuids):
  7. """ Checks whether the ids are unique """
  8. assert len(np.unique(uuids)) == len(uuids), \
  9. "UUIDs are not unique!"
  10. def new_iterator(data, n_jobs, batch_size, repeat=True, shuffle=True, n_prefetch=2):
  11. from chainer.iterators import SerialIterator, MultiprocessIterator
  12. if n_jobs > 0:
  13. it_cls = MultiprocessIterator
  14. try:
  15. import cv2
  16. cv2.setNumThreads(0)
  17. except ImportError:
  18. pass
  19. input_shape = getattr(data, "size", (512, 512))
  20. if isinstance(input_shape, int):
  21. input_shape = (input_shape, input_shape)
  22. elif not isinstance(input_shape, tuple):
  23. try:
  24. input_shape = tuple(input_shape)
  25. except TypeError as e:
  26. warnings.warn(f"Could not parse input_shape: \"{input_shape}\". Falling back to a default value of (512, 512)")
  27. input_shape = (512, 512)
  28. shared_mem_shape = (batch_size, 3) + input_shape
  29. shared_mem = np.zeros(shared_mem_shape, dtype=np.float32).nbytes
  30. logging.info(f"Using {shared_mem / 1024**2: .3f} MiB of shared memory")
  31. it_kwargs = dict(
  32. n_processes=n_jobs,
  33. n_prefetch=n_prefetch,
  34. batch_size=batch_size,
  35. repeat=repeat, shuffle=shuffle,
  36. shared_mem=shared_mem)
  37. else:
  38. it_cls = SerialIterator
  39. it_kwargs = dict(
  40. batch_size=batch_size,
  41. repeat=repeat, shuffle=shuffle)
  42. it = it_cls(data, **it_kwargs)
  43. n_batches = int(np.ceil(len(data) / it.batch_size))
  44. logging.info(f"Using {it_cls.__name__} with {n_batches:,d} batches per epoch and kwargs: {_format_kwargs(it_kwargs)}")
  45. return it, n_batches