| 12345678910111213141516171819202122232425262728293031323334 |
- from contextlib import nullcontext
- from functools import wraps
- from typing import Callable, Optional, Tuple, Type, TypeVar, Union, overload, ContextManager
- import torch
- __all__ = ["LayerType", "PadType", "nullwrap", "disable_compiler"]
- LayerType = Union[str, Callable, Type[torch.nn.Module]]
- PadType = Union[str, int, Tuple[int, int]]
- F = TypeVar("F", bound=Callable[..., object])
- @overload
- def nullwrap(fn: F) -> F: ... # decorator form
- @overload
- def nullwrap(fn: None = ...) -> ContextManager: ... # context‑manager form
- def nullwrap(fn: Optional[F] = None):
- # as a context manager
- if fn is None:
- return nullcontext() # `with nullwrap():`
- # as a decorator
- @wraps(fn)
- def wrapper(*args, **kwargs):
- return fn(*args, **kwargs)
- return wrapper # `@nullwrap`
- disable_compiler = getattr(getattr(torch, "compiler", None), "disable", None) or nullwrap
|