|
@@ -7,6 +7,10 @@ import logging
|
|
from functools import partial
|
|
from functools import partial
|
|
from tabulate import tabulate
|
|
from tabulate import tabulate
|
|
|
|
|
|
|
|
+from chainer.function_hooks import TimerHook
|
|
|
|
+from chainer.function_hooks import CupyMemoryProfileHook
|
|
|
|
+
|
|
|
|
+
|
|
def get_attr_from_path(obj, path, *, sep="/"):
|
|
def get_attr_from_path(obj, path, *, sep="/"):
|
|
def getter(o, attr):
|
|
def getter(o, attr):
|
|
return
|
|
return
|
|
@@ -66,7 +70,15 @@ def print_model_info(model, file=sys.stdout, input_size=None, input_var=None):
|
|
_print(f"Printing some information about \"{name}\" model")
|
|
_print(f"Printing some information about \"{name}\" model")
|
|
_print(tabulate(rows, tablefmt="fancy_grid"))
|
|
_print(tabulate(rows, tablefmt="fancy_grid"))
|
|
|
|
|
|
- shapes = _get_activation_shapes(model, input_size or default_size, input_var)
|
|
|
|
|
|
+ timer_hook = TimerHook()
|
|
|
|
+ memory_hook = CupyMemoryProfileHook()
|
|
|
|
+
|
|
|
|
+ with timer_hook, memory_hook:
|
|
|
|
+ shapes = _get_activation_shapes(model, input_size or default_size, input_var)
|
|
|
|
+
|
|
|
|
+ timer_hook.print_report()
|
|
|
|
+ memory_hook.print_report()
|
|
|
|
+
|
|
_print("In/Out activation shapes:")
|
|
_print("In/Out activation shapes:")
|
|
_print(tabulate(shapes,
|
|
_print(tabulate(shapes,
|
|
headers=["Link name", "Input", "Output"],
|
|
headers=["Link name", "Input", "Output"],
|