Просмотр исходного кода

moved dataset.BaseMixin to a separate file. Refactored dataset.image module. Added TransformMixin (for augmentations) and Size class

Dimitri Korsch 4 лет назад
Родитель
Сommit
10468761e4

+ 32 - 0
cvdatasets/dataset/__init__.py

@@ -1,3 +1,4 @@
+from cvdatasets.dataset.mixins.base import BaseMixin
 from cvdatasets.dataset.mixins.chainer_mixins import IteratorMixin
 from cvdatasets.dataset.mixins.features import PreExtractedFeaturesMixin
 from cvdatasets.dataset.mixins.parts import BBCropMixin
@@ -13,6 +14,7 @@ from cvdatasets.dataset.mixins.parts import RevealedPartMixin
 from cvdatasets.dataset.mixins.parts import UniformPartMixin
 from cvdatasets.dataset.mixins.reading import AnnotationsReadMixin
 from cvdatasets.dataset.mixins.reading import ImageListReadingMixin
+from cvdatasets.dataset.mixins.transform import TransformMixin
 
 
 class ImageWrapperDataset(PartMixin, PreExtractedFeaturesMixin, AnnotationsReadMixin, IteratorMixin):
@@ -23,3 +25,33 @@ class Dataset(ImageWrapperDataset):
 	def get_example(self, i):
 		im_obj = super(Dataset, self).get_example(i)
 		return im_obj.as_tuple()
+
+__all__ = [
+	"Dataset",
+	"ImageWrapperDataset",
+
+	### mixins ###
+	"BaseMixin",
+	# reading
+	"AnnotationsReadMixin",
+	"ImageListReadingMixin",
+
+	# features
+	"PreExtractedFeaturesMixin",
+
+	# parts / bounding boxes
+	"BBCropMixin",
+	"BBoxMixin",
+	"CroppedPartMixin",
+	"MultiBoxMixin",
+	"PartCropMixin",
+	"PartMixin",
+	"PartRevealMixin",
+	"PartsInBBMixin",
+	"RandomBlackOutMixin",
+	"RevealedPartMixin",
+	"UniformPartMixin",
+
+	# transform mixin
+	"TransformMixin",
+]

+ 8 - 0
cvdatasets/dataset/image/__init__.py

@@ -0,0 +1,8 @@
+from cvdatasets.dataset.image.image_wrapper import ImageWrapper
+from cvdatasets.dataset.image.size import Size
+
+
+__all__ = [
+	"ImageWrapper",
+	"Size",
+]

+ 11 - 23
cvdatasets/dataset/image.py → cvdatasets/dataset/image/image_wrapper.py

@@ -1,14 +1,12 @@
-from PIL import Image
-from imageio import imread
-from os.path import isfile
-
 import copy
 import numpy as np
 
-from .part import Parts
-from .part import UniformParts
-from .part import SurrogateType
+from PIL import Image
+
 from cvdatasets import utils
+from cvdatasets.dataset.part import Parts
+from cvdatasets.dataset.part import SurrogateType
+from cvdatasets.dataset.part import UniformParts
 
 def should_have_parts(func):
 	def inner(self, *args, **kwargs):
@@ -17,21 +15,6 @@ def should_have_parts(func):
 	return inner
 
 class ImageWrapper(object):
-	@staticmethod
-	def read_image(im_path, mode="RGB", n_retries=5):
-		_read = lambda: Image.open(im_path, mode="r")
-		if n_retries <= 0:
-			assert isfile(im_path), "Image \"{}\" does not exist!".format(im_path)
-			return _read()
-		else:
-			error = None
-			for i in range(n_retries):
-				try:
-					return _read()
-				except Exception as e:
-					error = e
-
-			raise RuntimeError("Reading image \"{}\" failed after {} n_retries! ({})".format(im_path, n_retries, error))
 
 
 	def __init__(self, im_path, label,
@@ -64,16 +47,21 @@ class ImageWrapper(object):
 	@property
 	def im_array(self):
 		if self._im_array is None:
+
 			if isinstance(self._im, Image.Image):
 				_im = self._im.convert(self.mode)
 				self._im_array = utils.asarray(_im)
+
 			elif isinstance(self._im, np.ndarray):
 				if self.mode == "RGB" and self._im.ndim == 2:
 					self._im_array = np.stack((self._im,) * 3, axis=-1)
+
 				elif self._im.ndim in (3, 4):
 					self._im_array = self._im
+
 				else:
 					raise ValueError()
+
 			else:
 				raise ValueError()
 		return self._im_array
@@ -87,7 +75,7 @@ class ImageWrapper(object):
 	@im.setter
 	def im(self, value):
 		if isinstance(value, str):
-			self._im = ImageWrapper.read_image(value, mode=self.mode)
+			self._im = utils.read_image(value, n_retries=5)
 			self._im_path = value
 		else:
 			self._im = value

+ 50 - 0
cvdatasets/dataset/image/size.py

@@ -0,0 +1,50 @@
+import numpy as np
+
+from collections.abc import Iterable
+
+class Size(object):
+	dtype=np.int32
+
+	def __init__(self, value):
+		self._size = np.zeros(2, dtype=self.dtype)
+		if isinstance(value, int):
+			self._size[:] = value
+
+		elif isinstance(value, Size):
+			self._size[:] = value._size
+
+		elif isinstance(value, Iterable):
+			assert len(value) <= 2, \
+				"only iterables of maximum size 2 are supported, but was {}!".format(len(value))
+			self._size[:] = np.round(value)
+
+
+		else:
+			raise ValueError("Unsupported data type: {}!".format(type(value)))
+
+	def __str__(self):
+		return "<Size {}x{}>".format(*self._size)
+
+	def __repr__(self):
+		return str(self)
+
+	def __add__(self, other):
+		return self.__class__(self._size + other)
+
+	def __sub__(self, other):
+		return self.__class__(self._size - other)
+
+	def __mul__(self, other):
+		return self.__class__(self._size * other)
+
+	def __truediv__(self, other):
+		return self.__class__(self._size / other)
+
+	def __floordiv__(self, other):
+		return self.__class__(self._size // other)
+
+	def __iter__(self):
+		return iter(self._size)
+
+	def __len__(self):
+		return len(self._size)

+ 0 - 33
cvdatasets/dataset/mixins/__init__.py

@@ -1,33 +0,0 @@
-from abc import ABC, abstractmethod
-
-import numpy as np
-import six
-
-from matplotlib.patches import Rectangle
-
-class BaseMixin(ABC):
-
-	@abstractmethod
-	def get_example(self, i):
-		s = super(BaseMixin, self)
-		if hasattr(s, "get_example"):
-			return s.get_example(i)
-
-	def plot_bounding_box(self, i, ax, fill=False, linestyle="--", **kwargs):
-		x, y, w, h = self.bounding_box(i)
-		ax.add_patch(Rectangle(
-			(x,y), w, h,
-			fill=False,
-			linestyle="-.",
-			**kwargs
-		))
-
-	def __getitem__(self, index):
-		if isinstance(index, slice):
-			current, stop, step = index.indices(len(self))
-			return [self.get_example(i) for i in
-					six.moves.range(current, stop, step)]
-		elif isinstance(index, list) or isinstance(index, np.ndarray):
-			return [self.get_example(i) for i in index]
-		else:
-			return self.get_example(index)

+ 32 - 0
cvdatasets/dataset/mixins/base.py

@@ -0,0 +1,32 @@
+import abc
+import numpy as np
+import six
+
+from matplotlib.patches import Rectangle
+
+class BaseMixin(abc.ABC):
+
+	@abc.abstractmethod
+	def get_example(self, i):
+		s = super(BaseMixin, self)
+		if hasattr(s, "get_example"):
+			return s.get_example(i)
+
+	def plot_bounding_box(self, i, ax, fill=False, linestyle="--", **kwargs):
+		x, y, w, h = self.bounding_box(i)
+		ax.add_patch(Rectangle(
+			(x,y), w, h,
+			fill=False,
+			linestyle="-.",
+			**kwargs
+		))
+
+	def __getitem__(self, index):
+		if isinstance(index, slice):
+			current, stop, step = index.indices(len(self))
+			return [self.get_example(i) for i in
+					six.moves.range(current, stop, step)]
+		elif isinstance(index, list) or isinstance(index, np.ndarray):
+			return [self.get_example(i) for i in index]
+		else:
+			return self.get_example(index)

+ 2 - 2
cvdatasets/dataset/mixins/chainer_mixins/base.py

@@ -5,9 +5,9 @@ except ImportError:
 else:
 	has_chainer = True
 
-from abc import ABC
+import abc
 
-class BaseChainerMixin(ABC):
+class BaseChainerMixin(abc.ABC):
 
 	def chainer_check(self):
 		global has_chainer

+ 1 - 1
cvdatasets/dataset/mixins/features.py

@@ -2,7 +2,7 @@ import numpy as np
 
 from os.path import isfile
 
-from . import BaseMixin
+from cvdatasets.dataset.mixins.base import BaseMixin
 
 
 class PreExtractedFeaturesMixin(BaseMixin):

+ 1 - 1
cvdatasets/dataset/mixins/parts.py

@@ -1,6 +1,6 @@
 import numpy as np
 
-from cvdatasets.dataset.mixins import BaseMixin
+from cvdatasets.dataset.mixins.base import BaseMixin
 
 class BBoxMixin(BaseMixin):
 

+ 0 - 0
cvdatasets/dataset/mixins/postprocess.py


+ 2 - 2
cvdatasets/dataset/mixins/reading.py

@@ -2,8 +2,8 @@ import numpy as np
 
 from os.path import join
 
-from . import BaseMixin
-from ..image import ImageWrapper
+from cvdatasets.dataset.mixins.base import BaseMixin
+from cvdatasets.dataset.image import ImageWrapper
 
 class AnnotationsReadMixin(BaseMixin):
 

+ 27 - 0
cvdatasets/dataset/mixins/transform.py

@@ -0,0 +1,27 @@
+import abc
+
+from cvdatasets.dataset.mixins.base import BaseMixin
+from cvdatasets.dataset.image.size import Size
+
+class TransformMixin(BaseMixin):
+
+	def __init__(self, size, *args, **kwargs):
+		super(TransformMixin, self).__init__(*args, **kwargs)
+
+		self.size = size
+
+	@abc.abstractmethod
+	def transform(self, im_obj):
+		pass
+
+	def get_example(self, i):
+		im_obj = super(TransformMixin, self).get_example(i)
+		return self.transform(im_obj)
+
+	@property
+	def size(self):
+		return self._size
+
+	@size.setter
+	def size(self, value):
+		self._size = Size(value)

+ 5 - 4
cvdatasets/utils/__init__.py

@@ -63,7 +63,8 @@ class _MetaInfo(object):
 		self.structure = []
 
 
-from .dataset import new_iterator
-from .image import asarray
-from .image import dimensions
-from .image import rescale
+from cvdatasets.utils.dataset import new_iterator
+from cvdatasets.utils.image import asarray
+from cvdatasets.utils.image import dimensions
+from cvdatasets.utils.image import rescale
+from cvdatasets.utils.image import read_image

+ 22 - 3
cvdatasets/utils/image.py

@@ -1,5 +1,24 @@
 import numpy as np
-from PIL.Image import Image as PIL_Image
+
+from os.path import isfile
+from PIL import Image
+
+def read_image(im_path, n_retries=5):
+	_read = lambda: Image.open(im_path, mode="r")
+	if n_retries <= 0:
+		assert isfile(im_path), "Image \"{}\" does not exist!".format(im_path)
+		return _read()
+
+	else:
+		error = None
+		for i in range(n_retries):
+			try:
+				return _read()
+			except Exception as e:
+				error = e
+
+		raise RuntimeError("Reading image \"{}\" failed after {} n_retries! ({})".format(im_path, n_retries, error))
+
 
 def rescale(im, coords, rescale_size, center_cropped=True, no_offset=False):
 	h, w, c = dimensions(im)
@@ -22,7 +41,7 @@ def dimensions(im):
 			import pdb; pdb.set_trace()
 		assert im.ndim == 3, "Only RGB images are currently supported!"
 		return im.shape
-	elif isinstance(im, PIL_Image):
+	elif isinstance(im, Image.Image):
 		w, h = im.size
 		c = len(im.getbands())
 		# assert c == 3, "Only RGB images are currently supported!"
@@ -33,7 +52,7 @@ def dimensions(im):
 def asarray(im, dtype=np.uint8):
 	if isinstance(im, np.ndarray):
 		return im.astype(dtype)
-	elif isinstance(im, PIL_Image):
+	elif isinstance(im, Image.Image):
 		return np.asarray(im, dtype=dtype)
 	else:
 		raise ValueError("Unknown image instance ({})!".format(type(im)))