|
@@ -1,5 +1,5 @@
|
|
|
from chia.v2.components.classifiers import keras_hierarchicalclassification
|
|
from chia.v2.components.classifiers import keras_hierarchicalclassification
|
|
|
-from chia.v2 import instrumentation
|
|
|
|
|
|
|
+from chia.v2 import instrumentation, knowledge
|
|
|
|
|
|
|
|
import networkx as nx
|
|
import networkx as nx
|
|
|
import numpy as np
|
|
import numpy as np
|
|
@@ -41,7 +41,7 @@ class CHILLAXKerasHC(
|
|
|
self.fc_layer = None
|
|
self.fc_layer = None
|
|
|
self.uid_to_dimension = {}
|
|
self.uid_to_dimension = {}
|
|
|
self.graph = None
|
|
self.graph = None
|
|
|
- self.observed_uids = None
|
|
|
|
|
|
|
+ self.prediction_target_uids = None
|
|
|
self.topo_sorted_uids = None
|
|
self.topo_sorted_uids = None
|
|
|
self.loss_weights = None
|
|
self.loss_weights = None
|
|
|
self.update_embedding()
|
|
self.update_embedding()
|
|
@@ -107,7 +107,7 @@ class CHILLAXKerasHC(
|
|
|
|
|
|
|
|
if self._mlnp:
|
|
if self._mlnp:
|
|
|
for i, (uid, p) in enumerate(sorted_tuples):
|
|
for i, (uid, p) in enumerate(sorted_tuples):
|
|
|
- if uid not in self.observed_uids:
|
|
|
|
|
|
|
+ if uid not in self.prediction_target_uids:
|
|
|
sorted_tuples[i] = (uid, 0.0)
|
|
sorted_tuples[i] = (uid, 0.0)
|
|
|
|
|
|
|
|
if self._normalize_scores:
|
|
if self._normalize_scores:
|
|
@@ -117,12 +117,12 @@ class CHILLAXKerasHC(
|
|
|
return list(sorted_tuples)
|
|
return list(sorted_tuples)
|
|
|
|
|
|
|
|
def update_embedding(self):
|
|
def update_embedding(self):
|
|
|
- current_concepts = self.kb.all_concepts.values()
|
|
|
|
|
|
|
+ current_concepts = self.kb.concepts()
|
|
|
current_concept_count = len(current_concepts)
|
|
current_concept_count = len(current_concepts)
|
|
|
self.report_metric("current_concepts", current_concept_count)
|
|
self.report_metric("current_concepts", current_concept_count)
|
|
|
|
|
|
|
|
if current_concept_count == 0:
|
|
if current_concept_count == 0:
|
|
|
- return
|
|
|
|
|
|
|
+ return True
|
|
|
|
|
|
|
|
try:
|
|
try:
|
|
|
old_weights = self.fc_layer.get_weights()
|
|
old_weights = self.fc_layer.get_weights()
|
|
@@ -146,20 +146,22 @@ class CHILLAXKerasHC(
|
|
|
bias_initializer="zero",
|
|
bias_initializer="zero",
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- # We need to reverse the graph for comfort because "is-a" has the concepts
|
|
|
|
|
- self.graph = self.kb.all_relations["hypernymy"]["graph"].reverse(copy=True)
|
|
|
|
|
|
|
+ try:
|
|
|
|
|
+ self.graph = self.kb.get_hyponymy_relation_rgraph()
|
|
|
|
|
+ except ValueError:
|
|
|
|
|
+ return False
|
|
|
|
|
|
|
|
# Memorize topological sorting for later
|
|
# Memorize topological sorting for later
|
|
|
all_uids = nx.topological_sort(self.graph)
|
|
all_uids = nx.topological_sort(self.graph)
|
|
|
self.topo_sorted_uids = list(all_uids)
|
|
self.topo_sorted_uids = list(all_uids)
|
|
|
- assert len(self.kb.all_concepts) == len(self.topo_sorted_uids)
|
|
|
|
|
|
|
+ assert len(current_concepts) == len(self.topo_sorted_uids)
|
|
|
|
|
|
|
|
self.uid_to_dimension = {
|
|
self.uid_to_dimension = {
|
|
|
uid: dimension for dimension, uid in enumerate(self.topo_sorted_uids)
|
|
uid: dimension for dimension, uid in enumerate(self.topo_sorted_uids)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- self.observed_uids = {
|
|
|
|
|
- concept.data["uid"] for concept in self.kb.get_observed_concepts()
|
|
|
|
|
|
|
+ self.prediction_target_uids = {
|
|
|
|
|
+ concept.uid for concept in self.kb.concepts(flags={knowledge.ConceptFlagV2.PREDICTION_TARGET})
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
if len(old_weights) == 2:
|
|
if len(old_weights) == 2:
|
|
@@ -189,9 +191,10 @@ class CHILLAXKerasHC(
|
|
|
self.fc_layer.set_weights([new_weights, new_biases])
|
|
self.fc_layer.set_weights([new_weights, new_biases])
|
|
|
|
|
|
|
|
self.update_loss_weights()
|
|
self.update_loss_weights()
|
|
|
|
|
+ return True
|
|
|
|
|
|
|
|
def update_loss_weights(self):
|
|
def update_loss_weights(self):
|
|
|
- if len(self.observed_uids) == 0:
|
|
|
|
|
|
|
+ if len(self.prediction_target_uids) == 0:
|
|
|
self.log_debug("Skipping loss weight update, no concepts found.")
|
|
self.log_debug("Skipping loss weight update, no concepts found.")
|
|
|
self.loss_weights = []
|
|
self.loss_weights = []
|
|
|
return
|
|
return
|
|
@@ -204,7 +207,7 @@ class CHILLAXKerasHC(
|
|
|
# (1) Calculate "natural" weights by assuming uniform distribution
|
|
# (1) Calculate "natural" weights by assuming uniform distribution
|
|
|
# over observed concepts
|
|
# over observed concepts
|
|
|
occurences = {uid: 0 for uid in self.topo_sorted_uids}
|
|
occurences = {uid: 0 for uid in self.topo_sorted_uids}
|
|
|
- for uid in self.observed_uids:
|
|
|
|
|
|
|
+ for uid in self.prediction_target_uids:
|
|
|
affected_uids = {uid}
|
|
affected_uids = {uid}
|
|
|
affected_uids |= nx.ancestors(self.graph, uid)
|
|
affected_uids |= nx.ancestors(self.graph, uid)
|
|
|
for affected_uid in list(affected_uids):
|
|
for affected_uid in list(affected_uids):
|
|
@@ -254,7 +257,7 @@ class CHILLAXKerasHC(
|
|
|
|
|
|
|
|
for i, uid in enumerate(self.uid_to_dimension):
|
|
for i, uid in enumerate(self.uid_to_dimension):
|
|
|
descendants = set(nx.descendants(self.graph, uid)) | {uid}
|
|
descendants = set(nx.descendants(self.graph, uid)) | {uid}
|
|
|
- reachable_leaf_nodes = descendants.intersection(self.observed_uids)
|
|
|
|
|
|
|
+ reachable_leaf_nodes = descendants.intersection(self.prediction_target_uids)
|
|
|
self.loss_weights[i] *= len(reachable_leaf_nodes)
|
|
self.loss_weights[i] *= len(reachable_leaf_nodes)
|
|
|
|
|
|
|
|
# Test if any leaf nodes are reachable
|
|
# Test if any leaf nodes are reachable
|
|
@@ -289,6 +292,10 @@ class CHILLAXKerasHC(
|
|
|
self.loss_weights /= gain
|
|
self.loss_weights /= gain
|
|
|
|
|
|
|
|
def loss(self, feature_batch, ground_truth):
|
|
def loss(self, feature_batch, ground_truth):
|
|
|
|
|
+ if not self.is_updated:
|
|
|
|
|
+ raise RuntimeError("This classifier is not yet ready to compute a loss. "
|
|
|
|
|
+ "Check if it has been notified of a hyponymy relation.")
|
|
|
|
|
+
|
|
|
loss_mask = np.zeros((len(ground_truth), len(self.uid_to_dimension)))
|
|
loss_mask = np.zeros((len(ground_truth), len(self.uid_to_dimension)))
|
|
|
for i, label in enumerate(ground_truth):
|
|
for i, label in enumerate(ground_truth):
|
|
|
# Loss mask
|
|
# Loss mask
|
|
@@ -326,9 +333,7 @@ class CHILLAXKerasHC(
|
|
|
return tf.reduce_mean(sum_per_batch_element)
|
|
return tf.reduce_mean(sum_per_batch_element)
|
|
|
|
|
|
|
|
def observe(self, samples, gt_resource_id):
|
|
def observe(self, samples, gt_resource_id):
|
|
|
- if self.kb.get_concept_stamp() != self.last_observed_concept_stamp:
|
|
|
|
|
- self.update_embedding()
|
|
|
|
|
- self.last_observed_concept_stamp = self.kb.get_concept_stamp()
|
|
|
|
|
|
|
+ self.maybe_update_embedding()
|
|
|
|
|
|
|
|
def regularization_losses(self):
|
|
def regularization_losses(self):
|
|
|
return self.fc_layer.losses
|
|
return self.fc_layer.losses
|