grad_scaler.py 958 B

1234567891011121314151617181920212223242526272829303132333435
  1. from typing_extensions import deprecated
  2. import torch
  3. __all__ = ["GradScaler"]
  4. class GradScaler(torch.amp.GradScaler):
  5. r"""
  6. See :class:`torch.amp.GradScaler`.
  7. ``torch.cpu.amp.GradScaler(args...)`` is deprecated. Please use ``torch.amp.GradScaler("cpu", args...)`` instead.
  8. """
  9. @deprecated(
  10. "`torch.cpu.amp.GradScaler(args...)` is deprecated. "
  11. "Please use `torch.amp.GradScaler('cpu', args...)` instead.",
  12. category=FutureWarning,
  13. )
  14. def __init__(
  15. self,
  16. init_scale: float = 2.0**16,
  17. growth_factor: float = 2.0,
  18. backoff_factor: float = 0.5,
  19. growth_interval: int = 2000,
  20. enabled: bool = True,
  21. ) -> None:
  22. super().__init__(
  23. "cpu",
  24. init_scale=init_scale,
  25. growth_factor=growth_factor,
  26. backoff_factor=backoff_factor,
  27. growth_interval=growth_interval,
  28. enabled=enabled,
  29. )