|
@@ -0,0 +1,132 @@
|
|
|
|
+import abc
|
|
|
|
+import networkx as nx
|
|
|
|
+import numpy as np
|
|
|
|
+import typing as T
|
|
|
|
+
|
|
|
|
+from collections import defaultdict
|
|
|
|
+from cvdatasets.annotation.files import AnnotationFiles
|
|
|
|
+
|
|
|
|
+# class Relations:
|
|
|
|
+# def __init__(self, tuples):
|
|
|
|
+# self.right_for_left = defaultdict(list)
|
|
|
|
+# self.left_for_right = defaultdict(list)
|
|
|
|
+# for (left, right) in tuples:
|
|
|
|
+
|
|
|
|
+# self.right_for_left[left].append(right)
|
|
|
|
+# self.left_for_right[right].append(left)
|
|
|
|
+
|
|
|
|
+# def get_left_for(self, right):
|
|
|
|
+# if right in self.left_for_right.keys():
|
|
|
|
+# return self.left_for_right[right]
|
|
|
|
+# else:
|
|
|
|
+# return set()
|
|
|
|
+
|
|
|
|
+# def get_right_for(self, left):
|
|
|
|
+# if left in self.right_for_left.keys():
|
|
|
|
+# return self.right_for_left[left]
|
|
|
|
+# else:
|
|
|
|
+# return set()
|
|
|
|
+
|
|
|
|
+### Code is inspired by https://github.com/cabrust/chia
|
|
|
|
+
|
|
|
|
+class Hierarchy:
|
|
|
|
+ _force_prediction_targets = False
|
|
|
|
+
|
|
|
|
+ def __init__(self, tuples: T.Tuple[int],
|
|
|
|
+ label_transform: T.Optional[T.Callable] = None):
|
|
|
|
+
|
|
|
|
+ self._label_transform = label_transform
|
|
|
|
+ self.graph = nx.DiGraph()
|
|
|
|
+
|
|
|
|
+ for child, parent in tuples:
|
|
|
|
+ self.graph.add_edge(parent, child)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+ topo_sorted_orig_labels = list(nx.topological_sort(self.graph))
|
|
|
|
+
|
|
|
|
+ self.orig_lab_to_dimension = {
|
|
|
|
+ lab: dimension
|
|
|
|
+ for dimension, lab in enumerate(topo_sorted_orig_labels)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+ # relations = Relations(tuples)
|
|
|
|
+
|
|
|
|
+ # used_pairs = set()
|
|
|
|
+
|
|
|
|
+ # for left in self._orig_labels:
|
|
|
|
+ # labs_to_right = relations.get_right_for(left)
|
|
|
|
+
|
|
|
|
+ # used_pairs |= {
|
|
|
|
+ # (left, right) for right in labs_to_right
|
|
|
|
+ # }
|
|
|
|
+
|
|
|
|
+ def label_transform(self, label):
|
|
|
|
+ func = self._label_transform
|
|
|
|
+ if func is None or not callable(func):
|
|
|
|
+ return label
|
|
|
|
+ return func(label)
|
|
|
|
+
|
|
|
|
+ @property
|
|
|
|
+ def n_concepts(self):
|
|
|
|
+ """ Returns number of concepts. In this context, a concept is
|
|
|
|
+ an element from a set of hierarchical labels.
|
|
|
|
+ """
|
|
|
|
+ return len(self.orig_lab_to_dimension)
|
|
|
|
+
|
|
|
|
+ def embed_labels(self, labels: np.ndarray, *, xp=np, dtype=np.int32) -> np.ndarray:
|
|
|
|
+ embedding = xp.zeros((len(labels), self.n_concepts), dtype=dtype)
|
|
|
|
+
|
|
|
|
+ for i, label in enumerate(labels):
|
|
|
|
+ if isinstance(label, str) and label == "chia::UNCERTAIN":
|
|
|
|
+ # "enable" all labels of this sample
|
|
|
|
+ embedding[i] = 1.0
|
|
|
|
+ continue
|
|
|
|
+ label = self.label_transform(label)
|
|
|
|
+
|
|
|
|
+ embedding[i, self.orig_lab_to_dimension[label]] = 1.0
|
|
|
|
+ for ancestor in nx.ancestors(self.graph, label):
|
|
|
|
+ embedding[i, self.orig_lab_to_dimension[ancestor]] = 1.0
|
|
|
|
+
|
|
|
|
+ return embedding
|
|
|
|
+
|
|
|
|
+ def loss_mask(self, labels: np.ndarray, *, xp=np, dtype=bool) -> np.ndarray:
|
|
|
|
+
|
|
|
|
+ mask = xp.zeros((len(labels), self.n_concepts), dtype=bool)
|
|
|
|
+ for i, label in enumerate(labels):
|
|
|
|
+ label = self.label_transform(label)
|
|
|
|
+
|
|
|
|
+ mask[i, self.orig_lab_to_dimension[label]] = True
|
|
|
|
+
|
|
|
|
+ for ancestor in nx.ancestors(self.graph, label):
|
|
|
|
+ mask[i, self.orig_lab_to_dimension[ancestor]] = True
|
|
|
|
+ for successor in self.graph.successors(ancestor):
|
|
|
|
+ mask[i, self.orig_lab_to_dimension[successor]] = True
|
|
|
|
+ # This should also cover the node itself, but we do it anyway
|
|
|
|
+
|
|
|
|
+ if not self._force_prediction_targets:
|
|
|
|
+ # Learn direct successors in order to "stop"
|
|
|
|
+ # prediction at these nodes.
|
|
|
|
+ # If MLNP is active, then this can be ignored.
|
|
|
|
+ # Because we never want to predict
|
|
|
|
+ # inner nodes, we interpret labels at
|
|
|
|
+ # inner nodes as imprecise labels.
|
|
|
|
+ for successor in self.graph.successors(label):
|
|
|
|
+ mask[i, self.orig_lab_to_dimension[successor]] = True
|
|
|
|
+ return mask
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class HierarchyMixin(abc.ABC):
|
|
|
|
+
|
|
|
|
+ def parse_annotations(self):
|
|
|
|
+ super().parse_annotations()
|
|
|
|
+ self._parse_hierarchy()
|
|
|
|
+
|
|
|
|
+ def _parse_hierarchy(self):
|
|
|
|
+ if self.files.hierarchy is None:
|
|
|
|
+ return
|
|
|
|
+
|
|
|
|
+ tuples = [entry.split(" ") for entry in self.files.hierarchy]
|
|
|
|
+ tuples = [(int(child), int(parent)) for child, parent in tuples]
|
|
|
|
+
|
|
|
|
+ self.hierarchy = Hierarchy(tuples, self.unq2orig)
|