12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667 |
- import inspect
- import chainer
- from contextlib import contextmanager
- from functools import partial
- def get_class_that_defined_method(meth):
- if inspect.isfunction(meth):
- cls_name = meth.__qualname__.split('.<locals>', 1)[0].rsplit('.', 1)[0]
- return getattr(inspect.getmodule(meth), cls_name, None)
- def wrapper(func, key):
- def inner(self):
- return func(self, key)
- return inner
- def add_tests(func, model_list) -> None:
- cls = get_class_that_defined_method(func)
- for key in model_list:
- new_func = wrapper(func, key)
- name = f"test_{key.replace('.', '__')}_{func.__name__}"
- new_func.__name__ = name
- setattr(cls, name, new_func)
- def is_all_equal(model0: chainer.Chain, model1: chainer.Chain, strict: bool = False) -> bool:
- params0 = dict(model0.namedparams())
- params1 = dict(model1.namedparams())
- for name in params0:
- param0, param1 = params0[name], params1[name]
- if param0.shape != param1.shape:
- if strict:
- return False
- else:
- continue
- if not (param0.array == param1.array).all():
- return False
- return True
- def is_any_different(model0: chainer.Chain, model1: chainer.Chain, strict: bool = False) -> bool:
- params0 = dict(model0.namedparams())
- params1 = dict(model1.namedparams())
- for name in params0:
- param0, param1 = params0[name], params1[name]
- # print(name)
- if param0.shape != param1.shape:
- if strict:
- return False
- else:
- continue
- if (param0.array != param1.array).any():
- return True
- return False
- @contextmanager
- def clear_print(msg):
- print(msg)
- yield
- print("\033[A{}\033[A".format(" "*len(msg)))
|