Dimitri Korsch 4 жил өмнө
parent
commit
660aa60725

+ 21 - 0
cvmodelz/model_info.py

@@ -0,0 +1,21 @@
+#!/usr/bin/env python
+if __name__ != '__main__': raise Exception("Do not import me!")
+
+from cvargparse import Arg
+from cvargparse import BaseParser
+
+from cvmodelz import models
+from cvmodelz import utils
+
+def main(args):
+
+	model = models.new(args.model_type, pretrained_model=None)
+	utils.print_model_info(model)
+
+parser = BaseParser()
+
+parser.add_args([
+	Arg("model_type", choices=models.get_all_models()),
+])
+
+main(parser.parse_args())

+ 34 - 6
cvmodelz/models/__init__.py

@@ -21,7 +21,7 @@ supported = dict(
 
 	chainercv2=(),
 
-	custom=(
+	cvmodelz=(
 		pretrained.VGG16,
 		pretrained.VGG19,
 
@@ -48,15 +48,43 @@ def is_cv_model(model):
 def is_cv2_model(model):
 	return _check(model, "chainercv2")
 
-def is_custom_model(model):
-	return _check(model, "custom")
+def is_cvmodelz_model(model):
+	return _check(model, "cvmodelz")
+
+
+def get_all_models(key=None):
+	global supported
+	if key is not None:
+		return [f"{key}.{cls.__name__}" for cls in supported[key]]
+
+	res = []
+	for key in supported:
+		res += get_all_models(key)
+
+	return res
+
+def new(model_type, *args, **kwargs):
+	global supported
+	key, cls_name = model_type.split(".")
+
+
+	for cls in supported[key]:
+		if cls.__name__ == cls_name:
+			break
+	else:
+		raise ValueError(f"Could not find {model_type}!")
+
+	return cls(*args, **kwargs)
+
 
 
 if __name__ == '__main__':
-	from cvmodelz import utils
+	# from cvmodelz import utils
+
+	print(get_all_models())
 
 	# model = L.VGG19Layers(pretrained_model=None)
-	model = pretrained.ResNet35()
+	# model = pretrained.ResNet35()
 	# print(model.pool)
-	utils.print_model_info(model)
+	# utils.print_model_info(model)
 

+ 1 - 0
requirements.txt

@@ -4,3 +4,4 @@ chainercv~=0.13
 chainercv2~=0.0
 
 pyaml~=20.4
+cvargparse~=0.3