Selaa lähdekoodia

Refactored classifier and extrapolator, added IC gain reporting

Clemens-Alexander Brust 5 vuotta sitten
vanhempi
commit
f2e1e8d471
2 muutettua tiedostoa jossa 150 lisäystä ja 90 poistoa
  1. 11 25
      chillax/chillax_classifier.py
  2. 139 65
      chillax/chillax_extrapolator.py

+ 11 - 25
chillax/chillax_classifier.py

@@ -401,7 +401,7 @@ class CHILLAXKerasHC(
         # Only do anything if there is an extrapolator
         if self.extrapolator is not None:
             epn = embedded_prediction.numpy()
-            extrapolated_ground_truth = []
+            extrapolator_inputs = []
             for i, ground_truth_element in enumerate(ground_truth):
                 # Get the raw scores
                 conditional_probabilities = self._calculate_conditional_probabilities(
@@ -425,37 +425,23 @@ class CHILLAXKerasHC(
                         )
 
                 # Calculate unconditionals and extrapolate
-                unconditional_probabilities = self._calculate_unconditional_probabilities(
-                    conditional_probabilities
-                )
-                extrapolated_ground_truth += [
-                    self.extrapolator.extrapolate(
-                        ground_truth_element, unconditional_probabilities
+                unconditional_probabilities = (
+                    self._calculate_unconditional_probabilities(
+                        conditional_probabilities
                     )
+                )
+                extrapolator_inputs += [
+                    (ground_truth_element, unconditional_probabilities)
                 ]
 
-            # Handle reporting
-            self._running_sample_count += len(ground_truth)
-            self._running_changed_samples += sum(
-                [
-                    1
-                    for egt, gt in zip(extrapolated_ground_truth, ground_truth)
-                    if egt != gt
-                ]
+            extrapolated_ground_truth = self.extrapolator.extrapolate(
+                extrapolator_inputs
             )
 
+            # Handle reporting
             if self._reporting_step_counter % 10 == 9:
                 if self._last_reported_step < self._reporting_step_counter:
-                    if self._running_sample_count > 0:
-                        self.report_metric(
-                            "extrapolation_changed_sample_fraction",
-                            float(self._running_changed_samples)
-                            / float(self._running_sample_count),
-                            step=self._reporting_step_counter,
-                        )
-
-                    self._running_changed_samples = 0
-                    self._running_sample_count = 0
+                    self.extrapolator.reporting_report(self._reporting_step_counter)
 
                     self._last_reported_step = self._reporting_step_counter
 

+ 139 - 65
chillax/chillax_extrapolator.py

@@ -3,24 +3,50 @@ from chia import components, knowledge, instrumentation
 from chillax import information_content
 
 import abc
+import typing
 
 import networkx as nx
 import numpy as np
 
 
-class CHILLAXExtrapolator(instrumentation.Observer, abc.ABC):
+class CHILLAXExtrapolator(
+    instrumentation.Observer, instrumentation.Observable, abc.ABC
+):
     def __init__(
-        self, knowledge_base: knowledge.KnowledgeBase, apply_ground_truth: bool
+        self,
+        knowledge_base: knowledge.KnowledgeBase,
+        apply_ground_truth: bool,
+        ic_method: typing.Optional[str] = None,
     ):
+        instrumentation.Observable.__init__(self)
         self.knowledge_base = knowledge_base
         self.knowledge_base.register(self)
 
         self.apply_ground_truth = apply_ground_truth
 
         self.is_updated = False
-        self._update_relations_and_concepts()
 
-    def extrapolate(self, ground_truth_uid, unconditional_probabilities):
+        # Graph Cache
+        self._rgraph = nx.DiGraph()
+        self._uid_to_depth = dict()
+        self._prediction_targets = set()
+
+        # Information Content Cache
+        self._ic_calc: information_content.InformationContentCalculator = (
+            information_content.InformationContentCalculatorFactory.create(
+                {"name": ic_method if ic_method is not None else "zhou_2008_modified"}
+            )
+        )
+        self._ic_cache = dict()
+
+        self.update_relations_and_concepts()
+
+        # Reporting
+        self._reporting_samples_total = 0
+        self._reporting_samples_changed = 0
+        self._reporting_cum_ic_gain = 0
+
+    def extrapolate(self, extrapolator_inputs):
         if not self.is_updated:
             raise RuntimeError(
                 "This extrapolator is not updated. "
@@ -28,14 +54,85 @@ class CHILLAXExtrapolator(instrumentation.Observer, abc.ABC):
                 "RelationChange and ConceptChange messages."
             )
 
-        return self._extrapolate(ground_truth_uid, unconditional_probabilities)
+        outputs = []
+        for ground_truth_uid, unconditional_probabilities in extrapolator_inputs:
+            outputs += [
+                self._extrapolate(ground_truth_uid, unconditional_probabilities)
+            ]
+
+        self._reporting_update(
+            zip([gt_uid for gt_uid, _ in extrapolator_inputs], outputs)
+        )
+        return outputs
+
+    def _reporting_update(self, label_pairs):
+        for gt_uid, ext_uid in label_pairs:
+            if gt_uid != ext_uid:
+                self._reporting_samples_changed += 1
+                self._reporting_cum_ic_gain += (
+                    self._ic_cache[ext_uid] - self._ic_cache[gt_uid]
+                )
+
+            self._reporting_samples_total += 1
+
+    def _reporting_reset(self):
+        self._reporting_samples_total = 0
+        self._reporting_samples_changed = 0
+        self._reporting_cum_ic_gain = 0
+
+    def reporting_report(self, current_step):
+        if self._reporting_samples_total == 0:
+            return
+
+        self.report_metric(
+            "extrapolation_changed_sample_fraction",
+            self._reporting_samples_changed / float(self._reporting_samples_total),
+            step=current_step,
+        )
+        self.report_metric(
+            "extrapolation_avg_ic_gain",
+            self._reporting_cum_ic_gain / float(self._reporting_samples_total),
+            step=current_step,
+        )
+
+        self._reporting_reset()
 
     def update(self, message: instrumentation.Message):
         if isinstance(message, knowledge.RelationChangeMessage) or isinstance(
             message, knowledge.ConceptChangeMessage
         ):
             self.is_updated = False
-            self._update_relations_and_concepts()
+            self.update_relations_and_concepts()
+
+    def update_relations_and_concepts(self):
+        try:
+            # Update Information Content Cache
+            self._ic_cache = dict()
+            rgraph = self.knowledge_base.get_hyponymy_relation_rgraph()
+            for concept in self.knowledge_base.concepts():
+                self._ic_cache[
+                    concept.uid
+                ] = self._ic_calc.calculate_information_content(concept.uid, rgraph)
+
+            # Graph Update
+            self._rgraph = self.knowledge_base.get_hyponymy_relation_rgraph()
+            self._prediction_targets = {
+                concept.uid
+                for concept in self.knowledge_base.concepts(
+                    flags={knowledge.ConceptFlag.PREDICTION_TARGET}
+                )
+            }
+
+            root = list(nx.topological_sort(self._rgraph))[0]
+            self._uid_to_depth = {
+                concept.uid: len(nx.shortest_path(self._rgraph, root, concept.uid))
+                for concept in self.knowledge_base.concepts()
+            }
+
+        except ValueError as verr:
+            self.log_warning(f"Could not update extrapolator. {verr.args}")
+
+        self._update_relations_and_concepts()
 
     @abc.abstractmethod
     def _extrapolate(self, ground_truth_uid, unconditional_probabilities):
@@ -55,20 +152,16 @@ class SimpleThresholdCHILLAXExtrapolator(CHILLAXExtrapolator):
         self,
         knowledge_base,
         apply_ground_truth,
-        ic_method: str = "sanchez_2011_modified",
+        ic_method: typing.Optional[str] = None,
         threshold=0.55,
     ):
-        self._ic_calc: information_content.InformationContentCalculator = information_content.InformationContentCalculatorFactory.create(
-            {"name": ic_method}
+        super().__init__(
+            knowledge_base=knowledge_base,
+            apply_ground_truth=apply_ground_truth,
+            ic_method=ic_method,
         )
 
         self._threshold = threshold
-        self._ic_cache = dict()
-
-        # This needs to come later because ic_calc is needed for _update_relations_and_concepts
-        super().__init__(
-            knowledge_base=knowledge_base, apply_ground_truth=apply_ground_truth
-        )
 
     def _extrapolate(self, ground_truth_uid, unconditional_probabilities):
         candidates = [
@@ -98,44 +191,36 @@ class SimpleThresholdCHILLAXExtrapolator(CHILLAXExtrapolator):
         else:
             return ground_truth_uid
 
-    def _update_relations_and_concepts(self):
-        try:
-            self._ic_cache = dict()
-            rgraph = self.knowledge_base.get_hyponymy_relation_rgraph()
-            for concept in self.knowledge_base.concepts():
-                self._ic_cache[
-                    concept.uid
-                ] = self._ic_calc.calculate_information_content(concept.uid, rgraph)
-            self.is_updated = True
-
-        except ValueError:
-            return
-
 
 class DepthStepsCHILLAXExtrapolator(CHILLAXExtrapolator):
-    def __init__(self, knowledge_base, apply_ground_truth, steps=1, threshold=None):
-        self._steps = steps
-        self._threshold = threshold
-        self.rgraph = nx.DiGraph()
-        self.uid_to_depth = dict()
-        self.prediction_targets = set()
-
-        # This needs to come later because otherwise uid_to_depth will be overwritten
+    def __init__(
+        self,
+        knowledge_base,
+        apply_ground_truth,
+        ic_method: typing.Optional[str] = None,
+        steps=1,
+        threshold=None,
+    ):
         super().__init__(
-            knowledge_base=knowledge_base, apply_ground_truth=apply_ground_truth
+            knowledge_base=knowledge_base,
+            apply_ground_truth=apply_ground_truth,
+            ic_method=ic_method,
         )
 
+        self._steps = steps
+        self._threshold = threshold
+
     def _extrapolate(self, ground_truth_uid, unconditional_probabilities):
-        original_depth = self.uid_to_depth[ground_truth_uid]
+        original_depth = self._uid_to_depth[ground_truth_uid]
         allowed_depth = original_depth + self._steps
 
         allowed_uids = set()
-        for descendant in nx.descendants(self.rgraph, ground_truth_uid):
-            if self.uid_to_depth[descendant] == allowed_depth:
+        for descendant in nx.descendants(self._rgraph, ground_truth_uid):
+            if self._uid_to_depth[descendant] == allowed_depth:
                 allowed_uids |= {descendant}
             elif (
-                self.uid_to_depth[descendant] < allowed_depth
-                and descendant in self.prediction_targets
+                self._uid_to_depth[descendant] < allowed_depth
+                and descendant in self._prediction_targets
             ):
                 # We need to allow leaf nodes if they are shallower than the allowed depth.
                 # Otherwise, we won't have any candidates sometimes.
@@ -160,7 +245,11 @@ class DepthStepsCHILLAXExtrapolator(CHILLAXExtrapolator):
                 )
             )
             if self._threshold is not None:
-                candidates = [candidate for candidate in candidates if unconditional_probabilities[candidate] > self._threshold]
+                candidates = [
+                    candidate
+                    for candidate in candidates
+                    if unconditional_probabilities[candidate] > self._threshold
+                ]
             if len(candidates) > 0:
                 return candidates[0]
             else:
@@ -168,32 +257,17 @@ class DepthStepsCHILLAXExtrapolator(CHILLAXExtrapolator):
         else:
             return ground_truth_uid
 
-    def _update_relations_and_concepts(self):
-        try:
-            self.rgraph = self.knowledge_base.get_hyponymy_relation_rgraph()
-            self.prediction_targets = {
-                concept.uid
-                for concept in self.knowledge_base.concepts(
-                    flags={knowledge.ConceptFlag.PREDICTION_TARGET}
-                )
-            }
-
-            root = list(nx.topological_sort(self.rgraph))[0]
-            self.uid_to_depth = {
-                concept.uid: len(nx.shortest_path(self.rgraph, root, concept.uid))
-                for concept in self.knowledge_base.concepts()
-            }
-
-            self.is_updated = True
-        except ValueError:
-            return
-
 
 class ForcePredictionTargetCHILLAXExtrapolator(CHILLAXExtrapolator):
-    def __init__(self, knowledge_base, apply_ground_truth):
+    def __init__(
+        self, knowledge_base, apply_ground_truth, ic_method: typing.Optional[str] = None
+    ):
         super().__init__(
-            knowledge_base=knowledge_base, apply_ground_truth=apply_ground_truth
+            knowledge_base=knowledge_base,
+            apply_ground_truth=apply_ground_truth,
+            ic_method=ic_method,
         )
+
         self.prediction_targets = set()
 
     def _extrapolate(self, ground_truth_uid, unconditional_probabilities):