context.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. from __future__ import annotations
  2. import functools
  3. from contextlib import nullcontext
  4. from typing import Any, TYPE_CHECKING, TypeVar
  5. from typing_extensions import ParamSpec
  6. if TYPE_CHECKING:
  7. from collections.abc import Callable, Sequence
  8. import torch
  9. import torch._decomp
  10. import torch._prims
  11. import torch._refs
  12. import torch._refs.nn
  13. import torch._refs.nn.functional
  14. import torch._refs.special
  15. import torch.overrides
  16. from torch._prims_common import torch_function_passthrough
  17. _P = ParamSpec("_P")
  18. _R = TypeVar("_R")
  19. @functools.cache
  20. def torch_to_refs_map() -> dict[Any, Any]:
  21. """
  22. Mapping of torch API functions to torch._refs functions.
  23. E.g. torch_to_refs_map()[torch.add] == torch._refs.add
  24. """
  25. modules = [
  26. (torch, torch._refs),
  27. (torch.nn, torch._refs.nn),
  28. (torch.nn.functional, torch._refs.nn.functional),
  29. (torch.special, torch._refs.special),
  30. (torch.fft, torch._refs.fft),
  31. (torch.linalg, torch._refs.linalg),
  32. ]
  33. r: dict[Any, Any] = {
  34. torch.Tensor.__invert__: torch._refs.bitwise_not,
  35. torch.Tensor.__xor__: torch._refs.bitwise_xor,
  36. torch.Tensor.__and__: torch._refs.bitwise_and,
  37. torch.Tensor.__or__: torch._refs.bitwise_or,
  38. torch.Tensor.__eq__: torch._refs.eq,
  39. torch.Tensor.__rsub__: torch._refs.rsub,
  40. torch.Tensor.__rtruediv__: torch._refs.rtruediv,
  41. torch.Tensor.__floordiv__: torch._refs.floor_divide,
  42. torch.Tensor.__rfloordiv__: torch._refs.rfloordiv,
  43. torch.Tensor.__pow__: torch._refs.pow,
  44. torch.Tensor.__rpow__: torch._refs.rpow,
  45. torch.Tensor.new_empty: torch._refs.new_empty,
  46. torch.Tensor.new_full: torch._refs.new_full,
  47. torch.Tensor.new_zeros: torch._refs.new_zeros,
  48. torch.Tensor.new_ones: torch._refs.new_ones,
  49. torch.Tensor.fill_: torch._refs.fill_,
  50. torch.Tensor.zero_: torch._refs.zero_,
  51. torch.Tensor.to: torch._refs.to,
  52. torch.Tensor.sum_to_size: torch._refs.sum_to_size,
  53. # TODO: Should these methods be mapped some other way?
  54. torch.Tensor.copy_: torch._prims.copy_to,
  55. torch.Tensor.resize: torch._prims.resize,
  56. }
  57. for mod_torch, mod_refs in modules:
  58. for s in mod_refs.__all__: # type: ignore[attr-defined]
  59. r[mod_torch.__dict__.get(s)] = mod_refs.__dict__.get(s)
  60. # Support remapping torch.Tensor.foo to _refs.foo
  61. for s in dir(torch.Tensor):
  62. if s in torch._refs.__all__:
  63. r[getattr(torch.Tensor, s)] = torch._refs.__dict__.get(s)
  64. # Support conversions
  65. for s in torch._refs._conversions.__all__:
  66. tensor_attr = getattr(torch.Tensor, s, None) or getattr(torch, s)
  67. r[tensor_attr] = torch._refs._conversions.__dict__.get(s)
  68. return r
  69. @functools.cache
  70. def all_prims() -> set[Any]:
  71. """
  72. Set of all prim functions, e.g., torch._prims.add in all_prims()
  73. """
  74. return {torch._prims.__dict__.get(s) for s in torch._prims.__all__}
  75. class TorchRefsMode(torch.overrides.TorchFunctionMode):
  76. """
  77. Switches the interpretation of torch.* functions and Tensor methods to
  78. use PrimTorch refs in torch._refs. (Direct calls to _refs are unaffected.)
  79. >>> # xdoctest: +SKIP
  80. >>> with TorchRefsMode():
  81. ... torch.add(x, y) # calls torch._refs.add(x, y)
  82. By default, this context manager will fall back on the torch.* if the
  83. ref does not exist; set strict=True to error if this occurs.
  84. If the ref exists we still would like to fall back on the torch.* sometimes,
  85. this behavior can be customized by passing a function to should_fallback_fn.
  86. """
  87. def __init__(
  88. self,
  89. strict: bool = False,
  90. should_fallback_fn: Callable[..., bool] = lambda *_: False,
  91. prims_mode_cls: type = nullcontext,
  92. ) -> None:
  93. self.strict = strict
  94. self.should_fallback_fn = should_fallback_fn
  95. self.prims_mode_cls = prims_mode_cls
  96. def __torch_function__(
  97. self,
  98. orig_func: Callable[_P, _R],
  99. types: Sequence[type],
  100. args: Sequence[Any] = (),
  101. kwargs: dict[str, Any] | None = None,
  102. ) -> Any:
  103. if kwargs is None:
  104. kwargs = {}
  105. # For primitive operations, run them as is without interception
  106. # Unless we are in prims_mode, in which case we want to use nvprims
  107. if orig_func in torch_function_passthrough or orig_func in all_prims():
  108. with self.prims_mode_cls():
  109. # pyrefly: ignore [invalid-param-spec]
  110. return orig_func(*args, **kwargs)
  111. mapping = torch_to_refs_map()
  112. func = mapping.get(orig_func, None)
  113. # For torch.ops.aten.*, use registered decompositions from torch._decomp
  114. # torch._decomp.decomposition_table provides a mapping from
  115. # torch.ops.aten.* to torch._refs or torch._decomp.decompositions
  116. # implementations.
  117. # There're other ways to implement this functionality,
  118. # see https://github.com/pytorch/pytorch/pull/82657#discussion_r939776417
  119. if func is None and isinstance(orig_func, torch._ops.OpOverload):
  120. func = torch._decomp.decomposition_table.get(orig_func, None)
  121. elif func is None and isinstance(orig_func, torch._ops.OpOverloadPacket):
  122. default = getattr(orig_func, "default", None)
  123. if default is None and orig_func._dir:
  124. default = getattr(orig_func, orig_func._dir[0], None)
  125. if default is not None:
  126. func = torch._decomp.decomposition_table.get(default, None)
  127. if func is not None:
  128. # If the ref exists query whether we should use it or not
  129. if self.should_fallback_fn(self, orig_func, func, args, kwargs):
  130. # pyrefly: ignore [invalid-param-spec]
  131. return orig_func(*args, **kwargs)
  132. # torch calls inside func should be interpreted as refs calls
  133. with self:
  134. return func(*args, **kwargs)
  135. if self.strict:
  136. raise RuntimeError(
  137. f"no _refs support for {torch.overrides.resolve_name(orig_func)}"
  138. )
  139. # pyrefly: ignore [invalid-param-spec]
  140. return orig_func(*args, **kwargs)