_vmap_internals.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. # mypy: allow-untyped-defs
  2. import functools
  3. from collections.abc import Callable
  4. from typing import Any
  5. from typing_extensions import deprecated
  6. import torch
  7. from torch import Tensor
  8. from torch.utils._pytree import _broadcast_to_and_flatten, tree_flatten, tree_unflatten
  9. in_dims_t = int | tuple
  10. out_dims_t = int | tuple[int, ...]
  11. # Checks that all args-to-be-batched have the same batch dim size
  12. def _validate_and_get_batch_size(
  13. flat_in_dims: list[int | None],
  14. flat_args: list,
  15. ) -> int:
  16. batch_sizes = [
  17. arg.size(in_dim)
  18. for in_dim, arg in zip(flat_in_dims, flat_args)
  19. if in_dim is not None
  20. ]
  21. if batch_sizes and any(size != batch_sizes[0] for size in batch_sizes):
  22. raise ValueError(
  23. f"vmap: Expected all tensors to have the same size in the mapped "
  24. f"dimension, got sizes {batch_sizes} for the mapped dimension"
  25. )
  26. return batch_sizes[0]
  27. def _num_outputs(batched_outputs: Tensor | tuple[Tensor, ...]) -> int:
  28. if isinstance(batched_outputs, tuple):
  29. return len(batched_outputs)
  30. return 1
  31. # If value is a tuple, check it has length `num_elements`.
  32. # If value is not a tuple, make a tuple with `value` repeated `num_elements` times
  33. def _as_tuple(
  34. value: Any,
  35. num_elements: int,
  36. error_message_lambda: Callable[[], str],
  37. ) -> tuple:
  38. if not isinstance(value, tuple):
  39. return (value,) * num_elements
  40. if len(value) != num_elements:
  41. raise ValueError(error_message_lambda())
  42. return value
  43. # Creates BatchedTensors for every Tensor in arg that should be batched.
  44. # Returns the (potentially) batched arguments and the batch_size.
  45. def _create_batched_inputs(
  46. in_dims: in_dims_t,
  47. args: tuple,
  48. vmap_level: int,
  49. func: Callable,
  50. ) -> tuple[tuple, int]:
  51. if not isinstance(in_dims, int) and not isinstance(in_dims, tuple):
  52. raise ValueError(
  53. f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
  54. f"expected `in_dims` to be int or a (potentially nested) tuple "
  55. f"matching the structure of inputs, got: {type(in_dims)}."
  56. )
  57. if len(args) == 0:
  58. raise ValueError(
  59. f"vmap({_get_name(func)})(<inputs>): got no inputs. Maybe you forgot to add "
  60. f"inputs, or you are trying to vmap over a function with no inputs. "
  61. f"The latter is unsupported."
  62. )
  63. flat_args, args_spec = tree_flatten(args)
  64. flat_in_dims = _broadcast_to_and_flatten(in_dims, args_spec)
  65. if flat_in_dims is None:
  66. raise ValueError(
  67. f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
  68. f"in_dims is not compatible with the structure of `inputs`. "
  69. f"in_dims has structure {tree_flatten(in_dims)[1]} but inputs "
  70. f"has structure {args_spec}."
  71. )
  72. for arg, in_dim in zip(flat_args, flat_in_dims):
  73. if not isinstance(in_dim, int) and in_dim is not None:
  74. raise ValueError(
  75. f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
  76. f"Got in_dim={in_dim} for an input but in_dim must be either "
  77. f"an integer dimension or None."
  78. )
  79. if isinstance(in_dim, int) and not isinstance(arg, Tensor):
  80. raise ValueError(
  81. f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
  82. f"Got in_dim={in_dim} for an input but the input is of type "
  83. f"{type(arg)}. We cannot vmap over non-Tensor arguments, "
  84. f"please use None as the respective in_dim"
  85. )
  86. if in_dim is not None and (in_dim < 0 or in_dim >= arg.dim()):
  87. raise ValueError(
  88. f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
  89. f"Got in_dim={in_dim} for some input, but that input is a Tensor "
  90. f"of dimensionality {arg.dim()} so expected in_dim to satisfy "
  91. f"0 <= in_dim < {arg.dim()}."
  92. )
  93. batch_size = _validate_and_get_batch_size(flat_in_dims, flat_args)
  94. # See NOTE [Ignored _remove_batch_dim, _add_batch_dim]
  95. batched_inputs = [
  96. arg if in_dim is None else torch._add_batch_dim(arg, in_dim, vmap_level)
  97. for in_dim, arg in zip(flat_in_dims, flat_args)
  98. ]
  99. return tree_unflatten(batched_inputs, args_spec), batch_size
  100. # Undos the batching (and any batch dimensions) associated with the `vmap_level`.
  101. def _unwrap_batched(
  102. batched_outputs: Tensor | tuple[Tensor, ...],
  103. out_dims: out_dims_t,
  104. vmap_level: int,
  105. batch_size: int,
  106. func: Callable,
  107. allow_none_pass_through: bool = False,
  108. ) -> tuple:
  109. num_outputs = _num_outputs(batched_outputs)
  110. out_dims_as_tuple = _as_tuple(
  111. out_dims,
  112. num_outputs,
  113. lambda: f"vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must "
  114. f"have one dim per output (got {num_outputs} outputs) of {_get_name(func)}.",
  115. )
  116. # NOTE [Ignored _remove_batch_dim, _add_batch_dim]
  117. # There is something wrong with our type bindings for functions that begin
  118. # with '_', see #40397.
  119. if isinstance(batched_outputs, Tensor):
  120. out_dim = out_dims_as_tuple[0]
  121. return torch._remove_batch_dim(batched_outputs, vmap_level, batch_size, out_dim) # type: ignore[return-value]
  122. if allow_none_pass_through:
  123. return tuple(
  124. (
  125. torch._remove_batch_dim(out, vmap_level, batch_size, out_dim)
  126. if out is not None
  127. else None
  128. )
  129. for out, out_dim in zip(batched_outputs, out_dims_as_tuple)
  130. )
  131. else:
  132. return tuple(
  133. torch._remove_batch_dim(out, vmap_level, batch_size, out_dim)
  134. for out, out_dim in zip(batched_outputs, out_dims_as_tuple)
  135. )
  136. # Checks that `fn` returned one or more Tensors and nothing else.
  137. # NB: A python function that return multiple arguments returns a single tuple,
  138. # so we are effectively checking that `outputs` is a single Tensor or a tuple of
  139. # Tensors.
  140. def _validate_outputs(outputs: Any, func: Callable) -> None:
  141. if isinstance(outputs, Tensor):
  142. return
  143. if not isinstance(outputs, tuple):
  144. raise ValueError(
  145. f"vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return "
  146. f"Tensors, got type {type(outputs)} as the return."
  147. )
  148. for idx, output in enumerate(outputs):
  149. if isinstance(output, Tensor):
  150. continue
  151. raise ValueError(
  152. f"vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return "
  153. f"Tensors, got type {type(output)} for return {idx}."
  154. )
  155. def _check_out_dims_is_int_or_int_tuple(out_dims: out_dims_t, func: Callable) -> None:
  156. if isinstance(out_dims, int):
  157. return
  158. if not isinstance(out_dims, tuple) or not all(
  159. isinstance(out_dim, int) for out_dim in out_dims
  160. ):
  161. raise ValueError(
  162. f"vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must be "
  163. f"an int or a tuple of int representing where in the outputs the "
  164. f"vmapped dimension should appear."
  165. )
  166. def _get_name(func: Callable):
  167. if hasattr(func, "__name__"):
  168. return func.__name__
  169. # Not all callables have __name__, in fact, only static functions/methods do.
  170. # A callable created via functools.partial or an nn.Module, to name some
  171. # examples, don't have a __name__.
  172. return repr(func)
  173. # vmap(func)(inputs) wraps all Tensor inputs to be batched in BatchedTensors,
  174. # sends those into func, and then unwraps the output BatchedTensors. Operations
  175. # on BatchedTensors perform the batched operations that the user is asking for.
  176. @deprecated(
  177. "Please use `torch.vmap` instead of `torch._vmap_internals.vmap`.",
  178. category=FutureWarning,
  179. )
  180. def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Callable:
  181. """
  182. Please use torch.vmap instead of this API.
  183. """
  184. return _vmap(func, in_dims, out_dims)
  185. # A version of vmap but without the initial "experimental prototype" warning
  186. def _vmap(
  187. func: Callable,
  188. in_dims: in_dims_t = 0,
  189. out_dims: out_dims_t = 0,
  190. allow_none_pass_through: bool = False,
  191. ) -> Callable:
  192. # The `allow_none_pass_through` argument is a temporary workaround may be removed.
  193. # Currently it enables us to wrap the call in `autograd.grad` to the autograd engine,
  194. # which may return None if any of the inputs are unused. See the issue discussing this:
  195. # https://github.com/pytorch/functorch/issues/159.
  196. @functools.wraps(func)
  197. def wrapped(*args):
  198. _check_out_dims_is_int_or_int_tuple(out_dims, func)
  199. vmap_level = torch._C._vmapmode_increment_nesting()
  200. try:
  201. batched_inputs, batch_size = _create_batched_inputs(
  202. in_dims, args, vmap_level, func
  203. )
  204. batched_outputs = func(*batched_inputs)
  205. if not allow_none_pass_through:
  206. _validate_outputs(batched_outputs, func)
  207. return _unwrap_batched(
  208. batched_outputs,
  209. out_dims,
  210. vmap_level,
  211. batch_size,
  212. func,
  213. allow_none_pass_through=allow_none_pass_through,
  214. )
  215. finally:
  216. torch._C._vmapmode_decrement_nesting()
  217. return wrapped