ソースを参照

Refactored and with tests

Clemens-Alexander Brust 4 年 前
コミット
6b8742c34d

+ 4 - 9
chillax/experiment_selfsupervised.py

@@ -1,7 +1,6 @@
 from chia import containers, instrumentation
-from chia.components import classifiers, interactors
 from chia import helpers
-from chillax import chillax_classifier, chillax_extrapolator, noisy_oracle
+from chillax import methods
 
 import config as pcfg
 import argparse
@@ -20,12 +19,8 @@ def main(config_files):
     ] + [helpers.get_user_config()]
     config = pcfg.ConfigurationSet(*configs)
 
-    classifiers.ClassifierFactory.name_to_class_mapping.update(
-        {"chillax": chillax_classifier.CHILLAXKerasHC}
-    )
-    interactors.InteractorFactory.name_to_class_mapping.update(
-        {"noisy_oracle": noisy_oracle.NoisyOracleInteractor}
-    )
+    # Need to register our fancy new methods
+    methods.update_chia_factories()
 
     obs = instrumentation.NamedObservable("Experiment")
 
@@ -39,7 +34,7 @@ def main(config_files):
 
         # Now, build the extrapolator
         if "extrapolator" in config.keys(levels=1):
-            extrapolator = chillax_extrapolator.CHILLAXExtrapolatorFactory.create(
+            extrapolator = methods.CHILLAXExtrapolatorFactory.create(
                 config["extrapolator"],
                 knowledge_base=experiment_container.knowledge_base,
                 observers=experiment_container.observers,

+ 17 - 0
chillax/methods/__init__.py

@@ -0,0 +1,17 @@
+from chia.components import classifiers, interactors
+
+from chillax.methods import chillax_classifier, noisy_oracle
+
+from chillax.methods.chillax_extrapolator import CHILLAXExtrapolatorFactory
+
+
+def update_chia_factories():
+    classifiers.ClassifierFactory.name_to_class_mapping.update(
+        {"chillax": chillax_classifier.CHILLAXKerasHC}
+    )
+    interactors.InteractorFactory.name_to_class_mapping.update(
+        {"noisy_oracle": noisy_oracle.NoisyOracleInteractor}
+    )
+
+
+__all__ = ["CHILLAXExtrapolatorFactory", "update_chia_factories"]

+ 0 - 0
chillax/chillax_classifier.py → chillax/methods/chillax_classifier.py


+ 0 - 0
chillax/chillax_extrapolator.py → chillax/methods/chillax_extrapolator.py


+ 0 - 0
chillax/noisy_oracle.py → chillax/methods/noisy_oracle.py


+ 68 - 0
examples/configuration.json

@@ -0,0 +1,68 @@
+{
+  "meta": {
+    "name": "example-experiment"
+  },
+  "evaluators": [
+    {
+      "name": "accuracy"
+    }
+  ],
+  "with_wordnet": true,
+  "interactor": {
+    "name": "noisy_oracle",
+    "noise_model": "Inaccuracy"
+  },
+  "observers": [
+    {
+      "name": "stream"
+    }
+  ],
+  "sample_transformers": [
+    {
+      "name": "identity"
+    }
+  ],
+  "runner": {
+    "name": "epoch",
+    "epochs": 2,
+    "max_test_samples": 4
+  },
+  "dataset": {
+    "name": "icifar"
+  },
+  "extrapolator": {
+    "name": "do_nothing",
+    "apply_ground_truth": true
+  },
+  "model": {
+    "classifier": {
+      "name": "chillax",
+      "l2": 5e-5,
+      "force_prediction_targets": true,
+      "raw_output": false
+    },
+    "base_model": {
+      "name": "keras",
+      "trainer": {
+        "name": "fast_single_shot",
+        "batch_size": 2,
+        "inner_steps": 11
+      },
+      "feature_extractor": {
+        "side_length": 32,
+        "trainable": true,
+        "architecture": "ResNet50V2",
+        "l2": 5e-5,
+        "use_pretrained_weights": null
+      },
+      "optimizer": {
+        "name": "sgd",
+        "momentum": 0.9
+      },
+      "learning_rate_schedule": {
+        "name": "constant",
+        "initial_lr": 0.01
+      }
+    }
+  }
+}

+ 16 - 0
tests/test_experiment.py

@@ -0,0 +1,16 @@
+import os
+
+import config as pcfg
+import pytest
+
+from chia import containers, helpers, instrumentation
+from chia.components import classifiers
+
+
+def test_experiment():
+    """This tests runs the self-supervised experiment configuration once."""
+    from chillax import experiment_selfsupervised
+
+    example_config_files = ["examples/configuration.json"]
+
+    experiment_selfsupervised.main(example_config_files)