|
@@ -46,6 +46,10 @@ class CHILLAXKerasHC(
|
|
|
|
|
|
self.extrapolator = None
|
|
|
|
|
|
+ self._reporting_step_counter = 0
|
|
|
+ self._running_sample_count = 0
|
|
|
+ self._running_changed_samples = 0
|
|
|
+
|
|
|
def predict_embedded(self, feature_batch):
|
|
|
return self.fc_layer(feature_batch)
|
|
|
|
|
@@ -427,6 +431,30 @@ class CHILLAXKerasHC(
|
|
|
)
|
|
|
]
|
|
|
|
|
|
+ # Handle reporting
|
|
|
+ self._running_sample_count += len(ground_truth)
|
|
|
+ self._running_changed_samples += sum(
|
|
|
+ [
|
|
|
+ 1
|
|
|
+ for egt, gt in zip(extrapolated_ground_truth, ground_truth)
|
|
|
+ if egt != gt
|
|
|
+ ]
|
|
|
+ )
|
|
|
+
|
|
|
+ if self._reporting_step_counter % 10 == 9:
|
|
|
+ if self._running_sample_count > 0:
|
|
|
+ self.report_metric(
|
|
|
+ "extrapolation_changed_sample_fraction",
|
|
|
+ float(self._running_changed_samples)
|
|
|
+ / float(self._running_sample_count),
|
|
|
+ step=self._reporting_step_counter,
|
|
|
+ )
|
|
|
+
|
|
|
+ self._running_changed_samples = 0
|
|
|
+ self._running_sample_count = 0
|
|
|
+
|
|
|
+ self._reporting_step_counter += 1
|
|
|
+
|
|
|
return extrapolated_ground_truth
|
|
|
else:
|
|
|
return ground_truth
|