Sfoglia il codice sorgente

added function for uniform sampling

Dimitri Korsch 7 anni fa
parent
commit
24ef9c4a34
5 ha cambiato i file con 76 aggiunte e 35 eliminazioni
  1. 1 1
      nabirds/__init__.py
  2. 4 3
      nabirds/annotations.py
  3. 41 16
      nabirds/dataset.py
  4. 25 8
      nabirds/display.py
  5. 5 7
      setup.py

+ 1 - 1
nabirds/__init__.py

@@ -1,4 +1,4 @@
 from .dataset import Dataset
 from .annotations import NAB_Annotations, CUB_Annotations
 
-__version__ = "0.1.3"
+__version__ = "0.1.4"

+ 4 - 3
nabirds/annotations.py

@@ -92,13 +92,12 @@ class BaseAnnotations(abc.ABC):
 	def test_uuids(self):
 		return self._uuids(self.test_split)
 
-
 class NAB_Annotations(BaseAnnotations):
 	@property
 	def meta(self):
 		info = _MetaInfo(
-			images_file="images.txt",
 			images_folder="images",
+			images_file="images.txt",
 			labels_file="labels.txt",
 			hierarchy_file="hierarchy.txt",
 			split_file="train_test_split.txt",
@@ -118,8 +117,8 @@ class CUB_Annotations(BaseAnnotations):
 	@property
 	def meta(self):
 		info = _MetaInfo(
-			images_file="images.txt",
 			images_folder="images",
+			images_file="images.txt",
 			labels_file="labels.txt",
 			split_file="tr_ID.txt",
 			parts_file=join("parts", "part_locs.txt"),
@@ -148,3 +147,5 @@ class CUB_Annotations(BaseAnnotations):
 		super(CUB_Annotations, self)._load_parts()
 		# set part idxs from 1-idxs to 0-idxs
 		self.part_locs[..., 0] -= 1
+
+

+ 41 - 16
nabirds/dataset.py

@@ -31,39 +31,64 @@ DEFAULT_RATIO = np.sqrt(49 / 400)
 def __expand_parts(p):
 	return p[:, 0], p[:, 1:3], p[:, 3].astype(bool)
 
+
+
+def uniform_part_locs(im, ratio=DEFAULT_RATIO, round_op=np.floor):
+	h, w, c = im.shape
+
+	part_w = round_op(w * ratio).astype(np.int32)
+	part_h = round_op(h * ratio).astype(np.int32)
+
+	n, m = w // part_w, h // part_h
+
+	idxs = np.arange(n*m)
+	locs = np.zeros((2, n*m), dtype=np.int32)
+
+
+	for x in range(n):
+		for y in range(m):
+			i = y * n + x
+			x0, y0 = x * part_w, y * part_h
+			locs[:, i] = [x0 + part_w // 2, y0 + part_h // 2]
+
+	return idxs, locs
+
 def visible_part_locs(p):
 	idxs, locs, vis = __expand_parts(p)
 	return idxs[vis], locs[vis].T
 
-def visible_crops(im, p, ratio=DEFAULT_RATIO, padding_mode="edge"):
+
+def crops(im, xy, ratio=DEFAULT_RATIO, padding_mode="edge"):
 	assert im.ndim == 3, "Only RGB images are currently supported!"
-	idxs, locs, vis = __expand_parts(p)
+
 	h, w, c = im.shape
-	crop_h = crop_w = int(np.sqrt(h * w) * ratio)
-	crops = np.zeros((len(idxs), crop_h, crop_w, c), dtype=im.dtype)
+	crop_h, crop_w = int(h * ratio), int(w * ratio)
+	crops = np.zeros((xy.shape[1], crop_h, crop_w, c), dtype=im.dtype)
 
-	padding = np.array([crop_h, crop_w]) // 2
+	pad_h, pad_w = crop_h // 2, crop_w // 2
 
-	padded_im = np.pad(im, [padding, padding, [0,0]], mode=padding_mode)
+	padded_im = np.pad(im, [(pad_h, pad_h), (pad_w, pad_w), [0,0]], mode=padding_mode)
 
-	for i, loc, is_vis in zip(idxs, locs, vis):
-		if not is_vis: continue
-		x0, y0 = loc - crop_h // 2 + padding
+	for i, (x, y) in enumerate(xy.T):
+		x0, y0 = x - crop_w // 2 + pad_w, y - crop_h // 2 + pad_h
 		crops[i] = padded_im[y0:y0+crop_h, x0:x0+crop_w]
 
 	return crops
 
+def visible_crops(im, p, *args, **kw):
+	idxs, locs, vis = __expand_parts(p)
+	parts = crops(im, locs[vis].T, *args, **kw)
+	res = np.zeros((len(idxs),) + parts.shape[1:], dtype=parts.dtype)
+	res[vis] = parts
+	return res
+
 def reveal_parts(im, xy, ratio=DEFAULT_RATIO):
 	h, w, c = im.shape
-	crop_h = crop_w = int(np.sqrt(h * w) * ratio)
-
-	x0y0 = xy - crop_h // 2
+	crop_h, crop_w = int(h * ratio), int(w * ratio)
 
 	res = np.zeros_like(im)
-	for x0, y0 in x0y0.T:
-		x1, y1 = x0 + crop_w, y0 + crop_w
-		x0, y0 = max(x0, 0), max(y0, 0)
+	for x, y in xy.T:
+		x0, y0 = max(x - crop_w // 2, 0), max(y - crop_h // 2, 0)
 		res[y0:y0+crop_h, x0:x0+crop_w] = im[y0:y0+crop_h, x0:x0+crop_w]
 
 	return res
-

+ 25 - 8
nabirds/display.py

@@ -5,8 +5,11 @@ from argparse import ArgumentParser
 import logging
 import numpy as np
 
-from nabirds import Dataset, NAB_Annotations, CUB_Annotations
-from nabirds.dataset import visible_part_locs, visible_crops, reveal_parts
+from annotations import NAB_Annotations, CUB_Annotations
+from dataset import Dataset, reveal_parts, \
+	visible_part_locs, visible_crops, \
+	uniform_part_locs, crops
+
 import matplotlib.pyplot as plt
 
 def init_logger(args):
@@ -37,7 +40,10 @@ def main(args):
 
 		im, parts, label = data[i]
 
-		idxs, xy = visible_part_locs(parts)
+		if args.uniform_parts:
+			idxs, xy = uniform_part_locs(im, ratio=args.ratio)
+		else:
+			idxs, xy = visible_part_locs(parts)
 
 		logging.debug(label)
 		logging.debug(idxs)
@@ -53,14 +59,21 @@ def main(args):
 		ax.scatter(*xy, marker="x", c=idxs)
 
 		fig2 = plt.figure(figsize=(16,9))
-		n_parts = parts.shape[0]
-		rows, cols = (2,6) if args.dataset.lower() == "nab" else (3,5)
-		for j, crop in enumerate(visible_crops(im, parts, ratio=args.ratio), 1):
+
+		if args.uniform_parts:
+			part_crops = crops(im, xy, ratio=args.ratio)
+		else:
+			part_crops = visible_crops(im, parts, ratio=args.ratio)
+
+		n_crops = len(part_crops)
+		rows = int(np.ceil(np.sqrt(n_crops)))
+		cols = int(np.ceil(n_crops / rows))
+		for j, crop in enumerate(part_crops, 1):
 			ax = fig2.add_subplot(rows, cols, j)
 			ax.imshow(crop)
 
-			middle = crop.shape[0] / 2
-			ax.scatter(middle, middle, marker="x")
+			middle_h, middle_w = crop.shape[0] / 2, crop.shape[1] / 2
+			ax.scatter(middle_w, middle_h, marker="x")
 
 		plt.show()
 		plt.close(fig1)
@@ -96,6 +109,10 @@ parser.add_argument("--ratio",
 	help="Part extraction ratio",
 	type=float, default=.2)
 
+parser.add_argument("--uniform_parts", "-u",
+	help="Do not use GT parts, but sample parts uniformly from the image",
+	action="store_true")
+
 parser.add_argument(
 	'--logfile', type=str, default='',
 	help='File for logging output')

+ 5 - 7
setup.py

@@ -7,9 +7,9 @@ import sys
 from setuptools import setup, find_packages
 
 try: # for pip >= 10
-    from pip._internal.req import parse_requirements
+	from pip._internal.req import parse_requirements
 except ImportError: # for pip <= 9.0.3
-    from pip.req import parse_requirements
+	from pip.req import parse_requirements
 
 import nabirds
 install_requires = [line.strip() for line in open("requirements.txt").readlines()]
@@ -20,14 +20,12 @@ setup(
 	description='Wrapper (inofficial) for NA-Birds bataset (http://dl.allaboutbirds.org/nabirds)',
 	author='Dimitri Korsch',
 	author_email='korschdima@gmail.com',
-	# url='https://chainer.org/',
 	license='MIT License',
 	packages=find_packages(),
 	zip_safe=False,
 	setup_requires=[],
 	install_requires=install_requires,
-    package_data={'': ['requirements.txt']},
-    data_files=[('.',['requirements.txt'])],
-    include_package_data=True,
-	# tests_require=['mock', 'nose'],
+	package_data={'': ['requirements.txt']},
+	data_files=[('.',['requirements.txt'])],
+	include_package_data=True,
 )