浏览代码

added some plotting routines

Dimitri Korsch 6 年之前
父节点
当前提交
b489749f0e
共有 3 个文件被更改,包括 48 次插入29 次删除
  1. 12 0
      nabirds/dataset/mixins/__init__.py
  2. 34 22
      nabirds/dataset/part.py
  3. 2 7
      scripts/display_from_info.py

+ 12 - 0
nabirds/dataset/mixins/__init__.py

@@ -1,7 +1,10 @@
 from abc import ABC, abstractmethod
+
 import numpy as np
 import six
 
+from matplotlib.patches import Rectangle
+
 class BaseMixin(ABC):
 
 	@abstractmethod
@@ -10,6 +13,15 @@ class BaseMixin(ABC):
 		if hasattr(s, "get_example"):
 			return s.get_example(i)
 
+	def plot_bounding_box(self, i, ax, fill=False, linestyle="--", **kwargs):
+		x, y, w, h = self.bounding_box(i)
+		ax.add_patch(Rectangle(
+			(x,y), w, h,
+			fill=False,
+			linestyle="-.",
+			**kwargs
+		))
+
 	def __getitem__(self, index):
 		if isinstance(index, slice):
 			current, stop, step = index.indices(len(self))

+ 34 - 22
nabirds/dataset/part.py

@@ -1,5 +1,8 @@
 import numpy as np
+from matplotlib import pyplot as plt
+
 from contextlib import contextmanager
+from matplotlib.patches import Rectangle
 from abc import ABC, abstractmethod, abstractproperty
 
 from . import utils
@@ -48,19 +51,21 @@ class Parts(object):
 	def visible_crops(self, *args, **kwargs):
 		return np.array([p.crop(*args, **kwargs) for p in self._parts])
 
+	def plot(self, cmap=plt.cm.jet, **kwargs):
+		for i, p in enumerate(self._parts):
+			p.plot(color=cmap(i/len(self._parts)), **kwargs)
 
 class BasePart(ABC):
-	def __init__(self, image, annotation):
+	def __init__(self, annotation):
 		super(BasePart, self).__init__()
-		self.image = image
 		self.read_annotation(annotation)
 
 	@staticmethod
-	def new(image, annotation):
+	def new(annotation):
 		if len(annotation) == 4:
-			return LocationPart(image, annotation)
+			return LocationPart(annotation)
 		elif len(annotation) == 5:
-			return BBoxPart(image, annotation)
+			return BBoxPart(annotation)
 		else:
 			raise ValueError("Unknown part annotation format: {}".format(annotation))
 
@@ -80,26 +85,14 @@ class BasePart(ABC):
 	@property
 	def xy(self):
 		return np.array([self.x, self.y])
-	@property
-	def c(self):
-		h, w, c = utils.dimensions(self.image)
-		return c
 
 	@abstractmethod
-	def crop(self, ratio=None, padding_mode="edge"):
+	def crop(self, *args, **kwargs):
 		raise NotImplementedError
 
+	def plot(self, **kwargs):
+		return
 
-class LocationPart(BasePart):
-
-	def read_annotation(self, annotation):
-		# here x,y are the center of the part
-		self._id, self.x, self.y, self.is_visible = annotation
-		self._ratio = None
-
-	@abstractmethod
-	def crop(self, padding_mode="edge", *args, **kwargs):
-		raise NotImplementedError
 
 class LocationPart(BasePart):
 
@@ -119,6 +112,18 @@ class LocationPart(BasePart):
 			return utils.crop(image, self.xy, w, h,
 				padding_mode, is_location=True)
 
+	def plot(self, im, ax, ratio, fill=False, linestyle="--", **kwargs):
+		if not self.is_visible: return
+		x, y = self.xy
+		_h, _w, c = utils.dimensions(im)
+		w, h = int(_w * ratio), int(_h * ratio)
+		ax.add_patch(Rectangle(
+			(x-w//2, y-h//2), w, h,
+			fill=fill, linestyle=linestyle,
+			**kwargs
+		))
+
+
 class BBoxPart(BasePart):
 
 	def read_annotation(self, annotation):
@@ -126,6 +131,13 @@ class BBoxPart(BasePart):
 		self._id, self.x, self.y, self.w, self.h = annotation
 		self._is_visible = True
 
-	def crop(self, padding_mode="edge", *args, **kwargs):
-		return utils.crop(self.image, self.xy, self.w, self.h,
+	def crop(self, image, padding_mode="edge", *args, **kwargs):
+		return utils.crop(image, self.xy, self.w, self.h,
 			padding_mode, is_location=False)
+
+	def plot(self, im, ax, ratio, fill=False, linestyle="--", **kwargs):
+		ax.add_patch(Rectangle(
+			(self.x, self.y), self.w, self.h,
+			fill=fill, linestyle=linestyle,
+			**kwargs
+		))

+ 2 - 7
scripts/display_from_info.py

@@ -91,13 +91,8 @@ def main(args):
 		axs[0].set_title("Visible Parts")
 		axs[0].imshow(im)
 		if not args.crop_to_bb:
-			x, y, w, h = data.bounding_box(i)
-			axs[0].add_patch(Rectangle(
-				(x,y), w, h,
-				fill=False,
-				linestyle="--"
-			))
-		axs[0].scatter(*xy, marker="x", c=idxs)
+			data.plot_bounding_box(i, axs[0])
+		parts.plot(im=im, ax=axs[0], ratio=data.ratio)
 
 		axs[1].axis("off")
 		axs[1].set_title("{}selected parts".format("randomly " if args.rnd else ""))