autocast_mode.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529
  1. # mypy: allow-untyped-defs
  2. import collections
  3. import functools
  4. import warnings
  5. from typing import Any, Optional
  6. import torch
  7. from torch.types import _dtype
  8. try:
  9. import numpy as np
  10. HAS_NUMPY = True
  11. except ModuleNotFoundError:
  12. HAS_NUMPY = False
  13. np = None # type: ignore[assignment]
  14. __all__ = [
  15. "autocast_decorator",
  16. "autocast",
  17. "is_autocast_available",
  18. "custom_fwd",
  19. "custom_bwd",
  20. ]
  21. def is_autocast_available(device_type: str) -> bool:
  22. r"""
  23. Return a bool indicating if autocast is available on :attr:`device_type`.
  24. Args:
  25. device_type(str): Device type to use. Possible values are: 'cuda', 'cpu', 'mtia', 'maia', 'xpu', and so on.
  26. The type is the same as the `type` attribute of a :class:`torch.device`.
  27. Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
  28. """
  29. return torch._C._is_autocast_available(device_type)
  30. def autocast_decorator(autocast_instance, func):
  31. @functools.wraps(func)
  32. def decorate_autocast(*args, **kwargs):
  33. with autocast_instance:
  34. return func(*args, **kwargs)
  35. decorate_autocast.__script_unsupported = ( # type: ignore[attr-defined]
  36. "@autocast() decorator is not supported in script mode"
  37. )
  38. return decorate_autocast
  39. class autocast:
  40. r"""
  41. Instances of :class:`autocast` serve as context managers or decorators that
  42. allow regions of your script to run in mixed precision.
  43. In these regions, ops run in an op-specific dtype chosen by autocast
  44. to improve performance while maintaining accuracy.
  45. See the :ref:`Autocast Op Reference<autocast-op-reference>` for details.
  46. When entering an autocast-enabled region, Tensors may be any type.
  47. You should not call ``half()`` or ``bfloat16()`` on your model(s) or inputs when using autocasting.
  48. :class:`autocast` should wrap only the forward pass(es) of your network, including the loss
  49. computation(s). Backward passes under autocast are not recommended.
  50. Backward ops run in the same type that autocast used for corresponding forward ops.
  51. Example for CUDA Devices::
  52. # Creates model and optimizer in default precision
  53. model = Net().cuda()
  54. optimizer = optim.SGD(model.parameters(), ...)
  55. for input, target in data:
  56. optimizer.zero_grad()
  57. # Enables autocasting for the forward pass (model + loss)
  58. with torch.autocast(device_type="cuda"):
  59. output = model(input)
  60. loss = loss_fn(output, target)
  61. # Exits the context manager before backward()
  62. loss.backward()
  63. optimizer.step()
  64. See the :ref:`Automatic Mixed Precision examples<amp-examples>` for usage (along with gradient scaling)
  65. in more complex scenarios (e.g., gradient penalty, multiple models/losses, custom autograd functions).
  66. :class:`autocast` can also be used as a decorator, e.g., on the ``forward`` method of your model::
  67. class AutocastModel(nn.Module):
  68. ...
  69. @torch.autocast(device_type="cuda")
  70. def forward(self, input): ...
  71. Floating-point Tensors produced in an autocast-enabled region may be ``float16``.
  72. After returning to an autocast-disabled region, using them with floating-point
  73. Tensors of different dtypes may cause type mismatch errors. If so, cast the Tensor(s)
  74. produced in the autocast region back to ``float32`` (or other dtype if desired).
  75. If a Tensor from the autocast region is already ``float32``, the cast is a no-op,
  76. and incurs no additional overhead.
  77. CUDA Example::
  78. # Creates some tensors in default dtype (here assumed to be float32)
  79. a_float32 = torch.rand((8, 8), device="cuda")
  80. b_float32 = torch.rand((8, 8), device="cuda")
  81. c_float32 = torch.rand((8, 8), device="cuda")
  82. d_float32 = torch.rand((8, 8), device="cuda")
  83. with torch.autocast(device_type="cuda"):
  84. # torch.mm is on autocast's list of ops that should run in float16.
  85. # Inputs are float32, but the op runs in float16 and produces float16 output.
  86. # No manual casts are required.
  87. e_float16 = torch.mm(a_float32, b_float32)
  88. # Also handles mixed input types
  89. f_float16 = torch.mm(d_float32, e_float16)
  90. # After exiting autocast, calls f_float16.float() to use with d_float32
  91. g_float32 = torch.mm(d_float32, f_float16.float())
  92. CPU Training Example::
  93. # Creates model and optimizer in default precision
  94. model = Net()
  95. optimizer = optim.SGD(model.parameters(), ...)
  96. for epoch in epochs:
  97. for input, target in data:
  98. optimizer.zero_grad()
  99. # Runs the forward pass with autocasting.
  100. with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
  101. output = model(input)
  102. loss = loss_fn(output, target)
  103. loss.backward()
  104. optimizer.step()
  105. CPU Inference Example::
  106. # Creates model in default precision
  107. model = Net().eval()
  108. with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
  109. for input in data:
  110. # Runs the forward pass with autocasting.
  111. output = model(input)
  112. CPU Inference Example with Jit Trace::
  113. class TestModel(nn.Module):
  114. def __init__(self, input_size, num_classes):
  115. super().__init__()
  116. self.fc1 = nn.Linear(input_size, num_classes)
  117. def forward(self, x):
  118. return self.fc1(x)
  119. input_size = 2
  120. num_classes = 2
  121. model = TestModel(input_size, num_classes).eval()
  122. # For now, we suggest to disable the Jit Autocast Pass,
  123. # As the issue: https://github.com/pytorch/pytorch/issues/75956
  124. torch._C._jit_set_autocast_mode(False)
  125. with torch.cpu.amp.autocast(cache_enabled=False):
  126. model = torch.jit.trace(model, torch.randn(1, input_size))
  127. model = torch.jit.freeze(model)
  128. # Models Run
  129. for _ in range(3):
  130. model(torch.randn(1, input_size))
  131. Type mismatch errors *in* an autocast-enabled region are a bug; if this is what you observe,
  132. please file an issue.
  133. ``autocast(enabled=False)`` subregions can be nested in autocast-enabled regions.
  134. Locally disabling autocast can be useful, for example, if you want to force a subregion
  135. to run in a particular ``dtype``. Disabling autocast gives you explicit control over
  136. the execution type. In the subregion, inputs from the surrounding region
  137. should be cast to ``dtype`` before use::
  138. # Creates some tensors in default dtype (here assumed to be float32)
  139. a_float32 = torch.rand((8, 8), device="cuda")
  140. b_float32 = torch.rand((8, 8), device="cuda")
  141. c_float32 = torch.rand((8, 8), device="cuda")
  142. d_float32 = torch.rand((8, 8), device="cuda")
  143. with torch.autocast(device_type="cuda"):
  144. e_float16 = torch.mm(a_float32, b_float32)
  145. with torch.autocast(device_type="cuda", enabled=False):
  146. # Calls e_float16.float() to ensure float32 execution
  147. # (necessary because e_float16 was created in an autocasted region)
  148. f_float32 = torch.mm(c_float32, e_float16.float())
  149. # No manual casts are required when re-entering the autocast-enabled region.
  150. # torch.mm again runs in float16 and produces float16 output, regardless of input types.
  151. g_float16 = torch.mm(d_float32, f_float32)
  152. The autocast state is thread-local. If you want it enabled in a new thread, the context manager or decorator
  153. must be invoked in that thread. This affects :class:`torch.nn.DataParallel` and
  154. :class:`torch.nn.parallel.DistributedDataParallel` when used with more than one GPU per process
  155. (see :ref:`Working with Multiple GPUs<amp-multigpu>`).
  156. Args:
  157. device_type(str, required): Device type to use. Possible values are: 'cuda', 'cpu', 'mtia', 'maia', 'xpu', and 'hpu'.
  158. The type is the same as the `type` attribute of a :class:`torch.device`.
  159. Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
  160. enabled(bool, optional): Whether autocasting should be enabled in the region.
  161. Default: ``True``
  162. dtype(torch_dtype, optional): Data type for ops run in autocast. It uses the default value
  163. (``torch.float16`` for CUDA and ``torch.bfloat16`` for CPU), given by
  164. :func:`~torch.get_autocast_dtype`, if :attr:`dtype` is ``None``.
  165. Default: ``None``
  166. cache_enabled(bool, optional): Whether the weight cache inside autocast should be enabled.
  167. Default: ``True``
  168. """
  169. def __init__(
  170. self,
  171. device_type: str,
  172. dtype: Optional[_dtype] = None,
  173. enabled: bool = True,
  174. cache_enabled: Optional[bool] = None,
  175. ):
  176. if not isinstance(device_type, str):
  177. raise ValueError(
  178. f"Expected `device_type` of type `str`, got: `{type(device_type)}`"
  179. )
  180. self.fast_dtype = (
  181. torch.get_autocast_dtype(device_type) if dtype is None else dtype
  182. )
  183. if torch._jit_internal.is_scripting():
  184. self._enabled = enabled
  185. self.device = device_type
  186. if self.fast_dtype is None:
  187. raise AssertionError("fast_dtype must not be None in scripting mode")
  188. return
  189. self.device = device_type
  190. if not is_autocast_available(self.device):
  191. raise RuntimeError(
  192. f"User specified an unsupported autocast device_type '{self.device}'"
  193. )
  194. device_supported_dtypes = [torch.bfloat16, torch.float16]
  195. self.custom_backend_name = torch._C._get_privateuse1_backend_name()
  196. if self.device == self.custom_backend_name:
  197. necessary_funcs = [
  198. "get_amp_supported_dtype",
  199. ]
  200. message = f"Tried to use AMP with the `{self.custom_backend_name}` backend, but the backend has not "
  201. message += "registered a module or the module miss some necessary funcs. The backend should register "
  202. message += "a module by `torch._register_device_module`, and the module must have these funcs: \n"
  203. message += "`get_amp_supported_dtype() -> List[torch.dtype]`. \n"
  204. if not hasattr(torch, self.custom_backend_name):
  205. raise AssertionError(message)
  206. self.custom_device_mod = getattr(torch, self.custom_backend_name)
  207. for func in necessary_funcs:
  208. if not hasattr(self.custom_device_mod, func):
  209. raise AssertionError(
  210. message + f"But the func `{func}` is missing. \n"
  211. )
  212. device_supported_dtypes = self.custom_device_mod.get_amp_supported_dtype()
  213. self._cache_enabled = (
  214. torch.is_autocast_cache_enabled()
  215. if cache_enabled is None
  216. else cache_enabled
  217. )
  218. device_name = (
  219. self.device
  220. if self.device == self.custom_backend_name
  221. else self.device.upper()
  222. )
  223. if enabled:
  224. # Special case for CUDA AMP and bfloat16 support
  225. if self.device == "cuda":
  226. if torch.cuda.amp.common.amp_definitely_not_available():
  227. warnings.warn(
  228. "CUDA is not available or torch_xla is imported. Disabling autocast.",
  229. stacklevel=2,
  230. )
  231. enabled = False
  232. elif (
  233. self.fast_dtype == torch.bfloat16
  234. and not torch.cuda.is_bf16_supported()
  235. ):
  236. raise RuntimeError(
  237. "Current CUDA Device does not support bfloat16. Please switch dtype to float16."
  238. )
  239. elif self.fast_dtype not in device_supported_dtypes:
  240. error_message = (
  241. f"In {device_name} autocast, but the target dtype is not supported. Disabling autocast.\n"
  242. f"{device_name} Autocast only supports dtypes of "
  243. + ", ".join(map(str, device_supported_dtypes))
  244. + " currently."
  245. )
  246. warnings.warn(error_message, stacklevel=2)
  247. enabled = False
  248. # Special case for MPS bfloat16 support on macOS < 14
  249. if (
  250. self.device == "mps"
  251. and self.fast_dtype == torch.bfloat16
  252. and not torch.backends.mps.is_macos_or_newer(14, 0)
  253. ):
  254. error_message = (
  255. "In MPS autocast, but the target dtype torch.bfloat16 is not supported "
  256. "on macOS versions below 14. Disabling autocast."
  257. )
  258. warnings.warn(error_message, stacklevel=2)
  259. enabled = False
  260. self._enabled = enabled
  261. def __enter__(self):
  262. if torch._jit_internal.is_scripting():
  263. if self.fast_dtype is None:
  264. raise AssertionError("fast_dtype must not be None in scripting mode")
  265. return self
  266. self.prev_cache_enabled = torch.is_autocast_cache_enabled()
  267. self.prev = torch.is_autocast_enabled(self.device)
  268. self.prev_fastdtype = torch.get_autocast_dtype(self.device)
  269. torch.set_autocast_enabled(self.device, self._enabled)
  270. torch.set_autocast_dtype(self.device, self.fast_dtype) # type: ignore[arg-type]
  271. torch.autocast_increment_nesting()
  272. torch.set_autocast_cache_enabled(self._cache_enabled)
  273. # only dispatch to PreDispatchTorchFunctionMode to avoid exposing this
  274. # API to other functional modes. We only expose to PreDispatchTorchFunctionMode
  275. # for preserving autocast in torch.export.export.
  276. if torch._C._is_torch_function_mode_enabled():
  277. stacks = torch.overrides._get_current_function_mode_stack()
  278. for mode in stacks:
  279. if isinstance(
  280. mode,
  281. torch.fx.experimental.proxy_tensor.PreDispatchTorchFunctionMode,
  282. ):
  283. args = (
  284. self.device,
  285. self.fast_dtype,
  286. self._enabled,
  287. self._cache_enabled,
  288. )
  289. mode.__torch_function__(torch.amp._enter_autocast, (), args)
  290. return self
  291. return self
  292. def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override]
  293. if torch._jit_internal.is_scripting():
  294. return
  295. # Drop the cache when we exit to a nesting level that's outside any instance of autocast.
  296. if torch.autocast_decrement_nesting() == 0:
  297. torch.clear_autocast_cache()
  298. torch.set_autocast_enabled(self.device, self.prev)
  299. torch.set_autocast_dtype(self.device, self.prev_fastdtype)
  300. torch.set_autocast_cache_enabled(self.prev_cache_enabled)
  301. # only dispatch to PreDispatchTorchFunctionMode to avoid exposing this
  302. # API to other functional modes. We only expose to PreDispatchTorchFunctionMode
  303. # for preserving autocast in torch.export.export.
  304. if torch._C._is_torch_function_mode_enabled():
  305. stacks = torch.overrides._get_current_function_mode_stack()
  306. for mode in stacks:
  307. if isinstance(
  308. mode,
  309. torch.fx.experimental.proxy_tensor.PreDispatchTorchFunctionMode,
  310. ):
  311. mode.__torch_function__(torch.amp._exit_autocast, (), ())
  312. # This is very important because the above line actually doesn't
  313. # run exit code so it end up swallowing exceptions.
  314. return False
  315. return False
  316. def __call__(self, func):
  317. if torch._jit_internal.is_scripting():
  318. return func
  319. return autocast_decorator(self, func)
  320. # These functions aren't meant for public usage.
  321. # They are what we trace into a graph during pre_dispatch tracing
  322. # when we encounter an autocast context manager.
  323. def _enter_autocast(*vals):
  324. # For pre-dispatch tracing, if a TorchFunction mode is active, we'll want to trace this into a graph.
  325. if torch._C._is_torch_function_mode_enabled():
  326. return torch.overrides.handle_torch_function(
  327. torch.amp._enter_autocast, [], *vals
  328. )
  329. mode = torch.amp.autocast(*vals)
  330. mode.__enter__()
  331. return mode
  332. def _exit_autocast(mode):
  333. if torch._C._is_torch_function_mode_enabled():
  334. return torch.overrides.handle_torch_function(torch.amp._exit_autocast, [], mode)
  335. mode.__exit__(None, None, None)
  336. # Casts Tensors and containers of Tensors. Special-cases passthroughs for strings and np.ndarrays, which
  337. # may be falsely detected as "Iterables."
  338. def _cast(value, device_type: str, dtype: _dtype):
  339. if isinstance(value, torch.Tensor):
  340. is_eligible = (
  341. value.is_floating_point()
  342. and value.device.type == device_type
  343. and (value.dtype is not torch.float64)
  344. )
  345. return value.to(dtype) if is_eligible else value
  346. elif isinstance(value, (str, bytes)):
  347. return value
  348. elif HAS_NUMPY and isinstance(
  349. value,
  350. # pyrefly: ignore [missing-attribute]
  351. np.ndarray,
  352. ):
  353. return value
  354. elif isinstance(value, collections.abc.Mapping):
  355. return {
  356. _cast(k, device_type, dtype): _cast(v, device_type, dtype)
  357. for k, v in value.items()
  358. }
  359. elif isinstance(value, collections.abc.Iterable):
  360. iterable = (_cast(v, device_type, dtype) for v in value)
  361. if isinstance(value, (list, tuple)):
  362. return type(value)(iterable)
  363. else:
  364. return iterable
  365. else:
  366. return value
  367. def custom_fwd(
  368. fwd=None,
  369. *,
  370. device_type: str,
  371. cast_inputs: Optional[_dtype] = None,
  372. ):
  373. """
  374. Create a helper decorator for ``forward`` methods of custom autograd functions.
  375. Autograd functions are subclasses of :class:`torch.autograd.Function`.
  376. See the :ref:`example page<amp-custom-examples>` for more detail.
  377. Args:
  378. device_type(str): Device type to use. 'cuda', 'cpu', 'mtia', 'maia', 'xpu' and so on.
  379. The type is the same as the `type` attribute of a :class:`torch.device`.
  380. Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
  381. cast_inputs (:class:`torch.dtype` or None, optional, default=None): If not ``None``,
  382. when ``forward`` runs in an autocast-enabled region, casts incoming
  383. floating-point Tensors to the target dtype (non-floating-point Tensors are not affected),
  384. then executes ``forward`` with autocast disabled.
  385. If ``None``, ``forward``'s internal ops execute with the current autocast state.
  386. .. note::
  387. If the decorated ``forward`` is called outside an autocast-enabled region,
  388. :func:`custom_fwd<custom_fwd>` is a no-op and ``cast_inputs`` has no effect.
  389. """
  390. if not isinstance(device_type, str):
  391. raise ValueError(
  392. f"Expected `device_type` of type `str`, got: `{type(device_type)}`"
  393. )
  394. if fwd is None:
  395. return functools.partial(
  396. custom_fwd, device_type=device_type, cast_inputs=cast_inputs
  397. )
  398. @functools.wraps(fwd)
  399. def decorate_fwd(*args, **kwargs):
  400. args[0]._dtype = torch.get_autocast_dtype(device_type)
  401. if cast_inputs is None:
  402. args[0]._fwd_used_autocast = torch.is_autocast_enabled(device_type)
  403. return fwd(*args, **kwargs) # pyrefly: ignore [not-callable]
  404. else:
  405. autocast_context = torch.is_autocast_enabled(device_type)
  406. args[0]._fwd_used_autocast = False
  407. if autocast_context:
  408. with autocast(device_type=device_type, enabled=False):
  409. return fwd( # pyrefly: ignore # not-callable
  410. *_cast(args, device_type, cast_inputs),
  411. **_cast(kwargs, device_type, cast_inputs),
  412. )
  413. else:
  414. return fwd(*args, **kwargs) # pyrefly: ignore [not-callable]
  415. return decorate_fwd
  416. # Autograd ensures incoming gradients are the same type as forward outputs. Allowing a separate
  417. # cast_inputs argument on custom_bwd is unnecessary and could cause errors if it doesn't match
  418. # cast_inputs supplied to custom_fwd.
  419. def custom_bwd(bwd=None, *, device_type: str):
  420. """Create a helper decorator for backward methods of custom autograd functions.
  421. Autograd functions are subclasses of :class:`torch.autograd.Function`.
  422. Ensures that ``backward`` executes with the same autocast state as ``forward``.
  423. See the :ref:`example page<amp-custom-examples>` for more detail.
  424. Args:
  425. device_type(str): Device type to use. 'cuda', 'cpu', 'mtia', 'maia', 'xpu' and so on.
  426. The type is the same as the `type` attribute of a :class:`torch.device`.
  427. Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
  428. """
  429. if not isinstance(device_type, str):
  430. raise ValueError(
  431. f"Expected `device_type` of type `str`, got: `{type(device_type)}`"
  432. )
  433. if bwd is None:
  434. return functools.partial(custom_bwd, device_type=device_type)
  435. @functools.wraps(bwd)
  436. def decorate_bwd(*args, **kwargs):
  437. with autocast(
  438. device_type=device_type,
  439. enabled=args[0]._fwd_used_autocast,
  440. dtype=args[0]._dtype,
  441. ):
  442. return bwd(*args, **kwargs) # pyrefly: ignore [not-callable]
  443. return decorate_bwd