hooks.py 666 B

12345678910111213141516171819202122232425262728293031
  1. # mypy: allow-untyped-defs
  2. import contextlib
  3. from collections.abc import Callable
  4. from typing import TYPE_CHECKING
  5. if TYPE_CHECKING:
  6. import torch
  7. # Executed in the order they're registered
  8. INTERMEDIATE_HOOKS: list[Callable[[str, "torch.Tensor"], None]] = []
  9. @contextlib.contextmanager
  10. def intermediate_hook(fn):
  11. INTERMEDIATE_HOOKS.append(fn)
  12. try:
  13. yield
  14. finally:
  15. INTERMEDIATE_HOOKS.pop()
  16. def run_intermediate_hooks(name, val):
  17. global INTERMEDIATE_HOOKS
  18. hooks = INTERMEDIATE_HOOKS
  19. INTERMEDIATE_HOOKS = []
  20. try:
  21. for hook in hooks:
  22. hook(name, val)
  23. finally:
  24. INTERMEDIATE_HOOKS = hooks