Browse Source

fixed argument parsing inside a jupyter notebook

Dimitri Korsch 3 years ago
parent
commit
bd9e1fcbc5
2 changed files with 28 additions and 2 deletions
  1. 26 0
      cvargparse/parser/base.py
  2. 2 2
      cvargparse/utils/dataclass.py

+ 26 - 0
cvargparse/parser/base.py

@@ -12,6 +12,23 @@ from cvargparse.utils.logger_config import init_logging_handlers
 
 
 from dataclasses import is_dataclass
 from dataclasses import is_dataclass
 
 
+def is_notebook():
+	""" checks if we run this code in a notebook or what kind of shell """
+	try:
+		shell = get_ipython().__class__.__name__
+
+		if shell == 'ZMQInteractiveShell':
+			return True   # Jupyter notebook or qtconsole
+
+		elif shell == 'TerminalInteractiveShell':
+			return False  # Terminal running IPython
+
+		else:
+			return False  # Other type (?)
+
+	except NameError:
+		return False      # Probably standard Python interpreter
+
 class LoggerMixin(abc.ABC):
 class LoggerMixin(abc.ABC):
 
 
 	def __init__(self, *args, nologging: bool = False, **kw):
 	def __init__(self, *args, nologging: bool = False, **kw):
@@ -105,6 +122,12 @@ class BaseParser(LoggerMixin, argparse.ArgumentParser):
 
 
 	def parse_args(self, args=None, namespace=None):
 	def parse_args(self, args=None, namespace=None):
 		args = self._merge_args(args)
 		args = self._merge_args(args)
+		if args is None and is_notebook():
+			# we need to set this at some value other than None
+			# otherwise, we parse arguments of the jupyter notebook
+			# process
+			args = ""
+
 		self._args = super().parse_args(args, namespace)
 		self._args = super().parse_args(args, namespace)
 
 
 		if self.has_logging:
 		if self.has_logging:
@@ -119,6 +142,9 @@ class BaseParser(LoggerMixin, argparse.ArgumentParser):
 
 
 		dataclass_args = as_args(self._dataclass_instance)
 		dataclass_args = as_args(self._dataclass_instance)
 
 
+		if dataclass_args is None:
+			return args
+
 		if args is None:
 		if args is None:
 			return dataclass_args
 			return dataclass_args
 
 

+ 2 - 2
cvargparse/utils/dataclass.py

@@ -23,10 +23,10 @@ def get_arglist(cls_or_instance) -> list:
 
 
 def as_args(instance) -> list:
 def as_args(instance) -> list:
 
 
-	dataclass_args = []
 	if not _is_dataclass_instance(instance):
 	if not _is_dataclass_instance(instance):
-		return dataclass_args
+		return None
 
 
+	dataclass_args = []
 	data = asdict(instance)
 	data = asdict(instance)
 	for field in fields(instance):
 	for field in fields(instance):
 		wrapped = FieldWrapper(field)
 		wrapped = FieldWrapper(field)