_stats.py 1.0 KB

12345678910111213141516171819202122232425262728293031
  1. # NOTE! PLEASE KEEP THIS FILE *FREE* OF TORCH DEPS! IT SHOULD BE IMPORTABLE ANYWHERE.
  2. # IF YOU FEEL AN OVERWHELMING URGE TO ADD A TORCH DEP, MAKE A TRAMPOLINE FILE A LA torch._dynamo.utils
  3. # AND SCRUB AWAY TORCH NOTIONS THERE.
  4. import collections
  5. import functools
  6. from collections import OrderedDict
  7. from collections.abc import Callable
  8. from typing import TypeVar
  9. from typing_extensions import ParamSpec
  10. simple_call_counter: OrderedDict[str, int] = collections.OrderedDict()
  11. _P = ParamSpec("_P")
  12. _R = TypeVar("_R")
  13. def count_label(label: str) -> None:
  14. prev = simple_call_counter.setdefault(label, 0)
  15. simple_call_counter[label] = prev + 1
  16. def count(fn: Callable[_P, _R]) -> Callable[_P, _R]:
  17. @functools.wraps(fn)
  18. def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
  19. if fn.__qualname__ not in simple_call_counter:
  20. simple_call_counter[fn.__qualname__] = 0
  21. simple_call_counter[fn.__qualname__] = simple_call_counter[fn.__qualname__] + 1
  22. return fn(*args, **kwargs)
  23. return wrapper