Переглянути джерело

Added adaptive IC gain and range methods

Clemens-Alexander Brust 4 роки тому
батько
коміт
5e9f011a29
1 змінених файлів з 140 додано та 0 видалено
  1. 140 0
      chillax/chillax_extrapolator.py

+ 140 - 0
chillax/chillax_extrapolator.py

@@ -3,6 +3,7 @@ from chia import components, knowledge, instrumentation
 from chillax import information_content
 
 import abc
+import collections
 import typing
 
 import networkx as nx
@@ -292,10 +293,149 @@ class ForcePredictionTargetCHILLAXExtrapolator(CHILLAXExtrapolator):
             return ground_truth_uid
 
 
+class ICGainRangeCHILLAXExtrapolator(CHILLAXExtrapolator):
+    def __init__(
+        self,
+        knowledge_base,
+        apply_ground_truth,
+        ic_method: typing.Optional[str] = None,
+        ic_gain_target=None,
+        ic_range=0.2,
+        probability_threshold=0.55,
+    ):
+        super().__init__(
+            knowledge_base=knowledge_base,
+            apply_ground_truth=apply_ground_truth,
+            ic_method=ic_method,
+        )
+
+        self._ic_gain_target = ic_gain_target
+        self._ic_range = ic_range
+        self._probability_threshold = probability_threshold
+
+    def _extrapolate(self, ground_truth_uid, unconditional_probabilities):
+        # We need this more often
+        ground_truth_ic = self._ic_cache[ground_truth_uid]
+        target_ic = ground_truth_ic + self._ic_gain_target
+        half_range = self._ic_range / 2.0
+
+        candidates = [
+            uid
+            for (uid, probability) in unconditional_probabilities.items()
+            if -half_range <= (self._ic_cache[uid] - target_ic) <= half_range
+            and probability >= self._probability_threshold
+        ]
+
+        if len(candidates) > 0:
+            candidates_with_ic = [(uid, self._ic_cache[uid]) for uid in candidates]
+
+            # Sort by probability first, see other methods for explanation of noise
+            candidates_with_ic = list(
+                sorted(
+                    candidates_with_ic,
+                    key=lambda x: unconditional_probabilities[x[0]]
+                    + np.random.normal(0, 0.0001),
+                    reverse=True,
+                )
+            )
+
+            # Sort by IC. Stable sorting is guaranteed by python.
+            candidates_with_ic = list(
+                sorted(candidates_with_ic, key=lambda x: x[1], reverse=True)
+            )
+            return candidates_with_ic[0][0]
+        else:
+            return ground_truth_uid
+
+
+class AdaptiveICGainCHILLAXExtrapolator(CHILLAXExtrapolator):
+    def __init__(
+        self,
+        knowledge_base,
+        apply_ground_truth,
+        ic_method: typing.Optional[str] = None,
+        ic_gain_target=None,
+        min_threshold=0.55,
+        max_threshold=1.0,
+        learning_rate=1.0,
+    ):
+        super().__init__(
+            knowledge_base=knowledge_base,
+            apply_ground_truth=apply_ground_truth,
+            ic_method=ic_method,
+        )
+
+        self._ic_gain_target = ic_gain_target
+        self._min_threshold = min_threshold
+        self._max_threshold = max_threshold
+        self._learning_rate = learning_rate
+
+        # Initialize the threshold
+        self._threshold = min_threshold
+
+        self._last_ic_gains = collections.deque(maxlen=64)
+
+    def _extrapolate(self, ground_truth_uid, unconditional_probabilities):
+        """This is basically the same as SimpleThresholdCHILLAXExtrapolator, just with added reporting etc."""
+        candidates = [
+            uid
+            for (uid, probability) in unconditional_probabilities.items()
+            if probability >= self._threshold
+        ]
+
+        if len(candidates) > 0:
+            candidates_with_ic = [(uid, self._ic_cache[uid]) for uid in candidates]
+
+            # Sort by probability first, see other methods for explanation of noise
+            candidates_with_ic = list(
+                sorted(
+                    candidates_with_ic,
+                    key=lambda x: unconditional_probabilities[x[0]]
+                    + np.random.normal(0, 0.0001),
+                    reverse=True,
+                )
+            )
+
+            # Sort by IC second. Stable sorting is guaranteed by python.
+            candidates_with_ic = list(
+                sorted(candidates_with_ic, key=lambda x: x[1], reverse=True)
+            )
+            return_value = candidates_with_ic[0][0]
+        else:
+            return_value = ground_truth_uid
+
+        # Compute the actual IC gain of our actions
+        realized_ic_gain = (
+            self._ic_cache[return_value] - self._ic_cache[ground_truth_uid]
+        )
+
+        # Compute average IC gain, maxlen should do the rest :)
+        self._last_ic_gains.append(realized_ic_gain)
+        avg_ic_gain = sum(self._last_ic_gains) / float(len(self._last_ic_gains))
+
+        # Assume that increasing the threshold decreases the possible IC gain
+        # e.g. if average IC is 0.3 too much, increase the threshold by 0.3 (lr=1.0)
+        step = self._learning_rate * (avg_ic_gain - self._ic_gain_target)
+        self._threshold = max(
+            self._min_threshold, min(self._max_threshold, self._threshold + step)
+        )
+
+        return return_value
+
+    def reporting_report(self, current_step):
+        """We want to have a look at the thresholds."""
+        self.report_metric(
+            "extrapolation_current_threshold", self._threshold, current_step
+        )
+        super().reporting_report(current_step)
+
+
 class CHILLAXExtrapolatorFactory(components.Factory):
     name_to_class_mapping = {
         "do_nothing": DoNothingCHILLAXExtrapolator,
         "simple_threshold": SimpleThresholdCHILLAXExtrapolator,
         "depth_steps": DepthStepsCHILLAXExtrapolator,
         "force_prediction_target": ForcePredictionTargetCHILLAXExtrapolator,
+        "adaptive_ic_gain": AdaptiveICGainCHILLAXExtrapolator,
+        "ic_gain_range": ICGainRangeCHILLAXExtrapolator,
     }