Quellcode durchsuchen

updated handling of the count parameter for over-/undersampling

Dimitri Korsch vor 2 Jahren
Ursprung
Commit
287055cba8
1 geänderte Dateien mit 12 neuen und 8 gelöschten Zeilen
  1. 12 8
      cvdatasets/dataset/mixins/chainer_mixins/sampling_mixin.py

+ 12 - 8
cvdatasets/dataset/mixins/chainer_mixins/sampling_mixin.py

@@ -13,8 +13,7 @@ class SamplingType(enum.Enum):
 
 	def __call__(self, dataset,
 				 random_state=None,
-				 min_count: int = 10,
-				 max_count: int = 100,
+				 count: int = -1,
 				 ):
 		if random_state is None:
 			rnd = np.random.RandomState()
@@ -35,25 +34,25 @@ class SamplingType(enum.Enum):
 			idxs = []
 			if self == SamplingType.undersample:
 				# logging.debug("UNDERSAMPLING")
-				count = max(max_count, cls_count.min())
+				_count = max(count, cls_count.min())
 
 				for cls in np.unique(labs):
 					mask = cls == labs
 					cls_idxs = np.where(mask)[0]
-					if len(cls_idxs) > count:
-						cls_idxs = rnd.choice(cls_idxs, count, replace=False)
+					if len(cls_idxs) > _count:
+						cls_idxs = rnd.choice(cls_idxs, _count, replace=False)
 
 					idxs.extend(cls_idxs)
 
 			elif self == SamplingType.oversample:
 				# logging.debug("OVERSAMPLING")
-				count = min(min_count, cls_count.max())
+				_count = min(count, cls_count.max())
 
 				for cls in np.unique(labs):
 					mask = cls == labs
 					cls_idxs = np.where(mask)[0]
-					if len(cls_idxs) < count:
-						cls_idxs = rnd.choice(cls_idxs, count, replace=True)
+					if len(cls_idxs) < _count:
+						cls_idxs = rnd.choice(cls_idxs, _count, replace=True)
 
 					idxs.extend(cls_idxs)
 
@@ -64,8 +63,10 @@ class SamplingType(enum.Enum):
 class SamplingMixin(IteratorMixin):
 
 	def __init__(self, sampling_type: T.Optional[SamplingType] = None,
+		         sampling_count: int = -1,
 		         *args, **kwargs):
 		self._sampling_type = sampling_type
+		self._sampling_count = sampling_count
 		super().__init__(*args, **kwargs)
 
 	def new_iterator(self, **kwargs):
@@ -73,8 +74,11 @@ class SamplingMixin(IteratorMixin):
 
 		if None not in (it.order_sampler, self._sampling_type):
 			it.order_sampler = self._sampling_type(self,
+				count=self._sampling_count,
 				random_state=it.order_sampler._random)
 
+			logging.info(f"Initialized new sampler: {self._sampling_type}")
+
 			if hasattr(it, "_initialize_loop"):
 				it._initialize_loop()
 			else: