|
@@ -12,6 +12,23 @@ from cvargparse.utils.logger_config import init_logging_handlers
|
|
|
|
|
|
from dataclasses import is_dataclass
|
|
|
|
|
|
+def is_notebook():
|
|
|
+ """ checks if we run this code in a notebook or what kind of shell """
|
|
|
+ try:
|
|
|
+ shell = get_ipython().__class__.__name__
|
|
|
+
|
|
|
+ if shell == 'ZMQInteractiveShell':
|
|
|
+ return True # Jupyter notebook or qtconsole
|
|
|
+
|
|
|
+ elif shell == 'TerminalInteractiveShell':
|
|
|
+ return False # Terminal running IPython
|
|
|
+
|
|
|
+ else:
|
|
|
+ return False # Other type (?)
|
|
|
+
|
|
|
+ except NameError:
|
|
|
+ return False # Probably standard Python interpreter
|
|
|
+
|
|
|
class LoggerMixin(abc.ABC):
|
|
|
|
|
|
def __init__(self, *args, nologging: bool = False, **kw):
|
|
@@ -105,6 +122,12 @@ class BaseParser(LoggerMixin, argparse.ArgumentParser):
|
|
|
|
|
|
def parse_args(self, args=None, namespace=None):
|
|
|
args = self._merge_args(args)
|
|
|
+ if args is None and is_notebook():
|
|
|
+ # we need to set this at some value other than None
|
|
|
+ # otherwise, we parse arguments of the jupyter notebook
|
|
|
+ # process
|
|
|
+ args = ""
|
|
|
+
|
|
|
self._args = super().parse_args(args, namespace)
|
|
|
|
|
|
if self.has_logging:
|
|
@@ -119,6 +142,9 @@ class BaseParser(LoggerMixin, argparse.ArgumentParser):
|
|
|
|
|
|
dataclass_args = as_args(self._dataclass_instance)
|
|
|
|
|
|
+ if dataclass_args is None:
|
|
|
+ return args
|
|
|
+
|
|
|
if args is None:
|
|
|
return dataclass_args
|
|
|
|