Ver código fonte

first version of the NA birds wrapper

Dimitri Korsch 7 anos atrás
commit
83808eb9eb
7 arquivos alterados com 327 adições e 0 exclusões
  1. 99 0
      .gitignore
  2. 42 0
      example.py
  3. 3 0
      nabirds/__init__.py
  4. 98 0
      nabirds/annotations.py
  5. 53 0
      nabirds/dataset.py
  6. 4 0
      requirements.txt
  7. 28 0
      setup.py

+ 99 - 0
.gitignore

@@ -0,0 +1,99 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+
+# PyInstaller
+#  Usually these files are written by a python script from a template
+#  before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# pyenv
+.python-version
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/

+ 42 - 0
example.py

@@ -0,0 +1,42 @@
+#!/usr/bin/env python
+if __name__ != '__main__': raise Exception("Do not import me!")
+
+from nabirds import Dataset, Annotations
+from nabirds.dataset import visible_part_locs, visible_crops
+import matplotlib.pyplot as plt
+
+annot = Annotations("/home/korsch1/korsch/datasets/birds/nabirds")
+
+print(annot.labels.shape)
+data = Dataset(annot.train_uuids, annot)
+
+for i, (im, parts, label) in enumerate(data, 1):
+	if i <= 15: continue
+
+	idxs, (xs, ys) = visible_part_locs(parts)
+
+	print(label)
+	print(idxs)
+
+	fig1 = plt.figure(figsize=(16,9))
+	ax = fig1.add_subplot(111)
+
+	ax.imshow(im)
+	ax.scatter(xs, ys, marker="x", c=idxs)
+
+	fig2 = plt.figure(figsize=(16,9))
+	n_parts = parts.shape[0]
+
+	for j, crop in enumerate(visible_crops(im, parts, .5), 1):
+		ax = fig2.add_subplot(2, 6, j)
+		ax.imshow(crop)
+
+		middle = crop.shape[0] / 2
+		ax.scatter(middle, middle, marker="x")
+
+	plt.show()
+	plt.close(fig1)
+	plt.close(fig2)
+
+	if i >= 20: break
+

+ 3 - 0
nabirds/__init__.py

@@ -0,0 +1,3 @@
+from .dataset import Dataset
+from .annotations import Annotations
+

+ 98 - 0
nabirds/annotations.py

@@ -0,0 +1,98 @@
+from os.path import join, isfile
+import numpy as np
+from collections import defaultdict
+
+
+class Annotations(object):
+	class meta:
+		images_file = "images.txt"
+		images_folder = "images"
+		labels_file = "labels.txt"
+		hierarchy_file = "hierarchy.txt"
+		split_file = "train_test_split.txt"
+		parts_file = join("parts", "part_locs.txt")
+
+		structure = [
+			[images_file, "_images"],
+			[labels_file, "labels"],
+			[hierarchy_file, "hierarchy"],
+			[split_file, "_split"],
+			[parts_file, "_part_locs"],
+		]
+
+	def _path(self, file):
+		return join(self.root, file)
+
+	def _open(self, file):
+		return open(self._path(file))
+
+	def read_content(self, file, attr):
+		content = None
+		if isfile(self._path(file)):
+			with self._open(file) as f:
+				content = [line.strip() for line in f if line.strip()]
+
+		setattr(self, attr, content)
+
+	def __init__(self, root):
+		super(Annotations, self).__init__()
+		self.root = root
+
+		for fname, attr in Annotations.meta.structure:
+			self.read_content(fname, attr)
+
+		self.labels = np.array([int(l) for l in self.labels], dtype=np.int32)
+
+		self._load_uuids()
+		self._load_parts()
+		self._load_split()
+
+	def _load_uuids(self):
+		assert self._images is not None, "Images were not loaded!"
+		uuid_fnames = [i.split() for i in self._images]
+		self.uuids, self.images = map(np.array, zip(*uuid_fnames))
+		self.uuid_to_idx = {uuid: i for i, uuid in enumerate(self.uuids)}
+
+	def _load_parts(self):
+		assert self._part_locs is not None, "Part locations were not loaded!"
+		# this part is quite slow... TODO: some runtime improvements?
+		uuid_to_parts = defaultdict(list)
+		for content in [i.split() for i in self._part_locs]:
+			uuid_to_parts[content[0]].append([int(i) for i in content[1:]])
+		self.part_locs = np.stack([uuid_to_parts[uuid] for uuid in self.uuids])
+
+	def _load_split(self):
+		assert self._split is not None, "Train-test split was not loaded!"
+		uuid_to_split = {uuid: int(split) for uuid, split in [i.split() for i in self._split]}
+		self.train_split = np.array([uuid_to_split[uuid] for uuid in self.uuids], dtype=bool)
+		self.test_split = np.logical_not(self.train_split)
+
+	def image_path(self, image):
+		return join(self.root, Annotations.meta.images_folder, image)
+
+	def image(self, uuid):
+		fname = self.images[self.uuid_to_idx[uuid]]
+		return self.image_path(fname)
+
+	def label(self, uuid):
+		return self.labels[self.uuid_to_idx[uuid]]
+
+	def parts(self, uuid):
+		return self.part_locs[self.uuid_to_idx[uuid]]
+
+
+	def _uuids(self, split):
+		return self.uuids[split]
+		# for i in np.where(split)[0]:
+		# 	uuid = self.image_list[i]
+		# 	yield uuid
+
+	@property
+	def train_uuids(self):
+		return self._uuids(self.train_split)
+
+	@property
+	def test_uuids(self):
+		return self._uuids(self.test_split)
+
+

+ 53 - 0
nabirds/dataset.py

@@ -0,0 +1,53 @@
+from imageio import imread
+import numpy as np
+
+
+class Dataset(object):
+	def __init__(self, uuids, annotations):
+		super(Dataset, self).__init__()
+		self.uuids = uuids
+		self._annot = annotations
+
+	def __len__(self):
+		return len(self.uuids)
+
+	def _get(self, method, i):
+		return getattr(self._annot, method)(self.uuids[i])
+
+
+
+	def get_example(self, i, mode="RGB"):
+		methods = ["image", "parts", "label"]
+		im_path, parts, label = [self._get(m, i) for m in methods]
+		return imread(im_path, pilmode=mode), parts, label
+
+	__getitem__  = get_example
+
+
+# some convention functions
+
+def __expand_parts(p):
+	return p[:, 0], p[:, 1:3], p[:, 3].astype(bool)
+
+def visible_part_locs(p):
+	idxs, locs, vis = __expand_parts(p)
+	return idxs[vis], locs[vis].T
+
+def visible_crops(im, p, ratio=np.sqrt(49 / 400), padding_mode="edge"):
+	assert im.ndim == 3, "Only RGB images are currently supported!"
+	idxs, locs, vis = __expand_parts(p)
+	h, w, c = im.shape
+	crop_h = crop_w = int(np.sqrt(h*w) * ratio)
+	crops = np.zeros((len(idxs), crop_h, crop_w, c), dtype=im.dtype)
+
+	padding = np.array([crop_h, crop_w]) // 2
+
+	padded_im = np.pad(im, [padding, padding, [0,0]], mode=padding_mode)
+
+	for i, loc, is_vis in zip(*__expand_parts(p)):
+		if not is_vis: continue
+		x0, y0 = loc - crop_h // 2 + padding
+		crops[i] = padded_im[y0:y0+crop_h, x0:x0+crop_w]
+
+	return crops
+

+ 4 - 0
requirements.txt

@@ -0,0 +1,4 @@
+imageio
+numpy
+pillow
+matplotlib

+ 28 - 0
setup.py

@@ -0,0 +1,28 @@
+#!/usr/bin/env python
+
+import os
+import pkg_resources
+import sys
+
+from setuptools import setup, find_packages
+from pip.req import parse_requirements
+
+install_requires = [line.strip() for line in open("requirements.txt").readlines()]
+
+setup(
+	name='nabirds',
+	version='0.1.0',
+	description='Wrapper for NA-Birds bataset (http://dl.allaboutbirds.org/nabirds)',
+	author='Dimitri Korsch',
+	author_email='korschdima@gmail.com',
+	# url='https://chainer.org/',
+	license='MIT License',
+	packages=find_packages(),
+	zip_safe=False,
+	setup_requires=[],
+	install_requires=install_requires,
+    package_data={'': ['requirements.txt']},
+    data_files=[('.',['requirements.txt'])],
+    include_package_data=True,
+	# tests_require=['mock', 'nose'],
+)