typing.py 903 B

12345678910111213141516171819202122232425262728293031323334
  1. from contextlib import nullcontext
  2. from functools import wraps
  3. from typing import Callable, Optional, Tuple, Type, TypeVar, Union, overload, ContextManager
  4. import torch
  5. __all__ = ["LayerType", "PadType", "nullwrap", "disable_compiler"]
  6. LayerType = Union[str, Callable, Type[torch.nn.Module]]
  7. PadType = Union[str, int, Tuple[int, int]]
  8. F = TypeVar("F", bound=Callable[..., object])
  9. @overload
  10. def nullwrap(fn: F) -> F: ... # decorator form
  11. @overload
  12. def nullwrap(fn: None = ...) -> ContextManager: ... # context‑manager form
  13. def nullwrap(fn: Optional[F] = None):
  14. # as a context manager
  15. if fn is None:
  16. return nullcontext() # `with nullwrap():`
  17. # as a decorator
  18. @wraps(fn)
  19. def wrapper(*args, **kwargs):
  20. return fn(*args, **kwargs)
  21. return wrapper # `@nullwrap`
  22. disable_compiler = getattr(getattr(torch, "compiler", None), "disable", None) or nullwrap