소스 검색

added reveal routine to parts class

Dimitri Korsch 6 년 전
부모
커밋
f6d8da96b4
2개의 변경된 파일29개의 추가작업 그리고 4개의 파일을 삭제
  1. 28 1
      nabirds/dataset/part.py
  2. 1 3
      scripts/display_from_info.py

+ 28 - 1
nabirds/dataset/part.py

@@ -55,6 +55,18 @@ class Parts(object):
 		for i, p in enumerate(self._parts):
 			p.plot(color=cmap(i/len(self._parts)), **kwargs)
 
+	def reveal(self, im, ratio, *args, **kwargs):
+		res = np.zeros_like(im)
+
+		for part in self._parts:
+			if not part.is_visible: continue
+			x, y, crop = part.reveal(im, ratio=ratio, *args, **kwargs)
+			h, w, _ = crop.shape
+			res[y:y+h, x:x+w] = crop
+
+		return res
+
+
 class BasePart(ABC):
 	def __init__(self, annotation):
 		super(BasePart, self).__init__()
@@ -112,6 +124,15 @@ class LocationPart(BasePart):
 			return utils.crop(image, self.xy, w, h,
 				padding_mode, is_location=True)
 
+
+	def reveal(self, im, ratio, *args, **kwargs):
+		_h, _w, c = utils.dimensions(im)
+		w, h = int(_w * ratio), int(_h * ratio)
+		x,y = self.xy
+		x, y = max(x - w // 2, 0), max(y - h // 2, 0)
+		return x, y, im[y:y+h, x:x+w]
+
+
 	def plot(self, im, ax, ratio, fill=False, linestyle="--", **kwargs):
 		if not self.is_visible: return
 		x, y = self.xy
@@ -123,7 +144,6 @@ class LocationPart(BasePart):
 			**kwargs
 		))
 
-
 class BBoxPart(BasePart):
 
 	def read_annotation(self, annotation):
@@ -135,6 +155,13 @@ class BBoxPart(BasePart):
 		return utils.crop(image, self.xy, self.w, self.h,
 			padding_mode, is_location=False)
 
+	def reveal(self, im, ratio, *args, **kwargs):
+		_h, _w, c = utils.dimensions(im)
+		x,y = self.xy
+		return x, y, im[y:y+self.h, x:x+self.w]
+
+
+
 	def plot(self, im, ax, ratio, fill=False, linestyle="--", **kwargs):
 		ax.add_patch(Rectangle(
 			(self.x, self.y), self.w, self.h,

+ 1 - 3
scripts/display_from_info.py

@@ -17,8 +17,6 @@ from matplotlib.patches import Rectangle
 from argparse import ArgumentParser
 
 from nabirds import CUB_Annotations, Dataset
-from nabirds.dataset import utils
-
 
 def init_logger(args):
 	fmt = "%(levelname)s - [%(asctime)s] %(filename)s:%(lineno)d [%(funcName)s]: %(message)s"
@@ -96,7 +94,7 @@ def main(args):
 
 		axs[1].axis("off")
 		axs[1].set_title("{}selected parts".format("randomly " if args.rnd else ""))
-		axs[1].imshow(utils.reveal_parts(im, xy, ratio=data.ratio))
+		axs[1].imshow(parts.reveal(im, ratio=data.ratio))
 		crop_names = list(data._annot.part_names.values())
 		plot_crops(part_crops, "Selected parts", names=crop_names)