experiment_selfsupervised.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. from chia import containers, instrumentation
  2. from chia.components import classifiers
  3. from chia import helpers
  4. from chillax import chillax_classifier, chillax_extrapolator
  5. import config as pcfg
  6. import argparse
  7. class CheapObserver(instrumentation.Observer):
  8. def update(self, message: instrumentation.Message):
  9. print(f"Message: {message}")
  10. def main(config_files):
  11. configs = [
  12. pcfg.config_from_json(config_file, read_from_file=True)
  13. for config_file in config_files
  14. ]
  15. config = pcfg.ConfigurationSet(*configs)
  16. classifiers.ClassifierFactory.name_to_class_mapping.update(
  17. {"chillax": chillax_classifier.CHILLAXKerasHC}
  18. )
  19. # This shouldn't be necessary, but...
  20. helpers.setup_environment()
  21. obs = instrumentation.NamedObservable("Experiment")
  22. experiment_container = containers.ExperimentContainer(config, outer_observable=obs)
  23. with experiment_container.exception_shroud:
  24. obs.log_info("Hello!")
  25. # Now, build the extrapolator
  26. if "extrapolator" in config.keys():
  27. extrapolator = chillax_extrapolator.CHILLAXExtrapolatorFactory.create(
  28. config["extrapolator"],
  29. knowledge_base=experiment_container.knowledge_base,
  30. observers=experiment_container.observers,
  31. )
  32. experiment_container.classifier.extrapolator = extrapolator
  33. experiment_container.runner.run()
  34. # Make sure all the data is saved
  35. obs.send_shutdown(successful=True)
  36. if __name__ == "__main__":
  37. parser = argparse.ArgumentParser(prog="chillax.experiment_selfsupervised")
  38. parser.add_argument("config_file", type=str, nargs="+")
  39. args = parser.parse_args()
  40. main(config_files=args.config_file)