Przeglądaj źródła

added possibility to rescale part positions

Dimitri Korsch 6 lat temu
rodzic
commit
95b704e9aa

+ 13 - 7
nabirds/dataset/image.py

@@ -21,27 +21,26 @@ class ImageWrapper(object):
 		return im
 
 
-	def __init__(self, im_path, label, parts=None, mode="RGB"):
+	def __init__(self, im_path, label, parts=None, mode="RGB", part_rescale_size=None):
+
 
 		self.mode = mode
 		self.im = im_path
 		self._im_array = None
 
 		self.label = label
-		self.parts = parts
+		self.parts = utils.rescale_parts(self.im, parts, part_rescale_size)
+
+		self.part_rescale_size = part_rescale_size
 
 		self.parent = None
 		self._feature = None
 
 	def __del__(self):
 		if isinstance(self._im, Image.Image):
-			if self._im is not None and self._im.fp is not None:
+			if self._im is not None and getattr(self._im, "fp", None) is not None:
 				self._im.close()
 
-	@property
-	def im(self):
-		return self._im
-
 	@property
 	def im_array(self):
 		if self._im_array is None:
@@ -59,11 +58,18 @@ class ImageWrapper(object):
 				raise ValueError()
 		return self._im_array
 
+	@property
+	def im(self):
+		if self._im.mode != self.mode:
+			self._im = self._im.convert(self.mode)
+		return self._im
+
 	@im.setter
 	def im(self, value):
 		if isinstance(value, str):
 			assert isfile(value), "Image \"{}\" does not exist!".format(value)
 			self._im = ImageWrapper.read_image(value, mode=self.mode)
+			self._im_path = value
 		else:
 			self._im = value
 

+ 11 - 2
nabirds/dataset/mixins/__init__.py

@@ -1,4 +1,6 @@
 from abc import ABC, abstractmethod
+import numpy as np
+import six
 
 class BaseMixin(ABC):
 
@@ -8,5 +10,12 @@ class BaseMixin(ABC):
 		if hasattr(s, "get_example"):
 			return s.get_example(i)
 
-	def __getitem__(self, i):
-		return self.get_example(i)
+	def __getitem__(self, index):
+		if isinstance(index, slice):
+			current, stop, step = index.indices(len(self))
+			return [self.get_example(i) for i in
+					six.moves.range(current, stop, step)]
+		elif isinstance(index, list) or isinstance(index, np.ndarray):
+			return [self.get_example(i) for i in index]
+		else:
+			return self.get_example(index)

+ 6 - 5
nabirds/dataset/mixins/reading.py

@@ -7,11 +7,12 @@ from ..image import ImageWrapper
 
 class AnnotationsReadMixin(BaseMixin):
 
-	def __init__(self, uuids, annotations, mode="RGB"):
+	def __init__(self, uuids, annotations, part_rescale_size=None, mode="RGB"):
 		super(AnnotationsReadMixin, self).__init__()
 		self.uuids = uuids
 		self._annot = annotations
 		self.mode = mode
+		self.part_rescale_size = part_rescale_size
 
 	def __len__(self):
 		return len(self.uuids)
@@ -20,14 +21,14 @@ class AnnotationsReadMixin(BaseMixin):
 		return getattr(self._annot, method)(self.uuids[i])
 
 	def get_example(self, i):
-		res = super(AnnotationsReadMixin, self).get_example(i)
-		# if the super class returns something, then the class inheritance is wrong
-		assert res is None, "AnnotationsReadMixin should be the last class in the hierarchy!"
+		# res = super(AnnotationsReadMixin, self).get_example(i)
+		# # if the super class returns something, then the class inheritance is wrong
+		# assert res is None, "AnnotationsReadMixin should be the last class in the hierarchy!"
 
 		methods = ["image", "parts", "label"]
 		im_path, parts, label = [self._get(m, i) for m in methods]
 
-		return ImageWrapper(im_path, int(label), parts, mode=self.mode)
+		return ImageWrapper(im_path, int(label), parts, mode=self.mode, part_rescale_size=self.part_rescale_size)
 
 	@property
 	def labels(self):

+ 13 - 0
nabirds/dataset/utils.py

@@ -6,8 +6,21 @@ DEFAULT_RATIO = np.sqrt(49 / 400)
 def __expand_parts(p):
 	return p[:, 0], p[:, 1:3], p[:, 3].astype(bool)
 
+def rescale_parts(im, parts, part_rescale_size):
+	if part_rescale_size is None or part_rescale_size < 0:
+		return parts
+
+	h, w, c = dimensions(im)
+	xy = parts[:, 1:3]
+	xy = xy / part_rescale_size * np.array([w, h])
+	parts[:, 1:3] = xy
+
+	return parts
+
 def dimensions(im):
 	if isinstance(im, np.ndarray):
+		if im.ndim != 3:
+			import pdb; pdb.set_trace()
 		assert im.ndim == 3, "Only RGB images are currently supported!"
 		return im.shape
 	elif isinstance(im, PIL_Image):

+ 16 - 0
nabirds/display.py

@@ -1,6 +1,17 @@
 #!/usr/bin/env python
 if __name__ != '__main__': raise Exception("Do not import me!")
 
+"""
+	Possible calls:
+
+	./display.sh /home/korsch1/korsch/datasets/birds/cub200_11 --dataset cub -s600 -n5 --features /home/korsch1/korsch/datasets/birds/features/{train,val}_16parts_gt.npz --ratio 0.31
+	> displays GT parts of CUB200
+
+	./display.sh /home/korsch1/korsch/datasets/birds/NAC/2017-bilinear/ --dataset cub -s600 -n5 --features /home/korsch1/korsch/datasets/birds/features/{train,val}_16parts_gt.npz --ratio 0.31 --rescale_size 227
+	> displays NAC parts of CUB200
+
+"""
+
 from argparse import ArgumentParser
 import logging
 import numpy as np
@@ -58,6 +69,7 @@ def main(args):
 
 	data = Dataset(
 		uuids=uuids, annotations=annot,
+		part_rescale_size=args.rescale_size,
 		features=features,
 
 		uniform_parts=args.uniform_parts,
@@ -147,6 +159,10 @@ parser.add_argument("--ratio",
 	help="Part extraction ratio",
 	type=float, default=.2)
 
+parser.add_argument("--rescale_size",
+	help="rescales the part positions from this size to original image size",
+	type=int, default=-1)
+
 parser.add_argument("--rnd",
 	help="select random subset of present parts",
 	action="store_true")