|
@@ -113,8 +113,9 @@ class SimpleThresholdCHILLAXExtrapolator(CHILLAXExtrapolator):
|
|
|
|
|
|
|
|
|
class DepthStepsCHILLAXExtrapolator(CHILLAXExtrapolator):
|
|
|
- def __init__(self, knowledge_base, apply_ground_truth, steps=1):
|
|
|
+ def __init__(self, knowledge_base, apply_ground_truth, steps=1, threshold=None):
|
|
|
self._steps = steps
|
|
|
+ self._threshold = threshold
|
|
|
self.rgraph = nx.DiGraph()
|
|
|
self.uid_to_depth = dict()
|
|
|
self.prediction_targets = set()
|
|
@@ -158,7 +159,12 @@ class DepthStepsCHILLAXExtrapolator(CHILLAXExtrapolator):
|
|
|
reverse=True,
|
|
|
)
|
|
|
)
|
|
|
- return candidates[0]
|
|
|
+ if self._threshold is not None:
|
|
|
+ candidates = [candidate for candidate in candidates if unconditional_probabilities[candidate] > self._threshold]
|
|
|
+ if len(candidates) > 0:
|
|
|
+ return candidates[0]
|
|
|
+ else:
|
|
|
+ return ground_truth_uid
|
|
|
else:
|
|
|
return ground_truth_uid
|
|
|
|