vmap.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. from __future__ import annotations
  7. import contextlib
  8. import functools
  9. import itertools
  10. from collections.abc import Callable # noqa: TC003
  11. from functools import partial
  12. from typing import Any, cast, NoReturn, TYPE_CHECKING
  13. from typing_extensions import ParamSpec, TypeVar
  14. import torch
  15. from torch import Tensor
  16. from torch._C._functorch import is_batchedtensor
  17. from torch._functorch.predispatch import (
  18. _add_batch_dim,
  19. _remove_batch_dim,
  20. _vmap_decrement_nesting,
  21. _vmap_increment_nesting,
  22. lazy_load_decompositions,
  23. )
  24. from torch.utils._pytree import (
  25. _broadcast_to_and_flatten,
  26. tree_flatten,
  27. tree_map_,
  28. tree_unflatten,
  29. TreeSpec,
  30. )
  31. if TYPE_CHECKING:
  32. from collections.abc import Generator, Iterable
  33. _P = ParamSpec("_P")
  34. _R = TypeVar("_R")
  35. in_dims_t = int | tuple[Any, ...]
  36. out_dims_t = int | tuple[int, ...] | None
  37. def doesnt_support_saved_tensors_hooks(f: Callable[_P, _R]) -> Callable[_P, _R]:
  38. message = (
  39. "torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. "
  40. "Please open an issue with your use case."
  41. )
  42. @functools.wraps(f)
  43. def fn(*args: _P.args, **kwargs: _P.kwargs) -> _R:
  44. with torch.autograd.graph.disable_saved_tensors_hooks(message):
  45. return f(*args, **kwargs)
  46. return fn
  47. # Checks that all args-to-be-batched have the same batch dim size
  48. def _validate_and_get_batch_size(
  49. flat_in_dims: list[int | None], flat_args: list[Any]
  50. ) -> int:
  51. batch_sizes = [
  52. arg.size(in_dim)
  53. for in_dim, arg in zip(flat_in_dims, flat_args)
  54. if in_dim is not None
  55. ]
  56. if len(batch_sizes) == 0:
  57. raise ValueError("vmap: Expected at least one Tensor to vmap over")
  58. if batch_sizes and any(size != batch_sizes[0] for size in batch_sizes):
  59. raise ValueError(
  60. f"vmap: Expected all tensors to have the same size in the mapped "
  61. f"dimension, got sizes {batch_sizes} for the mapped dimension"
  62. )
  63. return batch_sizes[0]
  64. def _num_outputs(batched_outputs: Tensor | tuple[Tensor, ...]) -> int:
  65. if isinstance(batched_outputs, tuple):
  66. return len(batched_outputs)
  67. return 1
  68. # If value is a tuple, check it has length `num_elements`.
  69. # If value is not a tuple, make a tuple with `value` repeated `num_elements` times
  70. def _as_tuple(
  71. value: tuple[_R, ...] | _R,
  72. num_elements: int,
  73. error_message_lambda: Callable[[], str],
  74. ) -> tuple[_R, ...]:
  75. if not isinstance(value, tuple):
  76. return (value,) * num_elements
  77. if len(value) != num_elements:
  78. raise ValueError(error_message_lambda())
  79. return value
  80. def _process_batched_inputs(
  81. in_dims: in_dims_t, args: tuple[Any, ...], func: Callable[..., Any]
  82. ) -> tuple[int, list[int | None], list[Any], TreeSpec]:
  83. if not isinstance(in_dims, int) and not isinstance(in_dims, tuple):
  84. raise ValueError(
  85. f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
  86. f"expected `in_dims` to be int or a (potentially nested) tuple "
  87. f"matching the structure of inputs, got: {type(in_dims)}."
  88. )
  89. if len(args) == 0:
  90. raise ValueError(
  91. f"vmap({_get_name(func)})(<inputs>): got no inputs. Maybe you forgot to add "
  92. f"inputs, or you are trying to vmap over a function with no inputs. "
  93. f"The latter is unsupported."
  94. )
  95. flat_args, args_spec = tree_flatten(args)
  96. flat_in_dims = _broadcast_to_and_flatten(in_dims, args_spec)
  97. if flat_in_dims is None:
  98. raise ValueError(
  99. f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
  100. f"in_dims is not compatible with the structure of `inputs`. "
  101. f"in_dims has structure {tree_flatten(in_dims)[1]} but inputs "
  102. f"has structure {args_spec}."
  103. )
  104. for i, (arg, in_dim) in enumerate(zip(flat_args, flat_in_dims)):
  105. if not isinstance(in_dim, int) and in_dim is not None:
  106. raise ValueError(
  107. f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
  108. f"Got in_dim={in_dim} for an input but in_dim must be either "
  109. f"an integer dimension or None."
  110. )
  111. if isinstance(in_dim, int) and not isinstance(arg, Tensor):
  112. raise ValueError(
  113. f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
  114. f"Got in_dim={in_dim} for an input but the input is of type "
  115. f"{type(arg)}. We cannot vmap over non-Tensor arguments, "
  116. f"please use None as the respective in_dim"
  117. )
  118. if in_dim is not None and (in_dim < -arg.dim() or in_dim >= arg.dim()):
  119. raise ValueError(
  120. f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
  121. f"Got in_dim={in_dim} for some input, but that input is a Tensor "
  122. f"of dimensionality {arg.dim()} so expected in_dim to satisfy "
  123. f"-{arg.dim()} <= in_dim < {arg.dim()}."
  124. )
  125. if in_dim is not None and in_dim < 0:
  126. flat_in_dims[i] = in_dim % arg.dim()
  127. return (
  128. _validate_and_get_batch_size(flat_in_dims, flat_args),
  129. flat_in_dims,
  130. flat_args,
  131. args_spec,
  132. )
  133. # Creates BatchedTensors for every Tensor in arg that should be batched.
  134. # Returns the (potentially) batched arguments and the batch_size.
  135. # TODO: See if we can explain how flat works to the type checker
  136. def _create_batched_inputs(
  137. flat_in_dims: list[int | None],
  138. flat_args: list[Any],
  139. vmap_level: int,
  140. args_spec: TreeSpec,
  141. ) -> tuple[Any, ...]:
  142. # See NOTE [Ignored _remove_batch_dim, _add_batch_dim]
  143. batched_inputs = [
  144. arg if in_dim is None else _add_batch_dim(arg, in_dim, vmap_level)
  145. for in_dim, arg in zip(flat_in_dims, flat_args)
  146. ]
  147. return tree_unflatten(batched_inputs, args_spec)
  148. def _maybe_remove_batch_dim(
  149. name: str,
  150. batched_output: Any,
  151. vmap_level: int,
  152. batch_size: int,
  153. out_dim: int | None,
  154. ) -> torch.Tensor:
  155. if out_dim is None:
  156. if isinstance(batched_output, torch.Tensor) and is_batchedtensor(
  157. batched_output
  158. ):
  159. raise ValueError(
  160. f"vmap({name}, ...): `{name}` can not return a "
  161. f"BatchedTensor when out_dim is None"
  162. )
  163. return batched_output
  164. # out_dim is non None
  165. if not isinstance(batched_output, torch.Tensor):
  166. raise ValueError(
  167. f"vmap({name}, ...): `{name}` must only return "
  168. f"Tensors, got type {type(batched_output)}. "
  169. "Did you mean to set out_dims= to None for output?"
  170. )
  171. return _remove_batch_dim(batched_output, vmap_level, batch_size, out_dim)
  172. # Undos the batching (and any batch dimensions) associated with the `vmap_level`.
  173. def _unwrap_batched(
  174. batched_outputs: Tensor | tuple[Tensor, ...],
  175. out_dims: out_dims_t,
  176. vmap_level: int,
  177. batch_size: int,
  178. func: Callable[..., Any],
  179. ) -> tuple[Any, ...]:
  180. flat_batched_outputs, output_spec = tree_flatten(batched_outputs)
  181. def incompatible_error() -> NoReturn:
  182. raise ValueError(
  183. f"vmap({_get_name(func)}, ..., out_dims={out_dims})(<inputs>): "
  184. f"out_dims is not compatible with the structure of `outputs`. "
  185. f"out_dims has structure {tree_flatten(out_dims)[1]} but outputs "
  186. f"has structure {output_spec}."
  187. )
  188. flat_out_dims: list[int | None] = []
  189. if isinstance(batched_outputs, torch.Tensor):
  190. # Some weird edge case requires us to spell out the following
  191. # see test_out_dims_edge_case
  192. if isinstance(out_dims, int):
  193. flat_out_dims = [out_dims]
  194. elif isinstance(out_dims, tuple) and len(out_dims) == 1:
  195. flat_out_dims = list(out_dims)
  196. elif out_dims is None:
  197. flat_out_dims = [out_dims]
  198. else:
  199. incompatible_error()
  200. else:
  201. broadcast_result = _broadcast_to_and_flatten(out_dims, output_spec)
  202. if broadcast_result is None:
  203. incompatible_error()
  204. else:
  205. flat_out_dims = broadcast_result
  206. flat_outputs = [
  207. _maybe_remove_batch_dim(
  208. _get_name(func), batched_output, vmap_level, batch_size, out_dim
  209. )
  210. for batched_output, out_dim in zip(flat_batched_outputs, flat_out_dims)
  211. ]
  212. return tree_unflatten(flat_outputs, output_spec)
  213. def _check_int_or_none(x: Any, func: Callable[..., Any], out_dims: out_dims_t) -> None:
  214. if isinstance(x, int):
  215. return
  216. if x is None:
  217. return
  218. raise ValueError(
  219. f"vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must be "
  220. f"an int, None or a python collection of ints representing where in the outputs the "
  221. f"vmapped dimension should appear."
  222. )
  223. def _check_out_dims_is_int_or_int_pytree(
  224. out_dims: out_dims_t, func: Callable[..., Any]
  225. ) -> None:
  226. if isinstance(out_dims, int):
  227. return
  228. tree_map_(partial(_check_int_or_none, func=func, out_dims=out_dims), out_dims)
  229. def _get_name(func: Callable[..., Any]) -> str:
  230. if hasattr(func, "__name__"):
  231. return func.__name__
  232. if isinstance(func, functools.partial):
  233. return f"functools.partial({_get_name(func.func)}, ...)"
  234. # Not all callables have __name__, in fact, only static functions/methods
  235. # do. A callable created via nn.Module, to name one example, doesn't have a
  236. # __name__.
  237. return repr(func)
  238. def vmap_impl(
  239. func: Callable[_P, Tensor | tuple[Tensor, ...]],
  240. in_dims: in_dims_t,
  241. out_dims: out_dims_t,
  242. randomness: str,
  243. chunk_size: int | None,
  244. *args: _P.args,
  245. **kwargs: _P.kwargs,
  246. ) -> Any:
  247. lazy_load_decompositions()
  248. _check_out_dims_is_int_or_int_pytree(out_dims, func)
  249. batch_size, flat_in_dims, flat_args, args_spec = _process_batched_inputs(
  250. in_dims, args, func
  251. )
  252. if chunk_size is not None:
  253. chunks_flat_args = _get_chunked_inputs(
  254. flat_args, flat_in_dims, batch_size, chunk_size
  255. )
  256. return _chunked_vmap(
  257. func,
  258. flat_in_dims,
  259. chunks_flat_args,
  260. args_spec,
  261. out_dims,
  262. randomness,
  263. **kwargs,
  264. )
  265. # If chunk_size is not specified.
  266. return _flat_vmap(
  267. func,
  268. batch_size,
  269. flat_in_dims,
  270. flat_args,
  271. args_spec,
  272. out_dims,
  273. randomness,
  274. **kwargs,
  275. )
  276. def get_chunk_sizes(total_elems: int, chunk_size: int) -> list[int]:
  277. n_chunks = total_elems // chunk_size
  278. chunk_sizes = [chunk_size] * n_chunks
  279. # remainder chunk
  280. remainder = total_elems % chunk_size
  281. if remainder != 0:
  282. chunk_sizes.append(remainder)
  283. return chunk_sizes
  284. def _get_chunked_inputs(
  285. flat_args: list[Any],
  286. flat_in_dims: list[int | None],
  287. batch_size: int,
  288. chunk_size: int | None,
  289. ) -> Iterable[tuple[Any, ...]]:
  290. split_idxs = (batch_size,)
  291. if chunk_size is not None:
  292. chunk_sizes = get_chunk_sizes(batch_size, chunk_size)
  293. split_idxs = tuple(itertools.accumulate(chunk_sizes))
  294. flat_args_chunks = tuple(
  295. (
  296. t.tensor_split(split_idxs, dim=in_dim)
  297. if in_dim is not None
  298. else [
  299. t,
  300. ]
  301. * len(split_idxs)
  302. )
  303. for t, in_dim in zip(flat_args, flat_in_dims)
  304. )
  305. # transpose chunk dim and flatten structure
  306. # chunks_flat_args is a list of flatten args
  307. chunks_flat_args = zip(*flat_args_chunks)
  308. return chunks_flat_args
  309. def _flatten_chunks_output(
  310. chunks_output_: list[Any],
  311. ) -> tuple[list[tuple[Any, ...]], TreeSpec]:
  312. # chunks_output is a list of chunked outputs
  313. # flatten chunked outputs:
  314. flat_chunks_output: list[list[Any]] = []
  315. arg_spec: TreeSpec | None = None
  316. for output in chunks_output_:
  317. flat_output, arg_specs = tree_flatten(output)
  318. flat_chunks_output.append(flat_output)
  319. if arg_spec is None:
  320. arg_spec = arg_specs
  321. # transpose chunk dim and flatten structure
  322. # flat_output_chunks is flat list of chunks
  323. flat_output_chunks = list(zip(*flat_chunks_output))
  324. if arg_spec is None:
  325. raise AssertionError("arg_spec must not be None")
  326. return flat_output_chunks, arg_spec
  327. def _concat_chunked_outputs(
  328. out_dims: out_dims_t,
  329. arg_spec: TreeSpec,
  330. flat_output_chunks: list[tuple[Any, ...] | None],
  331. ) -> list[Tensor]:
  332. # concat chunks on out_dim
  333. flat_out_dims = _broadcast_to_and_flatten(out_dims, arg_spec)
  334. if flat_out_dims is None:
  335. raise AssertionError("flat_out_dims must not be None")
  336. if len(flat_out_dims) != len(flat_output_chunks):
  337. raise AssertionError(
  338. f"len(flat_out_dims)={len(flat_out_dims)} != len(flat_output_chunks)={len(flat_output_chunks)}"
  339. )
  340. flat_output: list[Tensor] = []
  341. for idx, out_dim in enumerate(flat_out_dims):
  342. chunk = flat_output_chunks[idx]
  343. if chunk is None:
  344. raise AssertionError(f"chunk at index {idx} must not be None")
  345. flat_output.append(torch.cat(chunk, dim=out_dim))
  346. # release tensors
  347. flat_output_chunks[idx] = None
  348. return flat_output
  349. # Applies vmap on chunked_input and returns concatenated output over the chunks.
  350. def _chunked_vmap(
  351. func: Callable[_P, Tensor | tuple[Tensor, ...]],
  352. flat_in_dims: list[int | None],
  353. chunks_flat_args: Iterable[tuple[Any, ...]],
  354. args_spec: TreeSpec,
  355. out_dims: out_dims_t,
  356. randomness: str,
  357. **kwargs: Any,
  358. ) -> Any:
  359. chunks_output: list[Any] = []
  360. rs = torch.get_rng_state() if randomness == "same" else None
  361. for flat_args_tuple in chunks_flat_args:
  362. flat_args = list(flat_args_tuple)
  363. batch_size = _validate_and_get_batch_size(flat_in_dims, flat_args)
  364. # The way we compute split the input in `_get_chunked_inputs`,
  365. # we may get a tensor with `0` batch-size. We skip any computation
  366. # in that case.
  367. # Eg.
  368. # >>> chunk_size = 1
  369. # >>> batch_size = 6
  370. # >>> t = torch.zeros(batch_size, 1)
  371. # >>> t.tensor_split([1, 2, 3, 4, 5, 6])
  372. # (tensor([[0.]]), tensor([[0.]]), tensor([[0.]]), tensor([[0.]]),
  373. # tensor([[0.]]), tensor([[0.]]), tensor([], size=(0, 1)))
  374. if batch_size == 0:
  375. continue
  376. if rs is not None:
  377. torch.set_rng_state(rs)
  378. chunks_output.append(
  379. _flat_vmap(
  380. func,
  381. batch_size,
  382. flat_in_dims,
  383. flat_args,
  384. args_spec,
  385. out_dims,
  386. randomness,
  387. **kwargs,
  388. )
  389. )
  390. flat_output_chunks, arg_spec = _flatten_chunks_output(chunks_output)
  391. # chunked output tensors are held by both `flat_output_chunks` and `chunks_output`.
  392. # eagerly remove the reference from `chunks_output`.
  393. del chunks_output
  394. # concat chunks on out_dim
  395. # Note: We use cast since flat_output_chunks is modified in _concat_chunked_outputs
  396. # to set elements to None after processing
  397. flat_output = _concat_chunked_outputs(
  398. out_dims, arg_spec, cast(list[tuple[Any, ...] | None], flat_output_chunks)
  399. )
  400. # finally unflatten the output
  401. return tree_unflatten(flat_output, arg_spec)
  402. # Vmap refactored helper functions:
  403. def _check_randomness_arg(randomness: str) -> None:
  404. if randomness not in ["error", "different", "same"]:
  405. raise RuntimeError(
  406. f"Only allowed values for randomness are 'error', 'different', or 'same'. Got {randomness}"
  407. )
  408. @contextlib.contextmanager
  409. def vmap_increment_nesting(
  410. batch_size: int, randomness: str
  411. ) -> Generator[int, None, None]:
  412. try:
  413. vmap_level = _vmap_increment_nesting(batch_size, randomness)
  414. yield vmap_level
  415. finally:
  416. _vmap_decrement_nesting()
  417. def _flat_vmap(
  418. func: Callable[..., Tensor | tuple[Tensor, ...]],
  419. batch_size: int,
  420. flat_in_dims: list[int | None],
  421. flat_args: list[Any],
  422. args_spec: TreeSpec,
  423. out_dims: out_dims_t,
  424. randomness: str,
  425. **kwargs: Any,
  426. ) -> Any:
  427. with vmap_increment_nesting(batch_size, randomness) as vmap_level:
  428. batched_inputs = _create_batched_inputs(
  429. flat_in_dims, flat_args, vmap_level, args_spec
  430. )
  431. batched_outputs = func(*batched_inputs, **kwargs)
  432. return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
  433. # `restore_vmap` is a private helper function. It is vmap but has the following
  434. # differences:
  435. # - instead of returning outputs, it returns an (outputs, out_dims) tuple.
  436. # out_dims is a pytree of same shape as outputs and contains Optional[int]
  437. # specifying where the vmapped dimension, if it exists, is in the corresponding output.
  438. # - does no validation on in_dims or inputs (vmap expects at least one Tensor to be vmapped).
  439. # restore_vmap allows for no inputs to have the vmap dimension
  440. # - does no validation on outputs (vmap expects only Tensor outputs)
  441. # restore_vmap allows for return of arbitrary outputs (not just Tensors)
  442. #
  443. # The TL;DR is that restore_vmap is more general than vmap and has a slightly
  444. # different API. The relaxations are so that we can "pause" vmap in the middle
  445. # of its execution and then "restore" it later (this is what we do in
  446. # the generate_vmap_rule=True implementation of autograd.Function).
  447. #
  448. # restore_vmap can be technically used in the implementation of vmap, but doing
  449. # that refactor is a bit technically challenging because:
  450. # - vmap couples the tensor-wrapping code with error checking
  451. # - vmap's tensor unwrapping code is in C++; we would need to rewrite part of it
  452. # in python because it overlaps with unwrap_batched
  453. def restore_vmap(
  454. func: Callable[..., _R], in_dims: in_dims_t, batch_size: int, randomness: str
  455. ) -> Callable[..., tuple[Any, Any]]:
  456. def inner(*args: Any, **kwargs: Any) -> tuple[Any, Any]:
  457. with vmap_increment_nesting(batch_size, randomness) as vmap_level:
  458. batched_inputs = wrap_batched(args, in_dims, vmap_level)
  459. batched_outputs = func(*batched_inputs, **kwargs)
  460. return unwrap_batched(batched_outputs, vmap_level)
  461. return inner
  462. def wrap_batched(
  463. args: tuple[Any, ...], bdims: in_dims_t, level: int
  464. ) -> tuple[Any, ...]:
  465. flat_args, spec = tree_flatten(args)
  466. flat_bdims = _broadcast_to_and_flatten(bdims, spec)
  467. if flat_bdims is None:
  468. raise AssertionError("flat_bdims must not be None")
  469. result = _create_batched_inputs(flat_bdims, flat_args, level, spec)
  470. return result
  471. def unwrap_batched(args: Any, level: int) -> tuple[Any, Any]:
  472. flat_args, spec = tree_flatten(args)
  473. if len(flat_args) == 0:
  474. return args, ()
  475. result = [
  476. (
  477. torch._C._functorch._unwrap_batched(arg, level)
  478. if isinstance(arg, torch.Tensor)
  479. else (arg, None)
  480. )
  481. for arg in flat_args
  482. ]
  483. output, bdims = zip(*result)
  484. return tree_unflatten(output, spec), tree_unflatten(bdims, spec)