Эх сурвалжийг харах

Classifier: Compatible with CHIA 2.0rc18 weight batches

Clemens-Alexander Brust 4 жил өмнө
parent
commit
98cdbe05f1

+ 2 - 2
chillax/chillax_classifier.py

@@ -307,7 +307,7 @@ class CHILLAXKerasHC(
         self.report_metric("gain_from_weighting", gain)
         self.report_metric("gain_from_weighting", gain)
         self.loss_weights /= gain
         self.loss_weights /= gain
 
 
-    def loss(self, feature_batch, ground_truth, global_step):
+    def loss(self, feature_batch, ground_truth, weight_batch, global_step):
         if not self.is_updated:
         if not self.is_updated:
             raise RuntimeError(
             raise RuntimeError(
                 "This classifier is not yet ready to compute a loss. "
                 "This classifier is not yet ready to compute a loss. "
@@ -360,7 +360,7 @@ class CHILLAXKerasHC(
             the_loss * loss_mask * self.loss_weights, axis=1
             the_loss * loss_mask * self.loss_weights, axis=1
         )
         )
 
 
-        return tf.reduce_mean(sum_per_batch_element)
+        return tf.reduce_mean(sum_per_batch_element * weight_batch)
 
 
     def observe(self, samples, gt_resource_id):
     def observe(self, samples, gt_resource_id):
         self.maybe_update_embedding()
         self.maybe_update_embedding()