Przeglądaj źródła

added dataset mixing for under-/oversampling

Dimitri Korsch 2 lat temu
rodzic
commit
e5e1a85f53

+ 7 - 0
cvdatasets/dataset/__init__.py

@@ -3,6 +3,8 @@ from cvdatasets.dataset.mixins.bounding_box import BBCropMixin
 from cvdatasets.dataset.mixins.bounding_box import BBoxMixin
 from cvdatasets.dataset.mixins.bounding_box import MultiBoxMixin
 from cvdatasets.dataset.mixins.chainer_mixins import IteratorMixin
+from cvdatasets.dataset.mixins.chainer_mixins import SamplingMixin
+from cvdatasets.dataset.mixins.chainer_mixins import SamplingType
 from cvdatasets.dataset.mixins.features import PreExtractedFeaturesMixin
 from cvdatasets.dataset.mixins.image_profiler import ImageProfilerMixin
 from cvdatasets.dataset.mixins.parts import BasePartMixin
@@ -62,4 +64,9 @@ __all__ = [
 
 	# transform mixin
 	"TransformMixin",
+
+	# chainer-assuming mixins / types
+	"IteratorMixin",
+	"SamplingMixin",
+	"SamplingType",
 ]

+ 3 - 1
cvdatasets/dataset/mixins/chainer_mixins/__init__.py

@@ -1 +1,3 @@
-from .iterator_mixin import IteratorMixin
+from cvdatasets.dataset.mixins.chainer_mixins.iterator_mixin import IteratorMixin
+from cvdatasets.dataset.mixins.chainer_mixins.sampling_mixin import SamplingMixin
+from cvdatasets.dataset.mixins.chainer_mixins.sampling_mixin import SamplingType

+ 1 - 1
cvdatasets/dataset/mixins/chainer_mixins/iterator_mixin.py

@@ -1,7 +1,7 @@
 import numpy as np
 import logging
 
-from .base import BaseChainerMixin
+from cvdatasets.dataset.mixins.chainer_mixins.base import BaseChainerMixin
 from cvdatasets.utils import new_iterator
 
 class IteratorMixin(BaseChainerMixin):

+ 83 - 0
cvdatasets/dataset/mixins/chainer_mixins/sampling_mixin.py

@@ -0,0 +1,83 @@
+import enum
+import logging
+import numpy as np
+import typing as T
+
+from cvdatasets.dataset.mixins.chainer_mixins.iterator_mixin import IteratorMixin
+
+
+class SamplingType(enum.Enum):
+
+	undersample = enum.auto()
+	oversample = enum.auto()
+
+	def __call__(self, dataset,
+				 random_state=None,
+				 min_count: int = 10,
+				 max_count: int = 100,
+				 ):
+		if random_state is None:
+			rnd = np.random.RandomState()
+
+		elif isinstance(random_state, int):
+			rnd = np.random.RandomState(random_state)
+
+		else:
+			rnd = random_state
+
+		labs = dataset.labels
+		cls_count = np.bincount(labs)
+
+		def sampler(current_order, current_position):
+
+			labs_now = dataset.labels
+
+			idxs = []
+			if self == SamplingType.undersample:
+				# logging.debug("UNDERSAMPLING")
+				count = max(max_count, cls_count.min())
+
+				for cls in np.unique(labs):
+					mask = cls == labs
+					cls_idxs = np.where(mask)[0]
+					if len(cls_idxs) > count:
+						cls_idxs = rnd.choice(cls_idxs, count, replace=False)
+
+					idxs.extend(cls_idxs)
+
+			elif self == SamplingType.oversample:
+				# logging.debug("OVERSAMPLING")
+				count = min(min_count, cls_count.max())
+
+				for cls in np.unique(labs):
+					mask = cls == labs
+					cls_idxs = np.where(mask)[0]
+					if len(cls_idxs) < count:
+						cls_idxs = rnd.choice(cls_idxs, count, replace=True)
+
+					idxs.extend(cls_idxs)
+
+			return rnd.permutation(idxs)
+
+		return sampler
+
+class SamplingMixin(IteratorMixin):
+
+	def __init__(self, sampling_type: T.Optional[SamplingType] = None,
+		         *args, **kwargs):
+		self._sampling_type = sampling_type
+		super().__init__(*args, **kwargs)
+
+	def new_iterator(self, **kwargs):
+		it, n_batches = super().new_iterator(**kwargs)
+
+		if None not in (it.order_sampler, self._sampling_type):
+			it.order_sampler = self._sampling_type(self,
+				random_state=it.order_sampler._random)
+
+			if hasattr(it, "_initialize_loop"):
+				it._initialize_loop()
+			else:
+				it.reset()
+
+		return it, n_batches