瀏覽代碼

fix about the labels in the CUB200 dataset

Dimitri Korsch 7 年之前
父節點
當前提交
6e99ed7fb9
共有 2 個文件被更改,包括 6 次插入1 次删除
  1. 1 1
      example_cub.py
  2. 5 0
      nabirds/annotations.py

+ 1 - 1
example_cub.py

@@ -27,7 +27,7 @@ for i, (im, parts, label) in enumerate(data, 1):
 	fig2 = plt.figure(figsize=(16,9))
 	n_parts = parts.shape[0]
 
-	for j, crop in enumerate(visible_crops(im, parts, .5), 1):
+	for j, crop in enumerate(visible_crops(im, parts), 1):
 		ax = fig2.add_subplot(3, 5, j)
 		ax.imshow(crop)
 

+ 5 - 0
nabirds/annotations.py

@@ -133,6 +133,11 @@ class CUB_Annotations(BaseAnnotations):
 		]
 		return info
 
+	def __init__(self, *args, **kwargs):
+		super(CUB_Annotations, self).__init__(*args, **kwargs)
+		# set labels from [1..200] to [0..199]
+		self.labels -= 1
+
 	def _load_split(self):
 		assert self._split is not None, "Train-test split was not loaded!"
 		uuid_to_split = {uuid: int(split) for uuid, split in zip(self.uuids, self._split)}