فهرست منبع

added JupyterArguments decorator for pretty printing of dataclasses in jupyter notebooks

Dimitri Korsch 3 سال پیش
والد
کامیت
f5bb26890c
6فایلهای تغییر یافته به همراه83 افزوده شده و 45 حذف شده
  1. 2 2
      cvargparse/__init__.py
  2. 1 1
      cvargparse/_version.py
  3. 30 0
      cvargparse/argument.py
  4. 4 4
      cvargparse/factory.py
  5. 43 35
      cvargparse/parser/base.py
  6. 3 3
      cvargparse/parser/gpu_parser.py

+ 2 - 2
cvargparse/__init__.py

@@ -1,7 +1,7 @@
 from cvargparse.argument import Argument
+from cvargparse.argument import Argument as Arg
 from cvargparse.argument import FileArgument
-Arg = Argument
-
+from cvargparse.argument import JupyterArguments
 from cvargparse.factory import ArgFactory
 from cvargparse.factory import BaseFactory
 from cvargparse.parser.base import BaseParser

+ 1 - 1
cvargparse/_version.py

@@ -1 +1 @@
-__version__ = "0.3.2"
+__version__ = "0.4.0"

+ 30 - 0
cvargparse/argument.py

@@ -1,3 +1,7 @@
+import pyaml
+
+from dataclasses import dataclass
+
 class Argument(object):
 	def __init__(self, *args, **kw):
 		super(Argument, self).__init__()
@@ -16,3 +20,29 @@ class FileArgument(Argument):
 			obj.kw["type"] = argparse.FileType(file_mode, encoding=encoding)
 			return obj
 		return wrapper
+
+def JupyterArguments(cls=None, *args, repr=False, **kwargs):
+
+	def _yaml_repr_(self) -> str:
+		cls_name = type(self).__name__
+		return pyaml.dump({cls_name: self.__dict__}, sort_dicts=False)
+
+	def wrap(cls):
+		if not repr and "__repr__" not in cls.__dict__:
+			setattr(cls, "__repr__", _yaml_repr_)
+		return dataclass(cls, *args, repr=repr, **kwargs)
+
+	# See if we're being called as @dataclass or @dataclass().
+	if cls is None:
+		return wrap
+
+	return wrap(cls)
+
+if __name__ == '__main__':
+
+	@JupyterArguments
+	class Args:
+		arg1: int = 0
+		arg2: int = 1
+
+	print(Args())

+ 4 - 4
cvargparse/factory.py

@@ -1,14 +1,15 @@
-from abc import ABC
+import abc
+import typing as T
 
 from cvargparse.utils import factory
 from cvargparse.argument import Argument as Arg
 
 
-class BaseFactory(ABC):
+class BaseFactory(abc.ABC):
 	'''
 
 	'''
-	def __init__(self, initial=None):
+	def __init__(self, initial: T.Optional[T.List[Arg]] = None):
 		super(BaseFactory, self).__init__()
 		self.args = initial or []
 
@@ -16,7 +17,6 @@ class BaseFactory(ABC):
 	def add(self, *args, **kwargs):
 		self.args.append(Arg(*args, **kwargs))
 
-
 	def get(self):
 		return self.args
 

+ 43 - 35
cvargparse/parser/base.py

@@ -1,21 +1,18 @@
+import abc
 import argparse
 import logging
+import typing as T
 import warnings
 
 from cvargparse.argument import Argument as Arg
 from cvargparse.factory import BaseFactory
 from cvargparse.utils.logger_config import init_logging_handlers
 
-class BaseParser(argparse.ArgumentParser):
+class LoggerMixin(abc.ABC):
 
-	def __init__(self, arglist=[], nologging=False, *args, **kw):
+	def __init__(self, *args, nologging: bool = False, **kw):
 		self._nologging = nologging
-		self._groups = {}
-		self._args = None
-
-		super(BaseParser, self).__init__(*args, **kw)
-
-		self.add_args(arglist)
+		super().__init__(*args, **kw)
 
 		if not self.has_logging: return
 
@@ -26,6 +23,42 @@ class BaseParser(argparse.ArgumentParser):
 				help='logging level. see logging module for more information'),
 		], group_name="Logger arguments")
 
+	@property
+	def has_logging(self):
+		return not self._nologging
+
+	def _logging_config(self, simple=False):
+
+		if self._args.logfile:
+			handler = logging.FileHandler(self._args.logfile, mode="w")
+		else:
+			handler = logging.StreamHandler()
+
+		# fmt = '%(message)s' if simple else '%(levelname)s - [%(asctime)s] %(filename)s:%(lineno)d [%(funcName)s]: %(message)s'
+		fmt = '{message}' if simple else '{levelname: ^7s} - [{asctime}] {filename}:{lineno} [{funcName}]: {message}'
+		if getattr(self._args, "debug", False):
+			lvl = logging.DEBUG
+		else:
+			lvl = getattr(logging, self._args.loglevel.upper(), logging.WARNING)
+
+		self._logger = init_logging_handlers([(handler, fmt, lvl)])
+
+	def init_logger(self, simple=False):
+		warnings.warn("This method is deprecated and does nothing since v0.3.0!",
+			DeprecationWarning, stacklevel=2)
+
+
+class BaseParser(LoggerMixin, argparse.ArgumentParser):
+
+	def __init__(self, arglist: T.Union[T.List[Arg], BaseFactory] = [], *args, **kw):
+		self._groups = {}
+		self._args = None
+
+		super().__init__(*args, **kw)
+
+		self.add_args(arglist)
+
+
 	@property
 	def args(self):
 		return self._args
@@ -36,13 +69,11 @@ class BaseParser(argparse.ArgumentParser):
 	def has_group(self, name):
 		return name in self._groups
 
-
 	def add_argument_group(self, title, description=None, *args, **kwargs):
-		group = super(BaseParser, self).add_argument_group(title=title, description=description, *args, **kwargs)
+		group = super().add_argument_group(title=title, description=description, *args, **kwargs)
 		self._groups[title] = group
 		return group
 
-
 	def add_args(self, arglist, group_name=None, group_kwargs={}):
 
 		if isinstance(arglist, BaseFactory):
@@ -61,34 +92,11 @@ class BaseParser(argparse.ArgumentParser):
 			else:
 				group.add_argument(*arg[0], **arg[1])
 
-	@property
-	def has_logging(self):
-		return not self._nologging
-
 	def parse_args(self, args=None, namespace=None):
-		self._args = super(BaseParser, self).parse_args(args, namespace)
+		self._args = super().parse_args(args, namespace)
 
 		if self.has_logging:
 			self._logging_config()
 
 		return self._args
 
-	def _logging_config(self, simple=False):
-
-		if self._args.logfile:
-			handler = logging.FileHandler(self._args.logfile, mode="w")
-		else:
-			handler = logging.StreamHandler()
-
-		# fmt = '%(message)s' if simple else '%(levelname)s - [%(asctime)s] %(filename)s:%(lineno)d [%(funcName)s]: %(message)s'
-		fmt = '{message}' if simple else '{levelname: ^7s} - [{asctime}] {filename}:{lineno} [{funcName}]: {message}'
-		if getattr(self._args, "debug", False):
-			lvl = logging.DEBUG
-		else:
-			lvl = getattr(logging, self._args.loglevel.upper(), logging.WARNING)
-
-		self._logger = init_logging_handlers([(handler, fmt, lvl)])
-
-	def init_logger(self, simple=False):
-		warnings.warn("This method does nothing since v0.3.0!")
-

+ 3 - 3
cvargparse/parser/gpu_parser.py

@@ -1,8 +1,8 @@
-from cvargparse.parser.base import BaseParser
+from cvargparse.parser import base
 
-class GPUParser(BaseParser):
+class GPUParser(base.BaseParser):
 	def __init__(self, *args, **kw):
-		super(GPUParser, self).__init__(*args, **kw)
+		super().__init__(*args, **kw)
 		self.add_argument(
 			"--gpu", "-g", type=int, nargs="+", default=[-1],
 			help="which GPU to use. select -1 for CPU only")