Dimitri Korsch 5 жил өмнө
parent
commit
dda2afc76b

+ 6 - 1
cvdatasets/annotation/base.py

@@ -187,5 +187,10 @@ class BaseAnnotations(abc.ABC):
 		pass
 
 
-class Annotations(mixins.BBoxMixin, mixins.PartsMixin, mixins.FeaturesMixin, BaseAnnotations):
+class Annotations(
+	mixins.BBoxMixin,
+	mixins.MultiBoxMixin,
+	mixins.PartsMixin,
+	mixins.FeaturesMixin,
+	BaseAnnotations):
 	pass

+ 2 - 0
cvdatasets/annotation/mixins/__init__.py

@@ -1,10 +1,12 @@
 from cvdatasets.annotation.mixins.bbox_mixin import BBoxMixin
 from cvdatasets.annotation.mixins.features_mixin import FeaturesMixin
+from cvdatasets.annotation.mixins.multi_box_mixin import MultiBoxMixin
 from cvdatasets.annotation.mixins.parts_mixin import PartsMixin
 
 __all__ = [
 	"BBoxMixin",
 	"FeaturesMixin",
+	"MultiBoxMixin",
 	"PartsMixin",
 ]
 

+ 49 - 0
cvdatasets/annotation/mixins/multi_box_mixin.py

@@ -0,0 +1,49 @@
+import abc
+import logging
+import numpy as np
+
+from cvdatasets.annotation.files import AnnotationFiles
+
+
+class MultiBoxMixin(abc.ABC):
+
+	def read_annotation_files(self) -> AnnotationFiles:
+		files = super(MultiBoxMixin, self).read_annotation_files()
+
+		files.load_files(
+			multi_boxes=("multi_boxes.json", True),
+		)
+		return files
+
+	@property
+	def has_multi_boxes(self) -> bool:
+		return self.files.multi_boxes is not None
+
+	def parse_annotations(self) -> None:
+		super(MultiBoxMixin, self).parse_annotations()
+		if self.has_multi_boxes:
+			self._parse_multi_boxes()
+
+	def _parse_multi_boxes(self) -> None:
+		logging.debug("Parsing multi-box annotations")
+
+		assert self.has_multi_boxes, \
+			"Multi-boxes were not loaded!"
+
+		self.multi_boxes = {}
+
+		for uuid in self.uuids:
+			idx = self.uuid_to_idx[uuid]
+			im_name = self.image_names[idx]
+			multi_box = self.files.multi_boxes[idx]
+			assert im_name == multi_box["image"], \
+				f"{im_name} != {multi_box['image']}"
+
+			self.multi_boxes[uuid] = multi_box
+
+	def multi_box(self, uuid) -> np.ndarray:
+		if self.has_multi_boxes:
+			return self.multi_boxes[self.uuid_to_idx[uuid]]
+
+		fname = self.image_names[self.uuid_to_idx[uuid]]
+		return dict(image=fname, objects=[dict(x0=0, x1=0, y0=1, y1=1)])