transform.py 1.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. import abc
  2. try:
  3. import chainer
  4. def is_train() -> bool:
  5. return chainer.config.train
  6. except ImportError as e:
  7. """ other frameworks (e.g., PyTorch) do not have this global flag """
  8. def is_train() -> bool:
  9. return False
  10. from cvdatasets.dataset.image.size import Size
  11. from cvdatasets.dataset.mixins.base import BaseMixin
  12. class TransformMixin(BaseMixin):
  13. def __init__(self, size, part_size=None, *args, **kwargs):
  14. super(TransformMixin, self).__init__(*args, **kwargs)
  15. self.size = size
  16. self.part_size = size if part_size is None else part_size
  17. @abc.abstractmethod
  18. def transform(self, im_obj):
  19. pass
  20. def get_example(self, i):
  21. im_obj = super(TransformMixin, self).get_example(i)
  22. return self.transform(im_obj)
  23. @property
  24. def size(self):
  25. if is_train():
  26. return self._size // 0.875
  27. else:
  28. return self._size
  29. @size.setter
  30. def size(self, value):
  31. self._size = Size(value)
  32. @property
  33. def part_size(self):
  34. if is_train():
  35. return self._part_size // 0.875
  36. else:
  37. return self._part_size
  38. @part_size.setter
  39. def part_size(self, value):
  40. self._part_size = Size(value)