apis.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484
  1. # NOTE: We allow Dynamo to see this file (via torch/_dynamo/trace_rules.py) so that it can
  2. # trace through functorch transforms.
  3. # Currently, we can't allow Dynamo to see `eager_transforms.py`/`vmap.py` as that break a lot of thing
  4. # and there isn't a mechanism to selectively expose only some functions (eg. grad) from a file
  5. # to Dynamo.
  6. from __future__ import annotations
  7. import functools
  8. from typing import Any, TYPE_CHECKING
  9. from typing_extensions import ParamSpec, TypeVar
  10. from torch._functorch.utils import argnums_t, exposed_in
  11. from torch._functorch.vmap import (
  12. _check_out_dims_is_int_or_int_pytree,
  13. _check_randomness_arg,
  14. _chunked_vmap,
  15. _process_batched_inputs,
  16. Callable,
  17. in_dims_t,
  18. out_dims_t,
  19. vmap_impl,
  20. )
  21. if TYPE_CHECKING:
  22. from collections.abc import Iterable
  23. import torch
  24. _P = ParamSpec("_P")
  25. _R = TypeVar("_R")
  26. # vmap(func)(inputs) wraps all Tensor inputs to be batched in BatchedTensors,
  27. # sends those into func, and then unwraps the output BatchedTensors. Operations
  28. # on BatchedTensors perform the batched operations that the user is asking for.
  29. #
  30. # vmap's randomness behavior differs from JAX's, which would require a PRNG key
  31. # to be passed everywhere.
  32. @exposed_in("torch.func")
  33. def vmap(
  34. func: Callable[_P, _R],
  35. in_dims: in_dims_t = 0,
  36. out_dims: out_dims_t = 0,
  37. randomness: str = "error",
  38. *,
  39. chunk_size: int | None = None,
  40. ) -> Callable[_P, _R]:
  41. """
  42. vmap is the vectorizing map; ``vmap(func)`` returns a new function that
  43. maps ``func`` over some dimension of the inputs. Semantically, vmap
  44. pushes the map into PyTorch operations called by ``func``, effectively
  45. vectorizing those operations.
  46. vmap is useful for handling batch dimensions: one can write a function
  47. ``func`` that runs on examples and then lift it to a function that can
  48. take batches of examples with ``vmap(func)``. vmap can also be used to
  49. compute batched gradients when composed with autograd.
  50. .. note::
  51. :func:`torch.vmap` is aliased to :func:`torch.func.vmap` for
  52. convenience. Use whichever one you'd like.
  53. Args:
  54. func (function): A Python function that takes one or more arguments.
  55. Must return one or more Tensors.
  56. in_dims (int or nested structure): Specifies which dimension of the
  57. inputs should be mapped over. ``in_dims`` should have a
  58. structure like the inputs. If the ``in_dim`` for a particular
  59. input is None, then that indicates there is no map dimension.
  60. Default: 0.
  61. out_dims (int or Tuple[int]): Specifies where the mapped dimension
  62. should appear in the outputs. If ``out_dims`` is a Tuple, then
  63. it should have one element per output. Default: 0.
  64. randomness (str): Specifies whether the randomness in this
  65. vmap should be the same or different across batches. If 'different',
  66. the randomness for each batch will be different. If 'same', the
  67. randomness will be the same across batches. If 'error', any calls to
  68. random functions will error. Default: 'error'. WARNING: this flag
  69. only applies to random PyTorch operations and does not apply to
  70. Python's random module or numpy randomness.
  71. chunk_size (None or int): If None (default), apply a single vmap over inputs.
  72. If not None, then compute the vmap :attr:`chunk_size` samples at a time.
  73. Note that :attr:`chunk_size=1` is equivalent to computing the vmap with a for-loop.
  74. If you run into memory issues computing the vmap, please try a non-None chunk_size.
  75. Returns:
  76. Returns a new "batched" function. It takes the same inputs as
  77. ``func``, except each input has an extra dimension at the index
  78. specified by ``in_dims``. It takes returns the same outputs as
  79. ``func``, except each output has an extra dimension at the index
  80. specified by ``out_dims``.
  81. .. warning:
  82. :func:`vmap` works best with functional-style code. Please do not
  83. perform any side-effects in ``func``, with the exception of
  84. in-place PyTorch operations. Examples of side-effects include mutating
  85. Python data structures and assigning values to variables not captured
  86. in ``func``.
  87. One example of using :func:`vmap` is to compute batched dot products. PyTorch
  88. doesn't provide a batched ``torch.dot`` API; instead of unsuccessfully
  89. rummaging through docs, use :func:`vmap` to construct a new function.
  90. >>> torch.dot # [D], [D] -> []
  91. >>> batched_dot = torch.func.vmap(torch.dot) # [N, D], [N, D] -> [N]
  92. >>> x, y = torch.randn(2, 5), torch.randn(2, 5)
  93. >>> batched_dot(x, y)
  94. :func:`vmap` can be helpful in hiding batch dimensions, leading to a simpler
  95. model authoring experience.
  96. >>> batch_size, feature_size = 3, 5
  97. >>> weights = torch.randn(feature_size, requires_grad=True)
  98. >>>
  99. >>> def model(feature_vec):
  100. >>> # Very simple linear model with activation
  101. >>> return feature_vec.dot(weights).relu()
  102. >>>
  103. >>> examples = torch.randn(batch_size, feature_size)
  104. >>> result = torch.vmap(model)(examples)
  105. :func:`vmap` can also help vectorize computations that were previously difficult
  106. or impossible to batch. One example is higher-order gradient computation.
  107. The PyTorch autograd engine computes vjps (vector-Jacobian products).
  108. Computing a full Jacobian matrix for some function f: R^N -> R^N usually
  109. requires N calls to ``autograd.grad``, one per Jacobian row. Using :func:`vmap`,
  110. we can vectorize the whole computation, computing the Jacobian in a single
  111. call to ``autograd.grad``.
  112. >>> # Setup
  113. >>> N = 5
  114. >>> f = lambda x: x**2
  115. >>> x = torch.randn(N, requires_grad=True)
  116. >>> y = f(x)
  117. >>> I_N = torch.eye(N)
  118. >>>
  119. >>> # Sequential approach
  120. >>> jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0]
  121. >>> for v in I_N.unbind()]
  122. >>> jacobian = torch.stack(jacobian_rows)
  123. >>>
  124. >>> # vectorized gradient computation
  125. >>> def get_vjp(v):
  126. >>> return torch.autograd.grad(y, x, v)
  127. >>> jacobian = torch.vmap(get_vjp)(I_N)
  128. :func:`vmap` can also be nested, producing an output with multiple batched dimensions
  129. >>> torch.dot # [D], [D] -> []
  130. >>> batched_dot = torch.vmap(
  131. ... torch.vmap(torch.dot)
  132. ... ) # [N1, N0, D], [N1, N0, D] -> [N1, N0]
  133. >>> x, y = torch.randn(2, 3, 5), torch.randn(2, 3, 5)
  134. >>> batched_dot(x, y) # tensor of size [2, 3]
  135. If the inputs are not batched along the first dimension, ``in_dims`` specifies
  136. the dimension that each inputs are batched along as
  137. >>> torch.dot # [N], [N] -> []
  138. >>> batched_dot = torch.vmap(torch.dot, in_dims=1) # [N, D], [N, D] -> [D]
  139. >>> x, y = torch.randn(2, 5), torch.randn(2, 5)
  140. >>> batched_dot(
  141. ... x, y
  142. ... ) # output is [5] instead of [2] if batched along the 0th dimension
  143. If there are multiple inputs each of which is batched along different dimensions,
  144. ``in_dims`` must be a tuple with the batch dimension for each input as
  145. >>> torch.dot # [D], [D] -> []
  146. >>> batched_dot = torch.vmap(torch.dot, in_dims=(0, None)) # [N, D], [D] -> [N]
  147. >>> x, y = torch.randn(2, 5), torch.randn(5)
  148. >>> batched_dot(
  149. ... x, y
  150. ... ) # second arg doesn't have a batch dim because in_dim[1] was None
  151. If the input is a Python struct, ``in_dims`` must be a tuple containing a struct
  152. matching the shape of the input:
  153. >>> f = lambda dict: torch.dot(dict["x"], dict["y"])
  154. >>> x, y = torch.randn(2, 5), torch.randn(5)
  155. >>> input = {"x": x, "y": y}
  156. >>> batched_dot = torch.vmap(f, in_dims=({"x": 0, "y": None},))
  157. >>> batched_dot(input)
  158. By default, the output is batched along the first dimension. However, it can be batched
  159. along any dimension by using ``out_dims``
  160. >>> f = lambda x: x**2
  161. >>> x = torch.randn(2, 5)
  162. >>> batched_pow = torch.vmap(f, out_dims=1)
  163. >>> batched_pow(x) # [5, 2]
  164. For any function that uses kwargs, the returned function will not batch the kwargs but will
  165. accept kwargs
  166. >>> x = torch.randn([2, 5])
  167. >>> def fn(x, scale=4.):
  168. >>> return x * scale
  169. >>>
  170. >>> batched_pow = torch.vmap(fn)
  171. >>> assert torch.allclose(batched_pow(x), x * 4)
  172. >>> batched_pow(x, scale=x) # scale is not batched, output has shape [2, 2, 5]
  173. .. note::
  174. vmap does not provide general autobatching or handle variable-length
  175. sequences out of the box.
  176. """
  177. from torch.compiler import is_compiling
  178. _check_randomness_arg(randomness)
  179. if not (chunk_size is None or chunk_size > 0):
  180. raise ValueError(
  181. f"vmap: chunk_size should be None or greater than 0. (got {chunk_size})"
  182. )
  183. def wrapped(*args: _P.args, **kwargs: _P.kwargs) -> _R:
  184. # pyrefly: ignore[bad-argument-type]
  185. return vmap_impl(
  186. # pyrefly: ignore[bad-argument-type]
  187. func,
  188. in_dims,
  189. out_dims,
  190. randomness,
  191. chunk_size,
  192. *args,
  193. **kwargs,
  194. )
  195. if not is_compiling():
  196. wrapped = functools.wraps(func)(wrapped)
  197. return wrapped
  198. def chunk_vmap(
  199. func: Callable[_P, _R],
  200. in_dims: in_dims_t = 0,
  201. out_dims: out_dims_t = 0,
  202. randomness: str = "error",
  203. chunks: int = 2,
  204. ) -> Callable[_P, _R]:
  205. """
  206. chunk_vmap is the vectorizing map (vmap) using chunks of input data. It is a mix of vmap (which vectorizes
  207. everything) and map (which executes things sequentially). ``chunk_vmap`` vectorizes the input with number of
  208. chunks at a time. For more details about vectorizing map, see :func:`vmap`.
  209. .. note::
  210. Please use :func:`vmap` with ``chunk_size`` argument instead of this API.
  211. Args:
  212. func (function): A Python function that takes one or more arguments.
  213. Must return one or more Tensors.
  214. in_dims (int or nested structure): Specifies which dimension of the
  215. inputs should be mapped over. ``in_dims`` should have a
  216. structure like the inputs. If the ``in_dim`` for a particular
  217. input is None, then that indicates there is no map dimension.
  218. Default: 0.
  219. out_dims (int or Tuple[int]): Specifies where the mapped dimension
  220. should appear in the outputs. If ``out_dims`` is a Tuple, then
  221. it should have one element per output. Default: 0.
  222. randomness (str): Specifies whether the randomness in this
  223. vmap should be the same or different across batches. If 'different',
  224. the randomness for each batch will be different. If 'same', the
  225. randomness will be the same across batches. If 'error', any calls to
  226. random functions will error. Default: 'error'. WARNING: this flag
  227. only applies to random PyTorch operations and does not apply to
  228. Python's random module or numpy randomness.
  229. chunks (int): Number of chunks to use to split the input data. Default is 2.
  230. If equals to 1 then :func:`vmap` is called.
  231. Returns:
  232. Returns a new "batched" function. It takes the same inputs as
  233. ``func``, except each input has an extra dimension at the index
  234. specified by ``in_dims``. It takes returns the same outputs as
  235. ``func``, except each output has an extra dimension at the index
  236. specified by ``out_dims``.
  237. """
  238. _check_randomness_arg(randomness)
  239. if chunks == 1:
  240. return vmap(func, in_dims=in_dims, out_dims=out_dims, randomness=randomness)
  241. def _get_chunk_flat_args(
  242. flat_args_: Iterable[Any],
  243. flat_in_dims_: Iterable[int | None],
  244. chunks_: int,
  245. ) -> Iterable[Any]:
  246. flat_args_chunks = tuple(
  247. t.chunk(chunks_, dim=in_dim)
  248. if in_dim is not None
  249. else [
  250. t,
  251. ]
  252. * chunks_
  253. for t, in_dim in zip(flat_args_, flat_in_dims_)
  254. )
  255. # transpose chunk dim and flatten structure
  256. # chunks_flat_args is a list of flatten args
  257. chunks_flat_args = zip(*flat_args_chunks)
  258. return chunks_flat_args
  259. @functools.wraps(func)
  260. def wrapped_with_chunks(*args: _P.args, **kwargs: _P.kwargs) -> _R:
  261. _check_out_dims_is_int_or_int_pytree(out_dims, func)
  262. _, flat_in_dims, flat_args, args_spec = _process_batched_inputs(
  263. in_dims, args, func
  264. )
  265. # Chunk flat arguments
  266. chunks_flat_args = _get_chunk_flat_args(flat_args, flat_in_dims, chunks)
  267. # Apply vmap on chunks
  268. return _chunked_vmap(
  269. # pyrefly: ignore[bad-argument-type]
  270. func,
  271. flat_in_dims,
  272. chunks_flat_args,
  273. args_spec,
  274. out_dims,
  275. randomness,
  276. **kwargs,
  277. )
  278. return wrapped_with_chunks
  279. # TODO: Improve the return type of this function
  280. @exposed_in("torch.func")
  281. def grad(
  282. func: Callable[_P, Any], argnums: argnums_t = 0, has_aux: bool = False
  283. ) -> Callable[_P, Any]:
  284. """``grad`` operator helps computing gradients of ``func`` with respect to the
  285. input(s) specified by ``argnums``. This operator can be nested to
  286. compute higher-order gradients.
  287. Args:
  288. func (Callable): A Python function that takes one or more arguments.
  289. Must return a single-element Tensor. If specified ``has_aux`` equals ``True``,
  290. function can return a tuple of single-element Tensor and other auxiliary objects:
  291. ``(output, aux)``.
  292. argnums (int or Tuple[int]): Specifies arguments to compute gradients with respect to.
  293. ``argnums`` can be single integer or tuple of integers. Default: 0.
  294. has_aux (bool): Flag indicating that ``func`` returns a tensor and other
  295. auxiliary objects: ``(output, aux)``. Default: False.
  296. Returns:
  297. Function to compute gradients with respect to its inputs. By default, the output of
  298. the function is the gradient tensor(s) with respect to the first argument.
  299. If specified ``has_aux`` equals ``True``, tuple of gradients and output auxiliary objects
  300. is returned. If ``argnums`` is a tuple of integers, a tuple of output gradients with
  301. respect to each ``argnums`` value is returned.
  302. Example of using ``grad``:
  303. >>> # xdoctest: +SKIP
  304. >>> from torch.func import grad
  305. >>> x = torch.randn([])
  306. >>> cos_x = grad(lambda x: torch.sin(x))(x)
  307. >>> assert torch.allclose(cos_x, x.cos())
  308. >>>
  309. >>> # Second-order gradients
  310. >>> neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)
  311. >>> assert torch.allclose(neg_sin_x, -x.sin())
  312. When composed with ``vmap``, ``grad`` can be used to compute per-sample-gradients:
  313. >>> # xdoctest: +SKIP
  314. >>> from torch.func import grad, vmap
  315. >>> batch_size, feature_size = 3, 5
  316. >>>
  317. >>> def model(weights, feature_vec):
  318. >>> # Very simple linear model with activation
  319. >>> assert feature_vec.dim() == 1
  320. >>> return feature_vec.dot(weights).relu()
  321. >>>
  322. >>> def compute_loss(weights, example, target):
  323. >>> y = model(weights, example)
  324. >>> return ((y - target) ** 2).mean() # MSELoss
  325. >>>
  326. >>> weights = torch.randn(feature_size, requires_grad=True)
  327. >>> examples = torch.randn(batch_size, feature_size)
  328. >>> targets = torch.randn(batch_size)
  329. >>> inputs = (weights, examples, targets)
  330. >>> grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(
  331. ... *inputs
  332. ... )
  333. Example of using ``grad`` with ``has_aux`` and ``argnums``:
  334. >>> # xdoctest: +SKIP
  335. >>> from torch.func import grad
  336. >>> def my_loss_func(y, y_pred):
  337. >>> loss_per_sample = (0.5 * y_pred - y) ** 2
  338. >>> loss = loss_per_sample.mean()
  339. >>> return loss, (y_pred, loss_per_sample)
  340. >>>
  341. >>> fn = grad(my_loss_func, argnums=(0, 1), has_aux=True)
  342. >>> y_true = torch.rand(4)
  343. >>> y_preds = torch.rand(4, requires_grad=True)
  344. >>> out = fn(y_true, y_preds)
  345. >>> # > output is ((grads w.r.t y_true, grads w.r.t y_preds), (y_pred, loss_per_sample))
  346. .. note::
  347. Using PyTorch ``torch.no_grad`` together with ``grad``.
  348. Case 1: Using ``torch.no_grad`` inside a function:
  349. >>> # xdoctest: +SKIP
  350. >>> def f(x):
  351. >>> with torch.no_grad():
  352. >>> c = x ** 2
  353. >>> return x - c
  354. In this case, ``grad(f)(x)`` will respect the inner ``torch.no_grad``.
  355. Case 2: Using ``grad`` inside ``torch.no_grad`` context manager:
  356. >>> # xdoctest: +SKIP
  357. >>> with torch.no_grad():
  358. >>> grad(f)(x)
  359. In this case, ``grad`` will respect the inner ``torch.no_grad``, but not the
  360. outer one. This is because ``grad`` is a "function transform": its result
  361. should not depend on the result of a context manager outside of ``f``.
  362. """
  363. # To avoid cyclical dependency.
  364. import torch._functorch.eager_transforms as eager_transforms
  365. from torch.compiler import is_compiling
  366. def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> tuple[Any, torch.Tensor]:
  367. return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs)
  368. if not is_compiling():
  369. wrapper = functools.wraps(func)(wrapper)
  370. return wrapper
  371. # TODO: Improve the return type of this function
  372. @exposed_in("torch.func")
  373. def grad_and_value(
  374. func: Callable[_P, Any], argnums: argnums_t = 0, has_aux: bool = False
  375. ) -> Callable[_P, tuple[Any, Any]]:
  376. """
  377. Returns a function to compute a tuple of the gradient and primal, or
  378. forward, computation.
  379. Args:
  380. func (Callable): A Python function that takes one or more arguments.
  381. Must return a single-element Tensor. If specified ``has_aux``
  382. equals ``True``, function can return a tuple of single-element
  383. Tensor and other auxiliary objects: ``(output, aux)``.
  384. argnums (int or Tuple[int]): Specifies arguments to compute gradients
  385. with respect to. ``argnums`` can be single integer or tuple of
  386. integers. Default: 0.
  387. has_aux (bool): Flag indicating that ``func`` returns a tensor and
  388. other auxiliary objects: ``(output, aux)``. Default: False.
  389. Returns:
  390. Function to compute a tuple of gradients with respect to its inputs
  391. and the forward computation. By default, the output of the function is
  392. a tuple of the gradient tensor(s) with respect to the first argument
  393. and the primal computation. If specified ``has_aux`` equals
  394. ``True``, tuple of gradients and tuple of the forward computation with
  395. output auxiliary objects is returned. If ``argnums`` is a tuple of
  396. integers, a tuple of a tuple of the output gradients with respect to
  397. each ``argnums`` value and the forward computation is returned.
  398. See :func:`grad` for examples
  399. """
  400. from torch._functorch import eager_transforms
  401. from torch.compiler import is_compiling
  402. def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> tuple[Any, torch.Tensor]:
  403. return eager_transforms.grad_and_value_impl(
  404. func, argnums, has_aux, args, kwargs
  405. )
  406. if not is_compiling():
  407. wrapper = functools.wraps(func)(wrapper)
  408. return wrapper