Clemens-Alexander Brust 5 лет назад
Родитель
Сommit
a40b0a07df
2 измененных файлов с 45 добавлено и 22 удалено
  1. 18 19
      chillax/experiment_selfsupervised.py
  2. 27 3
      main.json

+ 18 - 19
chillax/experiment_selfsupervised.py

@@ -3,15 +3,20 @@ from chia.v2.components import classifiers
 from chia.v2 import helpers
 from chia.v2 import helpers
 from chillax import chillax_classifier
 from chillax import chillax_classifier
 
 
+import config as pcfg
+import argparse
+
 
 
 class CheapObserver(instrumentation.Observer):
 class CheapObserver(instrumentation.Observer):
     def update(self, message: instrumentation.Message):
     def update(self, message: instrumentation.Message):
         print(f"Message: {message}")
         print(f"Message: {message}")
 
 
 
 
-def main():
-    import json
-    config = json.load(open("main.json"))
+def main(config_files):
+    configs = [pcfg.config_from_json(config_file, read_from_file=True) for config_file in config_files]
+    config = pcfg.ConfigurationSet(
+        *configs
+    )
 
 
     classifiers.ClassifierFactory.name_to_class_mapping.update(
     classifiers.ClassifierFactory.name_to_class_mapping.update(
         {"chillax": chillax_classifier.CHILLAXKerasHC}
         {"chillax": chillax_classifier.CHILLAXKerasHC}
@@ -19,28 +24,22 @@ def main():
 
 
     # This shouldn't be necessary, but...
     # This shouldn't be necessary, but...
     helpers.setup_environment()
     helpers.setup_environment()
+    obs = instrumentation.NamedObservable("Experiment")
 
 
     experiment_container = containers.ExperimentContainer(
     experiment_container = containers.ExperimentContainer(
-        config, observers=(CheapObserver(),)
+        config, outer_observable=obs
     )
     )
 
 
-    dataset = experiment_container.dataset
-    base_model = experiment_container.base_model
-
-    training_samples = dataset.train_pool(0, "label_gt")
-    training_samples = experiment_container.interactor.query_annotations_for(training_samples, "label_gt", "label_ann")
+    obs.log_info("Hello!")
 
 
-    base_model.observe(training_samples, "label_ann")
+    experiment_container.runner.run()
 
 
-    test_samples = dataset.test_pool(0, "label_gt")[:100]
-    test_samples = base_model.predict(test_samples, "label_pred")
-    result_dict = {}
-    for evaluator in experiment_container.evaluators:
-        evaluator.update(test_samples, "label_gt", "label_pred")
-        result_dict.update(evaluator.result())
-        evaluator.reset()
-    print(result_dict)
+    # Make sure all the data is saved
+    obs.send_shutdown()
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
-    main()
+    parser = argparse.ArgumentParser(prog="chillax.experiment_selfsupervised")
+    parser.add_argument("config_file", type=str, nargs='+')
+    args = parser.parse_args()
+    main(config_files=args.config_file)

+ 27 - 3
main.json

@@ -1,13 +1,34 @@
 {
 {
+  "meta": {
+    "name": "chillax-selfsupervised"
+  },
   "evaluators": [
   "evaluators": [
     {
     {
       "name": "accuracy"
       "name": "accuracy"
+    },
+    {
+      "name": "hierarchical"
+    },
+    {
+      "name": "topk_accuracy"
     }
     }
   ],
   ],
   "with_wordnet": true,
   "with_wordnet": true,
   "interactor": {
   "interactor": {
     "name": "oracle"
     "name": "oracle"
   },
   },
+  "observers": [
+    {
+      "name": "stream"
+    },
+    {
+      "name": "json"
+    }
+  ],
+  "runner": {
+    "name": "epoch",
+    "epochs": 10
+  },
   "dataset": {
   "dataset": {
     "name": "nabirds",
     "name": "nabirds",
     "base_path": "/home/brust/datasets/nabirds",
     "base_path": "/home/brust/datasets/nabirds",
@@ -17,7 +38,9 @@
   "model": {
   "model": {
     "classifier": {
     "classifier": {
       "name": "chillax",
       "name": "chillax",
-      "l2": 5e-5
+      "l2": 5e-5,
+      "mlnp": true,
+      "normalize_scores": true
     },
     },
     "base_model": {
     "base_model": {
       "name": "keras",
       "name": "keras",
@@ -33,7 +56,7 @@
         "name": "fast_single_shot",
         "name": "fast_single_shot",
         "batch_size": 16,
         "batch_size": 16,
         "sequential_training_batches": 2,
         "sequential_training_batches": 2,
-        "inner_steps": 59760
+        "inner_steps": 5976
       },
       },
       "feature_extractor": {
       "feature_extractor": {
         "side_length": 448,
         "side_length": 448,
@@ -59,7 +82,8 @@
         ]
         ]
       },
       },
       "optimizer": {
       "optimizer": {
-        "name": "sgd"
+        "name": "sgd",
+        "momentum": 0.9
       },
       },
       "learning_rate_schedule": {
       "learning_rate_schedule": {
         "name": "sgdr",
         "name": "sgdr",