grad_scaler.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714
  1. # mypy: allow-untyped-defs
  2. from __future__ import annotations
  3. import inspect
  4. import warnings
  5. from collections import abc, defaultdict
  6. from enum import Enum
  7. from typing import Any, cast, Optional, overload, TYPE_CHECKING, Union
  8. import torch
  9. if TYPE_CHECKING:
  10. from collections.abc import Iterable
  11. __all__ = ["OptState", "GradScaler"]
  12. class _MultiDeviceReplicator:
  13. """Lazily serves copies of a tensor to requested devices.
  14. Copies are cached per-device.
  15. """
  16. def __init__(self, master_tensor: torch.Tensor) -> None:
  17. self.master = master_tensor
  18. self._per_device_tensors: dict[torch.device, torch.Tensor] = {}
  19. def get(self, device: torch.device) -> torch.Tensor:
  20. retval = self._per_device_tensors.get(device, None)
  21. if retval is None:
  22. retval = self.master.to(device=device, non_blocking=True, copy=True)
  23. self._per_device_tensors[device] = retval
  24. return retval
  25. # Defines default_factory for GradScaler's _per_optimizer_states defaultdict,
  26. # as well as associated "enum" values. Prefers defining these at top level because
  27. # - Lambdas can't be pickled, so we don't want to supply a lambda as the factory.
  28. # - Defining READY, UNSCALED, STEPPED and _refresh_per_optimizer_state within GradScaler
  29. # causes a circular reference, which we'd rather avoid.
  30. class OptState(Enum):
  31. READY = 0
  32. UNSCALED = 1
  33. STEPPED = 2
  34. def _refresh_per_optimizer_state() -> dict[str, Any]:
  35. return {"stage": OptState.READY, "found_inf_per_device": {}}
  36. class GradScaler:
  37. """An instance ``scaler`` of :class:`GradScaler`.
  38. Helps perform the steps of gradient scaling
  39. conveniently.
  40. * ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor.
  41. * ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``.
  42. * ``scaler.update()`` updates ``scaler``'s scale factor.
  43. Example::
  44. # Creates a GradScaler once at the beginning of training.
  45. scaler = GradScaler()
  46. for epoch in epochs:
  47. for input, target in data:
  48. optimizer.zero_grad()
  49. output = model(input)
  50. loss = loss_fn(output, target)
  51. # Scales loss. Calls backward() on scaled loss to create scaled gradients.
  52. scaler.scale(loss).backward()
  53. # scaler.step() first unscales gradients of the optimizer's params.
  54. # If gradients don't contain infs/NaNs, optimizer.step() is then called,
  55. # otherwise, optimizer.step() is skipped.
  56. scaler.step(optimizer)
  57. # Updates the scale for next iteration.
  58. scaler.update()
  59. See the :ref:`Automatic Mixed Precision examples<amp-examples>` for usage
  60. (along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty,
  61. and multiple losses/optimizers.
  62. ``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow,
  63. a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if
  64. the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used
  65. without incurring inf or NaN gradient values.
  66. ``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every
  67. ``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`).
  68. * If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params
  69. themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``.
  70. * If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual.
  71. If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by
  72. ``growth_factor``.
  73. The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its
  74. value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these
  75. iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations).
  76. Args:
  77. device (str, optional, default="cuda"): Device type to use. Possible values are: 'cuda' and 'cpu'.
  78. The type is the same as the `type` attribute of a :class:`torch.device`.
  79. Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
  80. init_scale (float, optional, default=2.**16): Initial scale factor.
  81. growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during
  82. :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations.
  83. backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during
  84. :meth:`update` if inf/NaN gradients occur in an iteration.
  85. growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients
  86. that must occur for the scale to be multiplied by ``growth_factor``.
  87. enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply
  88. invokes the underlying ``optimizer.step()``, and other methods become no-ops.
  89. Default: ``True``
  90. """
  91. def __init__(
  92. self,
  93. device: str = "cuda",
  94. init_scale: float = 2.0**16,
  95. growth_factor: float = 2.0,
  96. backoff_factor: float = 0.5,
  97. growth_interval: int = 2000,
  98. enabled: bool = True,
  99. ) -> None:
  100. self._device = device
  101. self._enabled = enabled
  102. if self._device == "cuda":
  103. if enabled and torch.cuda.amp.common.amp_definitely_not_available():
  104. warnings.warn(
  105. "torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.",
  106. stacklevel=2,
  107. )
  108. self._enabled = False
  109. if self._enabled:
  110. if growth_factor <= 1.0:
  111. raise AssertionError("The growth factor must be > 1.0.")
  112. if backoff_factor >= 1.0:
  113. raise AssertionError("The backoff factor must be < 1.0.")
  114. self._init_scale = init_scale
  115. # self._scale will be lazily initialized during the first call to scale()
  116. self._scale: Optional[torch.Tensor] = None
  117. self._growth_factor = growth_factor
  118. self._backoff_factor = backoff_factor
  119. self._growth_interval = growth_interval
  120. self._init_growth_tracker = 0
  121. # self._growth_tracker will be lazily initialized during the first call to scale()
  122. self._growth_tracker: Optional[torch.Tensor] = None
  123. self._per_optimizer_states: dict[int, dict[str, Any]] = defaultdict(
  124. _refresh_per_optimizer_state
  125. )
  126. def _check_scale_growth_tracker(
  127. self, funcname: str
  128. ) -> tuple[torch.Tensor, torch.Tensor]:
  129. fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."
  130. if self._scale is None:
  131. raise AssertionError(f"Attempted {funcname} but _scale is None. " + fix)
  132. if self._growth_tracker is None:
  133. raise AssertionError(
  134. f"Attempted {funcname} but _growth_tracker is None. " + fix
  135. )
  136. return (self._scale, self._growth_tracker)
  137. def _lazy_init_scale_growth_tracker(self, dev: torch.device) -> None:
  138. if self._growth_tracker is not None:
  139. raise AssertionError("_growth_tracker initialized before _scale")
  140. self._scale = torch.full((), self._init_scale, dtype=torch.float32, device=dev)
  141. self._growth_tracker = torch.full(
  142. (), self._init_growth_tracker, dtype=torch.int32, device=dev
  143. )
  144. @overload
  145. def scale(self, outputs: torch.Tensor) -> torch.Tensor: ...
  146. @overload
  147. def scale(self, outputs: list[torch.Tensor]) -> list[torch.Tensor]: ...
  148. @overload
  149. def scale(self, outputs: tuple[torch.Tensor, ...]) -> tuple[torch.Tensor, ...]: ...
  150. @overload
  151. def scale(self, outputs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]: ...
  152. def scale(
  153. self,
  154. outputs: Union[torch.Tensor, Iterable[torch.Tensor]],
  155. ) -> Union[torch.Tensor, Iterable[torch.Tensor]]:
  156. """
  157. Multiplies ('scales') a tensor or list of tensors by the scale factor.
  158. Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned
  159. unmodified.
  160. Args:
  161. outputs (Tensor or iterable of Tensors): Outputs to scale.
  162. """
  163. if not self._enabled:
  164. return outputs
  165. # Short-circuit for the common case.
  166. if isinstance(outputs, torch.Tensor):
  167. if self._scale is None:
  168. self._lazy_init_scale_growth_tracker(outputs.device)
  169. # _scale is now guaranteed to be set
  170. scale = self._scale
  171. if scale is None:
  172. raise AssertionError("_scale should not be None after lazy init")
  173. return outputs * scale.to(device=outputs.device, non_blocking=True)
  174. # Invoke the more complex machinery only if we're treating multiple outputs.
  175. stash: list[
  176. _MultiDeviceReplicator
  177. ] = [] # holds a reference that can be overwritten by apply_scale
  178. def apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]):
  179. if isinstance(val, torch.Tensor):
  180. if len(stash) == 0:
  181. if self._scale is None:
  182. self._lazy_init_scale_growth_tracker(val.device)
  183. # _scale is now guaranteed to be set
  184. scale = self._scale
  185. if scale is None:
  186. raise AssertionError(
  187. "_scale should not be None after lazy init"
  188. )
  189. stash.append(_MultiDeviceReplicator(scale))
  190. return val * stash[0].get(val.device)
  191. if isinstance(val, abc.Iterable):
  192. iterable = map(apply_scale, val)
  193. if isinstance(val, (list, tuple)):
  194. return type(val)(iterable)
  195. return iterable
  196. raise ValueError("outputs must be a Tensor or an iterable of Tensors")
  197. return apply_scale(outputs)
  198. def _unscale_grads_(
  199. self,
  200. optimizer: torch.optim.Optimizer,
  201. inv_scale: torch.Tensor,
  202. found_inf: torch.Tensor,
  203. allow_fp16: bool,
  204. ) -> dict[torch.device, torch.Tensor]:
  205. per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
  206. per_device_found_inf = _MultiDeviceReplicator(found_inf)
  207. # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
  208. # There could be hundreds of grads, so we'd like to iterate through them just once.
  209. # However, we don't know their devices or dtypes in advance.
  210. # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
  211. # Google says mypy struggles with defaultdicts type annotations.
  212. per_device_and_dtype_grads: dict[
  213. torch.device, dict[torch.dtype, list[torch.Tensor]]
  214. ] = defaultdict(lambda: defaultdict(list))
  215. with torch.no_grad():
  216. for group in optimizer.param_groups:
  217. for param in group["params"]:
  218. if not isinstance(param, torch.Tensor):
  219. raise AssertionError(
  220. f"expected param to be torch.Tensor, got {type(param).__name__}"
  221. )
  222. if param.grad is None:
  223. continue
  224. if (not allow_fp16) and param.grad.dtype == torch.float16:
  225. raise ValueError("Attempting to unscale FP16 gradients.")
  226. if param.grad.is_sparse:
  227. # is_coalesced() == False means the sparse grad has values with duplicate indices.
  228. # coalesce() deduplicates indices and adds all values that have the same index.
  229. # For scaled fp16 values, there's a good chance coalescing will cause overflow,
  230. # so we should check the coalesced _values().
  231. if param.grad.dtype is torch.float16:
  232. param.grad = param.grad.coalesce()
  233. to_unscale = param.grad._values()
  234. else:
  235. to_unscale = param.grad
  236. # TODO: is there a way to split by device and dtype without appending in the inner loop?
  237. per_device_and_dtype_grads[to_unscale.device][
  238. to_unscale.dtype
  239. ].append(to_unscale)
  240. for device, per_dtype_grads in per_device_and_dtype_grads.items():
  241. for grads in per_dtype_grads.values():
  242. torch._amp_foreach_non_finite_check_and_unscale_(
  243. grads,
  244. per_device_found_inf.get(device),
  245. per_device_inv_scale.get(device),
  246. )
  247. return per_device_found_inf._per_device_tensors
  248. def unscale_(self, optimizer: torch.optim.Optimizer) -> None:
  249. """
  250. Divides ("unscales") the optimizer's gradient tensors by the scale factor.
  251. :meth:`unscale_` is optional, serving cases where you need to
  252. :ref:`modify or inspect gradients<working-with-unscaled-gradients>`
  253. between the backward pass(es) and :meth:`step`.
  254. If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
  255. Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
  256. ...
  257. scaler.scale(loss).backward()
  258. scaler.unscale_(optimizer)
  259. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
  260. scaler.step(optimizer)
  261. scaler.update()
  262. Args:
  263. optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
  264. .. note::
  265. :meth:`unscale_` does not incur a CPU-GPU sync.
  266. .. warning::
  267. :meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
  268. and only after all gradients for that optimizer's assigned parameters have been accumulated.
  269. Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
  270. .. warning::
  271. :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
  272. """
  273. if not self._enabled:
  274. return
  275. self._check_scale_growth_tracker("unscale_")
  276. optimizer_state = self._per_optimizer_states[id(optimizer)]
  277. if optimizer_state["stage"] is OptState.UNSCALED:
  278. raise RuntimeError(
  279. "unscale_() has already been called on this optimizer since the last update()."
  280. )
  281. elif optimizer_state["stage"] is OptState.STEPPED:
  282. raise RuntimeError("unscale_() is being called after step().")
  283. # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
  284. if self._scale is None:
  285. raise AssertionError("_scale is None in unscale_")
  286. inv_scale = (
  287. self._scale.double().reciprocal().float()
  288. if self._scale.device != torch.device("mps:0")
  289. else self._scale.reciprocal()
  290. )
  291. found_inf = torch.full((), 0.0, dtype=torch.float32, device=self._scale.device)
  292. optimizer_state["found_inf_per_device"] = self._unscale_grads_(
  293. optimizer, inv_scale, found_inf, False
  294. )
  295. optimizer_state["stage"] = OptState.UNSCALED
  296. def _maybe_opt_step(
  297. self,
  298. optimizer: torch.optim.Optimizer,
  299. optimizer_state: dict[str, Any],
  300. *args: Any,
  301. **kwargs: Any,
  302. ) -> Optional[float]:
  303. retval: Optional[float] = None
  304. if not sum(v.item() for v in optimizer_state["found_inf_per_device"].values()):
  305. retval = optimizer.step(*args, **kwargs)
  306. return retval
  307. def step(
  308. self, optimizer: torch.optim.Optimizer, *args: Any, **kwargs: Any
  309. ) -> Optional[float]:
  310. """Invoke ``unscale_(optimizer)`` followed by parameter update, if gradients are not infs/NaN.
  311. :meth:`step` carries out the following two operations:
  312. 1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer``
  313. earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs.
  314. 2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled
  315. gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params.
  316. ``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``.
  317. Returns the return value of ``optimizer.step(*args, **kwargs)``.
  318. Args:
  319. optimizer (torch.optim.Optimizer): Optimizer that applies the gradients.
  320. args: Any arguments.
  321. kwargs: Any keyword arguments.
  322. .. warning::
  323. Closure use is not currently supported.
  324. """
  325. if not self._enabled:
  326. return optimizer.step(*args, **kwargs)
  327. if "closure" in kwargs:
  328. raise RuntimeError(
  329. "Closure use is not currently supported if GradScaler is enabled."
  330. )
  331. self._check_scale_growth_tracker("step")
  332. optimizer_state = self._per_optimizer_states[id(optimizer)]
  333. if optimizer_state["stage"] is OptState.STEPPED:
  334. raise RuntimeError(
  335. "step() has already been called since the last update()."
  336. )
  337. retval: Optional[float] = None
  338. if getattr(optimizer, "_step_supports_amp_scaling", False):
  339. # This optimizer has customized scale-handling logic, so we can call optimizer.step() directly.
  340. # The contract with custom optimizers is that their step() should accept an additional,
  341. # optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information:
  342. # it can query its own state, invoke unscale_ on itself, etc
  343. # The contract above is being deprecated to avoid introducing `grad_scaler: GradScaler` argument
  344. # to `Optimizer.step`. The new behavior is going to add two Tensor attributes of `grad_scale`
  345. # and `found_inf` to the passed optimizer so that the optimizer can utilize those
  346. # to skip the parameter updates or unscale gradients before updating parameters in
  347. # the fused kernel, e.g. `FusedAdamMathFunctor`.
  348. # In this behavior, `GradScaler._check_inf_per_device` is called if `OptState.READY`,
  349. # while the method is expected to be called by users side, i.e. their optimizers.
  350. kwargs_ = kwargs
  351. has_grad_scaler_kwarg = (
  352. "grad_scaler" in inspect.signature(optimizer.step).parameters
  353. )
  354. if has_grad_scaler_kwarg:
  355. warnings.warn(
  356. "GradScaler is going to stop passing itself as a keyword argument to the passed "
  357. "optimizer. In the near future GradScaler registers `grad_scale: Tensor` and "
  358. "`found_inf: Tensor` to the passed optimizer and let the optimizer use them directly.",
  359. FutureWarning,
  360. stacklevel=2,
  361. )
  362. kwargs_.update({"grad_scaler": self})
  363. else:
  364. if optimizer_state["stage"] is OptState.READY:
  365. self._check_inf_per_device(optimizer)
  366. scaler = self._get_scale_async()
  367. if scaler is None:
  368. raise AssertionError("_get_scale_async returned None")
  369. found_inf = cast(
  370. torch.Tensor,
  371. sum(
  372. [ # noqa: C419
  373. t.to(scaler.device, non_blocking=True)
  374. for t in optimizer_state["found_inf_per_device"].values()
  375. ]
  376. ),
  377. )
  378. # Take the product of the scales, if the user has already set `optimizer.grad_scale`.
  379. optimizer.grad_scale = ( # type: ignore[attr-defined]
  380. getattr(optimizer, "grad_scale", None)
  381. if optimizer_state["stage"] == OptState.UNSCALED
  382. else scaler * getattr(optimizer, "grad_scale", 1)
  383. )
  384. optimizer.found_inf = found_inf # type: ignore[attr-defined]
  385. retval = optimizer.step(*args, **kwargs_)
  386. optimizer_state["stage"] = OptState.STEPPED
  387. if not has_grad_scaler_kwarg:
  388. del optimizer.grad_scale # type: ignore[attr-defined]
  389. del optimizer.found_inf # type: ignore[attr-defined]
  390. return retval
  391. if optimizer_state["stage"] is OptState.READY:
  392. self.unscale_(optimizer)
  393. if len(optimizer_state["found_inf_per_device"]) == 0:
  394. raise AssertionError("No inf checks were recorded for this optimizer.")
  395. retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs)
  396. optimizer_state["stage"] = OptState.STEPPED
  397. return retval
  398. def update(self, new_scale: Optional[Union[float, torch.Tensor]] = None) -> None:
  399. """Update the scale factor.
  400. If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
  401. to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
  402. the scale is multiplied by ``growth_factor`` to increase it.
  403. Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
  404. used directly, it's used to fill GradScaler's internal scale tensor. So if
  405. ``new_scale`` was a tensor, later in-place changes to that tensor will not further
  406. affect the scale GradScaler uses internally.)
  407. Args:
  408. new_scale (float or :class:`torch.Tensor`, optional, default=None): New scale factor.
  409. .. warning::
  410. :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
  411. been invoked for all optimizers used this iteration.
  412. .. warning::
  413. For performance reasons, we do not check the scale factor value to avoid synchronizations,
  414. so the scale factor is not guaranteed to be above 1. If the scale falls below 1 and/or
  415. you are seeing NaNs in your gradients or loss, something is likely wrong. For example,
  416. bf16-pretrained models are often incompatible with AMP/fp16 due to differing dynamic ranges.
  417. """
  418. if not self._enabled:
  419. return
  420. _scale, _growth_tracker = self._check_scale_growth_tracker("update")
  421. if new_scale is not None:
  422. if self._scale is None:
  423. raise AssertionError("_scale is None in update")
  424. # Accept a new user-defined scale.
  425. if isinstance(new_scale, float):
  426. self._scale.fill_(new_scale)
  427. else:
  428. reason = (
  429. "new_scale should be a float or a 1-element torch.cuda.FloatTensor or "
  430. "torch.FloatTensor with requires_grad=False."
  431. )
  432. if new_scale.device.type != self._device:
  433. raise AssertionError(reason)
  434. if new_scale.numel() != 1:
  435. raise AssertionError(reason)
  436. if new_scale.requires_grad is True:
  437. raise AssertionError(reason)
  438. self._scale.copy_(new_scale)
  439. else:
  440. # Consume shared inf/nan data collected from optimizers to update the scale.
  441. # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
  442. found_infs = [
  443. found_inf.to(device=_scale.device, non_blocking=True)
  444. for state in self._per_optimizer_states.values()
  445. for found_inf in state["found_inf_per_device"].values()
  446. ]
  447. if len(found_infs) == 0:
  448. raise AssertionError("No inf checks were recorded prior to update.")
  449. found_inf_combined = found_infs[0]
  450. if len(found_infs) > 1:
  451. for i in range(1, len(found_infs)):
  452. found_inf_combined += found_infs[i]
  453. torch._amp_update_scale_(
  454. _scale,
  455. _growth_tracker,
  456. found_inf_combined,
  457. self._growth_factor,
  458. self._backoff_factor,
  459. self._growth_interval,
  460. )
  461. # To prepare for next iteration, clear the data collected from optimizers this iteration.
  462. self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
  463. def _get_scale_async(self) -> Optional[torch.Tensor]:
  464. return self._scale
  465. def get_scale(self) -> float:
  466. """Return a Python float containing the current scale, or 1.0 if scaling is disabled.
  467. .. warning::
  468. :meth:`get_scale` incurs a CPU-GPU sync.
  469. """
  470. if self._enabled:
  471. return (
  472. self._init_scale
  473. if (scale := self._get_scale_async()) is None
  474. else cast(float, scale.item())
  475. )
  476. return 1.0
  477. def get_growth_factor(self) -> float:
  478. r"""Return a Python float containing the scale growth factor."""
  479. return self._growth_factor
  480. def set_growth_factor(self, new_factor: float) -> None:
  481. r"""Set a new scale growth factor.
  482. Args:
  483. new_scale (float): Value to use as the new scale growth factor.
  484. """
  485. self._growth_factor = new_factor
  486. def get_backoff_factor(self) -> float:
  487. r"""Return a Python float containing the scale backoff factor."""
  488. return self._backoff_factor
  489. def set_backoff_factor(self, new_factor: float) -> None:
  490. r"""Set a new scale backoff factor.
  491. Args:
  492. new_scale (float): Value to use as the new scale backoff factor.
  493. """
  494. self._backoff_factor = new_factor
  495. def get_growth_interval(self) -> int:
  496. r"""Return a Python int containing the growth interval."""
  497. return self._growth_interval
  498. def set_growth_interval(self, new_interval: int) -> None:
  499. r"""Set a new growth interval.
  500. Args:
  501. new_interval (int): Value to use as the new growth interval.
  502. """
  503. self._growth_interval = new_interval
  504. def _get_growth_tracker(self) -> int:
  505. if self._enabled:
  506. return (
  507. self._init_growth_tracker
  508. if self._growth_tracker is None
  509. else cast(int, self._growth_tracker.item())
  510. )
  511. return 0
  512. def is_enabled(self) -> bool:
  513. r"""Return a bool indicating whether this instance is enabled."""
  514. return self._enabled
  515. def state_dict(self) -> dict[str, Any]:
  516. r"""Return the state of the scaler as a :class:`dict`.
  517. It contains five entries:
  518. * ``"scale"`` - a Python float containing the current scale
  519. * ``"growth_factor"`` - a Python float containing the current growth factor
  520. * ``"backoff_factor"`` - a Python float containing the current backoff factor
  521. * ``"growth_interval"`` - a Python int containing the current growth interval
  522. * ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps.
  523. If this instance is not enabled, returns an empty dict.
  524. .. note::
  525. If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
  526. should be called after :meth:`update`.
  527. """
  528. if self._enabled:
  529. return {
  530. "scale": self.get_scale(),
  531. "growth_factor": self._growth_factor,
  532. "backoff_factor": self._backoff_factor,
  533. "growth_interval": self._growth_interval,
  534. "_growth_tracker": self._get_growth_tracker(),
  535. }
  536. return {}
  537. def load_state_dict(self, state_dict: dict[str, Any]) -> None:
  538. r"""Load the scaler state.
  539. If this instance is disabled, :meth:`load_state_dict` is a no-op.
  540. Args:
  541. state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`.
  542. """
  543. if not self._enabled:
  544. return
  545. if len(state_dict) == 0:
  546. raise RuntimeError(
  547. "The source state dict is empty, possibly because it was saved "
  548. "from a disabled instance of GradScaler."
  549. )
  550. self._init_scale = cast(float, state_dict["scale"])
  551. if self._scale is not None:
  552. self._scale.fill_(state_dict["scale"])
  553. self._growth_factor = cast(float, state_dict["growth_factor"])
  554. self._backoff_factor = cast(float, state_dict["backoff_factor"])
  555. self._growth_interval = cast(int, state_dict["growth_interval"])
  556. self._init_growth_tracker = cast(int, state_dict["_growth_tracker"])
  557. if self._growth_tracker is not None:
  558. self._growth_tracker.fill_(state_dict["_growth_tracker"])
  559. def __getstate__(self) -> dict[str, Any]:
  560. state = self.__dict__.copy()
  561. if self._enabled:
  562. if len(self._per_optimizer_states) != 0:
  563. raise AssertionError(
  564. "A GradScaler instance may only be pickled at the beginning "
  565. "of an iteration, or at the end after scaler.update()."
  566. )
  567. # Pickling _scale and _growth_tracker Tensors directly triggers
  568. # "warnings.warn("pickle support for Storage will be removed in 1.5..."
  569. # so instead, we set the unpickled instance up to reinitialize them lazily.
  570. state["_init_scale"] = self.get_scale()
  571. state["_init_growth_tracker"] = self._get_growth_tracker()
  572. state["_scale"] = None
  573. state["_growth_tracker"] = None
  574. return state
  575. def __setstate__(self, state: dict[str, Any]) -> None:
  576. self.__dict__.update(state)
  577. def _check_inf_per_device(self, optimizer: torch.optim.Optimizer) -> dict[str, Any]:
  578. _scale, _ = self._check_scale_growth_tracker("_check_inf_per_device")
  579. dummy_inv_scale = torch.full((), 1.0, dtype=torch.float32, device=_scale.device)
  580. found_inf = torch.full((), 0.0, dtype=torch.float32, device=_scale.device)
  581. self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = (
  582. self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True)
  583. )
  584. return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
  585. def _found_inf_per_device(self, optimizer: torch.optim.Optimizer) -> dict[str, Any]:
  586. return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]