fake_impl.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. # mypy: allow-untyped-defs
  2. import contextlib
  3. import functools
  4. from collections.abc import Callable
  5. from typing_extensions import deprecated
  6. import torch
  7. from torch._library.utils import Kernel, RegistrationHandle
  8. class FakeImplHolder:
  9. """A holder where one can register an fake impl to."""
  10. def __init__(self, qualname: str):
  11. self.qualname: str = qualname
  12. # kernels stores all registered fake kernels, ordered by registration
  13. # time ascendingly (newer registration after older registration). If an
  14. # operator library gets loaded that overrides an existing fake kernel,
  15. # both kernels will be in the list, but the newest one will be the one
  16. # that is run. If the library is unloaded, we will remove the kernel
  17. # from this list.
  18. self.kernels: list[Kernel] = []
  19. @property
  20. def kernel(self):
  21. if len(self.kernels) == 0:
  22. return None
  23. return self.kernels[-1]
  24. @kernel.setter
  25. def kernel(self, value):
  26. raise RuntimeError("Unable to directly set kernel.")
  27. def register(
  28. self, func: Callable, source: str, lib, *, allow_override=False
  29. ) -> RegistrationHandle:
  30. """Register an fake impl.
  31. Returns a RegistrationHandle that one can use to de-register this
  32. fake impl.
  33. """
  34. if not allow_override:
  35. if self.kernel is not None:
  36. raise RuntimeError(
  37. f"register_fake(...): the operator {self.qualname} "
  38. f"already has an fake impl registered at "
  39. f"{self.kernel.source}."
  40. )
  41. if torch._C._dispatch_has_kernel_for_dispatch_key(self.qualname, "Meta"):
  42. raise RuntimeError(
  43. f"register_fake(...): the operator {self.qualname} "
  44. f"already has an DispatchKey::Meta implementation via a "
  45. f"pre-existing torch.library or TORCH_LIBRARY registration. "
  46. f"Please either remove that registration or don't call "
  47. f"register_fake."
  48. )
  49. if torch._C._dispatch_has_kernel_for_dispatch_key(
  50. self.qualname, "CompositeImplicitAutograd"
  51. ):
  52. raise RuntimeError(
  53. f"register_fake(...): the operator {self.qualname} "
  54. f"already has an implementation for this device type via a "
  55. f"pre-existing registration to "
  56. f"DispatchKey::CompositeImplicitAutograd."
  57. f"CompositeImplicitAutograd operators do not need an fake "
  58. f"impl; "
  59. f"instead, the operator will decompose into its constituents "
  60. f"and those "
  61. f"can have fake impls defined on them."
  62. )
  63. # Store the kernel in this holder
  64. kernel = Kernel(func, source)
  65. self.kernels.append(kernel)
  66. def deregister_fake_kernel():
  67. self.kernels.remove(kernel)
  68. meta_kernel = construct_meta_kernel(self.qualname, self)
  69. lib.impl(self.qualname, meta_kernel, "Meta", allow_override=allow_override)
  70. handle = RegistrationHandle(deregister_fake_kernel)
  71. return handle
  72. def construct_meta_kernel(qualname: str, fake_impl_holder: FakeImplHolder) -> Callable:
  73. if fake_impl_holder.kernel is None:
  74. raise AssertionError("fake_impl_holder.kernel must not be None")
  75. @functools.wraps(fake_impl_holder.kernel.func)
  76. def meta_kernel(*args, **kwargs):
  77. if fake_impl_holder.kernel is None:
  78. raise AssertionError("fake_impl_holder.kernel must not be None")
  79. source = fake_impl_holder.kernel.source
  80. def error_on_ctx():
  81. raise RuntimeError(
  82. f"{qualname} ({source}): You're trying to run this operator "
  83. f"with meta Tensors (as opposed to FakeTensors), but this "
  84. f"operator may return an output Tensor with data-dependent shape. Meta "
  85. f"Tensors don't support operators with outputs that have data-dependent shapes "
  86. f"but FakeTensors do. "
  87. f"If your operator does not return an output with data-dependent shape, "
  88. f"make sure the FakeTensor and/or meta kernel does not call "
  89. f"torch.library.get_ctx(). Otherwise, please use FakeTensors."
  90. )
  91. with set_ctx_getter(error_on_ctx):
  92. return fake_impl_holder.kernel(*args, **kwargs)
  93. return meta_kernel
  94. def get_none():
  95. return None
  96. global_ctx_getter: Callable = get_none
  97. @contextlib.contextmanager
  98. def set_ctx_getter(ctx_getter):
  99. global global_ctx_getter
  100. prev = global_ctx_getter
  101. try:
  102. global_ctx_getter = ctx_getter
  103. yield
  104. finally:
  105. global_ctx_getter = prev
  106. class FakeImplCtx:
  107. """
  108. Context object for writing fake implementations for custom operators.
  109. """
  110. def __init__(self, _fake_mode, _op):
  111. self._fake_mode = _fake_mode
  112. self._shape_env = _fake_mode.shape_env
  113. self._op = _op
  114. @deprecated(
  115. "`create_unbacked_symint` is deprecated, please use `new_dynamic_size` instead",
  116. category=FutureWarning,
  117. )
  118. def create_unbacked_symint(self, *, min=2, max=None) -> torch.SymInt:
  119. return self.new_dynamic_size(min=min, max=max)
  120. def new_dynamic_size(self, *, min=0, max=None) -> torch.SymInt:
  121. """Constructs a new symint (symbolic int) representing a data-dependent value.
  122. This is useful for writing the fake implementation (which is necessary
  123. for torch.compile) for a CustomOp where an output Tensor has a size
  124. that depends on the data of the input Tensors.
  125. Args:
  126. min (int): A statically known inclusive lower bound for this symint. Default: 0
  127. max (Optional[int]): A statically known inclusive upper bound for this
  128. symint. Default: None
  129. .. warning:
  130. It is important that the ``min`` and ``max`` (if not None) values are set
  131. correctly, otherwise, there will be undefined behavior under
  132. torch.compile. The default value of ``min`` is 2 due to torch.compile
  133. specializing on 0/1 sizes.
  134. You must also verify that your implementation on concrete Tensors
  135. (e.g. CPU/CUDA) only returns Tensors where the size that corresponds
  136. to the symint also has respects these constraint.
  137. The easiest way to do this is to add an assertion in the CPU/CUDA/etc
  138. implementation that the size follows these bounds.
  139. Example::
  140. >>> # An operator with data-dependent output shape
  141. >>> lib = torch.library.Library("mymodule", "FRAGMENT")
  142. >>> lib.define("mymodule::custom_nonzero(Tensor x) -> Tensor")
  143. >>>
  144. >>> @torch.library.register_fake("mymodule::custom_nonzero")
  145. >>> def _(x):
  146. >>> # Number of nonzero-elements is data-dependent.
  147. >>> # Since we cannot peek at the data in an fake impl,
  148. >>> # we use the ctx object to construct a new symint that
  149. >>> # represents the data-dependent size.
  150. >>> ctx = torch.library.get_ctx()
  151. >>> nnz = ctx.new_dynamic_size()
  152. >>> shape = [nnz, x.dim()]
  153. >>> result = x.new_empty(shape, dtype=torch.int64)
  154. >>> return result
  155. >>>
  156. >>> @torch.library.impl(lib, "custom_nonzero", "CPU")
  157. >>> def _(x):
  158. >>> x_np = x.numpy()
  159. >>> res = np.stack(np.nonzero(x_np), axis=1)
  160. >>> return torch.tensor(res, device=x.device)
  161. """
  162. if (
  163. self._shape_env is None
  164. or not self._shape_env.allow_dynamic_output_shape_ops
  165. ):
  166. raise torch._subclasses.fake_tensor.DynamicOutputShapeException(self._op)
  167. if isinstance(min, torch.SymInt) or isinstance(max, torch.SymInt):
  168. raise ValueError(
  169. f"ctx.new_dynamic_size(min={min}, max={max}): expected "
  170. f"min and max to be statically known ints but got SymInt. "
  171. f"This is not supported."
  172. )
  173. if min < 0:
  174. raise ValueError(
  175. f"ctx.new_dynamic_size(min={min}, ...): expected min to be "
  176. f"greater than or equal to 0: this API can only create "
  177. f"non-negative sizes."
  178. )
  179. return allocate_size(self._shape_env, min, max)
  180. def allocate_size(shape_env, min_val=0, max_val=None):
  181. result = shape_env.create_unbacked_symint()
  182. torch.fx.experimental.symbolic_shapes._constrain_range_for_size(
  183. result, min=min_val, max=max_val
  184. )
  185. return result