deprecated.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. """
  2. The APIs in this file are exposed as `functorch.*`. They are thin wrappers
  3. around the torch.func.* APIs that have deprecation warnings -- we're trying
  4. to move people to the torch.func.* equivalents.
  5. NB: We don't use *args, **kwargs in the signatures because that changes the
  6. documentation.
  7. """
  8. from __future__ import annotations
  9. import textwrap
  10. import warnings
  11. from typing import Any, TYPE_CHECKING
  12. import torch._functorch.apis as apis
  13. import torch._functorch.eager_transforms as _impl
  14. import torch._functorch.make_functional as _nn_impl
  15. import torch.nn as nn
  16. if TYPE_CHECKING:
  17. from collections.abc import Callable
  18. from torch._functorch.eager_transforms import argnums_t
  19. from torch._functorch.vmap import in_dims_t, out_dims_t
  20. def get_warning(
  21. api: str, new_api: str | None = None, replace_newlines: bool = False
  22. ) -> str:
  23. if new_api is None:
  24. new_api = f"torch.func.{api}"
  25. warning = (
  26. f"We've integrated functorch into PyTorch. As the final step of the \n"
  27. f"integration, `functorch.{api}` is deprecated as of PyTorch \n"
  28. f"2.0 and will be deleted in a future version of PyTorch >= 2.3. \n"
  29. f"Please use `{new_api}` instead; see the PyTorch 2.0 release notes \n"
  30. f"and/or the `torch.func` migration guide for more details \n"
  31. f"https://pytorch.org/docs/main/func.migrating.html"
  32. )
  33. if replace_newlines:
  34. warning = warning.replace("\n", "")
  35. return warning
  36. def warn_deprecated(api: str, new_api: str | None = None) -> None:
  37. warning = get_warning(api, new_api, replace_newlines=True)
  38. warnings.warn(warning, FutureWarning, stacklevel=3)
  39. def setup_docs(
  40. functorch_api: Callable[..., Any],
  41. torch_func_api: Callable[..., Any] | None = None,
  42. new_api_name: str | None = None,
  43. ) -> None:
  44. api_name = functorch_api.__name__
  45. if torch_func_api is None:
  46. torch_func_api = getattr(_impl, api_name)
  47. # See https://docs.python.org/3/using/cmdline.html#cmdoption-OO
  48. if torch_func_api.__doc__ is None:
  49. return
  50. warning = get_warning(api_name, new_api_name)
  51. warning_note = "\n.. warning::\n\n" + textwrap.indent(warning, " ")
  52. warning_note = textwrap.indent(warning_note, " ")
  53. functorch_api.__doc__ = torch_func_api.__doc__ + warning_note
  54. def vmap(
  55. func: Callable[..., Any],
  56. in_dims: in_dims_t = 0,
  57. out_dims: out_dims_t = 0,
  58. randomness: str = "error",
  59. *,
  60. chunk_size: int | None = None,
  61. ) -> Callable[..., Any]:
  62. warn_deprecated("vmap", "torch.vmap")
  63. return apis.vmap(func, in_dims, out_dims, randomness, chunk_size=chunk_size)
  64. def grad(
  65. func: Callable[..., Any], argnums: argnums_t = 0, has_aux: bool = False
  66. ) -> Callable[..., Any]:
  67. warn_deprecated("grad")
  68. return apis.grad(func, argnums, has_aux)
  69. def grad_and_value(
  70. func: Callable[..., Any], argnums: argnums_t = 0, has_aux: bool = False
  71. ) -> Callable[..., Any]:
  72. warn_deprecated("grad_and_value")
  73. return apis.grad_and_value(func, argnums, has_aux)
  74. def vjp(func: Callable[..., Any], *primals: Any, has_aux: bool = False) -> Any:
  75. warn_deprecated("vjp")
  76. return _impl.vjp(func, *primals, has_aux=has_aux)
  77. def jvp(
  78. func: Callable[..., Any],
  79. primals: Any,
  80. tangents: Any,
  81. *,
  82. strict: bool = False,
  83. has_aux: bool = False,
  84. ) -> Any:
  85. warn_deprecated("jvp")
  86. return _impl.jvp(func, primals, tangents, strict=strict, has_aux=has_aux)
  87. def jacrev(
  88. func: Callable[..., Any],
  89. argnums: int | tuple[int, ...] = 0,
  90. *,
  91. has_aux: bool = False,
  92. chunk_size: int | None = None,
  93. _preallocate_and_copy: bool = False,
  94. ) -> Callable[..., Any]:
  95. warn_deprecated("jacrev")
  96. return _impl.jacrev(
  97. func,
  98. argnums,
  99. has_aux=has_aux,
  100. chunk_size=chunk_size,
  101. _preallocate_and_copy=_preallocate_and_copy,
  102. )
  103. def jacfwd(
  104. func: Callable[..., Any],
  105. argnums: argnums_t = 0,
  106. has_aux: bool = False,
  107. *,
  108. randomness: str = "error",
  109. ) -> Callable[..., Any]:
  110. warn_deprecated("jacfwd")
  111. return _impl.jacfwd(func, argnums, has_aux, randomness=randomness)
  112. def hessian(func: Callable[..., Any], argnums: int = 0) -> Callable[..., Any]:
  113. warn_deprecated("hessian")
  114. return _impl.hessian(func, argnums=argnums)
  115. def functionalize(
  116. func: Callable[..., Any], *, remove: str = "mutations"
  117. ) -> Callable[..., Any]:
  118. warn_deprecated("functionalize")
  119. return _impl.functionalize(func, remove=remove)
  120. def make_functional(model: nn.Module, disable_autograd_tracking: bool = False) -> Any:
  121. warn_deprecated("make_functional", "torch.func.functional_call")
  122. return _nn_impl.make_functional(model, disable_autograd_tracking)
  123. def make_functional_with_buffers(
  124. model: nn.Module, disable_autograd_tracking: bool = False
  125. ) -> Any:
  126. warn_deprecated("make_functional_with_buffers", "torch.func.functional_call")
  127. return _nn_impl.make_functional_with_buffers(model, disable_autograd_tracking)
  128. def combine_state_for_ensemble(models: list[nn.Module]) -> Any:
  129. warn_deprecated("combine_state_for_ensemble", "torch.func.stack_module_state")
  130. return _nn_impl.combine_state_for_ensemble(models)
  131. setup_docs(vmap, apis.vmap, "torch.vmap")
  132. setup_docs(grad, apis.grad)
  133. setup_docs(grad_and_value, apis.grad_and_value)
  134. setup_docs(vjp)
  135. setup_docs(jvp)
  136. setup_docs(jacrev)
  137. setup_docs(jacfwd)
  138. setup_docs(hessian)
  139. setup_docs(functionalize)
  140. setup_docs(make_functional, _nn_impl.make_functional, "torch.func.functional_call")
  141. setup_docs(
  142. make_functional_with_buffers, _nn_impl.make_functional, "torch.func.functional_call"
  143. )
  144. setup_docs(
  145. combine_state_for_ensemble,
  146. _nn_impl.combine_state_for_ensemble,
  147. "torch.func.stack_module_state",
  148. )