|
@@ -3,24 +3,50 @@ from chia import components, knowledge, instrumentation
|
|
|
from chillax import information_content
|
|
|
|
|
|
import abc
|
|
|
+import typing
|
|
|
|
|
|
import networkx as nx
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
-class CHILLAXExtrapolator(instrumentation.Observer, abc.ABC):
|
|
|
+class CHILLAXExtrapolator(
|
|
|
+ instrumentation.Observer, instrumentation.Observable, abc.ABC
|
|
|
+):
|
|
|
def __init__(
|
|
|
- self, knowledge_base: knowledge.KnowledgeBase, apply_ground_truth: bool
|
|
|
+ self,
|
|
|
+ knowledge_base: knowledge.KnowledgeBase,
|
|
|
+ apply_ground_truth: bool,
|
|
|
+ ic_method: typing.Optional[str] = None,
|
|
|
):
|
|
|
+ instrumentation.Observable.__init__(self)
|
|
|
self.knowledge_base = knowledge_base
|
|
|
self.knowledge_base.register(self)
|
|
|
|
|
|
self.apply_ground_truth = apply_ground_truth
|
|
|
|
|
|
self.is_updated = False
|
|
|
- self._update_relations_and_concepts()
|
|
|
|
|
|
- def extrapolate(self, ground_truth_uid, unconditional_probabilities):
|
|
|
+ # Graph Cache
|
|
|
+ self._rgraph = nx.DiGraph()
|
|
|
+ self._uid_to_depth = dict()
|
|
|
+ self._prediction_targets = set()
|
|
|
+
|
|
|
+ # Information Content Cache
|
|
|
+ self._ic_calc: information_content.InformationContentCalculator = (
|
|
|
+ information_content.InformationContentCalculatorFactory.create(
|
|
|
+ {"name": ic_method if ic_method is not None else "zhou_2008_modified"}
|
|
|
+ )
|
|
|
+ )
|
|
|
+ self._ic_cache = dict()
|
|
|
+
|
|
|
+ self.update_relations_and_concepts()
|
|
|
+
|
|
|
+ # Reporting
|
|
|
+ self._reporting_samples_total = 0
|
|
|
+ self._reporting_samples_changed = 0
|
|
|
+ self._reporting_cum_ic_gain = 0
|
|
|
+
|
|
|
+ def extrapolate(self, extrapolator_inputs):
|
|
|
if not self.is_updated:
|
|
|
raise RuntimeError(
|
|
|
"This extrapolator is not updated. "
|
|
@@ -28,14 +54,85 @@ class CHILLAXExtrapolator(instrumentation.Observer, abc.ABC):
|
|
|
"RelationChange and ConceptChange messages."
|
|
|
)
|
|
|
|
|
|
- return self._extrapolate(ground_truth_uid, unconditional_probabilities)
|
|
|
+ outputs = []
|
|
|
+ for ground_truth_uid, unconditional_probabilities in extrapolator_inputs:
|
|
|
+ outputs += [
|
|
|
+ self._extrapolate(ground_truth_uid, unconditional_probabilities)
|
|
|
+ ]
|
|
|
+
|
|
|
+ self._reporting_update(
|
|
|
+ zip([gt_uid for gt_uid, _ in extrapolator_inputs], outputs)
|
|
|
+ )
|
|
|
+ return outputs
|
|
|
+
|
|
|
+ def _reporting_update(self, label_pairs):
|
|
|
+ for gt_uid, ext_uid in label_pairs:
|
|
|
+ if gt_uid != ext_uid:
|
|
|
+ self._reporting_samples_changed += 1
|
|
|
+ self._reporting_cum_ic_gain += (
|
|
|
+ self._ic_cache[ext_uid] - self._ic_cache[gt_uid]
|
|
|
+ )
|
|
|
+
|
|
|
+ self._reporting_samples_total += 1
|
|
|
+
|
|
|
+ def _reporting_reset(self):
|
|
|
+ self._reporting_samples_total = 0
|
|
|
+ self._reporting_samples_changed = 0
|
|
|
+ self._reporting_cum_ic_gain = 0
|
|
|
+
|
|
|
+ def reporting_report(self, current_step):
|
|
|
+ if self._reporting_samples_total == 0:
|
|
|
+ return
|
|
|
+
|
|
|
+ self.report_metric(
|
|
|
+ "extrapolation_changed_sample_fraction",
|
|
|
+ self._reporting_samples_changed / float(self._reporting_samples_total),
|
|
|
+ step=current_step,
|
|
|
+ )
|
|
|
+ self.report_metric(
|
|
|
+ "extrapolation_avg_ic_gain",
|
|
|
+ self._reporting_cum_ic_gain / float(self._reporting_samples_total),
|
|
|
+ step=current_step,
|
|
|
+ )
|
|
|
+
|
|
|
+ self._reporting_reset()
|
|
|
|
|
|
def update(self, message: instrumentation.Message):
|
|
|
if isinstance(message, knowledge.RelationChangeMessage) or isinstance(
|
|
|
message, knowledge.ConceptChangeMessage
|
|
|
):
|
|
|
self.is_updated = False
|
|
|
- self._update_relations_and_concepts()
|
|
|
+ self.update_relations_and_concepts()
|
|
|
+
|
|
|
+ def update_relations_and_concepts(self):
|
|
|
+ try:
|
|
|
+ # Update Information Content Cache
|
|
|
+ self._ic_cache = dict()
|
|
|
+ rgraph = self.knowledge_base.get_hyponymy_relation_rgraph()
|
|
|
+ for concept in self.knowledge_base.concepts():
|
|
|
+ self._ic_cache[
|
|
|
+ concept.uid
|
|
|
+ ] = self._ic_calc.calculate_information_content(concept.uid, rgraph)
|
|
|
+
|
|
|
+ # Graph Update
|
|
|
+ self._rgraph = self.knowledge_base.get_hyponymy_relation_rgraph()
|
|
|
+ self._prediction_targets = {
|
|
|
+ concept.uid
|
|
|
+ for concept in self.knowledge_base.concepts(
|
|
|
+ flags={knowledge.ConceptFlag.PREDICTION_TARGET}
|
|
|
+ )
|
|
|
+ }
|
|
|
+
|
|
|
+ root = list(nx.topological_sort(self._rgraph))[0]
|
|
|
+ self._uid_to_depth = {
|
|
|
+ concept.uid: len(nx.shortest_path(self._rgraph, root, concept.uid))
|
|
|
+ for concept in self.knowledge_base.concepts()
|
|
|
+ }
|
|
|
+
|
|
|
+ except ValueError as verr:
|
|
|
+ self.log_warning(f"Could not update extrapolator. {verr.args}")
|
|
|
+
|
|
|
+ self._update_relations_and_concepts()
|
|
|
|
|
|
@abc.abstractmethod
|
|
|
def _extrapolate(self, ground_truth_uid, unconditional_probabilities):
|
|
@@ -55,20 +152,16 @@ class SimpleThresholdCHILLAXExtrapolator(CHILLAXExtrapolator):
|
|
|
self,
|
|
|
knowledge_base,
|
|
|
apply_ground_truth,
|
|
|
- ic_method: str = "sanchez_2011_modified",
|
|
|
+ ic_method: typing.Optional[str] = None,
|
|
|
threshold=0.55,
|
|
|
):
|
|
|
- self._ic_calc: information_content.InformationContentCalculator = information_content.InformationContentCalculatorFactory.create(
|
|
|
- {"name": ic_method}
|
|
|
+ super().__init__(
|
|
|
+ knowledge_base=knowledge_base,
|
|
|
+ apply_ground_truth=apply_ground_truth,
|
|
|
+ ic_method=ic_method,
|
|
|
)
|
|
|
|
|
|
self._threshold = threshold
|
|
|
- self._ic_cache = dict()
|
|
|
-
|
|
|
- # This needs to come later because ic_calc is needed for _update_relations_and_concepts
|
|
|
- super().__init__(
|
|
|
- knowledge_base=knowledge_base, apply_ground_truth=apply_ground_truth
|
|
|
- )
|
|
|
|
|
|
def _extrapolate(self, ground_truth_uid, unconditional_probabilities):
|
|
|
candidates = [
|
|
@@ -98,44 +191,36 @@ class SimpleThresholdCHILLAXExtrapolator(CHILLAXExtrapolator):
|
|
|
else:
|
|
|
return ground_truth_uid
|
|
|
|
|
|
- def _update_relations_and_concepts(self):
|
|
|
- try:
|
|
|
- self._ic_cache = dict()
|
|
|
- rgraph = self.knowledge_base.get_hyponymy_relation_rgraph()
|
|
|
- for concept in self.knowledge_base.concepts():
|
|
|
- self._ic_cache[
|
|
|
- concept.uid
|
|
|
- ] = self._ic_calc.calculate_information_content(concept.uid, rgraph)
|
|
|
- self.is_updated = True
|
|
|
-
|
|
|
- except ValueError:
|
|
|
- return
|
|
|
-
|
|
|
|
|
|
class DepthStepsCHILLAXExtrapolator(CHILLAXExtrapolator):
|
|
|
- def __init__(self, knowledge_base, apply_ground_truth, steps=1, threshold=None):
|
|
|
- self._steps = steps
|
|
|
- self._threshold = threshold
|
|
|
- self.rgraph = nx.DiGraph()
|
|
|
- self.uid_to_depth = dict()
|
|
|
- self.prediction_targets = set()
|
|
|
-
|
|
|
- # This needs to come later because otherwise uid_to_depth will be overwritten
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ knowledge_base,
|
|
|
+ apply_ground_truth,
|
|
|
+ ic_method: typing.Optional[str] = None,
|
|
|
+ steps=1,
|
|
|
+ threshold=None,
|
|
|
+ ):
|
|
|
super().__init__(
|
|
|
- knowledge_base=knowledge_base, apply_ground_truth=apply_ground_truth
|
|
|
+ knowledge_base=knowledge_base,
|
|
|
+ apply_ground_truth=apply_ground_truth,
|
|
|
+ ic_method=ic_method,
|
|
|
)
|
|
|
|
|
|
+ self._steps = steps
|
|
|
+ self._threshold = threshold
|
|
|
+
|
|
|
def _extrapolate(self, ground_truth_uid, unconditional_probabilities):
|
|
|
- original_depth = self.uid_to_depth[ground_truth_uid]
|
|
|
+ original_depth = self._uid_to_depth[ground_truth_uid]
|
|
|
allowed_depth = original_depth + self._steps
|
|
|
|
|
|
allowed_uids = set()
|
|
|
- for descendant in nx.descendants(self.rgraph, ground_truth_uid):
|
|
|
- if self.uid_to_depth[descendant] == allowed_depth:
|
|
|
+ for descendant in nx.descendants(self._rgraph, ground_truth_uid):
|
|
|
+ if self._uid_to_depth[descendant] == allowed_depth:
|
|
|
allowed_uids |= {descendant}
|
|
|
elif (
|
|
|
- self.uid_to_depth[descendant] < allowed_depth
|
|
|
- and descendant in self.prediction_targets
|
|
|
+ self._uid_to_depth[descendant] < allowed_depth
|
|
|
+ and descendant in self._prediction_targets
|
|
|
):
|
|
|
# We need to allow leaf nodes if they are shallower than the allowed depth.
|
|
|
# Otherwise, we won't have any candidates sometimes.
|
|
@@ -160,7 +245,11 @@ class DepthStepsCHILLAXExtrapolator(CHILLAXExtrapolator):
|
|
|
)
|
|
|
)
|
|
|
if self._threshold is not None:
|
|
|
- candidates = [candidate for candidate in candidates if unconditional_probabilities[candidate] > self._threshold]
|
|
|
+ candidates = [
|
|
|
+ candidate
|
|
|
+ for candidate in candidates
|
|
|
+ if unconditional_probabilities[candidate] > self._threshold
|
|
|
+ ]
|
|
|
if len(candidates) > 0:
|
|
|
return candidates[0]
|
|
|
else:
|
|
@@ -168,32 +257,17 @@ class DepthStepsCHILLAXExtrapolator(CHILLAXExtrapolator):
|
|
|
else:
|
|
|
return ground_truth_uid
|
|
|
|
|
|
- def _update_relations_and_concepts(self):
|
|
|
- try:
|
|
|
- self.rgraph = self.knowledge_base.get_hyponymy_relation_rgraph()
|
|
|
- self.prediction_targets = {
|
|
|
- concept.uid
|
|
|
- for concept in self.knowledge_base.concepts(
|
|
|
- flags={knowledge.ConceptFlag.PREDICTION_TARGET}
|
|
|
- )
|
|
|
- }
|
|
|
-
|
|
|
- root = list(nx.topological_sort(self.rgraph))[0]
|
|
|
- self.uid_to_depth = {
|
|
|
- concept.uid: len(nx.shortest_path(self.rgraph, root, concept.uid))
|
|
|
- for concept in self.knowledge_base.concepts()
|
|
|
- }
|
|
|
-
|
|
|
- self.is_updated = True
|
|
|
- except ValueError:
|
|
|
- return
|
|
|
-
|
|
|
|
|
|
class ForcePredictionTargetCHILLAXExtrapolator(CHILLAXExtrapolator):
|
|
|
- def __init__(self, knowledge_base, apply_ground_truth):
|
|
|
+ def __init__(
|
|
|
+ self, knowledge_base, apply_ground_truth, ic_method: typing.Optional[str] = None
|
|
|
+ ):
|
|
|
super().__init__(
|
|
|
- knowledge_base=knowledge_base, apply_ground_truth=apply_ground_truth
|
|
|
+ knowledge_base=knowledge_base,
|
|
|
+ apply_ground_truth=apply_ground_truth,
|
|
|
+ ic_method=ic_method,
|
|
|
)
|
|
|
+
|
|
|
self.prediction_targets = set()
|
|
|
|
|
|
def _extrapolate(self, ground_truth_uid, unconditional_probabilities):
|