浏览代码

refactored for CUB200 annotations

Dimitri Korsch 7 年之前
父节点
当前提交
806c9be667
共有 6 个文件被更改,包括 119 次插入30 次删除
  1. 1 1
      README.md
  2. 42 0
      example_cub.py
  3. 2 2
      example_nab.py
  4. 1 1
      nabirds/__init__.py
  5. 72 25
      nabirds/annotations.py
  6. 1 1
      setup.py

+ 1 - 1
README.md

@@ -2,4 +2,4 @@
 
 NA-Birds dataset:  http://dl.allaboutbirds.org/nabirds
 
-Some example code how to use this library can be found in `example.py`
+Some example code how to use this library can be found in `example_nab.py` or `example_cub.py`

+ 42 - 0
example_cub.py

@@ -0,0 +1,42 @@
+#!/usr/bin/env python
+if __name__ != '__main__': raise Exception("Do not import me!")
+
+from nabirds import Dataset, CUB_Annotations
+from nabirds.dataset import visible_part_locs, visible_crops
+import matplotlib.pyplot as plt
+
+annot = CUB_Annotations(root="/home/korsch1/korsch/datasets/birds/cub200_11")
+
+print(annot.labels.shape)
+data = Dataset(annot.train_uuids, annot)
+
+for i, (im, parts, label) in enumerate(data, 1):
+	if i <= 15: continue
+
+	idxs, (xs, ys) = visible_part_locs(parts)
+
+	print(label)
+	print(idxs)
+
+	fig1 = plt.figure(figsize=(16,9))
+	ax = fig1.add_subplot(111)
+
+	ax.imshow(im)
+	ax.scatter(xs, ys, marker="x", c=idxs)
+
+	fig2 = plt.figure(figsize=(16,9))
+	n_parts = parts.shape[0]
+
+	for j, crop in enumerate(visible_crops(im, parts, .5), 1):
+		ax = fig2.add_subplot(3, 5, j)
+		ax.imshow(crop)
+
+		middle = crop.shape[0] / 2
+		ax.scatter(middle, middle, marker="x")
+
+	plt.show()
+	plt.close(fig1)
+	plt.close(fig2)
+
+	if i >= 20: break
+

+ 2 - 2
example.py → example_nab.py

@@ -1,11 +1,11 @@
 #!/usr/bin/env python
 if __name__ != '__main__': raise Exception("Do not import me!")
 
-from nabirds import Dataset, Annotations
+from nabirds import Dataset, NAB_Annotations
 from nabirds.dataset import visible_part_locs, visible_crops
 import matplotlib.pyplot as plt
 
-annot = Annotations("/home/korsch1/korsch/datasets/birds/nabirds")
+annot = NAB_Annotations("/home/korsch1/korsch/datasets/birds/nabirds")
 
 print(annot.labels.shape)
 data = Dataset(annot.train_uuids, annot)

+ 1 - 1
nabirds/__init__.py

@@ -1,3 +1,3 @@
 from .dataset import Dataset
-from .annotations import Annotations
+from .annotations import NAB_Annotations, CUB_Annotations
 

+ 72 - 25
nabirds/annotations.py

@@ -1,24 +1,20 @@
 from os.path import join, isfile
 import numpy as np
 from collections import defaultdict
+import abc
+import warnings
 
+class _MetaInfo(object):
+	def __init__(self, **kwargs):
+		for name, value in kwargs.items():
+			setattr(self, name, value)
+		self.structure = []
 
-class Annotations(object):
-	class meta:
-		images_file = "images.txt"
-		images_folder = "images"
-		labels_file = "labels.txt"
-		hierarchy_file = "hierarchy.txt"
-		split_file = "train_test_split.txt"
-		parts_file = join("parts", "part_locs.txt")
-
-		structure = [
-			[images_file, "_images"],
-			[labels_file, "labels"],
-			[hierarchy_file, "hierarchy"],
-			[split_file, "_split"],
-			[parts_file, "_part_locs"],
-		]
+class BaseAnnotations(abc.ABC):
+	@property
+	@abc.abstractmethod
+	def meta(self):
+		pass
 
 	def _path(self, file):
 		return join(self.root, file)
@@ -28,17 +24,20 @@ class Annotations(object):
 
 	def read_content(self, file, attr):
 		content = None
-		if isfile(self._path(file)):
+		fpath = self._path(file)
+		if isfile(fpath):
 			with self._open(file) as f:
 				content = [line.strip() for line in f if line.strip()]
+		else:
+			warnings.warn("File \"{}\" was not found!".format(fpath))
 
 		setattr(self, attr, content)
 
 	def __init__(self, root):
-		super(Annotations, self).__init__()
+		super(BaseAnnotations, self).__init__()
 		self.root = root
 
-		for fname, attr in Annotations.meta.structure:
+		for fname, attr in self.meta.structure:
 			self.read_content(fname, attr)
 
 		self.labels = np.array([int(l) for l in self.labels], dtype=np.int32)
@@ -58,8 +57,9 @@ class Annotations(object):
 		# this part is quite slow... TODO: some runtime improvements?
 		uuid_to_parts = defaultdict(list)
 		for content in [i.split() for i in self._part_locs]:
-			uuid_to_parts[content[0]].append([int(i) for i in content[1:]])
-		self.part_locs = np.stack([uuid_to_parts[uuid] for uuid in self.uuids])
+			uuid_to_parts[content[0]].append([float(i) for i in content[1:]])
+
+		self.part_locs = np.stack([uuid_to_parts[uuid] for uuid in self.uuids]).astype(int)
 
 	def _load_split(self):
 		assert self._split is not None, "Train-test split was not loaded!"
@@ -68,7 +68,7 @@ class Annotations(object):
 		self.test_split = np.logical_not(self.train_split)
 
 	def image_path(self, image):
-		return join(self.root, Annotations.meta.images_folder, image)
+		return join(self.root, self.meta.images_folder, image)
 
 	def image(self, uuid):
 		fname = self.images[self.uuid_to_idx[uuid]]
@@ -83,9 +83,6 @@ class Annotations(object):
 
 	def _uuids(self, split):
 		return self.uuids[split]
-		# for i in np.where(split)[0]:
-		# 	uuid = self.image_list[i]
-		# 	yield uuid
 
 	@property
 	def train_uuids(self):
@@ -96,3 +93,53 @@ class Annotations(object):
 		return self._uuids(self.test_split)
 
 
+class NAB_Annotations(BaseAnnotations):
+	@property
+	def meta(self):
+		info = _MetaInfo(
+			images_file="images.txt",
+			images_folder="images",
+			labels_file="labels.txt",
+			hierarchy_file="hierarchy.txt",
+			split_file="train_test_split.txt",
+			parts_file=join("parts", "part_locs.txt"),
+		)
+
+		info.structure = [
+			[info.images_file, "_images"],
+			[info.labels_file, "labels"],
+			[info.hierarchy_file, "hierarchy"],
+			[info.split_file, "_split"],
+			[info.parts_file, "_part_locs"],
+		]
+		return info
+
+class CUB_Annotations(BaseAnnotations):
+	@property
+	def meta(self):
+		info = _MetaInfo(
+			images_file="images.txt",
+			images_folder="images",
+			labels_file="labels.txt",
+			split_file="tr_ID.txt",
+			parts_file=join("parts", "part_locs.txt"),
+		)
+
+		info.structure = [
+			[info.images_file, "_images"],
+			[info.labels_file, "labels"],
+			[info.split_file, "_split"],
+			[info.parts_file, "_part_locs"],
+		]
+		return info
+
+	def _load_split(self):
+		assert self._split is not None, "Train-test split was not loaded!"
+		uuid_to_split = {uuid: int(split) for uuid, split in zip(self.uuids, self._split)}
+		self.train_split = np.array([uuid_to_split[uuid] for uuid in self.uuids], dtype=bool)
+		self.test_split = np.logical_not(self.train_split)
+
+	def _load_parts(self):
+		super(CUB_Annotations, self)._load_parts()
+		# set part idxs from 1-idxs to 0-idxs
+		self.part_locs[..., 0] -= 1

+ 1 - 1
setup.py

@@ -12,7 +12,7 @@ install_requires = [line.strip() for line in open("requirements.txt").readlines(
 setup(
 	name='nabirds',
 	version='0.1.0',
-	description='Wrapper for NA-Birds bataset (http://dl.allaboutbirds.org/nabirds)',
+	description='Wrapper (inofficial) for NA-Birds bataset (http://dl.allaboutbirds.org/nabirds)',
 	author='Dimitri Korsch',
 	author_email='korschdima@gmail.com',
 	# url='https://chainer.org/',