Parcourir la source

refactored dataset module: added utils module and moved Dataset

Dimitri Korsch il y a 7 ans
Parent
commit
b682032279
3 fichiers modifiés avec 35 ajouts et 30 suppressions
  1. 24 0
      nabirds/dataset/__init__.py
  2. 0 26
      nabirds/dataset/utils.py
  3. 11 4
      nabirds/display.py

+ 24 - 0
nabirds/dataset/__init__.py

@@ -0,0 +1,24 @@
+from imageio import imread
+import numpy as np
+
+
+class Dataset(object):
+	def __init__(self, uuids, annotations):
+		super(Dataset, self).__init__()
+		self.uuids = uuids
+		self._annot = annotations
+
+	def __len__(self):
+		return len(self.uuids)
+
+	def _get(self, method, i):
+		return getattr(self._annot, method)(self.uuids[i])
+
+
+
+	def get_example(self, i, mode="RGB"):
+		methods = ["image", "parts", "label"]
+		im_path, parts, label = [self._get(m, i) for m in methods]
+		return imread(im_path, pilmode=mode), parts, label
+
+	__getitem__  = get_example

+ 0 - 26
nabirds/dataset.py → nabirds/dataset/utils.py

@@ -1,31 +1,5 @@
-from imageio import imread
 import numpy as np
 
-
-class Dataset(object):
-	def __init__(self, uuids, annotations):
-		super(Dataset, self).__init__()
-		self.uuids = uuids
-		self._annot = annotations
-
-	def __len__(self):
-		return len(self.uuids)
-
-	def _get(self, method, i):
-		return getattr(self._annot, method)(self.uuids[i])
-
-
-
-	def get_example(self, i, mode="RGB"):
-		methods = ["image", "parts", "label"]
-		im_path, parts, label = [self._get(m, i) for m in methods]
-		return imread(im_path, pilmode=mode), parts, label
-
-	__getitem__  = get_example
-
-
-# some convention functions
-
 DEFAULT_RATIO = np.sqrt(49 / 400)
 
 def __expand_parts(p):

+ 11 - 4
nabirds/display.py

@@ -6,7 +6,8 @@ import logging
 import numpy as np
 
 from annotations import NAB_Annotations, CUB_Annotations
-from dataset import Dataset, reveal_parts, \
+from dataset import Dataset
+from dataset.utils import reveal_parts, \
 	visible_part_locs, visible_crops, \
 	uniform_part_locs, crops
 
@@ -50,27 +51,29 @@ def main(args):
 
 		fig1 = plt.figure(figsize=(16,9))
 		ax = fig1.add_subplot(2,1,1)
-
 		ax.imshow(im)
 		ax.scatter(*xy, marker="x", c=idxs)
+		ax.axis("off")
 
 		ax = fig1.add_subplot(2,1,2)
 		ax.imshow(reveal_parts(im, xy, ratio=args.ratio))
 		ax.scatter(*xy, marker="x", c=idxs)
-
-		fig2 = plt.figure(figsize=(16,9))
+		ax.axis("off")
 
 		if args.uniform_parts:
 			part_crops = crops(im, xy, ratio=args.ratio)
 		else:
 			part_crops = visible_crops(im, parts, ratio=args.ratio)
 
+		fig2 = plt.figure(figsize=(16,9))
 		n_crops = len(part_crops)
 		rows = int(np.ceil(np.sqrt(n_crops)))
 		cols = int(np.ceil(n_crops / rows))
+
 		for j, crop in enumerate(part_crops, 1):
 			ax = fig2.add_subplot(rows, cols, j)
 			ax.imshow(crop)
+			ax.axis("off")
 
 			middle_h, middle_w = crop.shape[0] / 2, crop.shape[1] / 2
 			ax.scatter(middle_w, middle_h, marker="x")
@@ -109,6 +112,10 @@ parser.add_argument("--ratio",
 	help="Part extraction ratio",
 	type=float, default=.2)
 
+parser.add_argument("--rnd",
+	help="select random subset of present parts",
+	action="store_true")
+
 parser.add_argument("--uniform_parts", "-u",
 	help="Do not use GT parts, but sample parts uniformly from the image",
 	action="store_true")