Răsfoiți Sursa

added hierarchy logic implementation

Dimitri Korsch 2 ani în urmă
părinte
comite
377986160d

+ 4 - 2
cvdatasets/__init__.py

@@ -3,6 +3,7 @@ from cvdatasets.annotation.base import BaseAnnotations
 from cvdatasets.annotation.files import AnnotationFiles
 from cvdatasets.annotation.mixins.bbox_mixin import BBoxMixin
 from cvdatasets.annotation.mixins.features_mixin import FeaturesMixin
+from cvdatasets.annotation.mixins.hierarchy_mixin import Hierarchy
 from cvdatasets.annotation.mixins.parts_mixin import PartsMixin
 from cvdatasets.annotation.types import AnnotationArgs
 from cvdatasets.annotation.types import AnnotationType
@@ -15,9 +16,9 @@ from cvdatasets.utils import _MetaInfo
 
 __all__ = [
 	"_MetaInfo",
-	"Annotations",
-	"AnnotationFiles",
 	"AnnotationArgs",
+	"AnnotationFiles",
+	"Annotations",
 	"AnnotationType",
 	"BaseAnnotations",
 	"BBoxMixin",
@@ -25,6 +26,7 @@ __all__ = [
 	"FileListAnnotations",
 	"FolderAnnotations",
 	"FolderAnnotations",
+	"Hierarchy",
 	"ImageWrapperDataset",
 	"JSONAnnotations",
 	"PartsMixin",

+ 9 - 0
cvdatasets/annotation/base.py

@@ -116,6 +116,14 @@ class BaseAnnotations(abc.ABC):
 	@labels.setter
 	def labels(self, labels):
 		self._orig_labels, self._labels = np.unique(labels, return_inverse=True)
+		self._orig2unq_labels = {lab:i for i, lab in enumerate(self._orig_labels)}
+		self._unq2orig_labels = {i:lab for i, lab in enumerate(self._orig_labels)}
+
+	def unq2orig(self, label):
+		return self._unq2orig_labels[int(label)]
+
+	def orig2unq(self, label):
+		return self._orig2unq_labels[int(label)]
 
 	def image_path(self, image) -> str:
 		return str(self.root / self.images_folder / image)
@@ -200,5 +208,6 @@ class Annotations(
 	mixins.MultiBoxMixin,
 	mixins.PartsMixin,
 	mixins.FeaturesMixin,
+	mixins.HierarchyMixin,
 	BaseAnnotations):
 	pass

+ 2 - 1
cvdatasets/annotation/mixins/__init__.py

@@ -1,12 +1,13 @@
 from cvdatasets.annotation.mixins.bbox_mixin import BBoxMixin
 from cvdatasets.annotation.mixins.features_mixin import FeaturesMixin
+from cvdatasets.annotation.mixins.hierarchy_mixin import HierarchyMixin
 from cvdatasets.annotation.mixins.multi_box_mixin import MultiBoxMixin
 from cvdatasets.annotation.mixins.parts_mixin import PartsMixin
 
 __all__ = [
 	"BBoxMixin",
 	"FeaturesMixin",
+	"HierarchyMixin",
 	"MultiBoxMixin",
-	"PartsMixin",
 ]
 

+ 132 - 0
cvdatasets/annotation/mixins/hierarchy_mixin.py

@@ -0,0 +1,132 @@
+import abc
+import networkx as nx
+import numpy as np
+import typing as T
+
+from collections import defaultdict
+from cvdatasets.annotation.files import AnnotationFiles
+
+# class Relations:
+# 	def __init__(self, tuples):
+# 		self.right_for_left = defaultdict(list)
+# 		self.left_for_right = defaultdict(list)
+# 		for (left, right) in tuples:
+
+# 			self.right_for_left[left].append(right)
+# 			self.left_for_right[right].append(left)
+
+# 	def get_left_for(self, right):
+# 		if right in self.left_for_right.keys():
+# 			return self.left_for_right[right]
+# 		else:
+# 			return set()
+
+# 	def get_right_for(self, left):
+# 		if left in self.right_for_left.keys():
+# 			return self.right_for_left[left]
+# 		else:
+# 			return set()
+
+### Code is inspired by https://github.com/cabrust/chia
+
+class Hierarchy:
+	_force_prediction_targets = False
+
+	def __init__(self, tuples: T.Tuple[int],
+		label_transform: T.Optional[T.Callable] = None):
+
+		self._label_transform = label_transform
+		self.graph = nx.DiGraph()
+
+		for child, parent in tuples:
+			self.graph.add_edge(parent, child)
+
+
+		topo_sorted_orig_labels = list(nx.topological_sort(self.graph))
+
+		self.orig_lab_to_dimension = {
+			lab: dimension
+				for dimension, lab in enumerate(topo_sorted_orig_labels)
+		}
+
+
+		# relations = Relations(tuples)
+
+		# used_pairs = set()
+
+		# for left in self._orig_labels:
+		# 	labs_to_right = relations.get_right_for(left)
+
+		# 	used_pairs |= {
+		# 		(left, right) for right in labs_to_right
+		# 	}
+
+	def label_transform(self, label):
+		func = self._label_transform
+		if func is None or not callable(func):
+			return label
+		return func(label)
+
+	@property
+	def n_concepts(self):
+		""" Returns number of concepts. In this context, a concept is
+			an element from a set of hierarchical labels.
+		"""
+		return len(self.orig_lab_to_dimension)
+
+	def embed_labels(self, labels: np.ndarray, *, xp=np, dtype=np.int32) -> np.ndarray:
+		embedding = xp.zeros((len(labels), self.n_concepts), dtype=dtype)
+
+		for i, label in enumerate(labels):
+			if isinstance(label, str) and label == "chia::UNCERTAIN":
+				# "enable" all labels of this sample
+				embedding[i] = 1.0
+				continue
+			label = self.label_transform(label)
+
+			embedding[i, self.orig_lab_to_dimension[label]] = 1.0
+			for ancestor in nx.ancestors(self.graph, label):
+				embedding[i, self.orig_lab_to_dimension[ancestor]] = 1.0
+
+		return embedding
+
+	def loss_mask(self, labels: np.ndarray, *, xp=np, dtype=bool) -> np.ndarray:
+
+		mask = xp.zeros((len(labels), self.n_concepts), dtype=bool)
+		for i, label in enumerate(labels):
+			label = self.label_transform(label)
+
+			mask[i, self.orig_lab_to_dimension[label]] = True
+
+			for ancestor in nx.ancestors(self.graph, label):
+				mask[i, self.orig_lab_to_dimension[ancestor]] = True
+				for successor in self.graph.successors(ancestor):
+					mask[i, self.orig_lab_to_dimension[successor]] = True
+					# This should also cover the node itself, but we do it anyway
+
+			if not self._force_prediction_targets:
+				# Learn direct successors in order to "stop"
+				# prediction at these nodes.
+				# If MLNP is active, then this can be ignored.
+				# Because we never want to predict
+				# inner nodes, we interpret labels at
+				# inner nodes as imprecise labels.
+				for successor in self.graph.successors(label):
+					mask[i, self.orig_lab_to_dimension[successor]] = True
+		return mask
+
+
+class HierarchyMixin(abc.ABC):
+
+	def parse_annotations(self):
+		super().parse_annotations()
+		self._parse_hierarchy()
+
+	def _parse_hierarchy(self):
+		if self.files.hierarchy is None:
+			return
+
+		tuples = [entry.split(" ") for entry in self.files.hierarchy]
+		tuples = [(int(child), int(parent)) for child, parent in tuples]
+
+		self.hierarchy = Hierarchy(tuples, self.unq2orig)

+ 6 - 1
cvdatasets/annotation/types/file_list.py

@@ -16,7 +16,12 @@ class FileListAnnotations(Annotations):
 		super(FileListAnnotations, self).__init__(*args, **kwargs)
 
 	def load_files(self, file_obj) -> AnnotationFiles:
-		file_obj.load_files("images.txt", "labels.txt", "tr_ID.txt")
+		file_obj.load_files(
+			"images.txt",
+			"labels.txt",
+			"tr_ID.txt",
+			("hierarchy.txt", True),
+		)
 		return file_obj
 
 	def _parse_uuids(self) -> None: