kron.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563
  1. """ PyTorch Implementation of the Kron (PSGD) optimizer
  2. This is a PSGD optimizer using a Kronecker-factored preconditioner.
  3. This impl was adapted from https://github.com/evanatyourservice/kron_torch
  4. by Evan Walters, licensed CC-BY-4.0.
  5. Contributions to above also made by
  6. * Lucas Nestler, added to his https://github.com/ClashLuke/HeavyBall implementation.
  7. * Omead Pooladzandi https://github.com/opooladz
  8. The above work drew from https://github.com/lixilinx/psgd_torch by Xi-Lin Li
  9. References for added functionality:
  10. Cautious Optimizers: https://arxiv.org/abs/2411.16085
  11. Why Gradients Rapidly Increase Near the End of Training: https://arxiv.org/abs/2506.02285
  12. This `timm` impl
  13. * works with a wider variety of torch versions
  14. * fixes some checkpoint save/restore (resume issues)
  15. * adds decoupled weight-decay option
  16. * has some refactoring, cleanup of args, default/group items
  17. * warning about not having opt_einsum (unusable without)
  18. """
  19. import logging
  20. import string
  21. import random
  22. import warnings
  23. from typing import Any, Callable, Dict, Optional, Tuple, Union
  24. import numpy as np
  25. import torch
  26. try:
  27. # NOTE opt_einsum needed to avoid blowing up memory with einsum ops
  28. import opt_einsum
  29. import torch.backends.opt_einsum
  30. torch.backends.opt_einsum.enabled = True
  31. torch.backends.opt_einsum.strategy = "auto-hq"
  32. has_opt_einsum = True
  33. except ImportError:
  34. has_opt_einsum = False
  35. try:
  36. torch._dynamo.config.cache_size_limit = 1_000_000
  37. has_dynamo = True
  38. except AttributeError:
  39. has_dynamo = False
  40. from ._types import ParamsT
  41. _logger = logging.getLogger(__name__)
  42. def precond_update_prob_schedule(
  43. n: float,
  44. max_prob: float = 1.0,
  45. min_prob: float = 0.03,
  46. decay: float = 0.001,
  47. flat_start: float = 500,
  48. ) -> torch.Tensor:
  49. """Anneal preconditioner update probability during beginning of training.
  50. PSGD benefits from more preconditioner updates at the beginning of training,
  51. but once the preconditioner is learned the update probability can drop low.
  52. This schedule is an exponential anneal with a flat start. Default settings keep
  53. update probability at 1.0 for 200 steps then exponentially anneal down to
  54. `min_prob` by 4000 steps. Default settings work very well for most models and
  55. training regimes.
  56. """
  57. """Exponential anneal with flat start."""
  58. n = torch.tensor(n, dtype=torch.float32)
  59. prob = max_prob * torch.exp(-decay * (n - flat_start))
  60. prob.clamp_(min=min_prob, max=max_prob)
  61. return prob
  62. class Kron(torch.optim.Optimizer):
  63. """Implements PSGD Kron from https://github.com/lixilinx/psgd_torch.
  64. Args:
  65. params: Iterable of parameters to optimize or dicts defining parameter groups.
  66. lr: Learning rate.
  67. momentum: Momentum parameter.
  68. weight_decay: Weight decay.
  69. preconditioner_update_probability: Probability of updating the preconditioner.
  70. If None, defaults to a schedule that anneals from 1.0 to 0.03 by 4000 steps.
  71. max_size_triangular: Max size for dim's preconditioner to be triangular.
  72. min_ndim_triangular: Minimum number of dimensions a layer needs to have triangular preconditioners.
  73. memory_save_mode: 'one_diag', 'smart_one_diag', or 'all_diag', None is default
  74. to set all preconditioners to be triangular, 'one_diag' sets the largest
  75. or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners to be diagonal.
  76. momentum_into_precond_update: whether to send momentum into preconditioner
  77. update instead of raw gradients.
  78. mu_dtype: Dtype of the momentum accumulator.
  79. precond_dtype: Dtype of the preconditioner.
  80. decoupled_decay: AdamW style decoupled weight decay
  81. corrected_weight_decay: apply corrected weight decay when using decoupled_decay (lr**2 / max_lr)
  82. flatten: Flatten dimensions instead of fully relying on expressions for higher rank params
  83. flatten_start_dim: Start of flatten range, defaults to 2. Seems good tradeoff for ConvNets.
  84. flatten_end_dim: End of flatten range, defaults to -1.
  85. stochastic_weight_decay: Enable random modulation of weight decay
  86. deterministic: Deterministic behaviour across save / load (resume). FIXME slow, needs work
  87. """
  88. def __init__(
  89. self,
  90. params: ParamsT,
  91. lr: float = 0.001,
  92. momentum: float = 0.9,
  93. weight_decay: float = 0.0,
  94. preconditioner_update_probability: Optional[Union[Callable, float]] = None,
  95. max_size_triangular: int = 2048,
  96. min_ndim_triangular: int = 2,
  97. memory_save_mode: Optional[str] = None,
  98. momentum_into_precond_update: bool = True,
  99. precond_lr: float = 0.1,
  100. precond_init_scale: float = 1.0,
  101. mu_dtype: Optional[torch.dtype] = None,
  102. precond_dtype: Optional[torch.dtype] = None,
  103. decoupled_decay: bool = False,
  104. corrected_weight_decay: bool = False,
  105. flatten: bool = False,
  106. flatten_start_dim: int = 2,
  107. flatten_end_dim: int = -1,
  108. stochastic_weight_decay: bool = False,
  109. deterministic: bool = False,
  110. ):
  111. if not has_opt_einsum:
  112. warnings.warn("It is highly recommended to have 'opt_einsum' installed for this optimizer.")
  113. if not 0.0 <= lr:
  114. raise ValueError(f"Invalid learning rate: {lr}")
  115. if not 0.0 <= momentum < 1.0:
  116. raise ValueError(f"Invalid beta parameter: {momentum}")
  117. if not 0.0 <= weight_decay:
  118. raise ValueError(f"Invalid weight_decay value: {weight_decay}")
  119. defaults = dict(
  120. lr=lr,
  121. momentum=momentum,
  122. weight_decay=weight_decay,
  123. preconditioner_update_probability=preconditioner_update_probability,
  124. max_size_triangular=max_size_triangular,
  125. min_ndim_triangular=min_ndim_triangular,
  126. memory_save_mode=memory_save_mode,
  127. momentum_into_precond_update=momentum_into_precond_update,
  128. precond_lr=precond_lr,
  129. precond_init_scale=precond_init_scale,
  130. mu_dtype=mu_dtype,
  131. precond_dtype=precond_dtype,
  132. decoupled_decay=decoupled_decay,
  133. corrected_weight_decay=corrected_weight_decay,
  134. flatten=flatten,
  135. flatten_start_dim=flatten_start_dim,
  136. flatten_end_dim=flatten_end_dim,
  137. stochastic_weight_decay=stochastic_weight_decay,
  138. )
  139. super(Kron, self).__init__(params, defaults)
  140. self._param_exprs = {} # cache for einsum expr
  141. self._tiny = torch.finfo(torch.bfloat16).tiny
  142. self.rng = random.Random(1337)
  143. self.deterministic = deterministic
  144. # make compile optional (for bwd compat)
  145. if has_dynamo:
  146. self._calc_A_and_conjB = torch.compile(_calc_A_and_conjB, fullgraph=True, dynamic=False)
  147. self._q_terms = torch.compile(_q_terms, fullgraph=True, dynamic=False)
  148. self._precond_grad = torch.compile(_precond_grad, fullgraph=True, dynamic=False)
  149. self._balance_Q = torch.compile(_balance_Q, fullgraph=True, dynamic=False)
  150. else:
  151. self._calc_A_and_conjB = _calc_A_and_conjB
  152. self._q_terms = _q_terms
  153. self._precond_grad = _precond_grad
  154. self._balance_Q = _balance_Q
  155. def __setstate__(self, state):
  156. super().__setstate__(state)
  157. for group in self.param_groups:
  158. group.setdefault('corrected_weight_decay', False)
  159. def __getstate__(self):
  160. _dict = super().__getstate__()
  161. _dict["rng"] = self.rng
  162. return _dict
  163. def state_dict(self) -> Dict[str, Any]:
  164. # Get the optimizer's state dict
  165. optimizer_state = super().state_dict()
  166. # Add the generator state
  167. optimizer_state['rng_state'] = self.rng.getstate()
  168. return optimizer_state
  169. def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
  170. # Extract and remove the RNG state from the state dict
  171. rng_states = {}
  172. if 'rng_state' in state_dict:
  173. rng_states['rng_state'] = state_dict.pop('rng_state')
  174. # Load the optimizer state
  175. super().load_state_dict(state_dict)
  176. state_dict.update(rng_states) # add back
  177. # Restore the RNG state if it exists
  178. if 'rng_state' in rng_states:
  179. self.rng.setstate(rng_states['rng_state'])
  180. def __setstate__(self, state):
  181. super().__setstate__(state)
  182. self._param_exprs = {}
  183. @torch.no_grad()
  184. def step(self, closure=None):
  185. loss = None
  186. if closure is not None:
  187. with torch.enable_grad():
  188. loss = closure()
  189. total_momentum_size = 0
  190. total_momentum_mb = 0
  191. total_precond_size = 0
  192. total_precond_mb = 0
  193. for group in self.param_groups:
  194. mu_dtype = group.get("mu_dtype")
  195. precond_dtype = group.get("precond_dtype", torch.float32)
  196. momentum_into_precond_update = group.get("momentum_into_precond_update", True)
  197. update_prob = group.get("preconditioner_update_probability", None)
  198. for p in group["params"]:
  199. if p.grad is None:
  200. continue
  201. grad = p.grad
  202. state = self.state[p]
  203. flattened = False
  204. if group['flatten']:
  205. grad = safe_flatten(grad, group["flatten_start_dim"], group["flatten_end_dim"])
  206. flattened = True
  207. if len(state) == 0:
  208. state["step"] = 0
  209. state["update_counter"] = 0
  210. state["momentum_buffer"] = torch.zeros_like(grad, dtype=mu_dtype or grad.dtype)
  211. # init Q and einsum expressions on first step
  212. state["Q"], exprs = _init_Q_exprs(
  213. grad,
  214. group["precond_init_scale"],
  215. group["max_size_triangular"],
  216. group["min_ndim_triangular"],
  217. group["memory_save_mode"],
  218. dtype=precond_dtype,
  219. )
  220. self._param_exprs[p] = exprs
  221. # Accumulate sizes for log
  222. momentum_size = state["momentum_buffer"].numel()
  223. momentum_mb = momentum_size * state["momentum_buffer"].element_size() / 2**20
  224. total_momentum_size += momentum_size
  225. total_momentum_mb += momentum_mb
  226. precond_size = sum(q.numel() for q in state["Q"])
  227. precond_mb = sum(q.numel() * q.element_size() for q in state["Q"]) / 2**20
  228. total_precond_size += precond_size
  229. total_precond_mb += precond_mb
  230. elif p not in self._param_exprs:
  231. # init only the einsum expressions, called after state load, Q are loaded from state_dict
  232. exprs = _init_Q_exprs(
  233. grad,
  234. group["precond_init_scale"],
  235. group["max_size_triangular"],
  236. group["min_ndim_triangular"],
  237. group["memory_save_mode"],
  238. dtype=precond_dtype,
  239. init_q=False,
  240. )
  241. self._param_exprs[p] = exprs
  242. else:
  243. # retrieve cached expressions
  244. exprs = self._param_exprs[p]
  245. # update preconditioners all together deterministically
  246. if update_prob is None:
  247. update_prob = precond_update_prob_schedule
  248. if callable(update_prob):
  249. update_prob = update_prob(state["step"])
  250. state["update_counter"] += 1
  251. do_update = state["update_counter"] >= 1 / update_prob
  252. if do_update:
  253. state["update_counter"] = 0
  254. state["step"] += 1
  255. # Update momentum buffer
  256. beta = group["momentum"]
  257. bias_correction = 1 - beta ** state["step"]
  258. momentum_buffer = state["momentum_buffer"]
  259. momentum_buffer.mul_(group["momentum"]).add_(grad, alpha=1 - group["momentum"])
  260. # Restore momentum dtype
  261. if mu_dtype is not None:
  262. momentum_buffer.copy_(momentum_buffer.to(dtype=mu_dtype))
  263. debiased_momentum = (momentum_buffer / bias_correction).to(dtype=precond_dtype)
  264. # Balance preconditioners roughly every 100 updates
  265. balance = self.rng.random() < 0.01 and do_update
  266. if grad.dim() > 1 and balance:
  267. self._balance_Q(state["Q"])
  268. # Update preconditioner
  269. if do_update:
  270. exprA, exprGs, _ = exprs
  271. Q = state["Q"]
  272. if self.deterministic:
  273. torch_rng = torch.Generator(device=debiased_momentum.device)
  274. torch_rng.manual_seed(self.rng.randint(0, 2 ** 31))
  275. else:
  276. torch_rng = None
  277. V = torch.randn(
  278. debiased_momentum.shape,
  279. generator=torch_rng,
  280. dtype=precond_dtype,
  281. device=debiased_momentum.device,
  282. )
  283. G = debiased_momentum if momentum_into_precond_update else grad
  284. A, conjB = self._calc_A_and_conjB(exprA, G, Q, V)
  285. terms = self._q_terms(exprGs, A, conjB)
  286. for q, (term1, term2) in zip(Q, terms):
  287. tmp = term1 - term2
  288. tmp *= group["precond_lr"]
  289. if q.dim() < 2:
  290. tmp *= q
  291. tmp /= (term1 + term2).norm(float("inf")) + self._tiny
  292. else:
  293. tmp = torch.triu(tmp)
  294. tmp /= _norm_lower_bound(term1 + term2) + self._tiny
  295. tmp @= q
  296. q.sub_(tmp)
  297. # Precondition gradients
  298. pre_grad = self._precond_grad(
  299. state["Q"],
  300. exprs,
  301. debiased_momentum,
  302. ).to(dtype=p.dtype)
  303. # RMS of pre_grad should be 1.0, so let's cap at 1.1
  304. pre_grad.mul_(torch.clamp(1.1 / (pre_grad.square().mean().sqrt_() + 1e-8), max=1.0))
  305. if flattened:
  306. pre_grad = pre_grad.view(p.shape)
  307. # Apply weight decay
  308. weight_decay = group["weight_decay"]
  309. if weight_decay != 0:
  310. if group["stochastic_weight_decay"]:
  311. weight_decay = 2 * self.rng.random() * weight_decay
  312. if group["decoupled_decay"]:
  313. if group['corrected_weight_decay']:
  314. wd_scale = group["lr"] ** 2 / self.defaults['lr']
  315. else:
  316. wd_scale = group["lr"]
  317. p.mul_(1. - wd_scale * weight_decay)
  318. else:
  319. pre_grad.add_(p, alpha=weight_decay)
  320. # Update parameters
  321. p.add_(pre_grad, alpha=-group["lr"])
  322. if total_momentum_size > 0:
  323. _logger.info(f"PSGD Momentum buffer size: {total_momentum_size} elements, {total_momentum_mb:.2f} MB")
  324. _logger.info(f"PSGD Preconditioners size: {total_precond_size} elements, {total_precond_mb:.2f} MB")
  325. return loss
  326. def safe_flatten(tensor, start_dim=0, end_dim=-1):
  327. ndim = tensor.ndim
  328. # Convert negative end_dim to positive and clip to end
  329. end_dim = min(end_dim if end_dim >= 0 else ndim + end_dim, ndim - 1)
  330. # If tensor has fewer dims than start_dim or start > end, return tensor as is
  331. if ndim <= start_dim or start_dim > end_dim:
  332. return tensor
  333. # Now safe to flatten
  334. return tensor.flatten(start_dim, end_dim)
  335. def _init_Q_exprs(
  336. t,
  337. scale,
  338. max_size,
  339. min_ndim_triangular,
  340. memory_save_mode,
  341. dtype=None,
  342. init_q=True,
  343. ):
  344. """For a scalar or tensor t, we initialize its preconditioner Q and
  345. reusable einsum expressions for updating Q and preconditioning gradient.
  346. """
  347. letters = string.ascii_lowercase + string.ascii_uppercase
  348. dtype = dtype if dtype is not None else t.dtype
  349. shape = t.shape
  350. Q = []
  351. if len(shape) == 0: # scalar
  352. if init_q:
  353. Q.append(scale * torch.ones_like(t, dtype=dtype))
  354. exprA = ",->"
  355. exprGs = [",->"]
  356. exprP = ",,->"
  357. else: # tensor
  358. if len(shape) > 13:
  359. raise ValueError(f"Got tensor with dim {len(t.shape)}; Einstein runs out of letters!")
  360. scale = scale ** (1 / len(shape))
  361. if memory_save_mode is None:
  362. dim_diag = [False for _ in shape]
  363. elif memory_save_mode == "one_diag":
  364. rev_sorted_dims = np.argsort(shape)[::-1]
  365. dim_diag = [False for _ in shape]
  366. dim_diag[rev_sorted_dims[0]] = True
  367. elif memory_save_mode == "smart_one_diag":
  368. # addition proposed by Lucas Nestler
  369. rev_sorted_dims = np.argsort(shape)[::-1]
  370. sorted_shape = sorted(shape)
  371. dim_diag = [False for _ in shape]
  372. if len(shape) >= 2 and sorted_shape[-1] > sorted_shape[-2]:
  373. dim_diag[rev_sorted_dims[0]] = True
  374. elif memory_save_mode == "all_diag":
  375. dim_diag = [True for _ in shape]
  376. else:
  377. raise ValueError(
  378. f"Invalid memory_save_mode: {memory_save_mode}, must be one of [None, 'one_diag', 'all_diag']")
  379. piece1A, piece2A, piece3A = ([], "", "")
  380. exprGs = []
  381. piece1P, piece2P, piece3P, piece4P = ([], [], "", "")
  382. for i, (size, dim_d) in enumerate(zip(shape, dim_diag)):
  383. if (
  384. size == 1
  385. or size > max_size
  386. or len(shape) < min_ndim_triangular
  387. or dim_d
  388. ):
  389. # use diagonal matrix as preconditioner for this dim
  390. if init_q:
  391. Q.append(scale * torch.ones(size, dtype=dtype, device=t.device))
  392. piece1A.append(letters[i])
  393. piece2A = piece2A + letters[i]
  394. piece3A = piece3A + letters[i]
  395. piece1 = "".join([letters[i + 13] if j == i else letters[j] for j in range(len(shape))])
  396. subscripts = piece1 + "," + piece1 + "->" + letters[i + 13]
  397. exprGs.append(subscripts)
  398. piece1P.append(letters[i + 13])
  399. piece2P.append(letters[i + 13])
  400. piece3P = piece3P + letters[i + 13]
  401. piece4P = piece4P + letters[i + 13]
  402. else:
  403. # use triangular matrix as preconditioner for this dim
  404. if init_q:
  405. Q.append(scale * torch.eye(size, dtype=dtype, device=t.device))
  406. piece1A.append(letters[i] + letters[i + 13])
  407. piece2A = piece2A + letters[i + 13]
  408. piece3A = piece3A + letters[i]
  409. piece1 = "".join([letters[i + 13] if j == i else letters[j] for j in range(len(shape))])
  410. piece2 = "".join([letters[i + 26] if j == i else letters[j] for j in range(len(shape))])
  411. subscripts = piece1 + "," + piece2 + "->" + letters[i + 13] + letters[i + 26]
  412. exprGs.append(subscripts)
  413. a, b, c = (letters[i], letters[i + 13], letters[i + 26])
  414. piece1P.append(a + b)
  415. piece2P.append(a + c)
  416. piece3P = piece3P + c
  417. piece4P = piece4P + b
  418. exprA = ",".join(piece1A) + "," + piece2A + "->" + piece3A
  419. exprP = ",".join(piece1P) + "," + ",".join(piece2P) + "," + piece3P + "->" + piece4P
  420. exprGs = tuple(exprGs)
  421. if init_q:
  422. return [Q, (exprA, exprGs, exprP)]
  423. else:
  424. return exprA, exprGs, exprP
  425. def _lb(A, max_abs):
  426. A = A / max_abs
  427. aa = torch.real(A * A.conj())
  428. value0, i = torch.max(torch.sum(aa, dim=0), 0)
  429. value1, j = torch.max(torch.sum(aa, dim=1), 0)
  430. if value0 > value1:
  431. x = A[:, i].conj() @ A
  432. return max_abs * torch.linalg.vector_norm((x / torch.linalg.vector_norm(x)) @ A.H)
  433. else:
  434. x = A @ A[j].conj()
  435. return max_abs * torch.linalg.vector_norm(A.H @ (x / torch.linalg.vector_norm(x)))
  436. def _norm_lower_bound(A):
  437. """Cheap lower bound for the spectral norm of A."""
  438. max_abs = A.norm(float("inf"))
  439. return torch.where(max_abs > 0, _lb(A, max_abs), max_abs)
  440. def _solve_triangular_right(X, A):
  441. """X @ inv(A)"""
  442. orig_dtype = X.dtype
  443. X = X.to(dtype=torch.float32)
  444. A = A.to(dtype=torch.float32)
  445. out = torch.linalg.solve_triangular(A, X.reshape(-1, X.size(-1)), upper=True, left=False).reshape_as(X)
  446. return out.to(dtype=orig_dtype)
  447. def _balance_Q(Q_in):
  448. norms = torch.stack([q.norm(float("inf")) for q in Q_in])
  449. geometric_mean = norms.prod() ** (1 / len(Q_in))
  450. norms = geometric_mean / norms
  451. for i, q in enumerate(Q_in):
  452. q.mul_(norms[i])
  453. def _precond_grad(Q, exprs, G):
  454. """Precondition gradient G with preconditioner Q."""
  455. return torch.einsum(exprs[-1], *[q.conj() for q in Q], *Q, G)
  456. def _calc_A_and_conjB(exprA, G, Q, V):
  457. A = torch.einsum(exprA, *Q, G)
  458. order = G.dim()
  459. p = tuple(range(order))
  460. conjB = torch.permute(V.conj(), p[1:] + p[:1])
  461. for i, q in enumerate(Q):
  462. conjB = conjB / q if q.dim() < 2 else _solve_triangular_right(conjB, q)
  463. if i < order - 1:
  464. conjB = torch.transpose(conjB, i, order - 1)
  465. return A, conjB
  466. def _q_terms(exprGs, A, conjB):
  467. terms = []
  468. for exprG in exprGs:
  469. term1 = torch.einsum(exprG, A, A.conj())
  470. term2 = torch.einsum(exprG, conjB.conj(), conjB)
  471. terms.append((term1, term2))
  472. return terms