|
@@ -13,8 +13,7 @@ class SamplingType(enum.Enum):
|
|
|
|
|
|
def __call__(self, dataset,
|
|
|
random_state=None,
|
|
|
- min_count: int = 10,
|
|
|
- max_count: int = 100,
|
|
|
+ count: int = -1,
|
|
|
):
|
|
|
if random_state is None:
|
|
|
rnd = np.random.RandomState()
|
|
@@ -35,25 +34,25 @@ class SamplingType(enum.Enum):
|
|
|
idxs = []
|
|
|
if self == SamplingType.undersample:
|
|
|
# logging.debug("UNDERSAMPLING")
|
|
|
- count = max(max_count, cls_count.min())
|
|
|
+ _count = max(count, cls_count.min())
|
|
|
|
|
|
for cls in np.unique(labs):
|
|
|
mask = cls == labs
|
|
|
cls_idxs = np.where(mask)[0]
|
|
|
- if len(cls_idxs) > count:
|
|
|
- cls_idxs = rnd.choice(cls_idxs, count, replace=False)
|
|
|
+ if len(cls_idxs) > _count:
|
|
|
+ cls_idxs = rnd.choice(cls_idxs, _count, replace=False)
|
|
|
|
|
|
idxs.extend(cls_idxs)
|
|
|
|
|
|
elif self == SamplingType.oversample:
|
|
|
# logging.debug("OVERSAMPLING")
|
|
|
- count = min(min_count, cls_count.max())
|
|
|
+ _count = min(count, cls_count.max())
|
|
|
|
|
|
for cls in np.unique(labs):
|
|
|
mask = cls == labs
|
|
|
cls_idxs = np.where(mask)[0]
|
|
|
- if len(cls_idxs) < count:
|
|
|
- cls_idxs = rnd.choice(cls_idxs, count, replace=True)
|
|
|
+ if len(cls_idxs) < _count:
|
|
|
+ cls_idxs = rnd.choice(cls_idxs, _count, replace=True)
|
|
|
|
|
|
idxs.extend(cls_idxs)
|
|
|
|
|
@@ -64,8 +63,10 @@ class SamplingType(enum.Enum):
|
|
|
class SamplingMixin(IteratorMixin):
|
|
|
|
|
|
def __init__(self, sampling_type: T.Optional[SamplingType] = None,
|
|
|
+ sampling_count: int = -1,
|
|
|
*args, **kwargs):
|
|
|
self._sampling_type = sampling_type
|
|
|
+ self._sampling_count = sampling_count
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
def new_iterator(self, **kwargs):
|
|
@@ -73,8 +74,11 @@ class SamplingMixin(IteratorMixin):
|
|
|
|
|
|
if None not in (it.order_sampler, self._sampling_type):
|
|
|
it.order_sampler = self._sampling_type(self,
|
|
|
+ count=self._sampling_count,
|
|
|
random_state=it.order_sampler._random)
|
|
|
|
|
|
+ logging.info(f"Initialized new sampler: {self._sampling_type}")
|
|
|
+
|
|
|
if hasattr(it, "_initialize_loop"):
|
|
|
it._initialize_loop()
|
|
|
else:
|