Explorar el Código

Added extrapolators and IC methods

Clemens-Alexander Brust hace 5 años
padre
commit
7a59012b99

+ 115 - 51
chillax/chillax_classifier.py

@@ -14,8 +14,7 @@ class CHILLAXKerasHC(
         self,
         kb,
         l2=5e-5,
-        mlnp=True,
-        normalize_scores=True,
+        force_prediction_targets=True,
         raw_output=False,
         weighting="default",
         gain_compensation="simple",
@@ -26,13 +25,12 @@ class CHILLAXKerasHC(
         # Configuration
         self._l2_regularization_coefficient = l2
 
-        self._mlnp = mlnp
-        self._normalize_scores = normalize_scores
+        self._force_prediction_targets = force_prediction_targets
 
         self._raw_output = raw_output
-        if self._raw_output and (self._mlnp or self._normalize_scores):
+        if self._raw_output and self._force_prediction_targets:
             raise ValueError(
-                "Cannot use raw output and MLNP or normalization at the same time!"
+                "Cannot use raw output and forced prediction targets at the same time!"
             )
 
         self._weighting = weighting
@@ -46,6 +44,8 @@ class CHILLAXKerasHC(
         self.loss_weights = None
         self.update_embedding()
 
+        self.extrapolator = None
+
     def predict_embedded(self, feature_batch):
         return self.fc_layer(feature_batch)
 
@@ -67,55 +67,61 @@ class CHILLAXKerasHC(
         ]
 
     def _deembed_single(self, embedded_label):
-        conditional_probabilities = {
-            uid: embedded_label[i] for uid, i in self.uid_to_dimension.items()
-        }
+        conditional_probabilities = self._calculate_conditional_probabilities(
+            embedded_label
+        )
 
         if self._raw_output:
             # Directly output conditional probabilities
             return list(conditional_probabilities.items())
         else:
-            # Stage 1 calculates the unconditional probabilities
-            unconditional_probabilities = {}
-
-            for uid in self.topo_sorted_uids:
-                unconditional_probability = conditional_probabilities[uid]
-
-                no_parent_probability = 1.0
-                has_parents = False
-                for parent in self.graph.predecessors(uid):
-                    has_parents = True
-                    no_parent_probability *= 1.0 - unconditional_probabilities[parent]
-
-                if has_parents:
-                    unconditional_probability *= 1.0 - no_parent_probability
-
-                unconditional_probabilities[uid] = unconditional_probability
-
-            # Stage 2 calculates the joint probability of the synset and "no children"
-            joint_probabilities = {}
-            for uid in reversed(self.topo_sorted_uids):
-                joint_probability = unconditional_probabilities[uid]
-                no_child_probability = 1.0
-                for child in self.graph.successors(uid):
-                    no_child_probability *= 1.0 - unconditional_probabilities[child]
-
-                joint_probabilities[uid] = joint_probability * no_child_probability
+            unconditional_probabilities = self._calculate_unconditional_probabilities(
+                conditional_probabilities
+            )
 
-            tuples = joint_probabilities.items()
+            # Note: Stage 2 from IDK is missing here. This is on purpose.
+            tuples = unconditional_probabilities.items()
             sorted_tuples = list(sorted(tuples, key=lambda tup: tup[1], reverse=True))
 
-            if self._mlnp:
+            # If requested, only output scores for the forced prediction targets
+            if self._force_prediction_targets:
                 for i, (uid, p) in enumerate(sorted_tuples):
                     if uid not in self.prediction_target_uids:
                         sorted_tuples[i] = (uid, 0.0)
 
-            if self._normalize_scores:
                 total_scores = sum([p for uid, p in sorted_tuples])
-                sorted_tuples = [(uid, p / total_scores) for uid, p in sorted_tuples]
+                if total_scores > 0:
+                    sorted_tuples = [
+                        (uid, p / total_scores) for uid, p in sorted_tuples
+                    ]
 
             return list(sorted_tuples)
 
+    def _calculate_conditional_probabilities(self, embedded_label):
+        conditional_probabilities = {
+            uid: embedded_label[i] for uid, i in self.uid_to_dimension.items()
+        }
+        return conditional_probabilities
+
+    def _calculate_unconditional_probabilities(self, conditional_probabilities):
+        # Calculate the unconditional probabilities
+        unconditional_probabilities = {}
+        for uid in self.topo_sorted_uids:
+            unconditional_probability = conditional_probabilities[uid]
+
+            no_parent_probability = 1.0
+            has_parents = False
+            for parent in self.graph.predecessors(uid):
+                has_parents = True
+                no_parent_probability *= 1.0 - unconditional_probabilities[parent]
+
+            if has_parents:
+                unconditional_probability *= 1.0 - no_parent_probability
+
+            unconditional_probabilities[uid] = unconditional_probability
+
+        return unconditional_probabilities
+
     def update_embedding(self):
         current_concepts = self.kb.concepts()
         current_concept_count = len(current_concepts)
@@ -161,7 +167,10 @@ class CHILLAXKerasHC(
         }
 
         self.prediction_target_uids = {
-            concept.uid for concept in self.kb.concepts(flags={knowledge.ConceptFlagV2.PREDICTION_TARGET})
+            concept.uid
+            for concept in self.kb.concepts(
+                flags={knowledge.ConceptFlagV2.PREDICTION_TARGET}
+            )
         }
 
         if len(old_weights) == 2:
@@ -257,7 +266,9 @@ class CHILLAXKerasHC(
 
                 for i, uid in enumerate(self.uid_to_dimension):
                     descendants = set(nx.descendants(self.graph, uid)) | {uid}
-                    reachable_leaf_nodes = descendants.intersection(self.prediction_target_uids)
+                    reachable_leaf_nodes = descendants.intersection(
+                        self.prediction_target_uids
+                    )
                     self.loss_weights[i] *= len(reachable_leaf_nodes)
 
                     # Test if any leaf nodes are reachable
@@ -293,11 +304,22 @@ class CHILLAXKerasHC(
 
     def loss(self, feature_batch, ground_truth):
         if not self.is_updated:
-            raise RuntimeError("This classifier is not yet ready to compute a loss. "
-                               "Check if it has been notified of a hyponymy relation.")
+            raise RuntimeError(
+                "This classifier is not yet ready to compute a loss. "
+                "Check if it has been notified of a hyponymy relation."
+            )
 
-        loss_mask = np.zeros((len(ground_truth), len(self.uid_to_dimension)))
-        for i, label in enumerate(ground_truth):
+        # (1) Predict
+        prediction = self.predict_embedded(feature_batch)
+
+        # (2) Extrapolate ground truth
+        extrapolated_ground_truth = self._extrapolate(ground_truth, prediction)
+
+        # (3) Compute loss mask
+        loss_mask = np.zeros(
+            (len(extrapolated_ground_truth), len(self.uid_to_dimension))
+        )
+        for i, label in enumerate(extrapolated_ground_truth):
             # Loss mask
             loss_mask[i, self.uid_to_dimension[label]] = 1.0
 
@@ -307,7 +329,7 @@ class CHILLAXKerasHC(
                     loss_mask[i, self.uid_to_dimension[successor]] = 1.0
                     # This should also cover the node itself, but we do it anyway
 
-            if not self._mlnp:
+            if not self._force_prediction_targets:
                 # Learn direct successors in order to "stop"
                 # prediction at these nodes.
                 # If MLNP is active, then this can be ignored.
@@ -317,19 +339,20 @@ class CHILLAXKerasHC(
                 for successor in self.graph.successors(label):
                     loss_mask[i, self.uid_to_dimension[successor]] = 1.0
 
-        embedding = self.embed(ground_truth)
-        prediction = self.predict_embedded(feature_batch)
+        # (4) Embed ground truth
+        embedded_ground_truth = self.embed(extrapolated_ground_truth)
 
-        # Binary cross entropy loss function
+        # (5) Compute binary cross entropy loss function
         clipped_probs = tf.clip_by_value(prediction, 1e-7, (1.0 - 1e-7))
         the_loss = -(
-            embedding * tf.math.log(clipped_probs)
-            + (1.0 - embedding) * tf.math.log(1.0 - clipped_probs)
+            embedded_ground_truth * tf.math.log(clipped_probs)
+            + (1.0 - embedded_ground_truth) * tf.math.log(1.0 - clipped_probs)
         )
 
         sum_per_batch_element = tf.reduce_sum(
             the_loss * loss_mask * self.loss_weights, axis=1
         )
+
         return tf.reduce_mean(sum_per_batch_element)
 
     def observe(self, samples, gt_resource_id):
@@ -366,3 +389,44 @@ class CHILLAXKerasHC(
             (self.uid_to_dimension,) = pickle.load(target)
 
         self.update_embedding()
+
+    def _extrapolate(self, ground_truth, embedded_prediction):
+        # Only do anything if there is an extrapolator
+        if self.extrapolator is not None:
+            epn = embedded_prediction.numpy()
+            extrapolated_ground_truth = []
+            for i, ground_truth_element in enumerate(ground_truth):
+                # Get the raw scores
+                conditional_probabilities = self._calculate_conditional_probabilities(
+                    epn[i]
+                )
+
+                # If the extrapolator wants it, apply the ground truth to the prediction at the
+                # conditional probability level.
+                if self.extrapolator.apply_ground_truth:
+                    label_true = {ground_truth_element}
+                    known = {ground_truth_element}
+                    for ancestor in nx.ancestors(self.graph, ground_truth_element):
+                        label_true |= {ancestor}
+                        known |= {ancestor}
+                        for child in self.graph.successors(ancestor):
+                            known |= {child}
+
+                    for uid in known:
+                        conditional_probabilities[uid] = (
+                            1.0 if uid in label_true else 0.0
+                        )
+
+                # Calculate unconditionals and extrapolate
+                unconditional_probabilities = self._calculate_unconditional_probabilities(
+                    conditional_probabilities
+                )
+                extrapolated_ground_truth += [
+                    self.extrapolator.extrapolate(
+                        ground_truth_element, unconditional_probabilities
+                    )
+                ]
+
+            return extrapolated_ground_truth
+        else:
+            return ground_truth

+ 221 - 0
chillax/chillax_extrapolator.py

@@ -0,0 +1,221 @@
+from chia.v2 import components, knowledge, instrumentation
+
+from chillax import information_content
+
+import abc
+
+import networkx as nx
+import numpy as np
+
+
+class CHILLAXExtrapolator(instrumentation.Observer, abc.ABC):
+    def __init__(
+        self, knowledge_base: knowledge.KnowledgeBaseV2, apply_ground_truth: bool
+    ):
+        self.knowledge_base = knowledge_base
+        self.knowledge_base.register(self)
+
+        self.apply_ground_truth = apply_ground_truth
+
+        self.is_updated = False
+        self._update_relations_and_concepts()
+
+    def extrapolate(self, ground_truth_uid, unconditional_probabilities):
+        if not self.is_updated:
+            raise RuntimeError(
+                "This extrapolator is not updated. "
+                "Please check if it is subscribed to "
+                "RelationChange and ConceptChange messages."
+            )
+
+        return self._extrapolate(ground_truth_uid, unconditional_probabilities)
+
+    def update(self, message: instrumentation.Message):
+        if isinstance(message, knowledge.RelationChangeMessage) or isinstance(
+            message, knowledge.ConceptChangeMessage
+        ):
+            self.is_updated = False
+            self._update_relations_and_concepts()
+
+    @abc.abstractmethod
+    def _extrapolate(self, ground_truth_uid, unconditional_probabilities):
+        pass
+
+    def _update_relations_and_concepts(self):
+        self.is_updated = True
+
+
+class DoNothingCHILLAXExtrapolator(CHILLAXExtrapolator):
+    def _extrapolate(self, ground_truth_uid, unconditional_probabilities):
+        return ground_truth_uid
+
+
+class SimpleThresholdCHILLAXExtrapolator(CHILLAXExtrapolator):
+    def __init__(
+        self,
+        knowledge_base,
+        apply_ground_truth,
+        ic_method: str = "sanchez_2011_modified",
+        threshold=0.55,
+    ):
+        self._ic_calc: information_content.InformationContentCalculator = information_content.InformationContentCalculatorFactory.create(
+            {"name": ic_method}
+        )
+
+        self._threshold = threshold
+        self._ic_cache = dict()
+
+        # This needs to come later because ic_calc is needed for _update_relations_and_concepts
+        super().__init__(
+            knowledge_base=knowledge_base, apply_ground_truth=apply_ground_truth
+        )
+
+    def _extrapolate(self, ground_truth_uid, unconditional_probabilities):
+        candidates = [
+            uid
+            for (uid, probability) in unconditional_probabilities.items()
+            if probability >= self._threshold
+        ]
+
+        if len(candidates) > 0:
+            candidates_with_ic = [(uid, self._ic_cache[uid]) for uid in candidates]
+
+            # Sort by probability first, see other methods for explanation of noise
+            candidates_with_ic = list(
+                sorted(
+                    candidates_with_ic,
+                    key=lambda x: unconditional_probabilities[x[0]]
+                    + np.random.normal(0, 0.0001),
+                    reverse=True,
+                )
+            )
+
+            # Sort by IC second. Stable sorting is guaranteed by python.
+            candidates_with_ic = list(
+                sorted(candidates_with_ic, key=lambda x: x[1], reverse=True)
+            )
+            return candidates_with_ic[0][0]
+        else:
+            return ground_truth_uid
+
+    def _update_relations_and_concepts(self):
+        try:
+            self._ic_cache = dict()
+            rgraph = self.knowledge_base.get_hyponymy_relation_rgraph()
+            for concept in self.knowledge_base.concepts():
+                self._ic_cache[
+                    concept.uid
+                ] = self._ic_calc.calculate_information_content(concept.uid, rgraph)
+            self.is_updated = True
+
+        except ValueError:
+            return
+
+
+class DepthStepsCHILLAXExtrapolator(CHILLAXExtrapolator):
+    def __init__(self, knowledge_base, apply_ground_truth, steps=1):
+        super().__init__(
+            knowledge_base=knowledge_base, apply_ground_truth=apply_ground_truth
+        )
+        self._steps = steps
+        self.rgraph = nx.DiGraph()
+        self.uid_to_depth = dict()
+        self.prediction_targets = set()
+
+    def _extrapolate(self, ground_truth_uid, unconditional_probabilities):
+        original_depth = self.uid_to_depth[ground_truth_uid]
+        allowed_depth = original_depth + self._steps
+
+        allowed_uids = set()
+        for descendant in nx.descendants(self.rgraph, ground_truth_uid):
+            if self.uid_to_depth[descendant] == allowed_depth:
+                allowed_uids |= {descendant}
+            elif (
+                self.uid_to_depth[descendant] < allowed_depth
+                and descendant in self.prediction_targets
+            ):
+                # We need to allow leaf nodes if they are shallower than the allowed depth.
+                # Otherwise, we won't have any candidates sometimes.
+                allowed_uids |= {descendant}
+
+        candidates = [
+            uid
+            for (uid, probability) in unconditional_probabilities.items()
+            if uid in allowed_uids
+        ]
+
+        if len(candidates) > 0:
+            # When sorting by probability, add a very small amount of noise because of the nodes
+            # that return exactly 0.5. Otherwise, the sorting is done alphabetically or topologically,
+            # creating a bias.
+            candidates = list(
+                sorted(
+                    candidates,
+                    key=lambda x: unconditional_probabilities[x]
+                    + np.random.normal(0, 0.0001),
+                    reverse=True,
+                )
+            )
+            return candidates[0]
+        else:
+            return ground_truth_uid
+
+    def _update_relations_and_concepts(self):
+        try:
+            self.rgraph = self.knowledge_base.get_hyponymy_relation_rgraph()
+            self.prediction_targets = {
+                concept.uid
+                for concept in self.knowledge_base.concepts(
+                    flags={knowledge.ConceptFlagV2.PREDICTION_TARGET}
+                )
+            }
+
+            root = list(nx.topological_sort(self.rgraph))[0]
+            self.uid_to_depth = {
+                concept.uid: nx.shortest_path(root, concept.uid)
+                for concept in self.knowledge_base.concepts()
+            }
+
+            self.is_updated = True
+        except ValueError:
+            return
+
+
+class ForcePredictionTargetCHILLAXExtrapolator(CHILLAXExtrapolator):
+    def __init__(self, knowledge_base, apply_ground_truth):
+        super().__init__(
+            knowledge_base=knowledge_base, apply_ground_truth=apply_ground_truth
+        )
+        self.prediction_targets = set()
+
+    def _extrapolate(self, ground_truth_uid, unconditional_probabilities):
+        candidates = [
+            uid
+            for (uid, probability) in unconditional_probabilities.items()
+            if uid in self.prediction_targets
+        ]
+
+        if len(candidates) > 0:
+            # When sorting by probability, add a very small amount of noise because of the nodes
+            # that return exactly 0.5. Otherwise, the sorting is done alphabetically or topologically,
+            # creating a bias.
+            candidates = list(
+                sorted(
+                    candidates,
+                    key=lambda x: unconditional_probabilities[x]
+                    + np.random.normal(0, 0.0001),
+                    reverse=True,
+                )
+            )
+            return candidates[0]
+        else:
+            return ground_truth_uid
+
+
+class CHILLAXExtrapolatorFactory(components.Factory):
+    name_to_class_mapping = {
+        "do_nothing": DoNothingCHILLAXExtrapolator,
+        "simple_threshold": SimpleThresholdCHILLAXExtrapolator,
+        "depth_steps": DepthStepsCHILLAXExtrapolator,
+        "force_prediction_target": ForcePredictionTargetCHILLAXExtrapolator,
+    }

+ 19 - 13
chillax/experiment_selfsupervised.py

@@ -1,7 +1,7 @@
 from chia.v2 import containers, instrumentation
 from chia.v2.components import classifiers
 from chia.v2 import helpers
-from chillax import chillax_classifier
+from chillax import chillax_classifier, chillax_extrapolator
 
 import config as pcfg
 import argparse
@@ -13,10 +13,11 @@ class CheapObserver(instrumentation.Observer):
 
 
 def main(config_files):
-    configs = [pcfg.config_from_json(config_file, read_from_file=True) for config_file in config_files]
-    config = pcfg.ConfigurationSet(
-        *configs
-    )
+    configs = [
+        pcfg.config_from_json(config_file, read_from_file=True)
+        for config_file in config_files
+    ]
+    config = pcfg.ConfigurationSet(*configs)
 
     classifiers.ClassifierFactory.name_to_class_mapping.update(
         {"chillax": chillax_classifier.CHILLAXKerasHC}
@@ -26,20 +27,25 @@ def main(config_files):
     helpers.setup_environment()
     obs = instrumentation.NamedObservable("Experiment")
 
-    experiment_container = containers.ExperimentContainer(
-        config, outer_observable=obs
-    )
-
-    obs.log_info("Hello!")
+    experiment_container = containers.ExperimentContainer(config, outer_observable=obs)
 
-    experiment_container.runner.run()
+    with experiment_container.exception_shroud:
+        obs.log_info("Hello!")
+        # Now, build the extrapolator
+        extrapolator = chillax_extrapolator.CHILLAXExtrapolatorFactory.create(
+            config["extrapolator"],
+            knowledge_base=experiment_container.knowledge_base,
+            observers=experiment_container.observers,
+        )
+        experiment_container.classifier.extrapolator = extrapolator
+        experiment_container.runner.run()
 
     # Make sure all the data is saved
-    obs.send_shutdown()
+    obs.send_shutdown(successful=True)
 
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser(prog="chillax.experiment_selfsupervised")
-    parser.add_argument("config_file", type=str, nargs='+')
+    parser.add_argument("config_file", type=str, nargs="+")
     args = parser.parse_args()
     main(config_files=args.config_file)

+ 81 - 0
chillax/information_content.py

@@ -0,0 +1,81 @@
+import math
+import abc
+import networkx as nx
+
+from chia.v2 import components
+
+
+class InformationContentCalculator(abc.ABC):
+    @abc.abstractmethod
+    def calculate_information_content(
+        self, concept_uid: str, rgraph: nx.DiGraph
+    ) -> float:
+        pass
+
+
+class Sanchez2011OriginalICC(InformationContentCalculator):
+    def calculate_information_content(self, concept_uid: str, rgraph: nx.DiGraph):
+        exclusive_leaves = set(
+            filter(
+                lambda n: rgraph.out_degree[n] == 0, nx.descendants(rgraph, concept_uid)
+            )
+        ) - {concept_uid}
+
+        all_leaves = set(filter(lambda n: rgraph.out_degree[n] == 0, rgraph.nodes))
+
+        ancestors = set(nx.ancestors(rgraph, concept_uid)) | {concept_uid}
+
+        index = -math.log(
+            ((len(exclusive_leaves) / float(len(ancestors))) + 1.0)
+            / (float(len(all_leaves)) + 1.0)
+        )
+
+        return math.fabs(index)
+
+
+class Sanchez2011ModifiedICC(InformationContentCalculator):
+    def calculate_information_content(self, concept_uid: str, rgraph: nx.DiGraph):
+
+        all_leaves = set(filter(lambda n: rgraph.out_degree[n] == 0, rgraph.nodes))
+
+        non_exclusive_leaves = (
+            set(nx.descendants(rgraph, concept_uid)) | {concept_uid}
+        ) & all_leaves
+
+        ancestors = set(nx.ancestors(rgraph, concept_uid)) | {concept_uid}
+
+        index = -math.log(
+            ((len(non_exclusive_leaves) / float(len(ancestors))) + 1.0)
+            / (float(len(all_leaves)) + 1.0)
+        )
+
+        return math.fabs(index)
+
+
+class Zhou2008ModifiedICC(InformationContentCalculator):
+    def calculate_information_content(self, concept_uid: str, rgraph: nx.DiGraph):
+        root = next(nx.topological_sort(rgraph))
+
+        all_leaves = set(filter(lambda n: rgraph.out_degree[n] == 0, rgraph.nodes))
+        all_leaf_depths = [
+            nx.shortest_path_length(rgraph, root, leaf) for leaf in all_leaves
+        ]
+        highest_depth = max(all_leaf_depths)
+        uid_depth = nx.shortest_path_length(rgraph, root, concept_uid)
+        descendants = set(nx.descendants(rgraph, concept_uid)) | {concept_uid}
+
+        k = 0.6  # Harispe et al. 2015, page 55 claims that this is the "original" value.
+        index1 = 1.0 - (math.log(len(descendants)) / math.log(len(rgraph.nodes)))
+        index2 = math.log(uid_depth + 1) / math.log(highest_depth + 1)
+
+        index = k * index1 + (1.0 - k) * index2
+
+        return math.fabs(index)
+
+
+class InformationContentCalculatorFactory(components.Factory):
+    name_to_class_mapping = {
+        "sanchez_2011_original": Sanchez2011OriginalICC,
+        "sanchez_2011_modified": Sanchez2011ModifiedICC,
+        "zhou_2008_modified": Zhou2008ModifiedICC,
+    }

+ 8 - 1
main.json

@@ -15,7 +15,9 @@
   ],
   "with_wordnet": true,
   "interactor": {
-    "name": "oracle"
+    "name": "noisy_oracle",
+    "noise_model": "Poisson",
+    "lambda_": 1.0
   },
   "observers": [
     {
@@ -95,5 +97,10 @@
         "warmup_lr": 0.01
       }
     }
+  },
+  "extrapolator": {
+    "name": "simple_threshold",
+    "threshold": 0.8,
+    "apply_ground_truth": true
   }
 }