Explorar o código

Classifier: Compatible with CHIA 2.0rc18 weight batches

Clemens-Alexander Brust %!s(int64=4) %!d(string=hai) anos
pai
achega
98cdbe05f1
Modificáronse 1 ficheiros con 2 adicións e 2 borrados
  1. 2 2
      chillax/chillax_classifier.py

+ 2 - 2
chillax/chillax_classifier.py

@@ -307,7 +307,7 @@ class CHILLAXKerasHC(
         self.report_metric("gain_from_weighting", gain)
         self.loss_weights /= gain
 
-    def loss(self, feature_batch, ground_truth, global_step):
+    def loss(self, feature_batch, ground_truth, weight_batch, global_step):
         if not self.is_updated:
             raise RuntimeError(
                 "This classifier is not yet ready to compute a loss. "
@@ -360,7 +360,7 @@ class CHILLAXKerasHC(
             the_loss * loss_mask * self.loss_weights, axis=1
         )
 
-        return tf.reduce_mean(sum_per_batch_element)
+        return tf.reduce_mean(sum_per_batch_element * weight_batch)
 
     def observe(self, samples, gt_resource_id):
         self.maybe_update_embedding()