Browse Source

created base part mixin to handle ratio parameter

Dimitri Korsch 4 years ago
parent
commit
9347a32476
2 changed files with 21 additions and 14 deletions
  1. 5 3
      cvdatasets/dataset/__init__.py
  2. 16 11
      cvdatasets/dataset/mixins/parts.py

+ 5 - 3
cvdatasets/dataset/__init__.py

@@ -4,9 +4,10 @@ from cvdatasets.dataset.mixins.bounding_box import BBoxMixin
 from cvdatasets.dataset.mixins.bounding_box import MultiBoxMixin
 from cvdatasets.dataset.mixins.chainer_mixins import IteratorMixin
 from cvdatasets.dataset.mixins.features import PreExtractedFeaturesMixin
+from cvdatasets.dataset.mixins.parts import BasePartMixin
 from cvdatasets.dataset.mixins.parts import CroppedPartMixin
 from cvdatasets.dataset.mixins.parts import PartCropMixin
-from cvdatasets.dataset.mixins.parts import PartMixin
+from cvdatasets.dataset.mixins.parts import _PartMixin
 from cvdatasets.dataset.mixins.parts import PartRevealMixin
 from cvdatasets.dataset.mixins.parts import PartsInBBMixin
 from cvdatasets.dataset.mixins.parts import RandomBlackOutMixin
@@ -17,7 +18,7 @@ from cvdatasets.dataset.mixins.reading import ImageListReadingMixin
 from cvdatasets.dataset.mixins.transform import TransformMixin
 
 
-class ImageWrapperDataset(PartMixin, PreExtractedFeaturesMixin, AnnotationsReadMixin, IteratorMixin):
+class ImageWrapperDataset(_PartMixin, PreExtractedFeaturesMixin, AnnotationsReadMixin, IteratorMixin):
 	pass
 
 class Dataset(ImageWrapperDataset):
@@ -45,9 +46,10 @@ __all__ = [
 	"MultiBoxMixin",
 
 	# parts
+	"BasePartMixin",
 	"CroppedPartMixin",
 	"PartCropMixin",
-	"PartMixin",
+	"_PartMixin",
 	"PartRevealMixin",
 	"PartsInBBMixin",
 	"RandomBlackOutMixin",

+ 16 - 11
cvdatasets/dataset/mixins/parts.py

@@ -4,7 +4,14 @@ from cvdatasets.dataset.mixins.base import BaseMixin
 from cvdatasets.dataset.mixins.bounding_box import BBoxMixin
 from cvdatasets.dataset.mixins.bounding_box import BBCropMixin
 
-class PartsInBBMixin(BBoxMixin):
+class BasePartMixin(BaseMixin):
+
+	def __init__(self, ratio=None, *args, **kwargs):
+		super(BasePartMixin, self).__init__(*args, **kwargs)
+		self.ratio = ratio
+
+class PartsInBBMixin(BasePartMixin, BBoxMixin):
+
 	def __init__(self, parts_in_bb=False, *args, **kwargs):
 		super(PartsInBBMixin, self).__init__(*args, **kwargs)
 		self.parts_in_bb = parts_in_bb
@@ -17,7 +24,7 @@ class PartsInBBMixin(BBoxMixin):
 			return im_obj.hide_parts_outside_bb(*bb)
 		return im_obj
 
-class PartCropMixin(BaseMixin):
+class PartCropMixin(BasePartMixin):
 
 	def __init__(self, return_part_crops=False, *args, **kwargs):
 		super(PartCropMixin, self).__init__(*args, **kwargs)
@@ -30,7 +37,7 @@ class PartCropMixin(BaseMixin):
 		return im_obj
 
 
-class PartRevealMixin(BaseMixin):
+class PartRevealMixin(BasePartMixin):
 
 	def __init__(self, reveal_visible=False, *args, **kwargs):
 		super(PartRevealMixin, self).__init__(*args, **kwargs)
@@ -38,18 +45,16 @@ class PartRevealMixin(BaseMixin):
 
 	def get_example(self, i):
 		im_obj = super(PartRevealMixin, self).get_example(i)
-		assert hasattr(self, "ratio"), "\"ratio\" attribute is missing!"
 		if self.reveal_visible:
 			return im_obj.reveal_visible(self.ratio)
 		return im_obj
 
 
-class UniformPartMixin(BaseMixin):
+class UniformPartMixin(BasePartMixin):
 
-	def __init__(self, uniform_parts=False, ratio=None, *args, **kwargs):
+	def __init__(self, uniform_parts=False, *args, **kwargs):
 		super(UniformPartMixin, self).__init__(*args, **kwargs)
 		self.uniform_parts = uniform_parts
-		self.ratio = ratio
 
 	def get_example(self, i):
 		im_obj = super(UniformPartMixin, self).get_example(i)
@@ -57,7 +62,7 @@ class UniformPartMixin(BaseMixin):
 			return im_obj.uniform_parts(self.ratio)
 		return im_obj
 
-class RandomBlackOutMixin(BaseMixin):
+class RandomBlackOutMixin(BasePartMixin):
 
 	def __init__(self, seed=None, rnd_select=False, blackout_parts=None, *args, **kwargs):
 		super(RandomBlackOutMixin, self).__init__(*args, **kwargs)
@@ -74,18 +79,18 @@ class RandomBlackOutMixin(BaseMixin):
 
 # some shortcuts
 
-class PartMixin(RandomBlackOutMixin, PartsInBBMixin, UniformPartMixin, BBCropMixin):
+class _PartMixin(RandomBlackOutMixin, PartsInBBMixin, UniformPartMixin, BBCropMixin):
 	"""
 		TODO!
 	"""
 
-class RevealedPartMixin(PartRevealMixin, PartMixin):
+class RevealedPartMixin(PartRevealMixin, _PartMixin):
 	"""
 		TODO!
 	"""
 
 
-class CroppedPartMixin(PartCropMixin, PartMixin):
+class CroppedPartMixin(PartCropMixin, _PartMixin):
 	"""
 		TODO!
 	"""