cuda.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. """ CUDA / AMP utils
  2. Hacked together by / Copyright 2020 Ross Wightman
  3. """
  4. import torch
  5. try:
  6. from apex import amp
  7. has_apex = True
  8. except ImportError:
  9. amp = None
  10. has_apex = False
  11. from .clip_grad import dispatch_clip_grad
  12. class ApexScaler:
  13. state_dict_key = "amp"
  14. def __call__(
  15. self,
  16. loss,
  17. optimizer,
  18. clip_grad=None,
  19. clip_mode='norm',
  20. parameters=None,
  21. create_graph=False,
  22. need_update=True,
  23. ):
  24. with amp.scale_loss(loss, optimizer) as scaled_loss:
  25. scaled_loss.backward(create_graph=create_graph)
  26. if need_update:
  27. if clip_grad is not None:
  28. dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode)
  29. optimizer.step()
  30. def state_dict(self):
  31. if 'state_dict' in amp.__dict__:
  32. return amp.state_dict()
  33. def load_state_dict(self, state_dict):
  34. if 'load_state_dict' in amp.__dict__:
  35. amp.load_state_dict(state_dict)
  36. class NativeScaler:
  37. state_dict_key = "amp_scaler"
  38. def __init__(self, device='cuda'):
  39. try:
  40. self._scaler = torch.amp.GradScaler(device=device)
  41. except (AttributeError, TypeError) as e:
  42. self._scaler = torch.cuda.amp.GradScaler()
  43. def __call__(
  44. self,
  45. loss,
  46. optimizer,
  47. clip_grad=None,
  48. clip_mode='norm',
  49. parameters=None,
  50. create_graph=False,
  51. need_update=True,
  52. ):
  53. self._scaler.scale(loss).backward(create_graph=create_graph)
  54. if need_update:
  55. if clip_grad is not None:
  56. assert parameters is not None
  57. self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
  58. dispatch_clip_grad(parameters, clip_grad, mode=clip_mode)
  59. self._scaler.step(optimizer)
  60. self._scaler.update()
  61. def state_dict(self):
  62. return self._scaler.state_dict()
  63. def load_state_dict(self, state_dict):
  64. self._scaler.load_state_dict(state_dict)