Clemens-Alexander Brust 5 лет назад
Родитель
Сommit
970cb5c036
2 измененных файлов с 32 добавлено и 20 удалено
  1. 11 4
      chillax/experiment_selfsupervised.py
  2. 21 16
      chillax/method.py

+ 11 - 4
chillax/experiment_selfsupervised.py

@@ -26,7 +26,7 @@ def main():
                 "augmentation": {},
                 "trainer": {
                     "name": "fast_single_shot",
-                    "batch_size": 8,
+                    "batch_size": 2,
                     "inner_steps": 2000,
                 },
                 "feature_extractor": {"side_length": 448},
@@ -47,12 +47,19 @@ def main():
     experiment_container = containers.ExperimentContainer(
         config, observers=(CheapObserver(),)
     )
-    experiment_container.knowledge_base.observe_concepts(
-        experiment_container.dataset.observable_concepts()
-    )
 
     dataset = experiment_container.dataset
 
+    # Get prediction targets
+    experiment_container.knowledge_base.add_prediction_targets(
+        dataset.prediction_targets()
+    )
+
+    # Add relation source
+    experiment_container.knowledge_base.add_hyponymy_relation([dataset.get_hyponymy_relation_source()])
+
+    exit
+
     base_model = experiment_container.base_model
     training_samples = dataset.train_pool(0, "label_gt")
     base_model.observe(training_samples, "label_gt")

+ 21 - 16
chillax/method.py

@@ -1,5 +1,5 @@
 from chia.v2.components.classifiers import keras_hierarchicalclassification
-from chia.v2 import instrumentation
+from chia.v2 import instrumentation, knowledge
 
 import networkx as nx
 import numpy as np
@@ -41,7 +41,7 @@ class CHILLAXKerasHC(
         self.fc_layer = None
         self.uid_to_dimension = {}
         self.graph = None
-        self.observed_uids = None
+        self.prediction_target_uids = None
         self.topo_sorted_uids = None
         self.loss_weights = None
         self.update_embedding()
@@ -107,7 +107,7 @@ class CHILLAXKerasHC(
 
             if self._mlnp:
                 for i, (uid, p) in enumerate(sorted_tuples):
-                    if uid not in self.observed_uids:
+                    if uid not in self.prediction_target_uids:
                         sorted_tuples[i] = (uid, 0.0)
 
             if self._normalize_scores:
@@ -117,12 +117,12 @@ class CHILLAXKerasHC(
             return list(sorted_tuples)
 
     def update_embedding(self):
-        current_concepts = self.kb.all_concepts.values()
+        current_concepts = self.kb.concepts()
         current_concept_count = len(current_concepts)
         self.report_metric("current_concepts", current_concept_count)
 
         if current_concept_count == 0:
-            return
+            return True
 
         try:
             old_weights = self.fc_layer.get_weights()
@@ -146,20 +146,22 @@ class CHILLAXKerasHC(
             bias_initializer="zero",
         )
 
-        # We need to reverse the graph for comfort because "is-a" has the concepts
-        self.graph = self.kb.all_relations["hypernymy"]["graph"].reverse(copy=True)
+        try:
+            self.graph = self.kb.get_hyponymy_relation_rgraph()
+        except ValueError:
+            return False
 
         # Memorize topological sorting for later
         all_uids = nx.topological_sort(self.graph)
         self.topo_sorted_uids = list(all_uids)
-        assert len(self.kb.all_concepts) == len(self.topo_sorted_uids)
+        assert len(current_concepts) == len(self.topo_sorted_uids)
 
         self.uid_to_dimension = {
             uid: dimension for dimension, uid in enumerate(self.topo_sorted_uids)
         }
 
-        self.observed_uids = {
-            concept.data["uid"] for concept in self.kb.get_observed_concepts()
+        self.prediction_target_uids = {
+            concept.uid for concept in self.kb.concepts(flags={knowledge.ConceptFlagV2.PREDICTION_TARGET})
         }
 
         if len(old_weights) == 2:
@@ -189,9 +191,10 @@ class CHILLAXKerasHC(
             self.fc_layer.set_weights([new_weights, new_biases])
 
         self.update_loss_weights()
+        return True
 
     def update_loss_weights(self):
-        if len(self.observed_uids) == 0:
+        if len(self.prediction_target_uids) == 0:
             self.log_debug("Skipping loss weight update, no concepts found.")
             self.loss_weights = []
             return
@@ -204,7 +207,7 @@ class CHILLAXKerasHC(
         # (1) Calculate "natural" weights by assuming uniform distribution
         # over observed concepts
         occurences = {uid: 0 for uid in self.topo_sorted_uids}
-        for uid in self.observed_uids:
+        for uid in self.prediction_target_uids:
             affected_uids = {uid}
             affected_uids |= nx.ancestors(self.graph, uid)
             for affected_uid in list(affected_uids):
@@ -254,7 +257,7 @@ class CHILLAXKerasHC(
 
                 for i, uid in enumerate(self.uid_to_dimension):
                     descendants = set(nx.descendants(self.graph, uid)) | {uid}
-                    reachable_leaf_nodes = descendants.intersection(self.observed_uids)
+                    reachable_leaf_nodes = descendants.intersection(self.prediction_target_uids)
                     self.loss_weights[i] *= len(reachable_leaf_nodes)
 
                     # Test if any leaf nodes are reachable
@@ -289,6 +292,10 @@ class CHILLAXKerasHC(
         self.loss_weights /= gain
 
     def loss(self, feature_batch, ground_truth):
+        if not self.is_updated:
+            raise RuntimeError("This classifier is not yet ready to compute a loss. "
+                               "Check if it has been notified of a hyponymy relation.")
+
         loss_mask = np.zeros((len(ground_truth), len(self.uid_to_dimension)))
         for i, label in enumerate(ground_truth):
             # Loss mask
@@ -326,9 +333,7 @@ class CHILLAXKerasHC(
         return tf.reduce_mean(sum_per_batch_element)
 
     def observe(self, samples, gt_resource_id):
-        if self.kb.get_concept_stamp() != self.last_observed_concept_stamp:
-            self.update_embedding()
-            self.last_observed_concept_stamp = self.kb.get_concept_stamp()
+        self.maybe_update_embedding()
 
     def regularization_losses(self):
         return self.fc_layer.losses