rng_prims.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. # mypy: allow-untyped-defs
  2. from typing import cast
  3. import torch
  4. import torch.utils._pytree as pytree
  5. from torch import _prims
  6. from torch._C import DispatchKey
  7. from torch._higher_order_ops.utils import autograd_not_implemented
  8. from torch._ops import HigherOrderOperator
  9. from torch._prims_common import CUDARngStateHelper, make_contiguous_strides_for
  10. from torch._subclasses.fake_tensor import FakeTensorMode
  11. from torch.fx.experimental.proxy_tensor import (
  12. disable_proxy_modes_tracing,
  13. ProxyTorchDispatchMode,
  14. track_tensor_tree,
  15. )
  16. from torch.types import _device, _dtype
  17. def throw_on_non_cuda(device):
  18. raise RuntimeError(
  19. f"You are trying to functionalize a {device.type} RNG operator but {device.type} does not "
  20. f"use Philox/counter-based RNG. Therefore, functionalizing a {device.type} RNG operator is "
  21. "not supported. We are discussing the possibility of a Philox-based RNG implementation for CPU."
  22. )
  23. def register_rng_prim(name, schema, impl_aten, impl_meta, doc, tags=None):
  24. rngprim_def = torch.library.custom_op(
  25. "rngprims::" + name, impl_aten, mutates_args=(), schema=schema
  26. )
  27. rngprim_def.register_fake(impl_meta)
  28. prim_packet = getattr(torch._ops.ops.rngprims, name)
  29. prim = prim_packet.default
  30. if tags:
  31. prim._tags = tags
  32. for p in (prim_packet, prim):
  33. p.__doc__ = doc
  34. p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
  35. p.schema = name + schema
  36. p.impl_aten = impl_aten
  37. p.prim_meta_impl = impl_meta
  38. # Philox rand offsets could be shared in future with other philox ops, so
  39. # keeping these functions in global scope.
  40. def philox_rand_offset_meta(
  41. shape: torch.Size,
  42. ):
  43. return _prims.TensorLike(torch.tensor(0, dtype=torch.int64))
  44. def philox_rand_offset(
  45. shape: torch.Size,
  46. ):
  47. # For impl, look at the function calc_execution_policy in the file
  48. # aten/src/ATen/native/cuda/DistributionTemplates.h. The impl was copied at
  49. # commit hash 72aa0667bd16707d50eb8fa337092a1f5d11dfb6
  50. numel_scalar = 1
  51. for dim_size in shape:
  52. numel_scalar *= dim_size
  53. numel = torch.scalar_tensor(numel_scalar, dtype=torch.int64)
  54. block_size = 256
  55. unroll = 4
  56. curand4_engine_calls = 4
  57. device_property = torch.cuda.get_device_properties(torch.cuda.current_device())
  58. blocks_per_sm = device_property.max_threads_per_multi_processor // block_size
  59. num = cast(int, numel)
  60. grid_size = (num + block_size - 1) // block_size
  61. grid_size = min(grid_size, device_property.multi_processor_count * blocks_per_sm)
  62. return ((num - 1) // (block_size * grid_size * unroll) + 1) * curand4_engine_calls
  63. def register_philox_rand():
  64. name = "philox_rand"
  65. schema = "(SymInt[] size, Tensor seed, Tensor offset, int[]? stride, Device? device=None, ScalarType? dtype=None) -> (Tensor, Tensor)" # noqa: B950
  66. def _philox_rand_meta(
  67. shape: torch.Size,
  68. seed: torch.Tensor,
  69. offset: torch.Tensor,
  70. stride: tuple[int, ...] | None,
  71. device: _device,
  72. dtype: _dtype,
  73. ):
  74. # stride arg will be useful for distributed usecase. Currently, its unused.
  75. if stride is not None:
  76. raise AssertionError(f"stride must be None, got {stride}")
  77. stride = make_contiguous_strides_for(shape)
  78. random_values = _prims.TensorMeta(
  79. shape=shape, strides=stride, dtype=dtype, device=device
  80. )
  81. offset = philox_rand_offset_meta(shape)
  82. return (random_values, offset)
  83. def _philox_rand(
  84. shape: torch.Size,
  85. seed: torch.Tensor,
  86. offset: torch.Tensor,
  87. stride: tuple[int, ...] | None,
  88. device: _device,
  89. dtype: _dtype,
  90. ):
  91. # stride arg will be useful for distributed usecase. Currently, its unused.
  92. if stride is not None:
  93. raise AssertionError(f"stride must be None, got {stride}")
  94. if device.type == "cpu":
  95. devices = []
  96. else:
  97. devices = [device]
  98. if device.type != "cuda":
  99. raise throw_on_non_cuda(device)
  100. with torch.random.fork_rng(devices):
  101. CUDARngStateHelper.set_torch_state_tensor(seed, offset)
  102. random_values = torch.rand(shape, device=device, dtype=dtype)
  103. return random_values, philox_rand_offset(shape)
  104. register_rng_prim(
  105. name=name,
  106. schema=schema,
  107. impl_aten=_philox_rand,
  108. impl_meta=_philox_rand_meta,
  109. doc="Philox based stateless rand operator",
  110. tags=(torch.Tag.nondeterministic_seeded,),
  111. )
  112. def get_device(args, kwargs):
  113. if kwargs.get("device"):
  114. device = kwargs.get("device")
  115. if isinstance(device, str):
  116. device = torch.device(device)
  117. return device.type
  118. devices = {arg.device.type for arg in args if isinstance(arg, torch.Tensor)}
  119. if any(dev == "cuda" for dev in devices):
  120. return "cuda"
  121. elif any(dev == "xpu" for dev in devices):
  122. return "xpu"
  123. elif any(dev == "hpu" for dev in devices):
  124. return "hpu"
  125. elif any(dev == "cpu" for dev in devices):
  126. return "cpu"
  127. return None
  128. def register_run_and_save_rng_state_op():
  129. class RunAndSaveRngState(HigherOrderOperator):
  130. def __init__(self):
  131. super().__init__("run_and_save_rng_state", cacheable=True)
  132. def __call__(self, op, *args, **kwargs):
  133. # pyrefly: ignore [missing-attribute]
  134. return super().__call__(op, *args, **kwargs)
  135. run_and_save_rng_state = RunAndSaveRngState()
  136. run_and_save_rng_state.py_impl(DispatchKey.Autograd)(
  137. autograd_not_implemented(run_and_save_rng_state, deferred_error=True)
  138. )
  139. @run_and_save_rng_state.py_impl(DispatchKey.CUDA)
  140. def impl_cuda(op, *args, **kwargs):
  141. return torch.cuda.get_rng_state(), op(*args, **kwargs)
  142. @run_and_save_rng_state.py_impl(DispatchKey.CPU)
  143. def impl_cpu(op, *args, **kwargs):
  144. return torch.get_rng_state(), op(*args, **kwargs)
  145. @run_and_save_rng_state.py_impl(DispatchKey.HPU)
  146. def impl_hpu(op, *args, **kwargs):
  147. if hasattr(torch, "hpu"):
  148. return torch.hpu.get_rng_state(), op(*args, **kwargs)
  149. raise RuntimeError("functionalize a hpu RNG operator is not supported.")
  150. @run_and_save_rng_state.py_impl(DispatchKey.XPU)
  151. def impl_xpu(op, *args, **kwargs):
  152. return torch.xpu.get_rng_state(), op(*args, **kwargs)
  153. @run_and_save_rng_state.py_impl(DispatchKey.BackendSelect)
  154. def impl_backend_select(op, *args, **kwargs):
  155. impl_map = {
  156. "cuda": impl_cuda,
  157. "cpu": impl_cpu,
  158. "hpu": impl_hpu,
  159. "xpu": impl_xpu,
  160. }
  161. device = get_device(args, kwargs)
  162. if device not in impl_map:
  163. raise AssertionError(f"Backend not supported for {device}")
  164. impl = impl_map[device]
  165. return impl(op, *args, **kwargs)
  166. @run_and_save_rng_state.py_impl(FakeTensorMode)
  167. def impl_fake_tensor_mode(mode, op, *args, **kwargs):
  168. # Check device to call the right impl
  169. with mode:
  170. return impl_backend_select(op, *args, **kwargs)
  171. @run_and_save_rng_state.py_impl(ProxyTorchDispatchMode)
  172. def impl_proxy_dispatch_mode(mode, op, *args, **kwargs):
  173. out = impl_backend_select(op, *args, **kwargs)
  174. proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, (op, *args))
  175. proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
  176. out_proxy = mode.tracer.create_proxy(
  177. "call_function", run_and_save_rng_state, proxy_args, proxy_kwargs
  178. )
  179. return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
  180. return run_and_save_rng_state
  181. def register_run_with_rng_state_op():
  182. class RunWithRngState(HigherOrderOperator):
  183. def __init__(self):
  184. super().__init__("run_with_rng_state", cacheable=True)
  185. def __call__(self, rng_state, op, *args, **kwargs):
  186. # pyrefly: ignore [missing-attribute]
  187. return super().__call__(rng_state, op, *args, **kwargs)
  188. run_with_rng_state = RunWithRngState()
  189. run_with_rng_state.py_impl(DispatchKey.Autograd)(
  190. autograd_not_implemented(run_with_rng_state, deferred_error=True)
  191. )
  192. @run_with_rng_state.py_impl(DispatchKey.CUDA)
  193. def impl_cuda(rng_state, op, *args, **kwargs):
  194. current_state = torch.cuda.get_rng_state()
  195. torch.cuda.set_rng_state(rng_state.cpu())
  196. out = op(*args, **kwargs)
  197. torch.cuda.set_rng_state(current_state)
  198. return out
  199. @run_with_rng_state.py_impl(DispatchKey.CPU)
  200. def impl_cpu(rng_state, op, *args, **kwargs):
  201. current_state = torch.get_rng_state()
  202. torch.set_rng_state(rng_state)
  203. out = op(*args, **kwargs)
  204. torch.set_rng_state(current_state)
  205. return out
  206. @run_with_rng_state.py_impl(DispatchKey.HPU)
  207. def impl_hpu(rng_state, op, *args, **kwargs):
  208. if hasattr(torch, "hpu"):
  209. current_state = torch.hpu.get_rng_state()
  210. torch.hpu.set_rng_state(rng_state)
  211. out = op(*args, **kwargs)
  212. torch.hpu.set_rng_state(current_state)
  213. return out
  214. raise RuntimeError("functionalize a hpu RNG operator is not supported.")
  215. @run_with_rng_state.py_impl(DispatchKey.XPU)
  216. def impl_xpu(rng_state, op, *args, **kwargs):
  217. current_state = torch.xpu.get_rng_state()
  218. torch.xpu.set_rng_state(rng_state)
  219. out = op(*args, **kwargs)
  220. torch.xpu.set_rng_state(current_state)
  221. return out
  222. @run_with_rng_state.py_impl(ProxyTorchDispatchMode)
  223. def impl_proxy_dispatch_mode(mode, rng_state, op, *args, **kwargs):
  224. # TODO: you don't need to do this, the dispatch here already disabled
  225. # it
  226. with disable_proxy_modes_tracing():
  227. out = run_with_rng_state(rng_state, op, *args, **kwargs)
  228. proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, (rng_state, op, *args))
  229. proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
  230. out_proxy = mode.tracer.create_proxy(
  231. "call_function", run_with_rng_state, proxy_args, proxy_kwargs
  232. )
  233. return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
  234. @run_with_rng_state.py_impl(DispatchKey.BackendSelect)
  235. def impl_backend_select(rng_state, op, *args, **kwargs):
  236. impl_map = {
  237. "cuda": impl_cuda,
  238. "cpu": impl_cpu,
  239. "hpu": impl_hpu,
  240. "xpu": impl_xpu,
  241. }
  242. device = get_device(args, kwargs)
  243. if device not in impl_map:
  244. raise AssertionError(f"Backend not supported for {device}")
  245. impl = impl_map[device]
  246. return impl(rng_state, op, *args, **kwargs)
  247. @run_with_rng_state.py_impl(FakeTensorMode)
  248. def impl_fake_tensor_mode(mode, rng_state, op, *args, **kwargs):
  249. # Skip setting the set_rng_state as it does not work well with fake tensors.
  250. # And it does not matter for the fake tensor mode.
  251. with mode:
  252. return op(*args, **kwargs)
  253. @run_with_rng_state.py_functionalize_impl
  254. def impl_functional(ctx, rng_state, op, *args, **kwargs):
  255. unwrapped_rng_state = ctx.unwrap_tensors(rng_state)
  256. unwrapped_args = ctx.unwrap_tensors(args)
  257. unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
  258. with ctx.redispatch_to_next():
  259. out = run_with_rng_state(
  260. unwrapped_rng_state, op, *unwrapped_args, **unwrapped_kwargs
  261. )
  262. return ctx.wrap_tensors(out)
  263. return run_with_rng_state
  264. run_and_save_rng_state = register_run_and_save_rng_state_op()
  265. run_with_rng_state = register_run_with_rng_state_op()
  266. def register_graphsafe_run_with_rng_state_op():
  267. class GraphSafeRunWithRngState(HigherOrderOperator):
  268. def __init__(self):
  269. super().__init__("graphsafe_run_with_rng_state")
  270. def __call__(self, op, *args, rng_state=None, **kwargs):
  271. # pyrefly: ignore [missing-attribute]
  272. return super().__call__(op, *args, rng_state=rng_state, **kwargs)
  273. graphsafe_run_with_rng_state = GraphSafeRunWithRngState()
  274. graphsafe_run_with_rng_state.py_impl(DispatchKey.Autograd)(
  275. autograd_not_implemented(graphsafe_run_with_rng_state, deferred_error=True)
  276. )
  277. @graphsafe_run_with_rng_state.py_impl(DispatchKey.CUDA)
  278. def impl_cuda(op, *args, rng_state=None, **kwargs):
  279. # pyrefly: ignore [missing-attribute]
  280. device_idx = rng_state.device.index
  281. generator = torch.cuda.default_generators[device_idx]
  282. current_state = generator.graphsafe_get_state()
  283. generator.graphsafe_set_state(rng_state)
  284. out = op(*args, **kwargs)
  285. generator.graphsafe_set_state(current_state)
  286. return out
  287. @graphsafe_run_with_rng_state.py_impl(DispatchKey.BackendSelect)
  288. def impl_backend_select(op, *args, rng_state=None, **kwargs):
  289. device = get_device(args, kwargs)
  290. if device != "cuda":
  291. raise AssertionError(
  292. f"GraphSafe RNG operations only supported for CUDA, got {device}"
  293. )
  294. return impl_cuda(op, *args, rng_state=rng_state, **kwargs)
  295. @graphsafe_run_with_rng_state.py_impl(FakeTensorMode)
  296. def impl_fake_tensor_mode(mode, op, *args, rng_state=None, **kwargs):
  297. with mode:
  298. return op(*args, **kwargs)
  299. @graphsafe_run_with_rng_state.py_impl(ProxyTorchDispatchMode)
  300. def impl_proxy_dispatch_mode(mode, op, *args, rng_state=None, **kwargs):
  301. with disable_proxy_modes_tracing():
  302. out = graphsafe_run_with_rng_state(op, *args, rng_state=rng_state, **kwargs)
  303. proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, (op, *args))
  304. proxy_kwargs = pytree.tree_map(
  305. mode.tracer.unwrap_proxy, {"rng_state": rng_state, **kwargs}
  306. )
  307. out_proxy = mode.tracer.create_proxy(
  308. "call_function", graphsafe_run_with_rng_state, proxy_args, proxy_kwargs
  309. )
  310. return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
  311. @graphsafe_run_with_rng_state.py_functionalize_impl
  312. def impl_functional(ctx, op, *args, rng_state=None, **kwargs):
  313. unwrapped_rng_state = (
  314. ctx.unwrap_tensors(rng_state) if rng_state is not None else None
  315. )
  316. unwrapped_args = ctx.unwrap_tensors(args)
  317. unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
  318. with ctx.redispatch_to_next():
  319. out = graphsafe_run_with_rng_state(
  320. op, *unwrapped_args, rng_state=unwrapped_rng_state, **unwrapped_kwargs
  321. )
  322. return ctx.wrap_tensors(out)
  323. return graphsafe_run_with_rng_state
  324. graphsafe_run_with_rng_state = register_graphsafe_run_with_rng_state_op()
  325. def register_rng_prims():
  326. register_philox_rand()