Browse Source

added new "labels" property to the dataset class

Dimitri Korsch 6 years ago
parent
commit
faefeeb0fc
3 changed files with 24 additions and 4 deletions
  1. 1 1
      nabirds/__init__.py
  2. 14 3
      nabirds/dataset/image.py
  3. 9 0
      nabirds/dataset/mixins/reading.py

+ 1 - 1
nabirds/__init__.py

@@ -1,4 +1,4 @@
 from .dataset import Dataset
 from .annotations import NAB_Annotations, CUB_Annotations
 
-__version__ = "0.2.0"
+__version__ = "0.2.1"

+ 14 - 3
nabirds/dataset/image.py

@@ -35,7 +35,8 @@ class ImageWrapper(object):
 
 	def __del__(self):
 		if isinstance(self._im, Image.Image):
-			self._im.close()
+			if self._im is not None and self._im.fp is not None:
+				self._im.close()
 
 	@property
 	def im(self):
@@ -44,8 +45,18 @@ class ImageWrapper(object):
 	@property
 	def im_array(self):
 		if self._im_array is None:
-			self.im = self.im.convert(self.mode)
-			self._im_array = utils.asarray(self.im)
+			if isinstance(self._im, Image.Image):
+				_im = self._im.convert(self.mode)
+				self._im_array = utils.asarray(_im)
+			elif isinstance(self._im, np.ndarray):
+				if self.mode == "RGB" and self._im.ndim == 2:
+					self._im_array = np.stack((self._im,) * 3, axis=-1)
+				elif self._im.ndim == 3:
+					self._im_array = self._im
+				else:
+					raise ValueError()
+			else:
+				raise ValueError()
 		return self._im_array
 
 	@im.setter

+ 9 - 0
nabirds/dataset/mixins/reading.py

@@ -27,6 +27,10 @@ class AnnotationsReadMixin(BaseMixin):
 
 		return ImageWrapper(im_path, int(label), parts, mode=self.mode)
 
+	@property
+	def labels(self):
+		return np.array([self._get("label", i) for i in range(len(self))])
+
 
 class ImageListReadingMixin(BaseMixin):
 
@@ -48,3 +52,8 @@ class ImageListReadingMixin(BaseMixin):
 		im_path = join(self._root, im_file)
 
 		return ImageWrapper(im_path, int(label))
+
+	@property
+	def labels(self):
+		return np.array([label for (_, label) in self._pairs])
+