|
@@ -1,12 +1,43 @@
|
|
|
import pyaml
|
|
|
|
|
|
+from dataclasses import Field
|
|
|
+from dataclasses import MISSING
|
|
|
+from dataclasses import _is_dataclass_instance
|
|
|
+from dataclasses import asdict
|
|
|
from dataclasses import dataclass
|
|
|
from dataclasses import fields
|
|
|
-from dataclasses import MISSING
|
|
|
-from dataclasses import Field
|
|
|
|
|
|
from cvargparse import Arg
|
|
|
|
|
|
+def _set_attr(cls, attr, value):
|
|
|
+ if attr not in cls.__dict__:
|
|
|
+ setattr(cls, attr, value)
|
|
|
+
|
|
|
+def get_arglist(cls_or_instance) -> list:
|
|
|
+
|
|
|
+ arglist = []
|
|
|
+ for field in fields(cls_or_instance):
|
|
|
+ arglist.append(FieldWrapper(field).as_arg())
|
|
|
+
|
|
|
+ return arglist, getattr(cls_or_instance, "group_name", None)
|
|
|
+
|
|
|
+def as_args(instance) -> list:
|
|
|
+
|
|
|
+ dataclass_args = []
|
|
|
+ if not _is_dataclass_instance(instance):
|
|
|
+ return dataclass_args
|
|
|
+
|
|
|
+ data = asdict(instance)
|
|
|
+ for field in fields(instance):
|
|
|
+ key = field.name
|
|
|
+ value = data[key]
|
|
|
+ if value == field.default:
|
|
|
+ continue
|
|
|
+ dataclass_args.extend(f"--{key} {value}".split())
|
|
|
+
|
|
|
+ return dataclass_args
|
|
|
+
|
|
|
+
|
|
|
def cvdataclass(cls=None, *args, repr=False, **kwargs):
|
|
|
|
|
|
def _yaml_repr_(self) -> str:
|
|
@@ -14,8 +45,9 @@ def cvdataclass(cls=None, *args, repr=False, **kwargs):
|
|
|
return pyaml.dump({cls_name: self.__dict__}, sort_dicts=False)
|
|
|
|
|
|
def wrap(cls):
|
|
|
- if not repr and "__repr__" not in cls.__dict__:
|
|
|
- setattr(cls, "__repr__", _yaml_repr_)
|
|
|
+ if not repr:
|
|
|
+ _set_attr(cls, "__repr__", _yaml_repr_)
|
|
|
+
|
|
|
return dataclass(cls, *args, repr=repr, **kwargs)
|
|
|
|
|
|
# See if we're being called as @cvdataclass or @cvdataclass().
|
|
@@ -72,13 +104,6 @@ class FieldWrapper:
|
|
|
|
|
|
return self.field.type._choices
|
|
|
|
|
|
-def get_arglist_from_data_class(cls) -> list:
|
|
|
- arglist = []
|
|
|
- for field in fields(cls):
|
|
|
- arglist.append(FieldWrapper(field).as_arg())
|
|
|
-
|
|
|
- return arglist, getattr(cls, "group_name", None)
|
|
|
-
|
|
|
class Choices:
|
|
|
|
|
|
def __init__(self, choices, type):
|