# Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from __future__ import annotations import contextlib import functools import itertools from collections.abc import Callable # noqa: TC003 from functools import partial from typing import Any, cast, NoReturn, TYPE_CHECKING from typing_extensions import ParamSpec, TypeVar import torch from torch import Tensor from torch._C._functorch import is_batchedtensor from torch._functorch.predispatch import ( _add_batch_dim, _remove_batch_dim, _vmap_decrement_nesting, _vmap_increment_nesting, lazy_load_decompositions, ) from torch.utils._pytree import ( _broadcast_to_and_flatten, tree_flatten, tree_map_, tree_unflatten, TreeSpec, ) if TYPE_CHECKING: from collections.abc import Generator, Iterable _P = ParamSpec("_P") _R = TypeVar("_R") in_dims_t = int | tuple[Any, ...] out_dims_t = int | tuple[int, ...] | None def doesnt_support_saved_tensors_hooks(f: Callable[_P, _R]) -> Callable[_P, _R]: message = ( "torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. " "Please open an issue with your use case." ) @functools.wraps(f) def fn(*args: _P.args, **kwargs: _P.kwargs) -> _R: with torch.autograd.graph.disable_saved_tensors_hooks(message): return f(*args, **kwargs) return fn # Checks that all args-to-be-batched have the same batch dim size def _validate_and_get_batch_size( flat_in_dims: list[int | None], flat_args: list[Any] ) -> int: batch_sizes = [ arg.size(in_dim) for in_dim, arg in zip(flat_in_dims, flat_args) if in_dim is not None ] if len(batch_sizes) == 0: raise ValueError("vmap: Expected at least one Tensor to vmap over") if batch_sizes and any(size != batch_sizes[0] for size in batch_sizes): raise ValueError( f"vmap: Expected all tensors to have the same size in the mapped " f"dimension, got sizes {batch_sizes} for the mapped dimension" ) return batch_sizes[0] def _num_outputs(batched_outputs: Tensor | tuple[Tensor, ...]) -> int: if isinstance(batched_outputs, tuple): return len(batched_outputs) return 1 # If value is a tuple, check it has length `num_elements`. # If value is not a tuple, make a tuple with `value` repeated `num_elements` times def _as_tuple( value: tuple[_R, ...] | _R, num_elements: int, error_message_lambda: Callable[[], str], ) -> tuple[_R, ...]: if not isinstance(value, tuple): return (value,) * num_elements if len(value) != num_elements: raise ValueError(error_message_lambda()) return value def _process_batched_inputs( in_dims: in_dims_t, args: tuple[Any, ...], func: Callable[..., Any] ) -> tuple[int, list[int | None], list[Any], TreeSpec]: if not isinstance(in_dims, int) and not isinstance(in_dims, tuple): raise ValueError( f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(): " f"expected `in_dims` to be int or a (potentially nested) tuple " f"matching the structure of inputs, got: {type(in_dims)}." ) if len(args) == 0: raise ValueError( f"vmap({_get_name(func)})(): got no inputs. Maybe you forgot to add " f"inputs, or you are trying to vmap over a function with no inputs. " f"The latter is unsupported." ) flat_args, args_spec = tree_flatten(args) flat_in_dims = _broadcast_to_and_flatten(in_dims, args_spec) if flat_in_dims is None: raise ValueError( f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(): " f"in_dims is not compatible with the structure of `inputs`. " f"in_dims has structure {tree_flatten(in_dims)[1]} but inputs " f"has structure {args_spec}." ) for i, (arg, in_dim) in enumerate(zip(flat_args, flat_in_dims)): if not isinstance(in_dim, int) and in_dim is not None: raise ValueError( f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(): " f"Got in_dim={in_dim} for an input but in_dim must be either " f"an integer dimension or None." ) if isinstance(in_dim, int) and not isinstance(arg, Tensor): raise ValueError( f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(): " f"Got in_dim={in_dim} for an input but the input is of type " f"{type(arg)}. We cannot vmap over non-Tensor arguments, " f"please use None as the respective in_dim" ) if in_dim is not None and (in_dim < -arg.dim() or in_dim >= arg.dim()): raise ValueError( f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(): " f"Got in_dim={in_dim} for some input, but that input is a Tensor " f"of dimensionality {arg.dim()} so expected in_dim to satisfy " f"-{arg.dim()} <= in_dim < {arg.dim()}." ) if in_dim is not None and in_dim < 0: flat_in_dims[i] = in_dim % arg.dim() return ( _validate_and_get_batch_size(flat_in_dims, flat_args), flat_in_dims, flat_args, args_spec, ) # Creates BatchedTensors for every Tensor in arg that should be batched. # Returns the (potentially) batched arguments and the batch_size. # TODO: See if we can explain how flat works to the type checker def _create_batched_inputs( flat_in_dims: list[int | None], flat_args: list[Any], vmap_level: int, args_spec: TreeSpec, ) -> tuple[Any, ...]: # See NOTE [Ignored _remove_batch_dim, _add_batch_dim] batched_inputs = [ arg if in_dim is None else _add_batch_dim(arg, in_dim, vmap_level) for in_dim, arg in zip(flat_in_dims, flat_args) ] return tree_unflatten(batched_inputs, args_spec) def _maybe_remove_batch_dim( name: str, batched_output: Any, vmap_level: int, batch_size: int, out_dim: int | None, ) -> torch.Tensor: if out_dim is None: if isinstance(batched_output, torch.Tensor) and is_batchedtensor( batched_output ): raise ValueError( f"vmap({name}, ...): `{name}` can not return a " f"BatchedTensor when out_dim is None" ) return batched_output # out_dim is non None if not isinstance(batched_output, torch.Tensor): raise ValueError( f"vmap({name}, ...): `{name}` must only return " f"Tensors, got type {type(batched_output)}. " "Did you mean to set out_dims= to None for output?" ) return _remove_batch_dim(batched_output, vmap_level, batch_size, out_dim) # Undos the batching (and any batch dimensions) associated with the `vmap_level`. def _unwrap_batched( batched_outputs: Tensor | tuple[Tensor, ...], out_dims: out_dims_t, vmap_level: int, batch_size: int, func: Callable[..., Any], ) -> tuple[Any, ...]: flat_batched_outputs, output_spec = tree_flatten(batched_outputs) def incompatible_error() -> NoReturn: raise ValueError( f"vmap({_get_name(func)}, ..., out_dims={out_dims})(): " f"out_dims is not compatible with the structure of `outputs`. " f"out_dims has structure {tree_flatten(out_dims)[1]} but outputs " f"has structure {output_spec}." ) flat_out_dims: list[int | None] = [] if isinstance(batched_outputs, torch.Tensor): # Some weird edge case requires us to spell out the following # see test_out_dims_edge_case if isinstance(out_dims, int): flat_out_dims = [out_dims] elif isinstance(out_dims, tuple) and len(out_dims) == 1: flat_out_dims = list(out_dims) elif out_dims is None: flat_out_dims = [out_dims] else: incompatible_error() else: broadcast_result = _broadcast_to_and_flatten(out_dims, output_spec) if broadcast_result is None: incompatible_error() else: flat_out_dims = broadcast_result flat_outputs = [ _maybe_remove_batch_dim( _get_name(func), batched_output, vmap_level, batch_size, out_dim ) for batched_output, out_dim in zip(flat_batched_outputs, flat_out_dims) ] return tree_unflatten(flat_outputs, output_spec) def _check_int_or_none(x: Any, func: Callable[..., Any], out_dims: out_dims_t) -> None: if isinstance(x, int): return if x is None: return raise ValueError( f"vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must be " f"an int, None or a python collection of ints representing where in the outputs the " f"vmapped dimension should appear." ) def _check_out_dims_is_int_or_int_pytree( out_dims: out_dims_t, func: Callable[..., Any] ) -> None: if isinstance(out_dims, int): return tree_map_(partial(_check_int_or_none, func=func, out_dims=out_dims), out_dims) def _get_name(func: Callable[..., Any]) -> str: if hasattr(func, "__name__"): return func.__name__ if isinstance(func, functools.partial): return f"functools.partial({_get_name(func.func)}, ...)" # Not all callables have __name__, in fact, only static functions/methods # do. A callable created via nn.Module, to name one example, doesn't have a # __name__. return repr(func) def vmap_impl( func: Callable[_P, Tensor | tuple[Tensor, ...]], in_dims: in_dims_t, out_dims: out_dims_t, randomness: str, chunk_size: int | None, *args: _P.args, **kwargs: _P.kwargs, ) -> Any: lazy_load_decompositions() _check_out_dims_is_int_or_int_pytree(out_dims, func) batch_size, flat_in_dims, flat_args, args_spec = _process_batched_inputs( in_dims, args, func ) if chunk_size is not None: chunks_flat_args = _get_chunked_inputs( flat_args, flat_in_dims, batch_size, chunk_size ) return _chunked_vmap( func, flat_in_dims, chunks_flat_args, args_spec, out_dims, randomness, **kwargs, ) # If chunk_size is not specified. return _flat_vmap( func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs, ) def get_chunk_sizes(total_elems: int, chunk_size: int) -> list[int]: n_chunks = total_elems // chunk_size chunk_sizes = [chunk_size] * n_chunks # remainder chunk remainder = total_elems % chunk_size if remainder != 0: chunk_sizes.append(remainder) return chunk_sizes def _get_chunked_inputs( flat_args: list[Any], flat_in_dims: list[int | None], batch_size: int, chunk_size: int | None, ) -> Iterable[tuple[Any, ...]]: split_idxs = (batch_size,) if chunk_size is not None: chunk_sizes = get_chunk_sizes(batch_size, chunk_size) split_idxs = tuple(itertools.accumulate(chunk_sizes)) flat_args_chunks = tuple( ( t.tensor_split(split_idxs, dim=in_dim) if in_dim is not None else [ t, ] * len(split_idxs) ) for t, in_dim in zip(flat_args, flat_in_dims) ) # transpose chunk dim and flatten structure # chunks_flat_args is a list of flatten args chunks_flat_args = zip(*flat_args_chunks) return chunks_flat_args def _flatten_chunks_output( chunks_output_: list[Any], ) -> tuple[list[tuple[Any, ...]], TreeSpec]: # chunks_output is a list of chunked outputs # flatten chunked outputs: flat_chunks_output: list[list[Any]] = [] arg_spec: TreeSpec | None = None for output in chunks_output_: flat_output, arg_specs = tree_flatten(output) flat_chunks_output.append(flat_output) if arg_spec is None: arg_spec = arg_specs # transpose chunk dim and flatten structure # flat_output_chunks is flat list of chunks flat_output_chunks = list(zip(*flat_chunks_output)) if arg_spec is None: raise AssertionError("arg_spec must not be None") return flat_output_chunks, arg_spec def _concat_chunked_outputs( out_dims: out_dims_t, arg_spec: TreeSpec, flat_output_chunks: list[tuple[Any, ...] | None], ) -> list[Tensor]: # concat chunks on out_dim flat_out_dims = _broadcast_to_and_flatten(out_dims, arg_spec) if flat_out_dims is None: raise AssertionError("flat_out_dims must not be None") if len(flat_out_dims) != len(flat_output_chunks): raise AssertionError( f"len(flat_out_dims)={len(flat_out_dims)} != len(flat_output_chunks)={len(flat_output_chunks)}" ) flat_output: list[Tensor] = [] for idx, out_dim in enumerate(flat_out_dims): chunk = flat_output_chunks[idx] if chunk is None: raise AssertionError(f"chunk at index {idx} must not be None") flat_output.append(torch.cat(chunk, dim=out_dim)) # release tensors flat_output_chunks[idx] = None return flat_output # Applies vmap on chunked_input and returns concatenated output over the chunks. def _chunked_vmap( func: Callable[_P, Tensor | tuple[Tensor, ...]], flat_in_dims: list[int | None], chunks_flat_args: Iterable[tuple[Any, ...]], args_spec: TreeSpec, out_dims: out_dims_t, randomness: str, **kwargs: Any, ) -> Any: chunks_output: list[Any] = [] rs = torch.get_rng_state() if randomness == "same" else None for flat_args_tuple in chunks_flat_args: flat_args = list(flat_args_tuple) batch_size = _validate_and_get_batch_size(flat_in_dims, flat_args) # The way we compute split the input in `_get_chunked_inputs`, # we may get a tensor with `0` batch-size. We skip any computation # in that case. # Eg. # >>> chunk_size = 1 # >>> batch_size = 6 # >>> t = torch.zeros(batch_size, 1) # >>> t.tensor_split([1, 2, 3, 4, 5, 6]) # (tensor([[0.]]), tensor([[0.]]), tensor([[0.]]), tensor([[0.]]), # tensor([[0.]]), tensor([[0.]]), tensor([], size=(0, 1))) if batch_size == 0: continue if rs is not None: torch.set_rng_state(rs) chunks_output.append( _flat_vmap( func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs, ) ) flat_output_chunks, arg_spec = _flatten_chunks_output(chunks_output) # chunked output tensors are held by both `flat_output_chunks` and `chunks_output`. # eagerly remove the reference from `chunks_output`. del chunks_output # concat chunks on out_dim # Note: We use cast since flat_output_chunks is modified in _concat_chunked_outputs # to set elements to None after processing flat_output = _concat_chunked_outputs( out_dims, arg_spec, cast(list[tuple[Any, ...] | None], flat_output_chunks) ) # finally unflatten the output return tree_unflatten(flat_output, arg_spec) # Vmap refactored helper functions: def _check_randomness_arg(randomness: str) -> None: if randomness not in ["error", "different", "same"]: raise RuntimeError( f"Only allowed values for randomness are 'error', 'different', or 'same'. Got {randomness}" ) @contextlib.contextmanager def vmap_increment_nesting( batch_size: int, randomness: str ) -> Generator[int, None, None]: try: vmap_level = _vmap_increment_nesting(batch_size, randomness) yield vmap_level finally: _vmap_decrement_nesting() def _flat_vmap( func: Callable[..., Tensor | tuple[Tensor, ...]], batch_size: int, flat_in_dims: list[int | None], flat_args: list[Any], args_spec: TreeSpec, out_dims: out_dims_t, randomness: str, **kwargs: Any, ) -> Any: with vmap_increment_nesting(batch_size, randomness) as vmap_level: batched_inputs = _create_batched_inputs( flat_in_dims, flat_args, vmap_level, args_spec ) batched_outputs = func(*batched_inputs, **kwargs) return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func) # `restore_vmap` is a private helper function. It is vmap but has the following # differences: # - instead of returning outputs, it returns an (outputs, out_dims) tuple. # out_dims is a pytree of same shape as outputs and contains Optional[int] # specifying where the vmapped dimension, if it exists, is in the corresponding output. # - does no validation on in_dims or inputs (vmap expects at least one Tensor to be vmapped). # restore_vmap allows for no inputs to have the vmap dimension # - does no validation on outputs (vmap expects only Tensor outputs) # restore_vmap allows for return of arbitrary outputs (not just Tensors) # # The TL;DR is that restore_vmap is more general than vmap and has a slightly # different API. The relaxations are so that we can "pause" vmap in the middle # of its execution and then "restore" it later (this is what we do in # the generate_vmap_rule=True implementation of autograd.Function). # # restore_vmap can be technically used in the implementation of vmap, but doing # that refactor is a bit technically challenging because: # - vmap couples the tensor-wrapping code with error checking # - vmap's tensor unwrapping code is in C++; we would need to rewrite part of it # in python because it overlaps with unwrap_batched def restore_vmap( func: Callable[..., _R], in_dims: in_dims_t, batch_size: int, randomness: str ) -> Callable[..., tuple[Any, Any]]: def inner(*args: Any, **kwargs: Any) -> tuple[Any, Any]: with vmap_increment_nesting(batch_size, randomness) as vmap_level: batched_inputs = wrap_batched(args, in_dims, vmap_level) batched_outputs = func(*batched_inputs, **kwargs) return unwrap_batched(batched_outputs, vmap_level) return inner def wrap_batched( args: tuple[Any, ...], bdims: in_dims_t, level: int ) -> tuple[Any, ...]: flat_args, spec = tree_flatten(args) flat_bdims = _broadcast_to_and_flatten(bdims, spec) if flat_bdims is None: raise AssertionError("flat_bdims must not be None") result = _create_batched_inputs(flat_bdims, flat_args, level, spec) return result def unwrap_batched(args: Any, level: int) -> tuple[Any, Any]: flat_args, spec = tree_flatten(args) if len(flat_args) == 0: return args, () result = [ ( torch._C._functorch._unwrap_batched(arg, level) if isinstance(arg, torch.Tensor) else (arg, None) ) for arg in flat_args ] output, bdims = zip(*result) return tree_unflatten(output, spec), tree_unflatten(bdims, spec)