Преглед на файлове

some changes in image reading

Dimitri Korsch преди 6 години
родител
ревизия
0cbe788f84
променени са 3 файла, в които са добавени 15 реда и са изтрити 7 реда
  1. 2 2
      nabirds/dataset/image.py
  2. 8 2
      nabirds/dataset/utils.py
  3. 5 3
      nabirds/display.py

+ 2 - 2
nabirds/dataset/image.py

@@ -50,7 +50,7 @@ class ImageWrapper(object):
 			elif isinstance(self._im, np.ndarray):
 				if self.mode == "RGB" and self._im.ndim == 2:
 					self._im_array = np.stack((self._im,) * 3, axis=-1)
-				elif self._im.ndim == 3:
+				elif self._im.ndim in (3, 4):
 					self._im_array = self._im
 				else:
 					raise ValueError()
@@ -60,7 +60,7 @@ class ImageWrapper(object):
 
 	@property
 	def im(self):
-		if self._im.mode != self.mode:
+		if isinstance(self._im, Image.Image) and self._im.mode != self.mode:
 			self._im = self._im.convert(self.mode)
 		return self._im
 

+ 8 - 2
nabirds/dataset/utils.py

@@ -9,12 +9,18 @@ def __expand_parts(p):
 def rescale_parts(im, parts, part_rescale_size):
 	if part_rescale_size is None or part_rescale_size < 0:
 		return parts
-
 	h, w, c = dimensions(im)
+	scale = np.array([w, h]) / part_rescale_size
+
 	xy = parts[:, 1:3]
-	xy = xy / part_rescale_size * np.array([w, h])
+	xy = xy * scale
 	parts[:, 1:3] = xy
 
+	if parts.shape[1] == 5:
+		wh = parts[:, 3:5]
+		wh = wh * scale
+		parts[:, 3:5] = wh
+
 	return parts
 
 def dimensions(im):

+ 5 - 3
nabirds/display.py

@@ -33,7 +33,7 @@ def init_logger(args):
 		filename=args.logfile or None,
 		filemode="w")
 
-def plot_crops(crops, title, scatter_mid=False):
+def plot_crops(crops, title, scatter_mid=False, names=None):
 
 	fig = plt.figure(figsize=(16,9))
 	fig.suptitle(title, fontsize=16)
@@ -44,6 +44,8 @@ def plot_crops(crops, title, scatter_mid=False):
 
 	for j, crop in enumerate(crops, 1):
 		ax = fig.add_subplot(rows, cols, j)
+		if names is not None:
+			ax.set_title(names[j-1])
 		ax.imshow(crop)
 		ax.axis("off")
 		if scatter_mid:
@@ -115,8 +117,8 @@ def main(args):
 		ax.imshow(reveal_parts(im, xy, ratio=args.ratio))
 		# ax.scatter(*xy, marker="x", c=idxs)
 		ax.axis("off")
-
-		plot_crops(part_crops, "Selected parts")
+		crop_names = list(data._annot.part_names.values())
+		plot_crops(part_crops, "Selected parts", names=crop_names)
 
 		if args.rnd:
 			plot_crops(action_crops, "Actions")