|
@@ -114,14 +114,16 @@ class SimpleThresholdCHILLAXExtrapolator(CHILLAXExtrapolator):
|
|
|
|
|
|
class DepthStepsCHILLAXExtrapolator(CHILLAXExtrapolator):
|
|
|
def __init__(self, knowledge_base, apply_ground_truth, steps=1):
|
|
|
- super().__init__(
|
|
|
- knowledge_base=knowledge_base, apply_ground_truth=apply_ground_truth
|
|
|
- )
|
|
|
self._steps = steps
|
|
|
self.rgraph = nx.DiGraph()
|
|
|
self.uid_to_depth = dict()
|
|
|
self.prediction_targets = set()
|
|
|
|
|
|
+ # This needs to come later because otherwise uid_to_depth will be overwritten
|
|
|
+ super().__init__(
|
|
|
+ knowledge_base=knowledge_base, apply_ground_truth=apply_ground_truth
|
|
|
+ )
|
|
|
+
|
|
|
def _extrapolate(self, ground_truth_uid, unconditional_probabilities):
|
|
|
original_depth = self.uid_to_depth[ground_truth_uid]
|
|
|
allowed_depth = original_depth + self._steps
|
|
@@ -172,7 +174,7 @@ class DepthStepsCHILLAXExtrapolator(CHILLAXExtrapolator):
|
|
|
|
|
|
root = list(nx.topological_sort(self.rgraph))[0]
|
|
|
self.uid_to_depth = {
|
|
|
- concept.uid: nx.shortest_path(self.rgraph, root, concept.uid)
|
|
|
+ concept.uid: len(nx.shortest_path(self.rgraph, root, concept.uid))
|
|
|
for concept in self.knowledge_base.concepts()
|
|
|
}
|
|
|
|