|
@@ -0,0 +1,136 @@
|
|
|
+from chia.components.sample_transformers.sample_transformer import SampleTransformer
|
|
|
+from chia import knowledge
|
|
|
+from chia import instrumentation
|
|
|
+from chia import data
|
|
|
+
|
|
|
+from chillax import information_content
|
|
|
+
|
|
|
+import typing
|
|
|
+
|
|
|
+import numpy as np
|
|
|
+
|
|
|
+
|
|
|
+class SampleWeightByICSampleTransfomer(SampleTransformer, instrumentation.Observer):
|
|
|
+ """
|
|
|
+ SampleWeightByICSampleTransformer: Adds a training_weight to samples based on IC.
|
|
|
+
|
|
|
+ The update pattern for the KB is taken from chillax_extrapolator.
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ kb: knowledge.KnowledgeBase,
|
|
|
+ ic_method: typing.Optional[str] = None,
|
|
|
+ coef_a=0.0,
|
|
|
+ ):
|
|
|
+ SampleTransformer.__init__(self, kb=kb)
|
|
|
+ instrumentation.Observer.__init__(self)
|
|
|
+
|
|
|
+ self.coef_a = coef_a
|
|
|
+
|
|
|
+ # Information Content
|
|
|
+ self.kb.register(self)
|
|
|
+ self.is_updated = False
|
|
|
+
|
|
|
+ self._ic_calc: information_content.InformationContentCalculator = (
|
|
|
+ information_content.InformationContentCalculatorFactory.create(
|
|
|
+ {"name": ic_method if ic_method is not None else "zhou_2008_modified"}
|
|
|
+ )
|
|
|
+ )
|
|
|
+ self._ic_cache = dict()
|
|
|
+ self.update_relations_and_concepts()
|
|
|
+
|
|
|
+ def transform(
|
|
|
+ self,
|
|
|
+ samples: typing.List[data.Sample],
|
|
|
+ is_training: bool,
|
|
|
+ label_resource_id: str,
|
|
|
+ ):
|
|
|
+ if not self.is_updated:
|
|
|
+ raise RuntimeError(
|
|
|
+ "This sample_weight_by_ic is not updated. "
|
|
|
+ "Please check if it is subscribed to "
|
|
|
+ "RelationChange and ConceptChange messages."
|
|
|
+ )
|
|
|
+
|
|
|
+ # A lot of this stuff assumes len(samples) > 0, so:
|
|
|
+ if len(samples) == 0:
|
|
|
+ return samples
|
|
|
+
|
|
|
+ # Calculate the IC of each sample
|
|
|
+ sample_ics = {
|
|
|
+ sample.get_resource("uid"): self._ic_cache[
|
|
|
+ sample.get_resource(label_resource_id)
|
|
|
+ ]
|
|
|
+ for sample in samples
|
|
|
+ }
|
|
|
+ ic_array = np.asarray(list(sample_ics.values()))
|
|
|
+
|
|
|
+ self.log_info(f'Information for {"training" if is_training else "test"} data:')
|
|
|
+ self.log_info(
|
|
|
+ f"IC mean {np.mean(ic_array)}, std {np.std(ic_array)}, median {np.median(ic_array)}"
|
|
|
+ )
|
|
|
+ if not is_training:
|
|
|
+ self.log_debug(f"Not touching test data.")
|
|
|
+ return samples
|
|
|
+
|
|
|
+ sample_ics_exp = {
|
|
|
+ sample_uid: np.exp(sample_ic)
|
|
|
+ for sample_uid, sample_ic in sample_ics.items()
|
|
|
+ }
|
|
|
+ sample_ics_exp_tf = {
|
|
|
+ sample_uid: np.exp(self.coef_a * sample_ic)
|
|
|
+ for sample_uid, sample_ic in sample_ics.items()
|
|
|
+ }
|
|
|
+
|
|
|
+ ic_exp_array = np.asarray(list(sample_ics_exp.values()))
|
|
|
+ ic_exp_tf_array = np.asarray(list(sample_ics_exp_tf.values()))
|
|
|
+
|
|
|
+ self.log_info(
|
|
|
+ f"IC exp. mean {np.mean(ic_exp_array)}, std {np.std(ic_exp_array)}, median {np.median(ic_exp_array)}"
|
|
|
+ )
|
|
|
+ self.log_info(
|
|
|
+ f"IC exp. tf mean {np.mean(ic_exp_tf_array)}, std {np.std(ic_exp_tf_array)}, median {np.median(ic_exp_tf_array)}"
|
|
|
+ )
|
|
|
+
|
|
|
+ ic_exp_tf_sum = np.sum(ic_exp_tf_array)
|
|
|
+
|
|
|
+ # Apply softmax
|
|
|
+ sample_weights = {
|
|
|
+ sample_uid: sample_ic_exp_tf / ic_exp_tf_sum
|
|
|
+ for sample_uid, sample_ic_exp_tf in sample_ics_exp_tf.items()
|
|
|
+ }
|
|
|
+ sample_weight_array = np.asarray(list(sample_weights.values())) * len(samples)
|
|
|
+ self.log_info(
|
|
|
+ f"Weight*cnt mean {np.mean(sample_weight_array)}, std {np.std(sample_weight_array)}, median {np.median(sample_weight_array)}"
|
|
|
+ )
|
|
|
+
|
|
|
+ return [
|
|
|
+ sample.add_resource(
|
|
|
+ self.__class__.__name__,
|
|
|
+ "training_weight",
|
|
|
+ sample_weights[sample.get_resource("uid")],
|
|
|
+ )
|
|
|
+ for sample in samples
|
|
|
+ ]
|
|
|
+
|
|
|
+ def update_relations_and_concepts(self):
|
|
|
+ try:
|
|
|
+ # Update Information Content Cache
|
|
|
+ self._ic_cache = dict()
|
|
|
+ rgraph = self.kb.get_hyponymy_relation_rgraph()
|
|
|
+ for concept in self.kb.concepts():
|
|
|
+ self._ic_cache[
|
|
|
+ concept.uid
|
|
|
+ ] = self._ic_calc.calculate_information_content(concept.uid, rgraph)
|
|
|
+
|
|
|
+ self.is_updated = True
|
|
|
+ except ValueError as verr:
|
|
|
+ self.log_warning(f"Could not update sample_weight_by_ic. {verr.args}")
|
|
|
+
|
|
|
+ def update(self, message: instrumentation.Message):
|
|
|
+ if isinstance(message, knowledge.RelationChangeMessage) or isinstance(
|
|
|
+ message, knowledge.ConceptChangeMessage
|
|
|
+ ):
|
|
|
+ self.is_updated = False
|
|
|
+ self.update_relations_and_concepts()
|