Browse Source

Fixed bugs in DepthSteps method

Clemens-Alexander Brust 5 years ago
parent
commit
f0a5eaafbe
2 changed files with 7 additions and 5 deletions
  1. 6 4
      chillax/chillax_extrapolator.py
  2. 1 1
      chillax/version.py

+ 6 - 4
chillax/chillax_extrapolator.py

@@ -114,14 +114,16 @@ class SimpleThresholdCHILLAXExtrapolator(CHILLAXExtrapolator):
 
 class DepthStepsCHILLAXExtrapolator(CHILLAXExtrapolator):
     def __init__(self, knowledge_base, apply_ground_truth, steps=1):
-        super().__init__(
-            knowledge_base=knowledge_base, apply_ground_truth=apply_ground_truth
-        )
         self._steps = steps
         self.rgraph = nx.DiGraph()
         self.uid_to_depth = dict()
         self.prediction_targets = set()
 
+        # This needs to come later because otherwise uid_to_depth will be overwritten
+        super().__init__(
+            knowledge_base=knowledge_base, apply_ground_truth=apply_ground_truth
+        )
+
     def _extrapolate(self, ground_truth_uid, unconditional_probabilities):
         original_depth = self.uid_to_depth[ground_truth_uid]
         allowed_depth = original_depth + self._steps
@@ -172,7 +174,7 @@ class DepthStepsCHILLAXExtrapolator(CHILLAXExtrapolator):
 
             root = list(nx.topological_sort(self.rgraph))[0]
             self.uid_to_depth = {
-                concept.uid: nx.shortest_path(self.rgraph, root, concept.uid)
+                concept.uid: len(nx.shortest_path(self.rgraph, root, concept.uid))
                 for concept in self.knowledge_base.concepts()
             }
 

+ 1 - 1
chillax/version.py

@@ -1 +1 @@
-__version__ = "0.1a2"
+__version__ = "0.1a3"