Răsfoiți Sursa

Added threshold to DepthSteps

Clemens-Alexander Brust 5 ani în urmă
părinte
comite
9e30961e3f
2 a modificat fișierele cu 9 adăugiri și 3 ștergeri
  1. 8 2
      chillax/chillax_extrapolator.py
  2. 1 1
      chillax/version.py

+ 8 - 2
chillax/chillax_extrapolator.py

@@ -113,8 +113,9 @@ class SimpleThresholdCHILLAXExtrapolator(CHILLAXExtrapolator):
 
 
 class DepthStepsCHILLAXExtrapolator(CHILLAXExtrapolator):
-    def __init__(self, knowledge_base, apply_ground_truth, steps=1):
+    def __init__(self, knowledge_base, apply_ground_truth, steps=1, threshold=None):
         self._steps = steps
+        self._threshold = threshold
         self.rgraph = nx.DiGraph()
         self.uid_to_depth = dict()
         self.prediction_targets = set()
@@ -158,7 +159,12 @@ class DepthStepsCHILLAXExtrapolator(CHILLAXExtrapolator):
                     reverse=True,
                 )
             )
-            return candidates[0]
+            if self._threshold is not None:
+                candidates = [candidate for candidate in candidates if unconditional_probabilities[candidate] > self._threshold]
+            if len(candidates) > 0:
+                return candidates[0]
+            else:
+                return ground_truth_uid
         else:
             return ground_truth_uid
 

+ 1 - 1
chillax/version.py

@@ -1 +1 @@
-__version__ = "0.1a3"
+__version__ = "0.1a4"