|
@@ -47,6 +47,7 @@ class CHILLAXKerasHC(
|
|
|
self.extrapolator = None
|
|
|
|
|
|
self._reporting_step_counter = 0
|
|
|
+ self._last_reported_step = -1
|
|
|
self._running_sample_count = 0
|
|
|
self._running_changed_samples = 0
|
|
|
|
|
@@ -306,13 +307,15 @@ class CHILLAXKerasHC(
|
|
|
self.report_metric("gain_from_weighting", gain)
|
|
|
self.loss_weights /= gain
|
|
|
|
|
|
- def loss(self, feature_batch, ground_truth):
|
|
|
+ def loss(self, feature_batch, ground_truth, global_step):
|
|
|
if not self.is_updated:
|
|
|
raise RuntimeError(
|
|
|
"This classifier is not yet ready to compute a loss. "
|
|
|
"Check if it has been notified of a hyponymy relation."
|
|
|
)
|
|
|
|
|
|
+ self._reporting_step_counter = global_step
|
|
|
+
|
|
|
# (1) Predict
|
|
|
prediction = self.predict_embedded(feature_batch)
|
|
|
|
|
@@ -442,18 +445,19 @@ class CHILLAXKerasHC(
|
|
|
)
|
|
|
|
|
|
if self._reporting_step_counter % 10 == 9:
|
|
|
- if self._running_sample_count > 0:
|
|
|
- self.report_metric(
|
|
|
- "extrapolation_changed_sample_fraction",
|
|
|
- float(self._running_changed_samples)
|
|
|
- / float(self._running_sample_count),
|
|
|
- step=self._reporting_step_counter,
|
|
|
- )
|
|
|
+ if self._last_reported_step < self._reporting_step_counter:
|
|
|
+ if self._running_sample_count > 0:
|
|
|
+ self.report_metric(
|
|
|
+ "extrapolation_changed_sample_fraction",
|
|
|
+ float(self._running_changed_samples)
|
|
|
+ / float(self._running_sample_count),
|
|
|
+ step=self._reporting_step_counter,
|
|
|
+ )
|
|
|
|
|
|
- self._running_changed_samples = 0
|
|
|
- self._running_sample_count = 0
|
|
|
+ self._running_changed_samples = 0
|
|
|
+ self._running_sample_count = 0
|
|
|
|
|
|
- self._reporting_step_counter += 1
|
|
|
+ self._last_reported_step = self._reporting_step_counter
|
|
|
|
|
|
return extrapolated_ground_truth
|
|
|
else:
|