|
@@ -61,14 +61,11 @@ class ModelFactory(abc.ABC):
|
|
raise ValueError(f"Could not find {model_type}!")
|
|
raise ValueError(f"Could not find {model_type}!")
|
|
|
|
|
|
if model_cls in cls.supported["chainer"]:
|
|
if model_cls in cls.supported["chainer"]:
|
|
- if "pretrained_model" not in kwargs:
|
|
|
|
- kwargs["pretrained_model"] = None
|
|
|
|
|
|
+ kwargs["pretrained_model"] = kwargs.get("pretrained_model", None)
|
|
kwargs.pop("input_size", None)
|
|
kwargs.pop("input_size", None)
|
|
|
|
|
|
elif model_cls in cls.supported["chainercv2"]:
|
|
elif model_cls in cls.supported["chainercv2"]:
|
|
- if "pretrained" not in kwargs:
|
|
|
|
- kwargs["pretrained"] = False
|
|
|
|
-
|
|
|
|
|
|
+ kwargs["pretrained"] = kwargs.get("pretrained", False)
|
|
input_size = kwargs.pop("input_size", None)
|
|
input_size = kwargs.pop("input_size", None)
|
|
return ModelWrapper(model_cls(*args, **kwargs), input_size=input_size)
|
|
return ModelWrapper(model_cls(*args, **kwargs), input_size=input_size)
|
|
|
|
|
|
@@ -100,13 +97,19 @@ class ModelFactory(abc.ABC):
|
|
if key is not None:
|
|
if key is not None:
|
|
return [f"{key}.{model_cls.__name__}" for model_cls in cls.supported[key]]
|
|
return [f"{key}.{model_cls.__name__}" for model_cls in cls.supported[key]]
|
|
|
|
|
|
|
|
+ return cls.get_models(cls.supported.keys())
|
|
|
|
+
|
|
|
|
+ @classmethod
|
|
|
|
+ def get_models(cls, keys=None):
|
|
|
|
+ if keys is None:
|
|
|
|
+ keys = cls.supported.keys()
|
|
|
|
+
|
|
res = []
|
|
res = []
|
|
- for key in cls.supported:
|
|
|
|
|
|
+ for key in keys:
|
|
res += cls.get_all_models(key)
|
|
res += cls.get_all_models(key)
|
|
|
|
|
|
return res
|
|
return res
|
|
|
|
|
|
|
|
|
|
-
|
|
|
|
if __name__ == '__main__':
|
|
if __name__ == '__main__':
|
|
print(pyaml.dump(dict(Models=ModelFactory.get_all_models()), indent=2))
|
|
print(pyaml.dump(dict(Models=ModelFactory.get_all_models()), indent=2))
|