Przeglądaj źródła

Moved configuration to JSON file

Clemens-Alexander Brust 5 lat temu
rodzic
commit
35b5dce15f
3 zmienionych plików z 80 dodań i 42 usunięć
  1. 0 0
      chillax/chillax_classifier.py
  2. 5 42
      chillax/experiment_selfsupervised.py
  3. 75 0
      main.json

+ 0 - 0
chillax/method.py → chillax/chillax_classifier.py


+ 5 - 42
chillax/experiment_selfsupervised.py

@@ -1,7 +1,7 @@
 from chia.v2 import containers, instrumentation
 from chia.v2.components import classifiers
 from chia.v2 import helpers
-from chillax import method
+from chillax import chillax_classifier
 
 
 class CheapObserver(instrumentation.Observer):
@@ -10,40 +10,11 @@ class CheapObserver(instrumentation.Observer):
 
 
 def main():
-    config = {
-        "evaluators": [{"name": "accuracy"}],
-        "with_wordnet": True,
-        "interactor": {
-            "name": "noisy_oracle",
-            "noise_model": "Deng2014",
-            "relabel_fraction": 0.95
-        },
-        "dataset": {
-            "name": "nabirds",
-            "base_path": "/home/brust/datasets/nabirds",
-            "side_length": 512,
-            "use_lazy_mode": True,
-        },
-        "model": {
-            "classifier": {"name": "chillax"},
-            "base_model": {
-                "name": "keras",
-                "augmentation": {},
-                "trainer": {
-                    "name": "fast_single_shot",
-                    "batch_size": 2,
-                    "inner_steps": 2000,
-                },
-                "feature_extractor": {"side_length": 448},
-                "preprocessor": {},
-                "optimizer": {"name": "sgd"},
-                "learning_rate_schedule": {"name": "constant", "initial_lr": 0.1},
-            },
-        },
-    }
+    import json
+    config = json.load(open("main.json"))
 
     classifiers.ClassifierFactory.name_to_class_mapping.update(
-        {"chillax": method.CHILLAXKerasHC}
+        {"chillax": chillax_classifier.CHILLAXKerasHC}
     )
 
     # This shouldn't be necessary, but...
@@ -54,16 +25,8 @@ def main():
     )
 
     dataset = experiment_container.dataset
-
-    # Get prediction targets
-    experiment_container.knowledge_base.add_prediction_targets(
-        dataset.prediction_targets()
-    )
-
-    # Add relation source
-    experiment_container.knowledge_base.add_hyponymy_relation([dataset.get_hyponymy_relation_source()])
-
     base_model = experiment_container.base_model
+
     training_samples = dataset.train_pool(0, "label_gt")
     training_samples = experiment_container.interactor.query_annotations_for(training_samples, "label_gt", "label_ann")
 

+ 75 - 0
main.json

@@ -0,0 +1,75 @@
+{
+  "evaluators": [
+    {
+      "name": "accuracy"
+    }
+  ],
+  "with_wordnet": true,
+  "interactor": {
+    "name": "oracle"
+  },
+  "dataset": {
+    "name": "nabirds",
+    "base_path": "/home/brust/datasets/nabirds",
+    "side_length": 512,
+    "use_lazy_mode": true
+  },
+  "model": {
+    "classifier": {
+      "name": "chillax",
+      "l2": 5e-5
+    },
+    "base_model": {
+      "name": "keras",
+      "augmentation": {
+        "do_random_flip_vertical": false,
+        "do_random_scale": false,
+        "do_random_rotate": false,
+        "do_random_brightness_and_contrast": false,
+        "do_random_hue_and_saturation": false,
+        "do_random_crop": false
+      },
+      "trainer": {
+        "name": "fast_single_shot",
+        "batch_size": 16,
+        "sequential_training_batches": 2,
+        "inner_steps": 59760
+      },
+      "feature_extractor": {
+        "side_length": 448,
+        "trainable": true,
+        "architecture": "ResNet50V2",
+        "l2": 5e-5,
+        "use_pretrained_weights": "inat_features.h5"
+      },
+      "preprocessor": {
+        "random_crop_to_size": [
+          448,
+          448
+        ],
+        "channel_mean": [
+          125.30513277,
+          129.66606421,
+          118.45121113
+        ],
+        "channel_stddev": [
+          57.0045467,
+          56.70059436,
+          68.44430446
+        ]
+      },
+      "optimizer": {
+        "name": "sgd"
+      },
+      "learning_rate_schedule": {
+        "name": "sgdr",
+        "maximum_lr": 0.0044,
+        "T_0": 59760,
+        "T_mult": 1,
+        "minimum_lr": 1e-06,
+        "warmup_steps": 747,
+        "warmup_lr": 0.01
+      }
+    }
+  }
+}