Browse Source

updated support for parser creation from a dataclass

Dimitri Korsch 3 years ago
parent
commit
b63bbb3fa6
4 changed files with 63 additions and 16 deletions
  1. 2 1
      cvargparse/__init__.py
  2. 23 2
      cvargparse/parser/base.py
  3. 36 11
      cvargparse/utils/dataclass.py
  4. 2 2
      setup.py

+ 2 - 1
cvargparse/__init__.py

@@ -27,6 +27,7 @@ __all__ = [
 if __name__ == '__main__':
 	@cvdataclass
 	class Args:
+		group_name = "args"
 
 		arg1: float = None
 		arg2: str = "something"
@@ -34,4 +35,4 @@ if __name__ == '__main__':
 		arg3: Choices([1, 2, 3], int) = 1
 
 	parser = BaseParser(Args(arg3=2))
-	print(parser.parse_args())
+	print(parser.parse_args("--help".split()))

+ 23 - 2
cvargparse/parser/base.py

@@ -6,7 +6,8 @@ import warnings
 
 from cvargparse.argument import Argument as Arg
 from cvargparse.factory import BaseFactory
-from cvargparse.utils.dataclass import get_arglist_from_data_class
+from cvargparse.utils.dataclass import as_args
+from cvargparse.utils.dataclass import get_arglist
 from cvargparse.utils.logger_config import init_logging_handlers
 
 from dataclasses import is_dataclass
@@ -56,6 +57,7 @@ class BaseParser(LoggerMixin, argparse.ArgumentParser):
 	def __init__(self, arglist: T.Union[T.List[Arg], BaseFactory] = [], *args, **kw):
 		self._groups = {}
 		self._args = None
+		self._dataclass_instance = None
 
 		super().__init__(*args, **kw)
 
@@ -83,7 +85,9 @@ class BaseParser(LoggerMixin, argparse.ArgumentParser):
 			arglist = arglist.get()
 
 		elif is_dataclass(arglist):
-			arglist, group_name = get_arglist_from_data_class(arglist)
+			self._dataclass_instance = arglist
+			arglist, group_name = get_arglist(arglist)
+
 
 		if group_name is None:
 			group = self
@@ -99,10 +103,27 @@ class BaseParser(LoggerMixin, argparse.ArgumentParser):
 				group.add_argument(*arg[0], **arg[1])
 
 	def parse_args(self, args=None, namespace=None):
+		args = self._merge_args(args)
 		self._args = super().parse_args(args, namespace)
 
 		if self.has_logging:
 			self._logging_config()
 
+
 		return self._args
 
+	def _merge_args(self, args):
+		if self._dataclass_instance is None:
+			return args
+
+		dataclass_args = as_args(self._dataclass_instance)
+
+		if args is None:
+			return dataclass_args
+
+		return dataclass_args + args
+
+
+
+
+

+ 36 - 11
cvargparse/utils/dataclass.py

@@ -1,12 +1,43 @@
 import pyaml
 
+from dataclasses import Field
+from dataclasses import MISSING
+from dataclasses import _is_dataclass_instance
+from dataclasses import asdict
 from dataclasses import dataclass
 from dataclasses import fields
-from dataclasses import MISSING
-from dataclasses import Field
 
 from cvargparse import Arg
 
+def _set_attr(cls, attr, value):
+	if attr not in cls.__dict__:
+		setattr(cls, attr, value)
+
+def get_arglist(cls_or_instance) -> list:
+
+	arglist = []
+	for field in fields(cls_or_instance):
+		arglist.append(FieldWrapper(field).as_arg())
+
+	return arglist, getattr(cls_or_instance, "group_name", None)
+
+def as_args(instance) -> list:
+
+	dataclass_args = []
+	if not _is_dataclass_instance(instance):
+		return dataclass_args
+
+	data = asdict(instance)
+	for field in fields(instance):
+		key = field.name
+		value = data[key]
+		if value == field.default:
+			continue
+		dataclass_args.extend(f"--{key} {value}".split())
+
+	return dataclass_args
+
+
 def cvdataclass(cls=None, *args, repr=False, **kwargs):
 
 	def _yaml_repr_(self) -> str:
@@ -14,8 +45,9 @@ def cvdataclass(cls=None, *args, repr=False, **kwargs):
 		return pyaml.dump({cls_name: self.__dict__}, sort_dicts=False)
 
 	def wrap(cls):
-		if not repr and "__repr__" not in cls.__dict__:
-			setattr(cls, "__repr__", _yaml_repr_)
+		if not repr:
+			_set_attr(cls, "__repr__", _yaml_repr_)
+
 		return dataclass(cls, *args, repr=repr, **kwargs)
 
 	# See if we're being called as @cvdataclass or @cvdataclass().
@@ -72,13 +104,6 @@ class FieldWrapper:
 
 		return self.field.type._choices
 
-def get_arglist_from_data_class(cls) -> list:
-	arglist = []
-	for field in fields(cls):
-		arglist.append(FieldWrapper(field).as_arg())
-
-	return arglist, getattr(cls, "group_name", None)
-
 class Choices:
 
 	def __init__(self, choices, type):

+ 2 - 2
setup.py

@@ -15,7 +15,7 @@ cwd = Path(__file__).parent.resolve()
 with open(str(cwd / pkg_name / '_version.py')) as version_file:
 	exec(version_file.read())
 
-# install_requires = [line.strip() for line in open("requirements.txt").readlines()]
+install_requires = [line.strip() for line in open("requirements.txt").readlines()]
 
 setup(
 	name=pkg_name,
@@ -28,7 +28,7 @@ setup(
 	zip_safe=False,
 	setup_requires=[],
 	# no requirements yet
-	# install_requires=install_requires,
+	install_requires=install_requires,
 	package_data={'': ['requirements.txt']},
 	data_files=[('.',['requirements.txt'])],
 	include_package_data=True,