ソースを参照

Fixed broken extrapolator

Clemens-Alexander Brust 5 年 前
コミット
661c46eb53
3 ファイル変更35 行追加7 行削除
  1. 28 0
      chillax/chillax_classifier.py
  2. 5 4
      chillax/experiment_selfsupervised.py
  3. 2 3
      main.json

+ 28 - 0
chillax/chillax_classifier.py

@@ -46,6 +46,10 @@ class CHILLAXKerasHC(
 
         self.extrapolator = None
 
+        self._reporting_step_counter = 0
+        self._running_sample_count = 0
+        self._running_changed_samples = 0
+
     def predict_embedded(self, feature_batch):
         return self.fc_layer(feature_batch)
 
@@ -427,6 +431,30 @@ class CHILLAXKerasHC(
                     )
                 ]
 
+            # Handle reporting
+            self._running_sample_count += len(ground_truth)
+            self._running_changed_samples += sum(
+                [
+                    1
+                    for egt, gt in zip(extrapolated_ground_truth, ground_truth)
+                    if egt != gt
+                ]
+            )
+
+            if self._reporting_step_counter % 10 == 9:
+                if self._running_sample_count > 0:
+                    self.report_metric(
+                        "extrapolation_changed_sample_fraction",
+                        float(self._running_changed_samples)
+                        / float(self._running_sample_count),
+                        step=self._reporting_step_counter,
+                    )
+
+                self._running_changed_samples = 0
+                self._running_sample_count = 0
+
+            self._reporting_step_counter += 1
+
             return extrapolated_ground_truth
         else:
             return ground_truth

+ 5 - 4
chillax/experiment_selfsupervised.py

@@ -13,18 +13,19 @@ class CheapObserver(instrumentation.Observer):
 
 
 def main(config_files):
+    # This shouldn't be necessary, but...
+    helpers.setup_environment()
+
     configs = [
         pcfg.config_from_json(config_file, read_from_file=True)
         for config_file in config_files
-    ]
+    ] + [helpers.get_user_config()]
     config = pcfg.ConfigurationSet(*configs)
 
     classifiers.ClassifierFactory.name_to_class_mapping.update(
         {"chillax": chillax_classifier.CHILLAXKerasHC}
     )
 
-    # This shouldn't be necessary, but...
-    helpers.setup_environment()
     obs = instrumentation.NamedObservable("Experiment")
 
     experiment_container = containers.ExperimentContainer(config, outer_observable=obs)
@@ -33,7 +34,7 @@ def main(config_files):
         obs.log_info("Hello!")
 
         # Now, build the extrapolator
-        if "extrapolator" in config.keys():
+        if "extrapolator" in config.keys(levels=1):
             extrapolator = chillax_extrapolator.CHILLAXExtrapolatorFactory.create(
                 config["extrapolator"],
                 knowledge_base=experiment_container.knowledge_base,

+ 2 - 3
main.json

@@ -33,7 +33,6 @@
   },
   "dataset": {
     "name": "nabirds",
-    "base_path": "/home/brust/datasets/nabirds",
     "side_length": 512,
     "use_lazy_mode": true
   },
@@ -41,8 +40,8 @@
     "classifier": {
       "name": "chillax",
       "l2": 5e-5,
-      "mlnp": true,
-      "normalize_scores": true
+      "force_prediction_targets": true,
+      "raw_output": false
     },
     "base_model": {
       "name": "keras",