@@ -29,11 +29,12 @@ if __name__ == '__main__':
class Args:
group_name = "args"
- arg1: float = None
+ train_samples: float = None
+ test_samples: int = 2
arg2: str = "something"
arg3: Choices([1, 2, 3], int) = 1
is_arg4: bool = False
- parser = BaseParser(Args)
+ parser = BaseParser(Args(test_samples=20))
print(parser.parse_args("--is_arg4".split()))
@@ -86,8 +86,9 @@ class BaseParser(LoggerMixin, argparse.ArgumentParser):
elif is_dataclass(arglist):
self._dataclass_instance = arglist
- arglist, group_name = get_arglist(arglist)
-
+ arglist, _group_name = get_arglist(arglist)
+ if group_name is None:
+ group_name = _group_name
if group_name is None:
group = self