test_utils.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import inspect
  2. import chainer
  3. from contextlib import contextmanager
  4. from functools import partial
  5. def get_class_that_defined_method(meth):
  6. if inspect.isfunction(meth):
  7. cls_name = meth.__qualname__.split('.<locals>', 1)[0].rsplit('.', 1)[0]
  8. return getattr(inspect.getmodule(meth), cls_name, None)
  9. def wrapper(func, key):
  10. def inner(self):
  11. return func(self, key)
  12. return inner
  13. def add_tests(func, model_list) -> None:
  14. cls = get_class_that_defined_method(func)
  15. for key in model_list:
  16. new_func = wrapper(func, key)
  17. name = f"test_{key.replace('.', '__')}_{func.__name__}"
  18. new_func.__name__ = name
  19. setattr(cls, name, new_func)
  20. def is_all_equal(model0: chainer.Chain, model1: chainer.Chain, strict: bool = False) -> bool:
  21. params0 = dict(model0.namedparams())
  22. params1 = dict(model1.namedparams())
  23. for name in params0:
  24. param0, param1 = params0[name], params1[name]
  25. if param0.shape != param1.shape:
  26. if strict:
  27. return False
  28. else:
  29. continue
  30. if not (param0.array == param1.array).all():
  31. return False
  32. return True
  33. def is_any_different(model0: chainer.Chain, model1: chainer.Chain, strict: bool = False) -> bool:
  34. params0 = dict(model0.namedparams())
  35. params1 = dict(model1.namedparams())
  36. for name in params0:
  37. param0, param1 = params0[name], params1[name]
  38. # print(name)
  39. if param0.shape != param1.shape:
  40. if strict:
  41. return False
  42. else:
  43. continue
  44. if (param0.array != param1.array).any():
  45. return True
  46. return False
  47. @contextmanager
  48. def clear_print(msg):
  49. print(msg)
  50. yield
  51. print("\033[A{}\033[A".format(" "*len(msg)))