|
@@ -21,7 +21,7 @@ supported = dict(
|
|
|
|
|
|
chainercv2=(),
|
|
|
|
|
|
- custom=(
|
|
|
+ cvmodelz=(
|
|
|
pretrained.VGG16,
|
|
|
pretrained.VGG19,
|
|
|
|
|
@@ -48,15 +48,43 @@ def is_cv_model(model):
|
|
|
def is_cv2_model(model):
|
|
|
return _check(model, "chainercv2")
|
|
|
|
|
|
-def is_custom_model(model):
|
|
|
- return _check(model, "custom")
|
|
|
+def is_cvmodelz_model(model):
|
|
|
+ return _check(model, "cvmodelz")
|
|
|
+
|
|
|
+
|
|
|
+def get_all_models(key=None):
|
|
|
+ global supported
|
|
|
+ if key is not None:
|
|
|
+ return [f"{key}.{cls.__name__}" for cls in supported[key]]
|
|
|
+
|
|
|
+ res = []
|
|
|
+ for key in supported:
|
|
|
+ res += get_all_models(key)
|
|
|
+
|
|
|
+ return res
|
|
|
+
|
|
|
+def new(model_type, *args, **kwargs):
|
|
|
+ global supported
|
|
|
+ key, cls_name = model_type.split(".")
|
|
|
+
|
|
|
+
|
|
|
+ for cls in supported[key]:
|
|
|
+ if cls.__name__ == cls_name:
|
|
|
+ break
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Could not find {model_type}!")
|
|
|
+
|
|
|
+ return cls(*args, **kwargs)
|
|
|
+
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
- from cvmodelz import utils
|
|
|
+ # from cvmodelz import utils
|
|
|
+
|
|
|
+ print(get_all_models())
|
|
|
|
|
|
# model = L.VGG19Layers(pretrained_model=None)
|
|
|
- model = pretrained.ResNet35()
|
|
|
+ # model = pretrained.ResNet35()
|
|
|
# print(model.pool)
|
|
|
- utils.print_model_info(model)
|
|
|
+ # utils.print_model_info(model)
|
|
|
|