|
@@ -1,9 +1,11 @@
|
|
import pyaml
|
|
import pyaml
|
|
|
|
+import types
|
|
|
|
|
|
from dataclasses import dataclass
|
|
from dataclasses import dataclass
|
|
from typing import Tuple
|
|
from typing import Tuple
|
|
from typing import Callable
|
|
from typing import Callable
|
|
|
|
|
|
|
|
+pyaml.add_representer(types.FunctionType, lambda cls, func: cls.represent_data(str(func)))
|
|
|
|
|
|
@dataclass
|
|
@dataclass
|
|
class ModelInfo(object):
|
|
class ModelInfo(object):
|
|
@@ -18,7 +20,7 @@ class ModelInfo(object):
|
|
|
|
|
|
classifier_layers: Tuple[str] = ("fc",)
|
|
classifier_layers: Tuple[str] = ("fc",)
|
|
|
|
|
|
- prepare_func: Callable = None
|
|
|
|
|
|
+ prepare_func: Callable = lambda x: x
|
|
|
|
|
|
def __str__(self):
|
|
def __str__(self):
|
|
obj = dict(ModelInfo=self.__dict__)
|
|
obj = dict(ModelInfo=self.__dict__)
|
|
@@ -26,4 +28,5 @@ class ModelInfo(object):
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
if __name__ == '__main__':
|
|
- print(ModelInfo())
|
|
|
|
|
|
+ info = ModelInfo()
|
|
|
|
+ print(info)
|