test_utils.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import chainer
  2. import inspect
  3. import io
  4. from contextlib import contextmanager
  5. from functools import partial
  6. def get_class_that_defined_method(meth):
  7. if inspect.isfunction(meth):
  8. cls_name = meth.__qualname__.split('.<locals>', 1)[0].rsplit('.', 1)[0]
  9. return getattr(inspect.getmodule(meth), cls_name, None)
  10. def wrapper(func, key):
  11. def inner(self):
  12. return func(self, key)
  13. return inner
  14. def add_tests(func, model_list) -> None:
  15. cls = get_class_that_defined_method(func)
  16. for key in model_list:
  17. new_func = wrapper(func, key)
  18. name = f"test_{key.replace('.', '__')}_{func.__name__}"
  19. new_func.__name__ = name
  20. setattr(cls, name, new_func)
  21. def is_all_equal(model0: chainer.Chain, model1: chainer.Chain, strict: bool = False, exclude_clf: bool = False) -> bool:
  22. params0 = dict(model0.namedparams())
  23. params1 = dict(model1.namedparams())
  24. for name in params0:
  25. if exclude_clf and model0.clf_layer_name in name:
  26. continue
  27. param0, param1 = params0[name], params1[name]
  28. if param0.shape != param1.shape:
  29. if strict:
  30. return False, f"shape of {name} was not the same: {param0.shape} != {param1.shape}"
  31. else:
  32. continue
  33. if not (param0.array == param1.array).all():
  34. return False, f"array of {name} was not the same"
  35. return True, "All equal!"
  36. def is_any_different(model0: chainer.Chain, model1: chainer.Chain, strict: bool = False, exclude_clf: bool = False) -> bool:
  37. params0 = dict(model0.namedparams())
  38. params1 = dict(model1.namedparams())
  39. for name in params0:
  40. if exclude_clf and model0.clf_layer_name in name:
  41. continue
  42. param0, param1 = params0[name], params1[name]
  43. # print(name)
  44. if param0.shape != param1.shape:
  45. if strict:
  46. return False, f"shape of {name} was not the same: {param0.shape} != {param1.shape}"
  47. else:
  48. continue
  49. if (param0.array != param1.array).any():
  50. return True, f"Difference in array {name} found."
  51. return False, "All equal!"
  52. @contextmanager
  53. def memory_file() -> io.BytesIO:
  54. yield io.BytesIO()
  55. @contextmanager
  56. def clear_print(msg):
  57. print(msg)
  58. yield
  59. print("\033[A{}\033[A".format(" "*len(msg)))