Quellcode durchsuchen

fix about the labels in the CUB200 dataset

Dimitri Korsch vor 7 Jahren
Ursprung
Commit
6e99ed7fb9
2 geänderte Dateien mit 6 neuen und 1 gelöschten Zeilen
  1. 1 1
      example_cub.py
  2. 5 0
      nabirds/annotations.py

+ 1 - 1
example_cub.py

@@ -27,7 +27,7 @@ for i, (im, parts, label) in enumerate(data, 1):
 	fig2 = plt.figure(figsize=(16,9))
 	n_parts = parts.shape[0]
 
-	for j, crop in enumerate(visible_crops(im, parts, .5), 1):
+	for j, crop in enumerate(visible_crops(im, parts), 1):
 		ax = fig2.add_subplot(3, 5, j)
 		ax.imshow(crop)
 

+ 5 - 0
nabirds/annotations.py

@@ -133,6 +133,11 @@ class CUB_Annotations(BaseAnnotations):
 		]
 		return info
 
+	def __init__(self, *args, **kwargs):
+		super(CUB_Annotations, self).__init__(*args, **kwargs)
+		# set labels from [1..200] to [0..199]
+		self.labels -= 1
+
 	def _load_split(self):
 		assert self._split is not None, "Train-test split was not loaded!"
 		uuid_to_split = {uuid: int(split) for uuid, split in zip(self.uuids, self._split)}