Forráskód Böngészése

added label de-embedding for the hierarchical classification

Dimitri Korsch 2 éve
szülő
commit
0b8edf4858
1 módosított fájl, 66 hozzáadás és 2 törlés
  1. 66 2
      cvdatasets/annotation/mixins/hierarchy_mixin.py

+ 66 - 2
cvdatasets/annotation/mixins/hierarchy_mixin.py

@@ -10,6 +10,7 @@ from cvdatasets.annotation.files import AnnotationFiles
 
 class Hierarchy:
 	_force_prediction_targets = False
+	_raw_output = False
 
 	def __init__(self, tuples: T.Tuple[int],
 		label_transform: T.Optional[T.Callable] = None):
@@ -21,11 +22,11 @@ class Hierarchy:
 			self.graph.add_edge(parent, child)
 
 
-		topo_sorted_orig_labels = list(nx.topological_sort(self.graph))
+		self.topo_sorted_orig_labels = list(nx.topological_sort(self.graph))
 
 		self.orig_lab_to_dimension = {
 			lab: dimension
-				for dimension, lab in enumerate(topo_sorted_orig_labels)
+				for dimension, lab in enumerate(self.topo_sorted_orig_labels)
 		}
 
 	def label_transform(self, label):
@@ -83,6 +84,69 @@ class Hierarchy:
 		return mask
 
 
+	def deembed_dist(self, embedded_labels):
+		return [
+			self._deembed_single(embedded_label) for embedded_label in embedded_labels
+		]
+
+	def _deembed_single(self, embedded_label):
+		"""
+			code from https://github.com/cabrust/chia/blob/main/chia/components/classifiers/keras_idk_hc.py#L68
+		"""
+		cond_probs = {
+			label: embedded_label[dim] for label, dim in self.orig_lab_to_dimension.items()
+		}
+
+		if self._raw_output:
+			# Directly output conditional probabilities
+			return list(cond_probs.items())
+		else:
+			# Stage 1 calculates the unconditional probabilities
+			uncond_probs = {}
+
+			for lab in self.topo_sorted_orig_labels:
+				unconditional_probability = cond_probs[lab]
+
+				no_parent_probability = 1.0
+				has_parents = False
+				for parent in self.graph.predecessors(lab):
+					has_parents = True
+					no_parent_probability *= 1.0 - uncond_probs[parent]
+
+				if has_parents:
+					unconditional_probability *= 1.0 - no_parent_probability
+
+				uncond_probs[lab] = unconditional_probability
+
+			# Stage 2 calculates the joint probability of the synset and "no children"
+			joint_probabilities = {}
+			for lab in reversed(self.topo_sorted_orig_labels):
+				joint_probability = uncond_probs[lab]
+				no_child_probability = 1.0
+				for child in self.graph.successors(lab):
+					no_child_probability *= 1.0 - uncond_probs[child]
+
+				joint_probabilities[lab] = joint_probability * no_child_probability
+
+			tuples = joint_probabilities.items()
+			sorted_tuples = list(sorted(tuples, key=lambda tup: tup[1], reverse=True))
+
+			# If requested, only output scores for the forced prediction targets
+			if self._force_prediction_targets:
+				for i, (lab, p) in enumerate(sorted_tuples):
+					if lab not in self.prediction_target_uids:
+						sorted_tuples[i] = (lab, 0.0)
+
+				total_scores = sum([p for lab, p in sorted_tuples])
+				if total_scores > 0:
+					sorted_tuples = [
+						(lab, p / total_scores) for lab, p in sorted_tuples
+					]
+
+			return list(sorted_tuples)
+
+
+
 class HierarchyMixin(abc.ABC):
 
 	def parse_annotations(self):