فهرست منبع

first way of adding support for dataclass as contructor of the parser

Dimitri Korsch 3 سال پیش
والد
کامیت
f9f97361ac
5فایلهای تغییر یافته به همراه130 افزوده شده و 33 حذف شده
  1. 21 4
      cvargparse/__init__.py
  2. 0 29
      cvargparse/argument.py
  3. 6 0
      cvargparse/parser/base.py
  4. 102 0
      cvargparse/utils/dataclass.py
  5. 1 0
      requirements.txt

+ 21 - 4
cvargparse/__init__.py

@@ -1,20 +1,37 @@
 from cvargparse.argument import Argument
 from cvargparse.argument import Argument as Arg
 from cvargparse.argument import FileArgument
-from cvargparse.argument import JupyterArguments
 from cvargparse.factory import ArgFactory
 from cvargparse.factory import BaseFactory
 from cvargparse.parser.base import BaseParser
 from cvargparse.parser.gpu_parser import GPUParser
 from cvargparse.parser.mode_parser import ModeParserFactory
+from cvargparse.utils.dataclass import cvdataclass
+from cvargparse.utils.dataclass import Choices
 
 __all__ = [
 	"Arg",
-	"Argument",
-	"FileArgument",
 	"ArgFactory",
+	"Choices",
+	"Argument",
 	"BaseFactory",
 	"BaseParser",
-	"ModeParserFactory",
+	"cvdataclass",
+	"FileArgument",
 	"GPUParser",
+	"ModeParserFactory",
 ]
+
+
+
+if __name__ == '__main__':
+	@cvdataclass
+	class Args:
+
+		arg1: float = None
+		arg2: str = "something"
+
+		arg3: Choices([1, 2, 3], int) = 1
+
+	parser = BaseParser(Args(arg3=2))
+	print(parser.parse_args())

+ 0 - 29
cvargparse/argument.py

@@ -1,6 +1,3 @@
-import pyaml
-
-from dataclasses import dataclass
 
 class Argument(object):
 	def __init__(self, *args, **kw):
@@ -20,29 +17,3 @@ class FileArgument(Argument):
 			obj.kw["type"] = argparse.FileType(file_mode, encoding=encoding)
 			return obj
 		return wrapper
-
-def JupyterArguments(cls=None, *args, repr=False, **kwargs):
-
-	def _yaml_repr_(self) -> str:
-		cls_name = type(self).__name__
-		return pyaml.dump({cls_name: self.__dict__}, sort_dicts=False)
-
-	def wrap(cls):
-		if not repr and "__repr__" not in cls.__dict__:
-			setattr(cls, "__repr__", _yaml_repr_)
-		return dataclass(cls, *args, repr=repr, **kwargs)
-
-	# See if we're being called as @dataclass or @dataclass().
-	if cls is None:
-		return wrap
-
-	return wrap(cls)
-
-if __name__ == '__main__':
-
-	@JupyterArguments
-	class Args:
-		arg1: int = 0
-		arg2: int = 1
-
-	print(Args())

+ 6 - 0
cvargparse/parser/base.py

@@ -6,8 +6,11 @@ import warnings
 
 from cvargparse.argument import Argument as Arg
 from cvargparse.factory import BaseFactory
+from cvargparse.utils.dataclass import get_arglist_from_data_class
 from cvargparse.utils.logger_config import init_logging_handlers
 
+from dataclasses import is_dataclass
+
 class LoggerMixin(abc.ABC):
 
 	def __init__(self, *args, nologging: bool = False, **kw):
@@ -79,6 +82,9 @@ class BaseParser(LoggerMixin, argparse.ArgumentParser):
 		if isinstance(arglist, BaseFactory):
 			arglist = arglist.get()
 
+		elif is_dataclass(arglist):
+			arglist, group_name = get_arglist_from_data_class(arglist)
+
 		if group_name is None:
 			group = self
 		elif self.has_group(group_name):

+ 102 - 0
cvargparse/utils/dataclass.py

@@ -0,0 +1,102 @@
+import pyaml
+
+from dataclasses import dataclass
+from dataclasses import fields
+from dataclasses import MISSING
+from dataclasses import Field
+
+from cvargparse import Arg
+
+def cvdataclass(cls=None, *args, repr=False, **kwargs):
+
+	def _yaml_repr_(self) -> str:
+		cls_name = type(self).__name__
+		return pyaml.dump({cls_name: self.__dict__}, sort_dicts=False)
+
+	def wrap(cls):
+		if not repr and "__repr__" not in cls.__dict__:
+			setattr(cls, "__repr__", _yaml_repr_)
+		return dataclass(cls, *args, repr=repr, **kwargs)
+
+	# See if we're being called as @cvdataclass or @cvdataclass().
+	if cls is None:
+		return wrap
+
+	return wrap(cls)
+
+class FieldWrapper:
+
+	def __init__(self, field: Field):
+		super().__init__()
+		self._field = field
+
+	def as_arg(self) -> Arg:
+		return Arg(
+			self.name,
+			type=self.type,
+			default=self.default,
+			choices=self.choices
+		)
+
+	@property
+	def field(self):
+		return self._field
+
+	@property
+	def name(self):
+		return f"--{self.field.name}"
+
+	@property
+	def is_choice(self):
+		return isinstance(self.field.type, Choices)
+
+	@property
+	def type(self):
+		if self.is_choice:
+			return self.field.type._type
+
+		return self.field.type
+
+	@property
+	def default(self):
+
+		if self.field.default == MISSING:
+			return self.type()
+
+		return self.field.default
+
+	@property
+	def choices(self):
+		if not self.is_choice:
+			return None
+
+		return self.field.type._choices
+
+def get_arglist_from_data_class(cls) -> list:
+	arglist = []
+	for field in fields(cls):
+		arglist.append(FieldWrapper(field).as_arg())
+
+	return arglist, getattr(cls, "group_name", None)
+
+class Choices:
+
+	def __init__(self, choices, type):
+		self._choices = choices
+		self._type = type
+
+	def __contains__(self, value):
+		return value in self._choices
+
+	def __call__(self, *args, **kwargs):
+		return self._type(*args, **kwargs)
+
+
+if __name__ == '__main__':
+
+	@cvdataclass
+	class Args:
+		arg1: int = 0
+		arg2: int = 1
+
+	print(Args())

+ 1 - 0
requirements.txt

@@ -0,0 +1 @@
+pyaml