|
@@ -307,7 +307,7 @@ class CHILLAXKerasHC(
|
|
|
self.report_metric("gain_from_weighting", gain)
|
|
|
self.loss_weights /= gain
|
|
|
|
|
|
- def loss(self, feature_batch, ground_truth, global_step):
|
|
|
+ def loss(self, feature_batch, ground_truth, weight_batch, global_step):
|
|
|
if not self.is_updated:
|
|
|
raise RuntimeError(
|
|
|
"This classifier is not yet ready to compute a loss. "
|
|
@@ -360,7 +360,7 @@ class CHILLAXKerasHC(
|
|
|
the_loss * loss_mask * self.loss_weights, axis=1
|
|
|
)
|
|
|
|
|
|
- return tf.reduce_mean(sum_per_batch_element)
|
|
|
+ return tf.reduce_mean(sum_per_batch_element * weight_batch)
|
|
|
|
|
|
def observe(self, samples, gt_resource_id):
|
|
|
self.maybe_update_embedding()
|