| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778 |
- """ CUDA / AMP utils
- Hacked together by / Copyright 2020 Ross Wightman
- """
- import torch
- try:
- from apex import amp
- has_apex = True
- except ImportError:
- amp = None
- has_apex = False
- from .clip_grad import dispatch_clip_grad
- class ApexScaler:
- state_dict_key = "amp"
- def __call__(
- self,
- loss,
- optimizer,
- clip_grad=None,
- clip_mode='norm',
- parameters=None,
- create_graph=False,
- need_update=True,
- ):
- with amp.scale_loss(loss, optimizer) as scaled_loss:
- scaled_loss.backward(create_graph=create_graph)
- if need_update:
- if clip_grad is not None:
- dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode)
- optimizer.step()
- def state_dict(self):
- if 'state_dict' in amp.__dict__:
- return amp.state_dict()
- def load_state_dict(self, state_dict):
- if 'load_state_dict' in amp.__dict__:
- amp.load_state_dict(state_dict)
- class NativeScaler:
- state_dict_key = "amp_scaler"
- def __init__(self, device='cuda'):
- try:
- self._scaler = torch.amp.GradScaler(device=device)
- except (AttributeError, TypeError) as e:
- self._scaler = torch.cuda.amp.GradScaler()
- def __call__(
- self,
- loss,
- optimizer,
- clip_grad=None,
- clip_mode='norm',
- parameters=None,
- create_graph=False,
- need_update=True,
- ):
- self._scaler.scale(loss).backward(create_graph=create_graph)
- if need_update:
- if clip_grad is not None:
- assert parameters is not None
- self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
- dispatch_clip_grad(parameters, clip_grad, mode=clip_mode)
- self._scaler.step(optimizer)
- self._scaler.update()
- def state_dict(self):
- return self._scaler.state_dict()
- def load_state_dict(self, state_dict):
- self._scaler.load_state_dict(state_dict)
|