|
@@ -3,15 +3,20 @@ from chia.v2.components import classifiers
|
|
|
from chia.v2 import helpers
|
|
from chia.v2 import helpers
|
|
|
from chillax import chillax_classifier
|
|
from chillax import chillax_classifier
|
|
|
|
|
|
|
|
|
|
+import config as pcfg
|
|
|
|
|
+import argparse
|
|
|
|
|
+
|
|
|
|
|
|
|
|
class CheapObserver(instrumentation.Observer):
|
|
class CheapObserver(instrumentation.Observer):
|
|
|
def update(self, message: instrumentation.Message):
|
|
def update(self, message: instrumentation.Message):
|
|
|
print(f"Message: {message}")
|
|
print(f"Message: {message}")
|
|
|
|
|
|
|
|
|
|
|
|
|
-def main():
|
|
|
|
|
- import json
|
|
|
|
|
- config = json.load(open("main.json"))
|
|
|
|
|
|
|
+def main(config_files):
|
|
|
|
|
+ configs = [pcfg.config_from_json(config_file, read_from_file=True) for config_file in config_files]
|
|
|
|
|
+ config = pcfg.ConfigurationSet(
|
|
|
|
|
+ *configs
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
classifiers.ClassifierFactory.name_to_class_mapping.update(
|
|
classifiers.ClassifierFactory.name_to_class_mapping.update(
|
|
|
{"chillax": chillax_classifier.CHILLAXKerasHC}
|
|
{"chillax": chillax_classifier.CHILLAXKerasHC}
|
|
@@ -19,28 +24,22 @@ def main():
|
|
|
|
|
|
|
|
# This shouldn't be necessary, but...
|
|
# This shouldn't be necessary, but...
|
|
|
helpers.setup_environment()
|
|
helpers.setup_environment()
|
|
|
|
|
+ obs = instrumentation.NamedObservable("Experiment")
|
|
|
|
|
|
|
|
experiment_container = containers.ExperimentContainer(
|
|
experiment_container = containers.ExperimentContainer(
|
|
|
- config, observers=(CheapObserver(),)
|
|
|
|
|
|
|
+ config, outer_observable=obs
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- dataset = experiment_container.dataset
|
|
|
|
|
- base_model = experiment_container.base_model
|
|
|
|
|
-
|
|
|
|
|
- training_samples = dataset.train_pool(0, "label_gt")
|
|
|
|
|
- training_samples = experiment_container.interactor.query_annotations_for(training_samples, "label_gt", "label_ann")
|
|
|
|
|
|
|
+ obs.log_info("Hello!")
|
|
|
|
|
|
|
|
- base_model.observe(training_samples, "label_ann")
|
|
|
|
|
|
|
+ experiment_container.runner.run()
|
|
|
|
|
|
|
|
- test_samples = dataset.test_pool(0, "label_gt")[:100]
|
|
|
|
|
- test_samples = base_model.predict(test_samples, "label_pred")
|
|
|
|
|
- result_dict = {}
|
|
|
|
|
- for evaluator in experiment_container.evaluators:
|
|
|
|
|
- evaluator.update(test_samples, "label_gt", "label_pred")
|
|
|
|
|
- result_dict.update(evaluator.result())
|
|
|
|
|
- evaluator.reset()
|
|
|
|
|
- print(result_dict)
|
|
|
|
|
|
|
+ # Make sure all the data is saved
|
|
|
|
|
+ obs.send_shutdown()
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
|
- main()
|
|
|
|
|
|
|
+ parser = argparse.ArgumentParser(prog="chillax.experiment_selfsupervised")
|
|
|
|
|
+ parser.add_argument("config_file", type=str, nargs='+')
|
|
|
|
|
+ args = parser.parse_args()
|
|
|
|
|
+ main(config_files=args.config_file)
|