test_utils.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. import chainer
  2. import inspect
  3. import io
  4. from contextlib import contextmanager
  5. from functools import partial
  6. def get_class_that_defined_method(meth):
  7. if inspect.isfunction(meth):
  8. cls_name = meth.__qualname__.split('.<locals>', 1)[0].rsplit('.', 1)[0]
  9. return getattr(inspect.getmodule(meth), cls_name, None)
  10. def wrapper(func, key):
  11. def inner(self):
  12. return func(self, key)
  13. return inner
  14. def add_tests(func, model_list) -> None:
  15. cls = get_class_that_defined_method(func)
  16. for key in model_list:
  17. new_func = wrapper(func, key)
  18. name = f"test_{key.replace('.', '__')}_{func.__name__}"
  19. new_func.__name__ = name
  20. setattr(cls, name, new_func)
  21. def is_all_equal(model0: chainer.Chain, model1: chainer.Chain, strict: bool = False) -> bool:
  22. params0 = dict(model0.namedparams())
  23. params1 = dict(model1.namedparams())
  24. for name in params0:
  25. param0, param1 = params0[name], params1[name]
  26. if param0.shape != param1.shape:
  27. if strict:
  28. return False, f"shape of {name} was not the same: {param0.shape} != {param1.shape}"
  29. else:
  30. continue
  31. if not (param0.array == param1.array).all():
  32. return False, f"array of {name} was not the same"
  33. return True, "All equal!"
  34. def is_any_different(model0: chainer.Chain, model1: chainer.Chain, strict: bool = False) -> bool:
  35. params0 = dict(model0.namedparams())
  36. params1 = dict(model1.namedparams())
  37. for name in params0:
  38. param0, param1 = params0[name], params1[name]
  39. # print(name)
  40. if param0.shape != param1.shape:
  41. if strict:
  42. return False, f"shape of {name} was not the same: {param0.shape} != {param1.shape}"
  43. else:
  44. continue
  45. if (param0.array != param1.array).any():
  46. return True, f"Difference in array {name} found."
  47. return False, "All equal!"
  48. @contextmanager
  49. def memory_file() -> io.BytesIO:
  50. yield io.BytesIO()
  51. @contextmanager
  52. def clear_print(msg):
  53. print(msg)
  54. yield
  55. print("\033[A{}\033[A".format(" "*len(msg)))