Selaa lähdekoodia

Added sample_weight_by_ic method for new CHIA version

Clemens-Alexander Brust 4 vuotta sitten
vanhempi
commit
9d109c0257
4 muutettua tiedostoa jossa 143 lisäystä ja 4 poistoa
  1. 5 2
      chillax/methods/__init__.py
  2. 136 0
      chillax/methods/sample_weight_by_ic.py
  3. 1 1
      examples/configuration.json
  4. 1 1
      setup.py

+ 5 - 2
chillax/methods/__init__.py

@@ -1,6 +1,6 @@
-from chia.components import classifiers, interactors
+from chia.components import classifiers, interactors, sample_transformers
 
-from chillax.methods import chillax_classifier, noisy_oracle
+from chillax.methods import chillax_classifier, noisy_oracle, sample_weight_by_ic
 
 from chillax.methods.chillax_extrapolator import CHILLAXExtrapolatorFactory
 
@@ -12,6 +12,9 @@ def update_chia_factories():
     interactors.InteractorFactory.name_to_class_mapping.update(
         {"noisy_oracle": noisy_oracle.NoisyOracleInteractor}
     )
+    sample_transformers.SampleTransformerFactory.name_to_class_mapping.update(
+        {"sample_weight_by_ic": sample_weight_by_ic.SampleWeightByICSampleTransfomer}
+    )
 
 
 __all__ = ["CHILLAXExtrapolatorFactory", "update_chia_factories"]

+ 136 - 0
chillax/methods/sample_weight_by_ic.py

@@ -0,0 +1,136 @@
+from chia.components.sample_transformers.sample_transformer import SampleTransformer
+from chia import knowledge
+from chia import instrumentation
+from chia import data
+
+from chillax import information_content
+
+import typing
+
+import numpy as np
+
+
+class SampleWeightByICSampleTransfomer(SampleTransformer, instrumentation.Observer):
+    """
+    SampleWeightByICSampleTransformer: Adds a training_weight to samples based on IC.
+
+    The update pattern for the KB is taken from chillax_extrapolator.
+    """
+
+    def __init__(
+        self,
+        kb: knowledge.KnowledgeBase,
+        ic_method: typing.Optional[str] = None,
+        coef_a=0.0,
+    ):
+        SampleTransformer.__init__(self, kb=kb)
+        instrumentation.Observer.__init__(self)
+
+        self.coef_a = coef_a
+
+        # Information Content
+        self.kb.register(self)
+        self.is_updated = False
+
+        self._ic_calc: information_content.InformationContentCalculator = (
+            information_content.InformationContentCalculatorFactory.create(
+                {"name": ic_method if ic_method is not None else "zhou_2008_modified"}
+            )
+        )
+        self._ic_cache = dict()
+        self.update_relations_and_concepts()
+
+    def transform(
+        self,
+        samples: typing.List[data.Sample],
+        is_training: bool,
+        label_resource_id: str,
+    ):
+        if not self.is_updated:
+            raise RuntimeError(
+                "This sample_weight_by_ic is not updated. "
+                "Please check if it is subscribed to "
+                "RelationChange and ConceptChange messages."
+            )
+
+        # A lot of this stuff assumes len(samples) > 0, so:
+        if len(samples) == 0:
+            return samples
+
+        # Calculate the IC of each sample
+        sample_ics = {
+            sample.get_resource("uid"): self._ic_cache[
+                sample.get_resource(label_resource_id)
+            ]
+            for sample in samples
+        }
+        ic_array = np.asarray(list(sample_ics.values()))
+
+        self.log_info(f'Information for {"training" if is_training else "test"} data:')
+        self.log_info(
+            f"IC mean        {np.mean(ic_array)}, std {np.std(ic_array)}, median {np.median(ic_array)}"
+        )
+        if not is_training:
+            self.log_debug(f"Not touching test data.")
+            return samples
+
+        sample_ics_exp = {
+            sample_uid: np.exp(sample_ic)
+            for sample_uid, sample_ic in sample_ics.items()
+        }
+        sample_ics_exp_tf = {
+            sample_uid: np.exp(self.coef_a * sample_ic)
+            for sample_uid, sample_ic in sample_ics.items()
+        }
+
+        ic_exp_array = np.asarray(list(sample_ics_exp.values()))
+        ic_exp_tf_array = np.asarray(list(sample_ics_exp_tf.values()))
+
+        self.log_info(
+            f"IC exp.    mean {np.mean(ic_exp_array)}, std {np.std(ic_exp_array)}, median {np.median(ic_exp_array)}"
+        )
+        self.log_info(
+            f"IC exp. tf mean {np.mean(ic_exp_tf_array)}, std {np.std(ic_exp_tf_array)}, median {np.median(ic_exp_tf_array)}"
+        )
+
+        ic_exp_tf_sum = np.sum(ic_exp_tf_array)
+
+        # Apply softmax
+        sample_weights = {
+            sample_uid: sample_ic_exp_tf / ic_exp_tf_sum
+            for sample_uid, sample_ic_exp_tf in sample_ics_exp_tf.items()
+        }
+        sample_weight_array = np.asarray(list(sample_weights.values())) * len(samples)
+        self.log_info(
+            f"Weight*cnt mean {np.mean(sample_weight_array)}, std {np.std(sample_weight_array)}, median {np.median(sample_weight_array)}"
+        )
+
+        return [
+            sample.add_resource(
+                self.__class__.__name__,
+                "training_weight",
+                sample_weights[sample.get_resource("uid")],
+            )
+            for sample in samples
+        ]
+
+    def update_relations_and_concepts(self):
+        try:
+            # Update Information Content Cache
+            self._ic_cache = dict()
+            rgraph = self.kb.get_hyponymy_relation_rgraph()
+            for concept in self.kb.concepts():
+                self._ic_cache[
+                    concept.uid
+                ] = self._ic_calc.calculate_information_content(concept.uid, rgraph)
+
+            self.is_updated = True
+        except ValueError as verr:
+            self.log_warning(f"Could not update sample_weight_by_ic. {verr.args}")
+
+    def update(self, message: instrumentation.Message):
+        if isinstance(message, knowledge.RelationChangeMessage) or isinstance(
+            message, knowledge.ConceptChangeMessage
+        ):
+            self.is_updated = False
+            self.update_relations_and_concepts()

+ 1 - 1
examples/configuration.json

@@ -19,7 +19,7 @@
   ],
   "sample_transformers": [
     {
-      "name": "identity"
+      "name": "sample_weight_by_ic"
     }
   ],
   "runner": {

+ 1 - 1
setup.py

@@ -17,7 +17,7 @@ setup(
     packages=find_packages(),
     python_requires=">=3.7",
     install_requires=[
-        "chia>=2.0rc18",
+        "chia>=2.0rc20",
     ],
     # metadata to display on PyPI
     author="Clemens-Alexander Brust",