浏览代码

minor fix in the group_name definition

Dimitri Korsch 3 年之前
父节点
当前提交
f5581ddf17
共有 2 个文件被更改,包括 6 次插入4 次删除
  1. 3 2
      cvargparse/__init__.py
  2. 3 2
      cvargparse/parser/base.py

+ 3 - 2
cvargparse/__init__.py

@@ -29,11 +29,12 @@ if __name__ == '__main__':
 	class Args:
 		group_name = "args"
 
-		arg1: float = None
+		train_samples: float = None
+		test_samples: int = 2
 		arg2: str = "something"
 
 		arg3: Choices([1, 2, 3], int) = 1
 		is_arg4: bool = False
 
-	parser = BaseParser(Args)
+	parser = BaseParser(Args(test_samples=20))
 	print(parser.parse_args("--is_arg4".split()))

+ 3 - 2
cvargparse/parser/base.py

@@ -86,8 +86,9 @@ class BaseParser(LoggerMixin, argparse.ArgumentParser):
 
 		elif is_dataclass(arglist):
 			self._dataclass_instance = arglist
-			arglist, group_name = get_arglist(arglist)
-
+			arglist, _group_name = get_arglist(arglist)
+			if group_name is None:
+				group_name = _group_name
 
 		if group_name is None:
 			group = self