| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484 |
- # NOTE: We allow Dynamo to see this file (via torch/_dynamo/trace_rules.py) so that it can
- # trace through functorch transforms.
- # Currently, we can't allow Dynamo to see `eager_transforms.py`/`vmap.py` as that break a lot of thing
- # and there isn't a mechanism to selectively expose only some functions (eg. grad) from a file
- # to Dynamo.
- from __future__ import annotations
- import functools
- from typing import Any, TYPE_CHECKING
- from typing_extensions import ParamSpec, TypeVar
- from torch._functorch.utils import argnums_t, exposed_in
- from torch._functorch.vmap import (
- _check_out_dims_is_int_or_int_pytree,
- _check_randomness_arg,
- _chunked_vmap,
- _process_batched_inputs,
- Callable,
- in_dims_t,
- out_dims_t,
- vmap_impl,
- )
- if TYPE_CHECKING:
- from collections.abc import Iterable
- import torch
- _P = ParamSpec("_P")
- _R = TypeVar("_R")
- # vmap(func)(inputs) wraps all Tensor inputs to be batched in BatchedTensors,
- # sends those into func, and then unwraps the output BatchedTensors. Operations
- # on BatchedTensors perform the batched operations that the user is asking for.
- #
- # vmap's randomness behavior differs from JAX's, which would require a PRNG key
- # to be passed everywhere.
- @exposed_in("torch.func")
- def vmap(
- func: Callable[_P, _R],
- in_dims: in_dims_t = 0,
- out_dims: out_dims_t = 0,
- randomness: str = "error",
- *,
- chunk_size: int | None = None,
- ) -> Callable[_P, _R]:
- """
- vmap is the vectorizing map; ``vmap(func)`` returns a new function that
- maps ``func`` over some dimension of the inputs. Semantically, vmap
- pushes the map into PyTorch operations called by ``func``, effectively
- vectorizing those operations.
- vmap is useful for handling batch dimensions: one can write a function
- ``func`` that runs on examples and then lift it to a function that can
- take batches of examples with ``vmap(func)``. vmap can also be used to
- compute batched gradients when composed with autograd.
- .. note::
- :func:`torch.vmap` is aliased to :func:`torch.func.vmap` for
- convenience. Use whichever one you'd like.
- Args:
- func (function): A Python function that takes one or more arguments.
- Must return one or more Tensors.
- in_dims (int or nested structure): Specifies which dimension of the
- inputs should be mapped over. ``in_dims`` should have a
- structure like the inputs. If the ``in_dim`` for a particular
- input is None, then that indicates there is no map dimension.
- Default: 0.
- out_dims (int or Tuple[int]): Specifies where the mapped dimension
- should appear in the outputs. If ``out_dims`` is a Tuple, then
- it should have one element per output. Default: 0.
- randomness (str): Specifies whether the randomness in this
- vmap should be the same or different across batches. If 'different',
- the randomness for each batch will be different. If 'same', the
- randomness will be the same across batches. If 'error', any calls to
- random functions will error. Default: 'error'. WARNING: this flag
- only applies to random PyTorch operations and does not apply to
- Python's random module or numpy randomness.
- chunk_size (None or int): If None (default), apply a single vmap over inputs.
- If not None, then compute the vmap :attr:`chunk_size` samples at a time.
- Note that :attr:`chunk_size=1` is equivalent to computing the vmap with a for-loop.
- If you run into memory issues computing the vmap, please try a non-None chunk_size.
- Returns:
- Returns a new "batched" function. It takes the same inputs as
- ``func``, except each input has an extra dimension at the index
- specified by ``in_dims``. It takes returns the same outputs as
- ``func``, except each output has an extra dimension at the index
- specified by ``out_dims``.
- .. warning:
- :func:`vmap` works best with functional-style code. Please do not
- perform any side-effects in ``func``, with the exception of
- in-place PyTorch operations. Examples of side-effects include mutating
- Python data structures and assigning values to variables not captured
- in ``func``.
- One example of using :func:`vmap` is to compute batched dot products. PyTorch
- doesn't provide a batched ``torch.dot`` API; instead of unsuccessfully
- rummaging through docs, use :func:`vmap` to construct a new function.
- >>> torch.dot # [D], [D] -> []
- >>> batched_dot = torch.func.vmap(torch.dot) # [N, D], [N, D] -> [N]
- >>> x, y = torch.randn(2, 5), torch.randn(2, 5)
- >>> batched_dot(x, y)
- :func:`vmap` can be helpful in hiding batch dimensions, leading to a simpler
- model authoring experience.
- >>> batch_size, feature_size = 3, 5
- >>> weights = torch.randn(feature_size, requires_grad=True)
- >>>
- >>> def model(feature_vec):
- >>> # Very simple linear model with activation
- >>> return feature_vec.dot(weights).relu()
- >>>
- >>> examples = torch.randn(batch_size, feature_size)
- >>> result = torch.vmap(model)(examples)
- :func:`vmap` can also help vectorize computations that were previously difficult
- or impossible to batch. One example is higher-order gradient computation.
- The PyTorch autograd engine computes vjps (vector-Jacobian products).
- Computing a full Jacobian matrix for some function f: R^N -> R^N usually
- requires N calls to ``autograd.grad``, one per Jacobian row. Using :func:`vmap`,
- we can vectorize the whole computation, computing the Jacobian in a single
- call to ``autograd.grad``.
- >>> # Setup
- >>> N = 5
- >>> f = lambda x: x**2
- >>> x = torch.randn(N, requires_grad=True)
- >>> y = f(x)
- >>> I_N = torch.eye(N)
- >>>
- >>> # Sequential approach
- >>> jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0]
- >>> for v in I_N.unbind()]
- >>> jacobian = torch.stack(jacobian_rows)
- >>>
- >>> # vectorized gradient computation
- >>> def get_vjp(v):
- >>> return torch.autograd.grad(y, x, v)
- >>> jacobian = torch.vmap(get_vjp)(I_N)
- :func:`vmap` can also be nested, producing an output with multiple batched dimensions
- >>> torch.dot # [D], [D] -> []
- >>> batched_dot = torch.vmap(
- ... torch.vmap(torch.dot)
- ... ) # [N1, N0, D], [N1, N0, D] -> [N1, N0]
- >>> x, y = torch.randn(2, 3, 5), torch.randn(2, 3, 5)
- >>> batched_dot(x, y) # tensor of size [2, 3]
- If the inputs are not batched along the first dimension, ``in_dims`` specifies
- the dimension that each inputs are batched along as
- >>> torch.dot # [N], [N] -> []
- >>> batched_dot = torch.vmap(torch.dot, in_dims=1) # [N, D], [N, D] -> [D]
- >>> x, y = torch.randn(2, 5), torch.randn(2, 5)
- >>> batched_dot(
- ... x, y
- ... ) # output is [5] instead of [2] if batched along the 0th dimension
- If there are multiple inputs each of which is batched along different dimensions,
- ``in_dims`` must be a tuple with the batch dimension for each input as
- >>> torch.dot # [D], [D] -> []
- >>> batched_dot = torch.vmap(torch.dot, in_dims=(0, None)) # [N, D], [D] -> [N]
- >>> x, y = torch.randn(2, 5), torch.randn(5)
- >>> batched_dot(
- ... x, y
- ... ) # second arg doesn't have a batch dim because in_dim[1] was None
- If the input is a Python struct, ``in_dims`` must be a tuple containing a struct
- matching the shape of the input:
- >>> f = lambda dict: torch.dot(dict["x"], dict["y"])
- >>> x, y = torch.randn(2, 5), torch.randn(5)
- >>> input = {"x": x, "y": y}
- >>> batched_dot = torch.vmap(f, in_dims=({"x": 0, "y": None},))
- >>> batched_dot(input)
- By default, the output is batched along the first dimension. However, it can be batched
- along any dimension by using ``out_dims``
- >>> f = lambda x: x**2
- >>> x = torch.randn(2, 5)
- >>> batched_pow = torch.vmap(f, out_dims=1)
- >>> batched_pow(x) # [5, 2]
- For any function that uses kwargs, the returned function will not batch the kwargs but will
- accept kwargs
- >>> x = torch.randn([2, 5])
- >>> def fn(x, scale=4.):
- >>> return x * scale
- >>>
- >>> batched_pow = torch.vmap(fn)
- >>> assert torch.allclose(batched_pow(x), x * 4)
- >>> batched_pow(x, scale=x) # scale is not batched, output has shape [2, 2, 5]
- .. note::
- vmap does not provide general autobatching or handle variable-length
- sequences out of the box.
- """
- from torch.compiler import is_compiling
- _check_randomness_arg(randomness)
- if not (chunk_size is None or chunk_size > 0):
- raise ValueError(
- f"vmap: chunk_size should be None or greater than 0. (got {chunk_size})"
- )
- def wrapped(*args: _P.args, **kwargs: _P.kwargs) -> _R:
- # pyrefly: ignore[bad-argument-type]
- return vmap_impl(
- # pyrefly: ignore[bad-argument-type]
- func,
- in_dims,
- out_dims,
- randomness,
- chunk_size,
- *args,
- **kwargs,
- )
- if not is_compiling():
- wrapped = functools.wraps(func)(wrapped)
- return wrapped
- def chunk_vmap(
- func: Callable[_P, _R],
- in_dims: in_dims_t = 0,
- out_dims: out_dims_t = 0,
- randomness: str = "error",
- chunks: int = 2,
- ) -> Callable[_P, _R]:
- """
- chunk_vmap is the vectorizing map (vmap) using chunks of input data. It is a mix of vmap (which vectorizes
- everything) and map (which executes things sequentially). ``chunk_vmap`` vectorizes the input with number of
- chunks at a time. For more details about vectorizing map, see :func:`vmap`.
- .. note::
- Please use :func:`vmap` with ``chunk_size`` argument instead of this API.
- Args:
- func (function): A Python function that takes one or more arguments.
- Must return one or more Tensors.
- in_dims (int or nested structure): Specifies which dimension of the
- inputs should be mapped over. ``in_dims`` should have a
- structure like the inputs. If the ``in_dim`` for a particular
- input is None, then that indicates there is no map dimension.
- Default: 0.
- out_dims (int or Tuple[int]): Specifies where the mapped dimension
- should appear in the outputs. If ``out_dims`` is a Tuple, then
- it should have one element per output. Default: 0.
- randomness (str): Specifies whether the randomness in this
- vmap should be the same or different across batches. If 'different',
- the randomness for each batch will be different. If 'same', the
- randomness will be the same across batches. If 'error', any calls to
- random functions will error. Default: 'error'. WARNING: this flag
- only applies to random PyTorch operations and does not apply to
- Python's random module or numpy randomness.
- chunks (int): Number of chunks to use to split the input data. Default is 2.
- If equals to 1 then :func:`vmap` is called.
- Returns:
- Returns a new "batched" function. It takes the same inputs as
- ``func``, except each input has an extra dimension at the index
- specified by ``in_dims``. It takes returns the same outputs as
- ``func``, except each output has an extra dimension at the index
- specified by ``out_dims``.
- """
- _check_randomness_arg(randomness)
- if chunks == 1:
- return vmap(func, in_dims=in_dims, out_dims=out_dims, randomness=randomness)
- def _get_chunk_flat_args(
- flat_args_: Iterable[Any],
- flat_in_dims_: Iterable[int | None],
- chunks_: int,
- ) -> Iterable[Any]:
- flat_args_chunks = tuple(
- t.chunk(chunks_, dim=in_dim)
- if in_dim is not None
- else [
- t,
- ]
- * chunks_
- for t, in_dim in zip(flat_args_, flat_in_dims_)
- )
- # transpose chunk dim and flatten structure
- # chunks_flat_args is a list of flatten args
- chunks_flat_args = zip(*flat_args_chunks)
- return chunks_flat_args
- @functools.wraps(func)
- def wrapped_with_chunks(*args: _P.args, **kwargs: _P.kwargs) -> _R:
- _check_out_dims_is_int_or_int_pytree(out_dims, func)
- _, flat_in_dims, flat_args, args_spec = _process_batched_inputs(
- in_dims, args, func
- )
- # Chunk flat arguments
- chunks_flat_args = _get_chunk_flat_args(flat_args, flat_in_dims, chunks)
- # Apply vmap on chunks
- return _chunked_vmap(
- # pyrefly: ignore[bad-argument-type]
- func,
- flat_in_dims,
- chunks_flat_args,
- args_spec,
- out_dims,
- randomness,
- **kwargs,
- )
- return wrapped_with_chunks
- # TODO: Improve the return type of this function
- @exposed_in("torch.func")
- def grad(
- func: Callable[_P, Any], argnums: argnums_t = 0, has_aux: bool = False
- ) -> Callable[_P, Any]:
- """``grad`` operator helps computing gradients of ``func`` with respect to the
- input(s) specified by ``argnums``. This operator can be nested to
- compute higher-order gradients.
- Args:
- func (Callable): A Python function that takes one or more arguments.
- Must return a single-element Tensor. If specified ``has_aux`` equals ``True``,
- function can return a tuple of single-element Tensor and other auxiliary objects:
- ``(output, aux)``.
- argnums (int or Tuple[int]): Specifies arguments to compute gradients with respect to.
- ``argnums`` can be single integer or tuple of integers. Default: 0.
- has_aux (bool): Flag indicating that ``func`` returns a tensor and other
- auxiliary objects: ``(output, aux)``. Default: False.
- Returns:
- Function to compute gradients with respect to its inputs. By default, the output of
- the function is the gradient tensor(s) with respect to the first argument.
- If specified ``has_aux`` equals ``True``, tuple of gradients and output auxiliary objects
- is returned. If ``argnums`` is a tuple of integers, a tuple of output gradients with
- respect to each ``argnums`` value is returned.
- Example of using ``grad``:
- >>> # xdoctest: +SKIP
- >>> from torch.func import grad
- >>> x = torch.randn([])
- >>> cos_x = grad(lambda x: torch.sin(x))(x)
- >>> assert torch.allclose(cos_x, x.cos())
- >>>
- >>> # Second-order gradients
- >>> neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)
- >>> assert torch.allclose(neg_sin_x, -x.sin())
- When composed with ``vmap``, ``grad`` can be used to compute per-sample-gradients:
- >>> # xdoctest: +SKIP
- >>> from torch.func import grad, vmap
- >>> batch_size, feature_size = 3, 5
- >>>
- >>> def model(weights, feature_vec):
- >>> # Very simple linear model with activation
- >>> assert feature_vec.dim() == 1
- >>> return feature_vec.dot(weights).relu()
- >>>
- >>> def compute_loss(weights, example, target):
- >>> y = model(weights, example)
- >>> return ((y - target) ** 2).mean() # MSELoss
- >>>
- >>> weights = torch.randn(feature_size, requires_grad=True)
- >>> examples = torch.randn(batch_size, feature_size)
- >>> targets = torch.randn(batch_size)
- >>> inputs = (weights, examples, targets)
- >>> grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(
- ... *inputs
- ... )
- Example of using ``grad`` with ``has_aux`` and ``argnums``:
- >>> # xdoctest: +SKIP
- >>> from torch.func import grad
- >>> def my_loss_func(y, y_pred):
- >>> loss_per_sample = (0.5 * y_pred - y) ** 2
- >>> loss = loss_per_sample.mean()
- >>> return loss, (y_pred, loss_per_sample)
- >>>
- >>> fn = grad(my_loss_func, argnums=(0, 1), has_aux=True)
- >>> y_true = torch.rand(4)
- >>> y_preds = torch.rand(4, requires_grad=True)
- >>> out = fn(y_true, y_preds)
- >>> # > output is ((grads w.r.t y_true, grads w.r.t y_preds), (y_pred, loss_per_sample))
- .. note::
- Using PyTorch ``torch.no_grad`` together with ``grad``.
- Case 1: Using ``torch.no_grad`` inside a function:
- >>> # xdoctest: +SKIP
- >>> def f(x):
- >>> with torch.no_grad():
- >>> c = x ** 2
- >>> return x - c
- In this case, ``grad(f)(x)`` will respect the inner ``torch.no_grad``.
- Case 2: Using ``grad`` inside ``torch.no_grad`` context manager:
- >>> # xdoctest: +SKIP
- >>> with torch.no_grad():
- >>> grad(f)(x)
- In this case, ``grad`` will respect the inner ``torch.no_grad``, but not the
- outer one. This is because ``grad`` is a "function transform": its result
- should not depend on the result of a context manager outside of ``f``.
- """
- # To avoid cyclical dependency.
- import torch._functorch.eager_transforms as eager_transforms
- from torch.compiler import is_compiling
- def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> tuple[Any, torch.Tensor]:
- return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs)
- if not is_compiling():
- wrapper = functools.wraps(func)(wrapper)
- return wrapper
- # TODO: Improve the return type of this function
- @exposed_in("torch.func")
- def grad_and_value(
- func: Callable[_P, Any], argnums: argnums_t = 0, has_aux: bool = False
- ) -> Callable[_P, tuple[Any, Any]]:
- """
- Returns a function to compute a tuple of the gradient and primal, or
- forward, computation.
- Args:
- func (Callable): A Python function that takes one or more arguments.
- Must return a single-element Tensor. If specified ``has_aux``
- equals ``True``, function can return a tuple of single-element
- Tensor and other auxiliary objects: ``(output, aux)``.
- argnums (int or Tuple[int]): Specifies arguments to compute gradients
- with respect to. ``argnums`` can be single integer or tuple of
- integers. Default: 0.
- has_aux (bool): Flag indicating that ``func`` returns a tensor and
- other auxiliary objects: ``(output, aux)``. Default: False.
- Returns:
- Function to compute a tuple of gradients with respect to its inputs
- and the forward computation. By default, the output of the function is
- a tuple of the gradient tensor(s) with respect to the first argument
- and the primal computation. If specified ``has_aux`` equals
- ``True``, tuple of gradients and tuple of the forward computation with
- output auxiliary objects is returned. If ``argnums`` is a tuple of
- integers, a tuple of a tuple of the output gradients with respect to
- each ``argnums`` value and the forward computation is returned.
- See :func:`grad` for examples
- """
- from torch._functorch import eager_transforms
- from torch.compiler import is_compiling
- def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> tuple[Any, torch.Tensor]:
- return eager_transforms.grad_and_value_impl(
- func, argnums, has_aux, args, kwargs
- )
- if not is_compiling():
- wrapper = functools.wraps(func)(wrapper)
- return wrapper
|