Browse Source

refactored dataset class: added options for random part selection, uniform part creation and bounding box cropping

Dimitri Korsch 6 years ago
parent
commit
4283cd17a0
2 changed files with 104 additions and 30 deletions
  1. 77 22
      nabirds/dataset/__init__.py
  2. 27 8
      nabirds/display.py

+ 77 - 22
nabirds/dataset/__init__.py

@@ -1,12 +1,50 @@
-from imageio import imread
 import numpy as np
 
+from abc import ABC, abstractmethod
+from imageio import imread
+
+from . import utils
+
+class BaseMixin(ABC):
+
+	@abstractmethod
+	def get_example(self, i):
+		pass
+
+	def __getitem__(self, i):
+		return self.get_example(i)
 
-class Dataset(object):
-	def __init__(self, uuids, annotations, crop_to_bb=False, crop_uniform=False):
-		super(Dataset, self).__init__()
+
+class ImageReadMixin(BaseMixin):
+
+	def __init__(self, uuids, annotations, mode="RGB"):
+		super(BaseMixin, self).__init__()
 		self.uuids = uuids
 		self._annot = annotations
+		self.mode = mode
+
+	def __len__(self):
+		return len(self.uuids)
+
+	def _get(self, method, i):
+		return getattr(self._annot, method)(self.uuids[i])
+
+	def get_example(self, i):
+		res = super(ImageReadMixin, self).get_example(i)
+		# if the super class returns something, then the class inheritance is wrong
+		assert res is None, "ImageReadMixin should be the last class in the hierarchy!"
+
+		methods = ["image", "parts", "label"]
+		im_path, parts, label = [self._get(m, i) for m in methods]
+
+		im = imread(im_path, pilmode=self.mode)
+
+		return im, parts, label
+
+class BBCropMixin(BaseMixin):
+
+	def __init__(self, crop_to_bb=False, crop_uniform=False, *args, **kwargs):
+		super(BBCropMixin, self).__init__(*args, **kwargs)
 		self.crop_to_bb = crop_to_bb
 		self.crop_uniform = crop_uniform
 
@@ -23,29 +61,46 @@ class Dataset(object):
 			w = h = crop_size * 2
 		return x,y,w,h
 
-	def __len__(self):
-		return len(self.uuids)
-
-	def _get(self, method, i):
-		return getattr(self._annot, method)(self.uuids[i])
-
-
-	def get_example(self, i, mode="RGB"):
-		methods = ["image", "parts", "label"]
-		im_path, parts, label = [self._get(m, i) for m in methods]
-
-		im = imread(im_path, pilmode=mode)
-
+	def get_example(self, i):
+		im, parts, label = super(BBCropMixin, self).get_example(i)
 		if self.crop_to_bb:
 			x,y,w,h = self.bounding_box(i)
 			im = im[y:y+h, x:x+w]
 			parts[:, 1] -= x
 			parts[:, 2] -= y
+		return im, parts, label
+
+class UniformPartMixin(BaseMixin):
+
+	def __init__(self, uniform_parts=False, ratio=utils.DEFAULT_RATIO, *args, **kwargs):
+		super(UniformPartMixin, self).__init__(*args, **kwargs)
+		self.uniform_parts = uniform_parts
+		self.ratio = ratio
 
-		h,w,c = im.shape
-		# fit to the dimensions of the image
-		parts[:, 1] = np.minimum(parts[:, 1], w - 1)
-		parts[:, 2] = np.minimum(parts[:, 2], h - 1)
+	def get_example(self, i):
+		im, parts, label = super(UniformPartMixin, self).get_example(i)
+		if self.uniform_parts:
+			parts = utils.uniform_parts(im, ratio=self.ratio)
 		return im, parts, label
 
-	__getitem__  = get_example
+class RandomBlackOutMixin(BaseMixin):
+
+	def __init__(self, seed=None, rnd_select=False, n_parts=None, *args, **kwargs):
+		super(RandomBlackOutMixin, self).__init__(*args, **kwargs)
+		self.rnd = np.random.RandomState(seed)
+		self.rnd_select = rnd_select
+		self.n_parts = n_parts
+
+	def get_example(self, i):
+		im, parts, lab = super(RandomBlackOutMixin, self).get_example(i)
+		if self.rnd_select:
+			idxs, xy = utils.visible_part_locs(parts)
+			rnd_idxs = utils.random_idxs(idxs, rnd=self.rnd, n_parts=self.n_parts)
+
+			parts[:, -1] = 0
+			parts[rnd_idxs, -1] = 1
+
+		return im, parts, lab
+
+class Dataset(RandomBlackOutMixin, UniformPartMixin, BBCropMixin, ImageReadMixin):
+	pass

+ 27 - 8
nabirds/display.py

@@ -33,7 +33,19 @@ def main(args):
 	annot = annotation_cls.get(args.dataset.lower())(args.data)
 
 	uuids = getattr(annot, "{}_uuids".format(args.subset.lower()))
-	data = Dataset(uuids, annot)
+	data = Dataset(
+		uuids=uuids, annotations=annot,
+
+		uniform_parts=args.uniform_parts,
+
+		crop_to_bb=args.crop_to_bb,
+		crop_uniform=args.crop_uniform,
+
+		rnd_select=args.rnd,
+		ratio=args.ratio,
+		seed=args.seed
+
+	)
 	n_images = len(data)
 	logging.info("Found {} images in the {} subset".format(n_images, args.subset))
 
@@ -41,10 +53,6 @@ def main(args):
 		if i + 1 <= args.start: continue
 		im, parts, label = data[i]
 
-
-		if args.uniform_parts:
-			parts = uniform_parts(im, ratio=args.ratio)
-
 		idxs, xy = visible_part_locs(parts)
 		part_crops = visible_crops(im, parts, ratio=args.ratio)
 
@@ -59,9 +67,6 @@ def main(args):
 		ax.scatter(*xy, marker="x", c=idxs)
 		ax.axis("off")
 
-		if args.rnd:
-			idxs, xy, part_crops = random_select(idxs, xy, part_crops)
-
 		ax = fig1.add_subplot(2,1,2)
 		ax.set_title("{}selected parts".format("randomly " if args.rnd else ""))
 		ax.imshow(reveal_parts(im, xy, ratio=args.ratio))
@@ -123,6 +128,16 @@ parser.add_argument("--uniform_parts", "-u",
 	help="Do not use GT parts, but sample parts uniformly from the image",
 	action="store_true")
 
+parser.add_argument("--crop_to_bb",
+	help="Crop image to the bounding box",
+	action="store_true")
+
+parser.add_argument("--crop_uniform",
+	help="Try to extend the bounding box to same height and width",
+	action="store_true")
+
+
+
 parser.add_argument(
 	'--logfile', type=str, default='',
 	help='File for logging output')
@@ -131,4 +146,8 @@ parser.add_argument(
 	'--loglevel', type=str, default='INFO',
 	help='logging level. see logging module for more information')
 
+parser.add_argument(
+	'--seed', type=int, default=12311123,
+	help='random seed')
+
 main(parser.parse_args())