Quellcode durchsuchen

refactored dataset mixins: added some and sorted them in different modules

Dimitri Korsch vor 6 Jahren
Ursprung
Commit
297e5c6e92

+ 6 - 105
nabirds/dataset/__init__.py

@@ -1,106 +1,7 @@
-import numpy as np
+from .reading import AnnotationsReadMixin, ImageListReadingMixin
+from .parts import PartMixin, RevealedPartMixin, CroppedPartMixin
 
-from abc import ABC, abstractmethod
-from imageio import imread
-
-from . import utils
-
-class BaseMixin(ABC):
-
-	@abstractmethod
-	def get_example(self, i):
-		pass
-
-	def __getitem__(self, i):
-		return self.get_example(i)
-
-
-class ImageReadMixin(BaseMixin):
-
-	def __init__(self, uuids, annotations, mode="RGB"):
-		super(BaseMixin, self).__init__()
-		self.uuids = uuids
-		self._annot = annotations
-		self.mode = mode
-
-	def __len__(self):
-		return len(self.uuids)
-
-	def _get(self, method, i):
-		return getattr(self._annot, method)(self.uuids[i])
-
-	def get_example(self, i):
-		res = super(ImageReadMixin, self).get_example(i)
-		# if the super class returns something, then the class inheritance is wrong
-		assert res is None, "ImageReadMixin should be the last class in the hierarchy!"
-
-		methods = ["image", "parts", "label"]
-		im_path, parts, label = [self._get(m, i) for m in methods]
-
-		im = imread(im_path, pilmode=self.mode)
-
-		return im, parts, label
-
-class BBCropMixin(BaseMixin):
-
-	def __init__(self, crop_to_bb=False, crop_uniform=False, *args, **kwargs):
-		super(BBCropMixin, self).__init__(*args, **kwargs)
-		self.crop_to_bb = crop_to_bb
-		self.crop_uniform = crop_uniform
-
-	def bounding_box(self, i):
-		bbox = self._get("bounding_box", i)
-		x,y,w,h = [bbox[attr] for attr in "xywh"]
-		if self.crop_uniform:
-			x0 = x + w//2
-			y0 = y + h//2
-
-			crop_size = max(w//2, h//2)
-
-			x,y = max(x0 - crop_size, 0), max(y0 - crop_size, 0)
-			w = h = crop_size * 2
-		return x,y,w,h
-
-	def get_example(self, i):
-		im, parts, label = super(BBCropMixin, self).get_example(i)
-		if self.crop_to_bb:
-			x,y,w,h = self.bounding_box(i)
-			im = im[y:y+h, x:x+w]
-			parts[:, 1] -= x
-			parts[:, 2] -= y
-		return im, parts, label
-
-class UniformPartMixin(BaseMixin):
-
-	def __init__(self, uniform_parts=False, ratio=utils.DEFAULT_RATIO, *args, **kwargs):
-		super(UniformPartMixin, self).__init__(*args, **kwargs)
-		self.uniform_parts = uniform_parts
-		self.ratio = ratio
-
-	def get_example(self, i):
-		im, parts, label = super(UniformPartMixin, self).get_example(i)
-		if self.uniform_parts:
-			parts = utils.uniform_parts(im, ratio=self.ratio)
-		return im, parts, label
-
-class RandomBlackOutMixin(BaseMixin):
-
-	def __init__(self, seed=None, rnd_select=False, n_parts=None, *args, **kwargs):
-		super(RandomBlackOutMixin, self).__init__(*args, **kwargs)
-		self.rnd = np.random.RandomState(seed)
-		self.rnd_select = rnd_select
-		self.n_parts = n_parts
-
-	def get_example(self, i):
-		im, parts, lab = super(RandomBlackOutMixin, self).get_example(i)
-		if self.rnd_select:
-			idxs, xy = utils.visible_part_locs(parts)
-			rnd_idxs = utils.random_idxs(idxs, rnd=self.rnd, n_parts=self.n_parts)
-
-			parts[:, -1] = 0
-			parts[rnd_idxs, -1] = 1
-
-		return im, parts, lab
-
-class Dataset(RandomBlackOutMixin, UniformPartMixin, BBCropMixin, ImageReadMixin):
-	pass
+class Dataset(PartMixin, AnnotationsReadMixin):
+	"""
+		TODO!
+	"""

+ 12 - 0
nabirds/dataset/base.py

@@ -0,0 +1,12 @@
+from abc import ABC, abstractmethod
+
+class BaseMixin(ABC):
+
+	@abstractmethod
+	def get_example(self, i):
+		s = super(BaseMixin, self)
+		if hasattr(s, "get_example"):
+			return s.get_example(i)
+
+	def __getitem__(self, i):
+		return self.get_example(i)

+ 134 - 0
nabirds/dataset/parts.py

@@ -0,0 +1,134 @@
+import numpy as np
+
+from .base import BaseMixin
+from . import utils
+
+
+class BasePartMixin(BaseMixin):
+
+	def get_example(self, i):
+		res = super(BasePartMixin, self).get_example(i)
+		if len(res) == 2:
+			# result has only image and label
+			im, lab = res
+			parts = None
+		else:
+			# result has already parts
+			im, parts, lab = res
+
+		return im, parts, lab
+
+class BBCropMixin(BasePartMixin):
+
+	def __init__(self, crop_to_bb=False, crop_uniform=False, *args, **kwargs):
+		super(BBCropMixin, self).__init__(*args, **kwargs)
+		self.crop_to_bb = crop_to_bb
+		self.crop_uniform = crop_uniform
+
+	def bounding_box(self, i):
+		bbox = self._get("bounding_box", i)
+		x,y,w,h = [bbox[attr] for attr in "xywh"]
+		if self.crop_uniform:
+			x0 = x + w//2
+			y0 = y + h//2
+
+			crop_size = max(w//2, h//2)
+
+			x,y = max(x0 - crop_size, 0), max(y0 - crop_size, 0)
+			w = h = crop_size * 2
+		return x,y,w,h
+
+	def get_example(self, i):
+		im, parts, label = super(BBCropMixin, self).get_example(i)
+		if self.crop_to_bb:
+			x,y,w,h = self.bounding_box(i)
+			im = im[y:y+h, x:x+w]
+			if parts is not None:
+				parts[:, 1] -= x
+				parts[:, 2] -= y
+		return im, parts, label
+
+class PartCropMixin(BasePartMixin):
+	def __init__(self, return_part_crops=False, *args, **kwargs):
+		super(PartCropMixin, self).__init__(*args, **kwargs)
+		self.return_part_crops = return_part_crops
+
+	def get_example(self, i):
+		im, parts, label = super(PartCropMixin, self).get_example(i)
+		assert hasattr(self, "ratio"), "\"ratio\" attribute is missing!"
+		if not self.return_part_crops or parts is None or not hasattr(self, "ratio"):
+			return im, label
+
+		crops = utils.visible_crops(im, parts)
+		idxs, _ = utils.visible_part_locs(parts)
+
+		return crops[idxs], label
+
+
+class PartRevealMixin(BasePartMixin):
+	def __init__(self, reveal_visible=False, *args, **kwargs):
+		super(PartRevealMixin, self).__init__(*args, **kwargs)
+		self.reveal_visible = reveal_visible
+
+	def get_example(self, i):
+		im, parts, label = super(PartRevealMixin, self).get_example(i)
+		assert hasattr(self, "ratio"), "\"ratio\" attribute is missing!"
+		if not self.reveal_visible or parts is None or not hasattr(self, "ratio"):
+			return im, label
+
+		_, xy = utils.visible_part_locs(parts)
+		im = utils.reveal_parts(im, xy, ratio=self.ratio)
+		return im, lab
+
+
+class UniformPartMixin(BasePartMixin):
+
+	def __init__(self, uniform_parts=False, ratio=utils.DEFAULT_RATIO, *args, **kwargs):
+		super(UniformPartMixin, self).__init__(*args, **kwargs)
+		self.uniform_parts = uniform_parts
+		self.ratio = ratio
+
+	def get_example(self, i):
+		im, parts, label = super(UniformPartMixin, self).get_example(i)
+		if self.uniform_parts:
+			parts = utils.uniform_parts(im, ratio=self.ratio)
+		return im, parts, label
+
+class RandomBlackOutMixin(BasePartMixin):
+
+	def __init__(self, seed=None, rnd_select=False, n_parts=None, *args, **kwargs):
+		super(RandomBlackOutMixin, self).__init__(*args, **kwargs)
+		self.rnd = np.random.RandomState(seed)
+		self.rnd_select = rnd_select
+		self.n_parts = n_parts
+
+	def get_example(self, i):
+		im, parts, lab = super(RandomBlackOutMixin, self).get_example(i)
+		if self.rnd_select:
+			idxs, xy = utils.visible_part_locs(parts)
+			rnd_idxs = utils.random_idxs(idxs, rnd=self.rnd, n_parts=self.n_parts)
+
+			parts[:, -1] = 0
+			parts[rnd_idxs, -1] = 1
+
+		return im, parts, lab
+
+
+
+# some shortcuts
+
+class PartMixin(RandomBlackOutMixin, UniformPartMixin, BBCropMixin):
+	"""
+		TODO!
+	"""
+
+class RevealedPartMixin(PartRevealMixin, PartMixin):
+	"""
+		TODO!
+	"""
+
+
+class CroppedPartMixin(PartCropMixin, PartMixin):
+	"""
+		TODO!
+	"""

+ 0 - 0
nabirds/dataset/postprocess.py


+ 54 - 0
nabirds/dataset/reading.py

@@ -0,0 +1,54 @@
+from imageio import imread
+from os.path import join, isfile
+
+from .base import BaseMixin
+
+class AnnotationsReadMixin(BaseMixin):
+
+	def __init__(self, uuids, annotations, mode="RGB"):
+		super(AnnotationsReadMixin, self).__init__()
+		self.uuids = uuids
+		self._annot = annotations
+		self.mode = mode
+
+	def __len__(self):
+		return len(self.uuids)
+
+	def _get(self, method, i):
+		return getattr(self._annot, method)(self.uuids[i])
+
+	def get_example(self, i):
+		res = super(AnnotationsReadMixin, self).get_example(i)
+		# if the super class returns something, then the class inheritance is wrong
+		assert res is None, "AnnotationsReadMixin should be the last class in the hierarchy!"
+
+		methods = ["image", "parts", "label"]
+		im_path, parts, label = [self._get(m, i) for m in methods]
+
+		im = imread(im_path, pilmode=self.mode)
+
+		return im, parts, label
+
+
+class ImageListReadingMixin(BaseMixin):
+	def __init__(self, pairs, root="."):
+		super(ImageListReadingMixin, self).__init__()
+		with open(pairs) as f:
+			self._pairs = [line.strip().split() for line in f]
+
+		assert all([len(pair) == 2 for pair in self._pairs]), \
+			"Invalid format of the pairs file!"
+
+		self._root = root
+
+	def __len__(self):
+		return len(self._pairs)
+
+	def get_example(self, i):
+		im_file, label = self._pairs[i]
+
+		im_path = join(self._root, im_file)
+		assert isfile(im_path), "Image \"{}\" does not exist!".format(im_path)
+		im = imread(im_path, pilmode="RGB")
+
+		return im, int(label)