|
|
@@ -1,7 +1,7 @@
|
|
|
from chia.v2 import containers, instrumentation
|
|
|
from chia.v2.components import classifiers
|
|
|
from chia.v2 import helpers
|
|
|
-from chillax import method
|
|
|
+from chillax import chillax_classifier
|
|
|
|
|
|
|
|
|
class CheapObserver(instrumentation.Observer):
|
|
|
@@ -10,40 +10,11 @@ class CheapObserver(instrumentation.Observer):
|
|
|
|
|
|
|
|
|
def main():
|
|
|
- config = {
|
|
|
- "evaluators": [{"name": "accuracy"}],
|
|
|
- "with_wordnet": True,
|
|
|
- "interactor": {
|
|
|
- "name": "noisy_oracle",
|
|
|
- "noise_model": "Deng2014",
|
|
|
- "relabel_fraction": 0.95
|
|
|
- },
|
|
|
- "dataset": {
|
|
|
- "name": "nabirds",
|
|
|
- "base_path": "/home/brust/datasets/nabirds",
|
|
|
- "side_length": 512,
|
|
|
- "use_lazy_mode": True,
|
|
|
- },
|
|
|
- "model": {
|
|
|
- "classifier": {"name": "chillax"},
|
|
|
- "base_model": {
|
|
|
- "name": "keras",
|
|
|
- "augmentation": {},
|
|
|
- "trainer": {
|
|
|
- "name": "fast_single_shot",
|
|
|
- "batch_size": 2,
|
|
|
- "inner_steps": 2000,
|
|
|
- },
|
|
|
- "feature_extractor": {"side_length": 448},
|
|
|
- "preprocessor": {},
|
|
|
- "optimizer": {"name": "sgd"},
|
|
|
- "learning_rate_schedule": {"name": "constant", "initial_lr": 0.1},
|
|
|
- },
|
|
|
- },
|
|
|
- }
|
|
|
+ import json
|
|
|
+ config = json.load(open("main.json"))
|
|
|
|
|
|
classifiers.ClassifierFactory.name_to_class_mapping.update(
|
|
|
- {"chillax": method.CHILLAXKerasHC}
|
|
|
+ {"chillax": chillax_classifier.CHILLAXKerasHC}
|
|
|
)
|
|
|
|
|
|
# This shouldn't be necessary, but...
|
|
|
@@ -54,16 +25,8 @@ def main():
|
|
|
)
|
|
|
|
|
|
dataset = experiment_container.dataset
|
|
|
-
|
|
|
- # Get prediction targets
|
|
|
- experiment_container.knowledge_base.add_prediction_targets(
|
|
|
- dataset.prediction_targets()
|
|
|
- )
|
|
|
-
|
|
|
- # Add relation source
|
|
|
- experiment_container.knowledge_base.add_hyponymy_relation([dataset.get_hyponymy_relation_source()])
|
|
|
-
|
|
|
base_model = experiment_container.base_model
|
|
|
+
|
|
|
training_samples = dataset.train_pool(0, "label_gt")
|
|
|
training_samples = experiment_container.interactor.query_annotations_for(training_samples, "label_gt", "label_ann")
|
|
|
|