lars.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. """ PyTorch LARS / LARC Optimizer
  2. An implementation of LARS (SGD) + LARC in PyTorch
  3. Based on:
  4. * PyTorch SGD: https://github.com/pytorch/pytorch/blob/1.7/torch/optim/sgd.py#L100
  5. * NVIDIA APEX LARC: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py
  6. Additional cleanup and modifications to properly support PyTorch XLA.
  7. Copyright 2021 Ross Wightman
  8. """
  9. import torch
  10. from torch.optim.optimizer import Optimizer
  11. class Lars(Optimizer):
  12. """ LARS for PyTorch
  13. Paper: `Large batch training of Convolutional Networks` - https://arxiv.org/pdf/1708.03888.pdf
  14. Args:
  15. params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
  16. lr (float, optional): learning rate (default: 1.0).
  17. momentum (float, optional): momentum factor (default: 0)
  18. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  19. dampening (float, optional): dampening for momentum (default: 0)
  20. nesterov (bool, optional): enables Nesterov momentum (default: False)
  21. trust_coeff (float): trust coefficient for computing adaptive lr / trust_ratio (default: 0.001)
  22. eps (float): eps for division denominator (default: 1e-8)
  23. trust_clip (bool): enable LARC trust ratio clipping (default: False)
  24. always_adapt (bool): always apply LARS LR adapt, otherwise only when group weight_decay != 0 (default: False)
  25. """
  26. def __init__(
  27. self,
  28. params,
  29. lr=1.0,
  30. momentum=0,
  31. dampening=0,
  32. weight_decay=0,
  33. nesterov=False,
  34. trust_coeff=0.001,
  35. eps=1e-8,
  36. trust_clip=False,
  37. always_adapt=False,
  38. ):
  39. if lr < 0.0:
  40. raise ValueError(f"Invalid learning rate: {lr}")
  41. if momentum < 0.0:
  42. raise ValueError(f"Invalid momentum value: {momentum}")
  43. if weight_decay < 0.0:
  44. raise ValueError(f"Invalid weight_decay value: {weight_decay}")
  45. if nesterov and (momentum <= 0 or dampening != 0):
  46. raise ValueError("Nesterov momentum requires a momentum and zero dampening")
  47. defaults = dict(
  48. lr=lr,
  49. momentum=momentum,
  50. dampening=dampening,
  51. weight_decay=weight_decay,
  52. nesterov=nesterov,
  53. trust_coeff=trust_coeff,
  54. eps=eps,
  55. trust_clip=trust_clip,
  56. always_adapt=always_adapt,
  57. )
  58. super().__init__(params, defaults)
  59. def __setstate__(self, state):
  60. super().__setstate__(state)
  61. for group in self.param_groups:
  62. group.setdefault("nesterov", False)
  63. @torch.no_grad()
  64. def step(self, closure=None):
  65. """Performs a single optimization step.
  66. Args:
  67. closure (callable, optional): A closure that reevaluates the model and returns the loss.
  68. """
  69. loss = None
  70. if closure is not None:
  71. with torch.enable_grad():
  72. loss = closure()
  73. for group in self.param_groups:
  74. weight_decay = group['weight_decay']
  75. momentum = group['momentum']
  76. dampening = group['dampening']
  77. nesterov = group['nesterov']
  78. trust_coeff = group['trust_coeff']
  79. eps = group['eps']
  80. for p in group['params']:
  81. if p.grad is None:
  82. continue
  83. grad = p.grad
  84. # apply LARS LR adaptation, LARC clipping, weight decay
  85. # ref: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py
  86. if weight_decay != 0 or group['always_adapt']:
  87. w_norm = p.norm(2.0)
  88. g_norm = grad.norm(2.0)
  89. trust_ratio = trust_coeff * w_norm / (g_norm + w_norm * weight_decay + eps)
  90. # FIXME nested where required since logical and/or not working in PT XLA
  91. # Set the ratio to 1.0 (no change) if either weight norm or grad norm is zero
  92. trust_ratio = torch.where(
  93. w_norm > 0,
  94. torch.where(g_norm > 0, trust_ratio, 1.0),
  95. 1.0,
  96. )
  97. if group['trust_clip']:
  98. trust_ratio = torch.clamp(trust_ratio / group['lr'], max=1.0)
  99. grad.add_(p, alpha=weight_decay)
  100. grad.mul_(trust_ratio)
  101. # apply SGD update https://github.com/pytorch/pytorch/blob/1.7/torch/optim/sgd.py#L100
  102. if momentum != 0:
  103. param_state = self.state[p]
  104. if 'momentum_buffer' not in param_state:
  105. buf = param_state['momentum_buffer'] = torch.clone(grad).detach()
  106. else:
  107. buf = param_state['momentum_buffer']
  108. buf.mul_(momentum).add_(grad, alpha=1. - dampening)
  109. if nesterov:
  110. grad = grad.add(buf, alpha=momentum)
  111. else:
  112. grad = buf
  113. p.add_(grad, alpha=-group['lr'])
  114. return loss