lbfgs.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537
  1. # mypy: allow-untyped-defs
  2. import torch
  3. from torch import Tensor
  4. from .optimizer import _to_scalar, Optimizer, ParamsT
  5. __all__ = ["LBFGS"]
  6. def _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None):
  7. # ported from https://github.com/torch/optim/blob/master/polyinterp.lua
  8. # Compute bounds of interpolation area
  9. if bounds is not None:
  10. xmin_bound, xmax_bound = bounds
  11. else:
  12. xmin_bound, xmax_bound = (x1, x2) if x1 <= x2 else (x2, x1)
  13. # Code for most common case: cubic interpolation of 2 points
  14. # w/ function and derivative values for both
  15. # Solution in this case (where x2 is the farthest point):
  16. # d1 = g1 + g2 - 3*(f1-f2)/(x1-x2);
  17. # d2 = sqrt(d1^2 - g1*g2);
  18. # min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2));
  19. # t_new = min(max(min_pos,xmin_bound),xmax_bound);
  20. d1 = g1 + g2 - 3 * (f1 - f2) / (x1 - x2)
  21. d2_square = d1**2 - g1 * g2
  22. if d2_square >= 0:
  23. d2 = d2_square.sqrt()
  24. if x1 <= x2:
  25. min_pos = x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2))
  26. else:
  27. min_pos = x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2))
  28. return min(max(min_pos, xmin_bound), xmax_bound)
  29. else:
  30. return (xmin_bound + xmax_bound) / 2.0
  31. def _strong_wolfe(
  32. obj_func, x, t, d, f, g, gtd, c1=1e-4, c2=0.9, tolerance_change=1e-9, max_ls=25
  33. ):
  34. # ported from https://github.com/torch/optim/blob/master/lswolfe.lua
  35. d_norm = d.abs().max()
  36. g = g.clone(memory_format=torch.contiguous_format)
  37. # evaluate objective and gradient using initial step
  38. f_new, g_new = obj_func(x, t, d)
  39. ls_func_evals = 1
  40. gtd_new = g_new.dot(d)
  41. # bracket an interval containing a point satisfying the Wolfe criteria
  42. t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd
  43. done = False
  44. ls_iter = 0
  45. while ls_iter < max_ls:
  46. # check conditions
  47. if f_new > (f + c1 * t * gtd) or (ls_iter > 1 and f_new >= f_prev):
  48. bracket = [t_prev, t]
  49. bracket_f = [f_prev, f_new]
  50. bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)]
  51. bracket_gtd = [gtd_prev, gtd_new]
  52. break
  53. if abs(gtd_new) <= -c2 * gtd:
  54. bracket = [t]
  55. bracket_f = [f_new]
  56. bracket_g = [g_new]
  57. done = True
  58. break
  59. if gtd_new >= 0:
  60. bracket = [t_prev, t]
  61. bracket_f = [f_prev, f_new]
  62. bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)]
  63. bracket_gtd = [gtd_prev, gtd_new]
  64. break
  65. # interpolate
  66. min_step = t + 0.01 * (t - t_prev)
  67. max_step = t * 10
  68. tmp = t
  69. t = _cubic_interpolate(
  70. t_prev, f_prev, gtd_prev, t, f_new, gtd_new, bounds=(min_step, max_step)
  71. )
  72. # next step
  73. t_prev = tmp
  74. f_prev = f_new
  75. g_prev = g_new.clone(memory_format=torch.contiguous_format)
  76. gtd_prev = gtd_new
  77. f_new, g_new = obj_func(x, t, d)
  78. ls_func_evals += 1
  79. gtd_new = g_new.dot(d)
  80. ls_iter += 1
  81. # reached max number of iterations?
  82. if ls_iter == max_ls:
  83. bracket = [0, t]
  84. bracket_f = [f, f_new]
  85. bracket_g = [g, g_new]
  86. # zoom phase: we now have a point satisfying the criteria, or
  87. # a bracket around it. We refine the bracket until we find the
  88. # exact point satisfying the criteria
  89. insuf_progress = False
  90. # find high and low points in bracket
  91. low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[-1] else (1, 0) # type: ignore[possibly-undefined]
  92. while not done and ls_iter < max_ls:
  93. # line-search bracket is so small
  94. if abs(bracket[1] - bracket[0]) * d_norm < tolerance_change: # type: ignore[possibly-undefined]
  95. break
  96. # compute new trial value
  97. t = _cubic_interpolate(
  98. # pyrefly: ignore [index-error]
  99. # pyrefly: ignore [unbound-name]
  100. bracket[0],
  101. # pyrefly: ignore [unbound-name]
  102. bracket_f[0],
  103. bracket_gtd[0], # type: ignore[possibly-undefined]
  104. # pyrefly: ignore [index-error]
  105. # pyrefly: ignore [unbound-name]
  106. bracket[1],
  107. # pyrefly: ignore [unbound-name]
  108. bracket_f[1],
  109. # pyrefly: ignore [unbound-name]
  110. bracket_gtd[1],
  111. )
  112. # test that we are making sufficient progress:
  113. # in case `t` is so close to boundary, we mark that we are making
  114. # insufficient progress, and if
  115. # + we have made insufficient progress in the last step, or
  116. # + `t` is at one of the boundary,
  117. # we will move `t` to a position which is `0.1 * len(bracket)`
  118. # away from the nearest boundary point.
  119. # pyrefly: ignore [unbound-name]
  120. eps = 0.1 * (max(bracket) - min(bracket))
  121. # pyrefly: ignore [unbound-name]
  122. if min(max(bracket) - t, t - min(bracket)) < eps:
  123. # interpolation close to boundary
  124. # pyrefly: ignore [unbound-name]
  125. if insuf_progress or t >= max(bracket) or t <= min(bracket):
  126. # evaluate at 0.1 away from boundary
  127. # pyrefly: ignore [unbound-name]
  128. if abs(t - max(bracket)) < abs(t - min(bracket)):
  129. # pyrefly: ignore [unbound-name]
  130. t = max(bracket) - eps
  131. else:
  132. # pyrefly: ignore [unbound-name]
  133. t = min(bracket) + eps
  134. insuf_progress = False
  135. else:
  136. insuf_progress = True
  137. else:
  138. insuf_progress = False
  139. # Evaluate new point
  140. f_new, g_new = obj_func(x, t, d)
  141. ls_func_evals += 1
  142. gtd_new = g_new.dot(d)
  143. ls_iter += 1
  144. # pyrefly: ignore [unbound-name]
  145. if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]:
  146. # Armijo condition not satisfied or not lower than lowest point
  147. # pyrefly: ignore [unsupported-operation]
  148. # pyrefly: ignore [unbound-name]
  149. bracket[high_pos] = t
  150. # pyrefly: ignore [unbound-name]
  151. bracket_f[high_pos] = f_new
  152. bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined]
  153. # pyrefly: ignore [unbound-name]
  154. bracket_gtd[high_pos] = gtd_new
  155. # pyrefly: ignore [unbound-name]
  156. low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0)
  157. else:
  158. if abs(gtd_new) <= -c2 * gtd:
  159. # Wolfe conditions satisfied
  160. done = True
  161. # pyrefly: ignore [index-error]
  162. # pyrefly: ignore [unbound-name]
  163. elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0:
  164. # old high becomes new low
  165. # pyrefly: ignore [unsupported-operation]
  166. # pyrefly: ignore [unbound-name]
  167. bracket[high_pos] = bracket[low_pos]
  168. # pyrefly: ignore [unbound-name]
  169. bracket_f[high_pos] = bracket_f[low_pos]
  170. bracket_g[high_pos] = bracket_g[low_pos] # type: ignore[possibly-undefined]
  171. # pyrefly: ignore [unbound-name]
  172. bracket_gtd[high_pos] = bracket_gtd[low_pos]
  173. # new point becomes new low
  174. # pyrefly: ignore [unsupported-operation]
  175. # pyrefly: ignore [unbound-name]
  176. bracket[low_pos] = t
  177. # pyrefly: ignore [unbound-name]
  178. bracket_f[low_pos] = f_new
  179. bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined]
  180. # pyrefly: ignore [unbound-name]
  181. bracket_gtd[low_pos] = gtd_new
  182. # return stuff
  183. t = bracket[low_pos] # type: ignore[possibly-undefined]
  184. # pyrefly: ignore [unbound-name]
  185. f_new = bracket_f[low_pos]
  186. g_new = bracket_g[low_pos] # type: ignore[possibly-undefined]
  187. return f_new, g_new, t, ls_func_evals
  188. class LBFGS(Optimizer):
  189. """Implements L-BFGS algorithm.
  190. Heavily inspired by `minFunc
  191. <https://www.cs.ubc.ca/~schmidtm/Software/minFunc.html>`_.
  192. .. warning::
  193. This optimizer doesn't support per-parameter options and parameter
  194. groups (there can be only one).
  195. .. warning::
  196. Right now all parameters have to be on a single device. This will be
  197. improved in the future.
  198. .. note::
  199. This is a very memory intensive optimizer (it requires additional
  200. ``param_bytes * (history_size + 1)`` bytes). If it doesn't fit in memory
  201. try reducing the history size, or use a different algorithm.
  202. Args:
  203. params (iterable): iterable of parameters to optimize. Parameters must be real.
  204. lr (float, optional): learning rate (default: 1)
  205. max_iter (int, optional): maximal number of iterations per optimization step
  206. (default: 20)
  207. max_eval (int, optional): maximal number of function evaluations per optimization
  208. step (default: max_iter * 1.25).
  209. tolerance_grad (float, optional): termination tolerance on first order optimality
  210. (default: 1e-7).
  211. tolerance_change (float, optional): termination tolerance on function
  212. value/parameter changes (default: 1e-9).
  213. history_size (int, optional): update history size (default: 100).
  214. line_search_fn (str, optional): either 'strong_wolfe' or None (default: None).
  215. """
  216. def __init__(
  217. self,
  218. params: ParamsT,
  219. lr: float | Tensor = 1,
  220. max_iter: int = 20,
  221. max_eval: int | None = None,
  222. tolerance_grad: float = 1e-7,
  223. tolerance_change: float = 1e-9,
  224. history_size: int = 100,
  225. line_search_fn: str | None = None,
  226. ) -> None:
  227. if isinstance(lr, Tensor) and lr.numel() != 1:
  228. raise ValueError("Tensor lr must be 1-element")
  229. if not 0.0 <= lr:
  230. raise ValueError(f"Invalid learning rate: {lr}")
  231. if max_eval is None:
  232. max_eval = max_iter * 5 // 4
  233. defaults = {
  234. "lr": lr,
  235. "max_iter": max_iter,
  236. "max_eval": max_eval,
  237. "tolerance_grad": tolerance_grad,
  238. "tolerance_change": tolerance_change,
  239. "history_size": history_size,
  240. "line_search_fn": line_search_fn,
  241. }
  242. super().__init__(params, defaults)
  243. if len(self.param_groups) != 1:
  244. raise ValueError(
  245. "LBFGS doesn't support per-parameter options (parameter groups)"
  246. )
  247. self._params = self.param_groups[0]["params"]
  248. self._numel_cache = None
  249. def _numel(self):
  250. if self._numel_cache is None:
  251. # pyrefly: ignore [bad-assignment]
  252. self._numel_cache = sum(
  253. 2 * p.numel() if torch.is_complex(p) else p.numel()
  254. for p in self._params
  255. )
  256. return self._numel_cache
  257. def _gather_flat_grad(self):
  258. views = []
  259. for p in self._params:
  260. if p.grad is None:
  261. view = p.new(p.numel()).zero_()
  262. elif p.grad.is_sparse:
  263. view = p.grad.to_dense().view(-1)
  264. else:
  265. view = p.grad.view(-1)
  266. if torch.is_complex(view):
  267. view = torch.view_as_real(view).view(-1)
  268. views.append(view)
  269. return torch.cat(views, 0)
  270. def _add_grad(self, step_size, update) -> None:
  271. offset = 0
  272. for p in self._params:
  273. if torch.is_complex(p):
  274. p = torch.view_as_real(p)
  275. numel = p.numel()
  276. # view as to avoid deprecated pointwise semantics
  277. p.add_(update[offset : offset + numel].view_as(p), alpha=step_size)
  278. offset += numel
  279. if offset != self._numel():
  280. raise AssertionError(f"Expected offset {offset} to equal {self._numel()}")
  281. def _clone_param(self):
  282. return [p.clone(memory_format=torch.contiguous_format) for p in self._params]
  283. def _set_param(self, params_data) -> None:
  284. for p, pdata in zip(self._params, params_data, strict=True):
  285. p.copy_(pdata)
  286. def _directional_evaluate(self, closure, x, t, d):
  287. self._add_grad(t, d)
  288. loss = float(closure())
  289. flat_grad = self._gather_flat_grad()
  290. self._set_param(x)
  291. return loss, flat_grad
  292. @torch.no_grad()
  293. def step(self, closure): # type: ignore[override]
  294. """Perform a single optimization step.
  295. Args:
  296. closure (Callable): A closure that reevaluates the model
  297. and returns the loss.
  298. """
  299. if len(self.param_groups) != 1:
  300. raise AssertionError(
  301. f"Expected exactly one param_group, but got {len(self.param_groups)}"
  302. )
  303. # Make sure the closure is always called with grad enabled
  304. closure = torch.enable_grad()(closure)
  305. group = self.param_groups[0]
  306. lr = _to_scalar(group["lr"])
  307. max_iter = group["max_iter"]
  308. max_eval = group["max_eval"]
  309. tolerance_grad = group["tolerance_grad"]
  310. tolerance_change = group["tolerance_change"]
  311. line_search_fn = group["line_search_fn"]
  312. history_size = group["history_size"]
  313. # NOTE: LBFGS has only global state, but we register it as state for
  314. # the first param, because this helps with casting in load_state_dict
  315. state = self.state[self._params[0]]
  316. state.setdefault("func_evals", 0)
  317. state.setdefault("n_iter", 0)
  318. # evaluate initial f(x) and df/dx
  319. orig_loss = closure()
  320. loss = float(orig_loss)
  321. current_evals = 1
  322. state["func_evals"] += 1
  323. flat_grad = self._gather_flat_grad()
  324. opt_cond = flat_grad.abs().max() <= tolerance_grad
  325. # optimal condition
  326. if opt_cond:
  327. return orig_loss
  328. # tensors cached in state (for tracing)
  329. d = state.get("d")
  330. t = state.get("t")
  331. old_dirs = state.get("old_dirs")
  332. old_stps = state.get("old_stps")
  333. ro = state.get("ro")
  334. H_diag = state.get("H_diag")
  335. prev_flat_grad = state.get("prev_flat_grad")
  336. prev_loss = state.get("prev_loss")
  337. n_iter = 0
  338. # optimize for a max of max_iter iterations
  339. while n_iter < max_iter:
  340. # keep track of nb of iterations
  341. n_iter += 1
  342. state["n_iter"] += 1
  343. ############################################################
  344. # compute gradient descent direction
  345. ############################################################
  346. if state["n_iter"] == 1:
  347. d = flat_grad.neg()
  348. old_dirs = []
  349. old_stps = []
  350. ro = []
  351. H_diag = 1
  352. else:
  353. # do lbfgs update (update memory)
  354. y = flat_grad.sub(prev_flat_grad)
  355. s = d.mul(t)
  356. ys = y.dot(s) # y*s
  357. if ys > 1e-10:
  358. # updating memory
  359. if len(old_dirs) == history_size:
  360. # shift history by one (limited-memory)
  361. old_dirs.pop(0)
  362. old_stps.pop(0)
  363. ro.pop(0)
  364. # store new direction/step
  365. old_dirs.append(y)
  366. old_stps.append(s)
  367. ro.append(1.0 / ys)
  368. # update scale of initial Hessian approximation
  369. H_diag = ys / y.dot(y) # (y*y)
  370. # compute the approximate (L-BFGS) inverse Hessian
  371. # multiplied by the gradient
  372. num_old = len(old_dirs)
  373. if "al" not in state:
  374. state["al"] = [None] * history_size
  375. al = state["al"]
  376. # iteration in L-BFGS loop collapsed to use just one buffer
  377. q = flat_grad.neg()
  378. for i in range(num_old - 1, -1, -1):
  379. al[i] = old_stps[i].dot(q) * ro[i]
  380. q.add_(old_dirs[i], alpha=-al[i])
  381. # multiply by initial Hessian
  382. # r/d is the final direction
  383. d = r = torch.mul(q, H_diag)
  384. for i in range(num_old):
  385. be_i = old_dirs[i].dot(r) * ro[i]
  386. r.add_(old_stps[i], alpha=al[i] - be_i)
  387. if prev_flat_grad is None:
  388. prev_flat_grad = flat_grad.clone(memory_format=torch.contiguous_format)
  389. else:
  390. prev_flat_grad.copy_(flat_grad)
  391. prev_loss = loss
  392. ############################################################
  393. # compute step length
  394. ############################################################
  395. # reset initial guess for step size
  396. if state["n_iter"] == 1:
  397. t = min(1.0, 1.0 / flat_grad.abs().sum()) * lr
  398. else:
  399. t = lr
  400. # directional derivative
  401. gtd = flat_grad.dot(d) # g * d
  402. # directional derivative is below tolerance
  403. if gtd > -tolerance_change:
  404. break
  405. # optional line search: user function
  406. ls_func_evals = 0
  407. if line_search_fn is not None:
  408. # perform line search, using user function
  409. if line_search_fn != "strong_wolfe":
  410. raise RuntimeError("only 'strong_wolfe' is supported")
  411. else:
  412. x_init = self._clone_param()
  413. def obj_func(x, t, d):
  414. return self._directional_evaluate(closure, x, t, d)
  415. loss, flat_grad, t, ls_func_evals = _strong_wolfe(
  416. obj_func,
  417. x_init,
  418. t,
  419. d,
  420. loss,
  421. flat_grad,
  422. gtd,
  423. max_ls=max_eval - current_evals,
  424. )
  425. self._add_grad(t, d)
  426. opt_cond = flat_grad.abs().max() <= tolerance_grad
  427. else:
  428. # no line search, simply move with fixed-step
  429. self._add_grad(t, d)
  430. if n_iter != max_iter:
  431. # re-evaluate function only if not in last iteration
  432. # the reason we do this: in a stochastic setting,
  433. # no use to re-evaluate that function here
  434. with torch.enable_grad():
  435. loss = closure()
  436. loss = float(loss)
  437. flat_grad = self._gather_flat_grad()
  438. opt_cond = flat_grad.abs().max() <= tolerance_grad
  439. ls_func_evals = 1
  440. # update func eval
  441. current_evals += ls_func_evals
  442. state["func_evals"] += ls_func_evals
  443. ############################################################
  444. # check conditions
  445. ############################################################
  446. if n_iter == max_iter:
  447. break
  448. if current_evals >= max_eval:
  449. break
  450. # optimal condition
  451. if opt_cond:
  452. break
  453. # lack of progress
  454. if d.mul(t).abs().max() <= tolerance_change:
  455. break
  456. if abs(loss - prev_loss) < tolerance_change:
  457. break
  458. state["d"] = d
  459. state["t"] = t
  460. state["old_dirs"] = old_dirs
  461. state["old_stps"] = old_stps
  462. state["ro"] = ro
  463. state["H_diag"] = H_diag
  464. state["prev_flat_grad"] = prev_flat_grad
  465. state["prev_loss"] = prev_loss
  466. return orig_loss