Sfoglia il codice sorgente

renamed a parameter and added n_parts property to dataset class

Dimitri Korsch 6 anni fa
parent
commit
857e8c0818
2 ha cambiato i file con 7 aggiunte e 3 eliminazioni
  1. 3 3
      nabirds/dataset/mixins/parts.py
  2. 4 0
      nabirds/dataset/mixins/reading.py

+ 3 - 3
nabirds/dataset/mixins/parts.py

@@ -84,16 +84,16 @@ class UniformPartMixin(BaseMixin):
 
 class RandomBlackOutMixin(BaseMixin):
 
-	def __init__(self, seed=None, rnd_select=False, n_parts=None, *args, **kwargs):
+	def __init__(self, seed=None, rnd_select=False, blackout_parts=None, *args, **kwargs):
 		super(RandomBlackOutMixin, self).__init__(*args, **kwargs)
 		self.rnd = np.random.RandomState(seed)
 		self.rnd_select = rnd_select
-		self.n_parts = n_parts
+		self.blackout_parts = blackout_parts
 
 	def get_example(self, i):
 		im_obj = super(RandomBlackOutMixin, self).get_example(i)
 		if self.rnd_select:
-			return im_obj.select_random_parts(rnd=self.rnd, n_parts=self.n_parts)
+			return im_obj.select_random_parts(rnd=self.rnd, n_parts=self.blackout_parts)
 		return im_obj
 
 

+ 4 - 0
nabirds/dataset/mixins/reading.py

@@ -30,6 +30,10 @@ class AnnotationsReadMixin(BaseMixin):
 
 		return ImageWrapper(im_path, int(label), parts, mode=self.mode, part_rescale_size=self.part_rescale_size)
 
+	@property
+	def n_parts(self):
+		return self._annot.part_locs.shape[1]
+
 	@property
 	def labels(self):
 		return np.array([self._get("label", i) for i in range(len(self))])