Forráskód Böngészése

refactored mixins. Added Image/Parts/Label wrapper object

Dimitri Korsch 6 éve
szülő
commit
517f55e9e7

+ 6 - 5
nabirds/dataset/__init__.py

@@ -1,7 +1,8 @@
-from .reading import AnnotationsReadMixin, ImageListReadingMixin
-from .parts import PartMixin, RevealedPartMixin, CroppedPartMixin
+from .mixins.reading import AnnotationsReadMixin, ImageListReadingMixin
+from .mixins.parts import PartMixin, RevealedPartMixin, CroppedPartMixin
 
 
 class Dataset(PartMixin, AnnotationsReadMixin):
 class Dataset(PartMixin, AnnotationsReadMixin):
-	"""
-		TODO!
-	"""
+
+	def get_example(self, i):
+		im_obj = super(Dataset, self).get_example(i)
+		return im_obj.as_tuple()

+ 98 - 0
nabirds/dataset/image.py

@@ -0,0 +1,98 @@
+from imageio import imread
+from os.path import isfile
+
+import copy
+import numpy as np
+
+from . import utils
+
+def should_have_parts(func):
+	def inner(self, *args, **kwargs):
+		assert self.has_parts, "parts are not present!"
+		return func(self, *args, **kwargs)
+	return inner
+
+class ImageWrapper(object):
+	def __init__(self, im_path, label, parts=None, mode="RGB"):
+		if isinstance(im_path, str):
+			assert isfile(im_path), "Image \"{}\" does not exist!".format(im_path)
+			self.im = imread(im_path, pilmode=mode)
+		else:
+			self.im = im_path
+
+		self.label = label
+		self.parts = parts
+
+		self.parent = None
+
+	def as_tuple(self):
+		return self.im, self.parts, self.label
+
+	def copy(self):
+		new = copy.deepcopy(self)
+		new.parent = self
+		return new
+
+
+	def crop(self, x, y, w, h):
+		result = self.copy()
+		result.im = self.im[y:y+h, x:x+w]
+		if self.has_parts:
+			result.parts[:, 1] -= x
+			result.parts[:, 2] -= y
+		return result
+
+	@should_have_parts
+	def hide_parts_outside_bb(self, x, y, w, h):
+		idxs, (xs,ys) = self.visible_part_locs()
+		f = np.logical_and
+		mask = f(f(x <= xs, xs <= x+w), f(y <= ys, ys <= y+h))
+		result = self.copy()
+		result.parts[:, -1] = mask.astype(self.parts.dtype)
+
+		return result
+
+	def uniform_parts(self, ratio):
+		result = self.copy()
+		result.parts = utils.uniform_parts(self.im, ratio=ratio)
+		return result
+
+	@should_have_parts
+	def select_random_parts(self, rnd, n_parts):
+
+		idxs, xy = self.visible_part_locs()
+		rnd_idxs = utils.random_idxs(idxs, rnd=rnd, n_parts=n_parts)
+		result = self.copy()
+
+		result.parts[:, -1] = 0
+		result.parts[rnd_idxs, -1] = 1
+
+		return result
+
+	@should_have_parts
+	def visible_crops(self, ratio):
+		return utils.visible_crops(self.im, self.parts, ratio=ratio)
+
+	@should_have_parts
+	def visible_part_locs(self):
+		return utils.visible_part_locs(self.parts)
+
+	@should_have_parts
+	def reveal_visible(self, ratio):
+		_, xy = self.visible_part_locs()
+		result = self.copy()
+		result.im = utils.reveal_parts(self.im, xy, ratio=ratio)
+		return result
+
+	@should_have_parts
+	def part_crops(self, ratio):
+		crops = self.visible_crops(ratio)
+		idxs, _ = self.visible_part_locs()
+		result = self.copy()
+		result.im = crops[idxs]
+		return result
+
+	@property
+	def has_parts(self):
+		return self.parts is not None
+

+ 0 - 0
nabirds/dataset/base.py → nabirds/dataset/mixins/__init__.py


+ 116 - 0
nabirds/dataset/mixins/parts.py

@@ -0,0 +1,116 @@
+import numpy as np
+
+from . import BaseMixin
+
+class BBCropMixin(BaseMixin):
+
+	def __init__(self, crop_to_bb=False, crop_uniform=False, *args, **kwargs):
+		super(BBCropMixin, self).__init__(*args, **kwargs)
+		self.crop_to_bb = crop_to_bb
+		self.crop_uniform = crop_uniform
+
+	def bounding_box(self, i):
+		bbox = self._get("bounding_box", i)
+		x,y,w,h = [bbox[attr] for attr in "xywh"]
+		if self.crop_uniform:
+			x0 = x + w//2
+			y0 = y + h//2
+
+			crop_size = max(w//2, h//2)
+
+			x,y = max(x0 - crop_size, 0), max(y0 - crop_size, 0)
+			w = h = crop_size * 2
+		return x,y,w,h
+
+	def get_example(self, i):
+		im_obj = super(BBCropMixin, self).get_example(i)
+		if self.crop_to_bb:
+			bb = self.bounding_box(i)
+			return im_obj.crop(*bb)
+		return im_obj
+
+class PartsInBBMixin(BaseMixin):
+	def __init__(self, parts_in_bb=False, *args, **kwargs):
+		super(PartsInBBMixin, self).__init__(*args, **kwargs)
+		self.parts_in_bb = parts_in_bb
+
+	def get_example(self, i):
+		im_obj = super(PartsInBBMixin, self).get_example(i)
+
+		if self.parts_in_bb:
+			bb = self.bounding_box(i)
+			return im_obj.hide_parts_outside_bb(*bb)
+		return im_obj
+
+class PartCropMixin(BaseMixin):
+
+	def __init__(self, return_part_crops=False, *args, **kwargs):
+		super(PartCropMixin, self).__init__(*args, **kwargs)
+		self.return_part_crops = return_part_crops
+
+	def get_example(self, i):
+		im_obj = super(PartCropMixin, self).get_example(i)
+		if self.return_part_crops:
+			return im_obj.part_crops(self.ratio)
+		return im_obj
+
+
+class PartRevealMixin(BaseMixin):
+
+	def __init__(self, reveal_visible=False, *args, **kwargs):
+		super(PartRevealMixin, self).__init__(*args, **kwargs)
+		self.reveal_visible = reveal_visible
+
+	def get_example(self, i):
+		im_obj = super(PartRevealMixin, self).get_example(i)
+		assert hasattr(self, "ratio"), "\"ratio\" attribute is missing!"
+		if not self.reveal_visible:
+			return im_obj.reveal_visible(self.ratio)
+		return im_obj
+
+
+class UniformPartMixin(BaseMixin):
+
+	def __init__(self, uniform_parts=False, ratio=None, *args, **kwargs):
+		super(UniformPartMixin, self).__init__(*args, **kwargs)
+		self.uniform_parts = uniform_parts
+		self.ratio = ratio
+
+	def get_example(self, i):
+		im_obj = super(UniformPartMixin, self).get_example(i)
+		if self.uniform_parts:
+			return im_obj.uniform_parts(self.ratio)
+		return im_obj
+
+class RandomBlackOutMixin(BaseMixin):
+
+	def __init__(self, seed=None, rnd_select=False, n_parts=None, *args, **kwargs):
+		super(RandomBlackOutMixin, self).__init__(*args, **kwargs)
+		self.rnd = np.random.RandomState(seed)
+		self.rnd_select = rnd_select
+		self.n_parts = n_parts
+
+	def get_example(self, i):
+		im_obj = super(RandomBlackOutMixin, self).get_example(i)
+		if self.rnd_select:
+			return im_obj.select_random_parts(rnd=self.rnd, n_parts=self.n_parts)
+		return im_obj
+
+
+# some shortcuts
+
+class PartMixin(RandomBlackOutMixin, PartsInBBMixin, UniformPartMixin, BBCropMixin):
+	"""
+		TODO!
+	"""
+
+class RevealedPartMixin(PartRevealMixin, PartMixin):
+	"""
+		TODO!
+	"""
+
+
+class CroppedPartMixin(PartCropMixin, PartMixin):
+	"""
+		TODO!
+	"""

+ 0 - 0
nabirds/dataset/postprocess.py → nabirds/dataset/mixins/postprocess.py


+ 5 - 10
nabirds/dataset/reading.py → nabirds/dataset/mixins/reading.py

@@ -1,7 +1,7 @@
-from imageio import imread
-from os.path import join, isfile
+from os.path import join
 
 
-from .base import BaseMixin
+from . import BaseMixin
+from ..image import ImageWrapper
 
 
 class AnnotationsReadMixin(BaseMixin):
 class AnnotationsReadMixin(BaseMixin):
 
 
@@ -25,9 +25,7 @@ class AnnotationsReadMixin(BaseMixin):
 		methods = ["image", "parts", "label"]
 		methods = ["image", "parts", "label"]
 		im_path, parts, label = [self._get(m, i) for m in methods]
 		im_path, parts, label = [self._get(m, i) for m in methods]
 
 
-		im = imread(im_path, pilmode=self.mode)
-
-		return im, parts, label
+		return ImageWrapper(im_path, int(label), parts, mode=self.mode)
 
 
 
 
 class ImageListReadingMixin(BaseMixin):
 class ImageListReadingMixin(BaseMixin):
@@ -46,9 +44,6 @@ class ImageListReadingMixin(BaseMixin):
 
 
 	def get_example(self, i):
 	def get_example(self, i):
 		im_file, label = self._pairs[i]
 		im_file, label = self._pairs[i]
-
 		im_path = join(self._root, im_file)
 		im_path = join(self._root, im_file)
-		assert isfile(im_path), "Image \"{}\" does not exist!".format(im_path)
-		im = imread(im_path, pilmode="RGB")
 
 
-		return im, int(label)
+		return ImageWrapper(im_path, int(label))

+ 0 - 134
nabirds/dataset/parts.py

@@ -1,134 +0,0 @@
-import numpy as np
-
-from .base import BaseMixin
-from . import utils
-
-
-class BasePartMixin(BaseMixin):
-
-	def get_example(self, i):
-		res = super(BasePartMixin, self).get_example(i)
-		if len(res) == 2:
-			# result has only image and label
-			im, lab = res
-			parts = None
-		else:
-			# result has already parts
-			im, parts, lab = res
-
-		return im, parts, lab
-
-class BBCropMixin(BasePartMixin):
-
-	def __init__(self, crop_to_bb=False, crop_uniform=False, *args, **kwargs):
-		super(BBCropMixin, self).__init__(*args, **kwargs)
-		self.crop_to_bb = crop_to_bb
-		self.crop_uniform = crop_uniform
-
-	def bounding_box(self, i):
-		bbox = self._get("bounding_box", i)
-		x,y,w,h = [bbox[attr] for attr in "xywh"]
-		if self.crop_uniform:
-			x0 = x + w//2
-			y0 = y + h//2
-
-			crop_size = max(w//2, h//2)
-
-			x,y = max(x0 - crop_size, 0), max(y0 - crop_size, 0)
-			w = h = crop_size * 2
-		return x,y,w,h
-
-	def get_example(self, i):
-		im, parts, label = super(BBCropMixin, self).get_example(i)
-		if self.crop_to_bb:
-			x,y,w,h = self.bounding_box(i)
-			im = im[y:y+h, x:x+w]
-			if parts is not None:
-				parts[:, 1] -= x
-				parts[:, 2] -= y
-		return im, parts, label
-
-class PartCropMixin(BasePartMixin):
-	def __init__(self, return_part_crops=False, *args, **kwargs):
-		super(PartCropMixin, self).__init__(*args, **kwargs)
-		self.return_part_crops = return_part_crops
-
-	def get_example(self, i):
-		im, parts, label = super(PartCropMixin, self).get_example(i)
-		assert hasattr(self, "ratio"), "\"ratio\" attribute is missing!"
-		if not self.return_part_crops or parts is None or not hasattr(self, "ratio"):
-			return im, label
-
-		crops = utils.visible_crops(im, parts)
-		idxs, _ = utils.visible_part_locs(parts)
-
-		return crops[idxs], label
-
-
-class PartRevealMixin(BasePartMixin):
-	def __init__(self, reveal_visible=False, *args, **kwargs):
-		super(PartRevealMixin, self).__init__(*args, **kwargs)
-		self.reveal_visible = reveal_visible
-
-	def get_example(self, i):
-		im, parts, label = super(PartRevealMixin, self).get_example(i)
-		assert hasattr(self, "ratio"), "\"ratio\" attribute is missing!"
-		if not self.reveal_visible or parts is None or not hasattr(self, "ratio"):
-			return im, label
-
-		_, xy = utils.visible_part_locs(parts)
-		im = utils.reveal_parts(im, xy, ratio=self.ratio)
-		return im, lab
-
-
-class UniformPartMixin(BasePartMixin):
-
-	def __init__(self, uniform_parts=False, ratio=utils.DEFAULT_RATIO, *args, **kwargs):
-		super(UniformPartMixin, self).__init__(*args, **kwargs)
-		self.uniform_parts = uniform_parts
-		self.ratio = ratio
-
-	def get_example(self, i):
-		im, parts, label = super(UniformPartMixin, self).get_example(i)
-		if self.uniform_parts:
-			parts = utils.uniform_parts(im, ratio=self.ratio)
-		return im, parts, label
-
-class RandomBlackOutMixin(BasePartMixin):
-
-	def __init__(self, seed=None, rnd_select=False, n_parts=None, *args, **kwargs):
-		super(RandomBlackOutMixin, self).__init__(*args, **kwargs)
-		self.rnd = np.random.RandomState(seed)
-		self.rnd_select = rnd_select
-		self.n_parts = n_parts
-
-	def get_example(self, i):
-		im, parts, lab = super(RandomBlackOutMixin, self).get_example(i)
-		if self.rnd_select:
-			idxs, xy = utils.visible_part_locs(parts)
-			rnd_idxs = utils.random_idxs(idxs, rnd=self.rnd, n_parts=self.n_parts)
-
-			parts[:, -1] = 0
-			parts[rnd_idxs, -1] = 1
-
-		return im, parts, lab
-
-
-
-# some shortcuts
-
-class PartMixin(RandomBlackOutMixin, UniformPartMixin, BBCropMixin):
-	"""
-		TODO!
-	"""
-
-class RevealedPartMixin(PartRevealMixin, PartMixin):
-	"""
-		TODO!
-	"""
-
-
-class CroppedPartMixin(PartCropMixin, PartMixin):
-	"""
-		TODO!
-	"""

+ 6 - 0
nabirds/display.py

@@ -41,6 +41,8 @@ def main(args):
 		crop_to_bb=args.crop_to_bb,
 		crop_to_bb=args.crop_to_bb,
 		crop_uniform=args.crop_uniform,
 		crop_uniform=args.crop_uniform,
 
 
+		parts_in_bb=args.parts_in_bb,
+
 		rnd_select=args.rnd,
 		rnd_select=args.rnd,
 		ratio=args.ratio,
 		ratio=args.ratio,
 		seed=args.seed
 		seed=args.seed
@@ -136,6 +138,10 @@ parser.add_argument("--crop_uniform",
 	help="Try to extend the bounding box to same height and width",
 	help="Try to extend the bounding box to same height and width",
 	action="store_true")
 	action="store_true")
 
 
+parser.add_argument("--parts_in_bb",
+	help="Only display parts, that are inside the bounding box",
+	action="store_true")
+
 
 
 
 
 parser.add_argument(
 parser.add_argument(