grad_scaler.py 1.0 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. from typing_extensions import deprecated
  2. import torch
  3. # We need to keep this unused import for BC reasons
  4. from torch.amp.grad_scaler import OptState # noqa: F401
  5. __all__ = ["GradScaler"]
  6. class GradScaler(torch.amp.GradScaler):
  7. r"""
  8. See :class:`torch.amp.GradScaler`.
  9. ``torch.cuda.amp.GradScaler(args...)`` is deprecated. Please use ``torch.amp.GradScaler("cuda", args...)`` instead.
  10. """
  11. @deprecated(
  12. "`torch.cuda.amp.GradScaler(args...)` is deprecated. "
  13. "Please use `torch.amp.GradScaler('cuda', args...)` instead.",
  14. category=FutureWarning,
  15. )
  16. def __init__(
  17. self,
  18. init_scale: float = 2.0**16,
  19. growth_factor: float = 2.0,
  20. backoff_factor: float = 0.5,
  21. growth_interval: int = 2000,
  22. enabled: bool = True,
  23. ) -> None:
  24. super().__init__(
  25. "cuda",
  26. init_scale=init_scale,
  27. growth_factor=growth_factor,
  28. backoff_factor=backoff_factor,
  29. growth_interval=growth_interval,
  30. enabled=enabled,
  31. )