optimizer.py 50 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193
  1. # mypy: allow-untyped-defs
  2. """Base optimizer."""
  3. import functools
  4. import warnings
  5. from collections import defaultdict, OrderedDict
  6. from collections.abc import Callable, Hashable, Iterable, Sequence
  7. from copy import deepcopy
  8. from itertools import chain
  9. from typing import Any, cast, overload, TypeAlias, TypeVar
  10. from typing_extensions import ParamSpec, Self
  11. import torch
  12. import torch.utils.hooks as hooks
  13. from torch.utils._foreach_utils import (
  14. _get_foreach_kernels_supported_devices,
  15. _get_fused_kernels_supported_devices,
  16. _group_tensors_by_device_and_dtype,
  17. Indices,
  18. TensorListList,
  19. )
  20. from torch.utils.hooks import RemovableHandle
  21. _T = TypeVar("_T")
  22. _P = ParamSpec("_P")
  23. Args: TypeAlias = tuple[Any, ...]
  24. Kwargs: TypeAlias = dict[str, Any]
  25. StateDict: TypeAlias = dict[str, Any]
  26. DeviceDict: TypeAlias = dict[torch.device | None, torch.Tensor]
  27. DeviceDtypeDict: TypeAlias = dict[tuple[torch.device, torch.dtype] | None, torch.Tensor]
  28. GlobalOptimizerPreHook: TypeAlias = Callable[
  29. ["Optimizer", Args, Kwargs], tuple[Args, Kwargs] | None
  30. ]
  31. GlobalOptimizerPostHook: TypeAlias = Callable[["Optimizer", Args, Kwargs], None]
  32. __all__ = [
  33. "Optimizer",
  34. "register_optimizer_step_pre_hook",
  35. "register_optimizer_step_post_hook",
  36. ]
  37. _global_optimizer_pre_hooks: dict[int, GlobalOptimizerPreHook] = OrderedDict()
  38. _global_optimizer_post_hooks: dict[int, GlobalOptimizerPostHook] = OrderedDict()
  39. _foreach_supported_types = [torch.Tensor, torch.nn.parameter.Parameter]
  40. class _RequiredParameter:
  41. """Singleton class representing a required parameter for an Optimizer."""
  42. def __repr__(self) -> str:
  43. return "<required parameter>"
  44. required = _RequiredParameter()
  45. def _use_grad_for_differentiable(func: Callable[_P, _T]) -> Callable[_P, _T]:
  46. def _use_grad(*args: _P.args, **kwargs: _P.kwargs) -> _T:
  47. import torch._dynamo
  48. # pyrefly: ignore [unsupported-operation]
  49. self = cast(Optimizer, args[0]) # assume first positional arg is `self`
  50. prev_grad = torch.is_grad_enabled()
  51. try:
  52. # Note on graph break below:
  53. # we need to graph break to ensure that aot respects the no_grad annotation.
  54. # This is important for perf because without this, functionalization will generate an epilogue
  55. # which updates the mutated parameters of the optimizer which is *not* visible to inductor, as a result,
  56. # inductor will allocate for every parameter in the model, which is horrible.
  57. # With this, aot correctly sees that this is an inference graph, and functionalization will generate
  58. # an epilogue which is appended to the graph, which *is* visible to inductor, as a result, inductor sees that
  59. # step is in place and is able to avoid the extra allocation.
  60. # In the future, we will either 1) continue to graph break on backward, so this graph break does not matter
  61. # or 2) have a fully fused forward and backward graph, which will have no_grad by default, and we can remove this
  62. # graph break to allow the fully fused fwd-bwd-optimizer graph to be compiled.
  63. # see https://github.com/pytorch/pytorch/issues/104053
  64. torch.set_grad_enabled(self.defaults["differentiable"])
  65. torch._dynamo.graph_break()
  66. ret = func(*args, **kwargs)
  67. finally:
  68. torch._dynamo.graph_break()
  69. torch.set_grad_enabled(prev_grad)
  70. return ret
  71. functools.update_wrapper(_use_grad, func)
  72. return _use_grad
  73. def _get_value(x):
  74. # item is significantly faster than a cpu tensor in eager mode
  75. if not torch.jit.is_scripting() and torch.compiler.is_compiling():
  76. return x
  77. else:
  78. return x.item() if isinstance(x, torch.Tensor) else x
  79. def _stack_if_compiling(x):
  80. if not torch.jit.is_scripting() and torch.compiler.is_compiling():
  81. return torch.stack(x)
  82. else:
  83. return x
  84. def _disable_dynamo_if_unsupported(
  85. single_tensor_fn: Callable[..., object] | None = None,
  86. ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
  87. # workaround for torchscript BC
  88. # it requires all called functions to be in the
  89. # global environment at the site at which the
  90. # maybe_fallback closure is created
  91. if single_tensor_fn:
  92. globals()[single_tensor_fn.__name__] = single_tensor_fn
  93. def wrapper(func: Callable[_P, _T]) -> Callable[_P, _T]:
  94. import inspect
  95. disabled_func = torch._disable_dynamo(func)
  96. ps = inspect.signature(func).parameters
  97. has_state_steps = True
  98. try:
  99. state_steps_ind = list(ps.keys()).index("state_steps")
  100. except ValueError:
  101. has_state_steps = False
  102. # Today, there are cases where we stack state steps
  103. # and pass them as the value arg of foreach ops.
  104. # Having state steps on cuda as the value arg is not supported in eager,
  105. # but this only occurs in the rare case that the user explicitly deletes
  106. # the capturable flag. If capturable=True, this is not a problem.
  107. @functools.wraps(func)
  108. def maybe_fallback(*args: _P.args, **kwargs: _P.kwargs):
  109. if torch.compiler.is_compiling() and (
  110. not kwargs.get("capturable", False)
  111. and has_state_steps
  112. # pyrefly: ignore [unsupported-operation]
  113. and (arg := args[state_steps_ind])
  114. and isinstance(arg, Sequence)
  115. and arg[0].is_cuda
  116. or (
  117. "state_steps" in kwargs
  118. # pyrefly: ignore [unsupported-operation]
  119. and (kwarg := kwargs["state_steps"])
  120. and isinstance(kwarg, Sequence)
  121. and kwarg[0].is_cuda
  122. )
  123. ):
  124. return disabled_func(*args, **kwargs)
  125. else:
  126. return func(*args, **kwargs)
  127. return maybe_fallback
  128. return wrapper
  129. # For any optimizer with a faster implementation, we attempt to default to the
  130. # fastest + stablest whenever possible. For foreach, the requirements are to have
  131. # native params all on CUDA. For fused, there's currently the additional requirement
  132. # that the tensors' dtypes must be floating point. Neither alternative supports
  133. # torch.jit.script nor differentiable, so we fall back to the single tensor
  134. # implementation in those cases.
  135. def _default_to_fused_or_foreach(
  136. params: list[torch.Tensor], differentiable: bool, use_fused: bool = False
  137. ) -> tuple[bool, bool]:
  138. if torch.jit.is_scripting() or differentiable:
  139. return False, False
  140. fused_supported_devices = _get_fused_kernels_supported_devices()
  141. foreach_supported_devices = _get_foreach_kernels_supported_devices()
  142. fused = use_fused and all(
  143. p is None
  144. or (
  145. type(p) in _foreach_supported_types
  146. and p.device.type in fused_supported_devices
  147. and torch.is_floating_point(p)
  148. )
  149. for p in params
  150. )
  151. foreach = not fused and all(
  152. p is None
  153. or (
  154. type(p) in _foreach_supported_types
  155. and p.device.type in foreach_supported_devices
  156. )
  157. for p in params
  158. )
  159. return fused, foreach
  160. def _device_dtype_check_for_fused(
  161. p: torch.Tensor, cuda_unsupported: bool = False
  162. ) -> None:
  163. fused_supported_devices = _get_fused_kernels_supported_devices()
  164. if cuda_unsupported:
  165. fused_supported_devices.remove("cuda")
  166. if not (p.device.type in fused_supported_devices and torch.is_floating_point(p)):
  167. raise RuntimeError(
  168. "`fused=True` requires all the params to be floating point Tensors of "
  169. f"supported devices: {fused_supported_devices} but {p.dtype} and {p.device.type}"
  170. )
  171. def _view_as_real(params, *state_and_grads) -> None:
  172. for i, p in enumerate(params):
  173. if torch.is_complex(p):
  174. params[i] = torch.view_as_real(params[i])
  175. for s in state_and_grads:
  176. s[i] = torch.view_as_real(s[i])
  177. def _get_scalar_dtype(is_fused=None):
  178. if is_fused:
  179. return torch.float32
  180. return (
  181. torch.float64 if torch.get_default_dtype() == torch.float64 else torch.float32
  182. )
  183. def _get_capturable_supported_devices(supports_xla: bool = True) -> list[str]:
  184. r"""Return the device type list that supports capturable optimizer."""
  185. capturable_supported_devices = ["cuda", "xpu", "hpu"]
  186. if not torch.jit.is_scripting():
  187. capturable_supported_devices.append(torch._C._get_privateuse1_backend_name())
  188. if supports_xla:
  189. capturable_supported_devices.append("xla")
  190. return capturable_supported_devices
  191. def _to_scalar(x: float | torch.Tensor):
  192. r"""This function converts a hyperparameter to a 0-dimension (scalar) tensor
  193. if it is a nonzero-dimensions 1-element tensor. If it is not a tensor, it is
  194. kept as is.
  195. Args:
  196. x (float or Tensor): A hyperparameter of the optimizer.
  197. If it is Tensor, it is needed to be 1-element.
  198. Returns:
  199. float or Tensor:
  200. a scalar tensor if x is Tensor otherwise Python scalar (float) value.
  201. """
  202. if isinstance(x, torch.Tensor) and x.dim() != 0:
  203. return x.squeeze()
  204. else:
  205. return x
  206. # Common doc strings among optimizers
  207. _params_doc = r"""params (iterable): iterable of parameters or named_parameters to optimize
  208. or iterable of dicts defining parameter groups. When using named_parameters,
  209. all parameters in all groups should be named"""
  210. _foreach_doc = r"""foreach (bool, optional): whether foreach implementation of optimizer
  211. is used. If unspecified by the user (so foreach is None), we will try to use
  212. foreach over the for-loop implementation on CUDA, since it is usually
  213. significantly more performant. Note that the foreach implementation uses
  214. ~ sizeof(params) more peak memory than the for-loop version due to the intermediates
  215. being a tensorlist vs just one tensor. If memory is prohibitive, batch fewer
  216. parameters through the optimizer at a time or switch this flag to False (default: None)"""
  217. _fused_doc = r"""fused (bool, optional): whether the fused implementation is used.
  218. Currently, `torch.float64`, `torch.float32`, `torch.float16`, and `torch.bfloat16`
  219. are supported. (default: None)
  220. .. note:: The foreach and fused implementations are typically faster than the for-loop,
  221. single-tensor implementation, with fused being theoretically fastest with both
  222. vertical and horizontal fusion. As such, if the user has not specified either
  223. flag (i.e., when foreach = fused = None), we will attempt defaulting to the foreach
  224. implementation when the tensors are all on CUDA. Why not fused? Since the fused
  225. implementation is relatively new, we want to give it sufficient bake-in time.
  226. To specify fused, pass True for fused. To force running the for-loop
  227. implementation, pass False for either foreach or fused. """
  228. _capturable_doc = r"""capturable (bool, optional): whether this instance is safe to
  229. capture in a graph, whether for CUDA graphs or for torch.compile support.
  230. Tensors are only capturable when on supported :ref:`accelerators<accelerators>`.
  231. Passing True can impair ungraphed performance, so if you don't intend to graph
  232. capture this instance, leave it False (default: False)"""
  233. _differentiable_doc = r"""differentiable (bool, optional): whether autograd should
  234. occur through the optimizer step in training. Otherwise, the step()
  235. function runs in a torch.no_grad() context. Setting to True can impair
  236. performance, so leave it False if you don't intend to run autograd
  237. through this instance (default: False)"""
  238. _maximize_doc = r"""maximize (bool, optional): maximize the objective with respect to the
  239. params, instead of minimizing (default: False)"""
  240. def register_optimizer_step_pre_hook(hook: GlobalOptimizerPreHook) -> RemovableHandle:
  241. r"""Register a pre hook common to all optimizers.
  242. The hook should have the following signature::
  243. hook(optimizer, args, kwargs) -> None or modified args and kwargs
  244. Args:
  245. hook (Callable): A user defined hook which is registered on all optimizers.
  246. Returns:
  247. :class:`torch.utils.hooks.RemovableHandle`:
  248. a handle that can be used to remove the added hook by calling
  249. ``handle.remove()``
  250. """
  251. handle = hooks.RemovableHandle(_global_optimizer_pre_hooks)
  252. _global_optimizer_pre_hooks[handle.id] = hook
  253. return handle
  254. def register_optimizer_step_post_hook(hook: GlobalOptimizerPostHook) -> RemovableHandle:
  255. r"""Register a post hook common to all optimizers.
  256. The hook should have the following signature::
  257. hook(optimizer, args, kwargs) -> None
  258. Args:
  259. hook (Callable): A user defined hook which is registered on all optimizers.
  260. Returns:
  261. :class:`torch.utils.hooks.RemovableHandle`:
  262. a handle that can be used to remove the added hook by calling
  263. ``handle.remove()``
  264. """
  265. handle = hooks.RemovableHandle(_global_optimizer_post_hooks)
  266. _global_optimizer_post_hooks[handle.id] = hook
  267. return handle
  268. ParamsT: TypeAlias = (
  269. Iterable[torch.Tensor]
  270. | Iterable[dict[str, Any]]
  271. | Iterable[tuple[str, torch.Tensor]]
  272. )
  273. R = TypeVar("R")
  274. T = TypeVar("T")
  275. class Optimizer:
  276. r"""Base class for all optimizers.
  277. .. warning::
  278. Parameters need to be specified as collections that have a deterministic
  279. ordering that is consistent between runs. Examples of objects that don't
  280. satisfy those properties are sets and iterators over values of dictionaries.
  281. Args:
  282. params (iterable): an iterable of :class:`torch.Tensor` s or
  283. :class:`dict` s. Specifies what Tensors should be optimized.
  284. defaults: (dict): a dict containing default values of optimization
  285. options (used when a parameter group doesn't specify them).
  286. """
  287. OptimizerPreHook: TypeAlias = Callable[
  288. [Self, Args, Kwargs], # type: ignore[misc]
  289. tuple[Args, Kwargs] | None,
  290. ]
  291. OptimizerPostHook: TypeAlias = Callable[[Self, Args, Kwargs], None] # type: ignore[misc]
  292. _optimizer_step_pre_hooks: dict[int, OptimizerPreHook]
  293. _optimizer_step_post_hooks: dict[int, OptimizerPostHook]
  294. # pyrefly: ignore [not-a-type]
  295. _optimizer_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]'
  296. _optimizer_state_dict_post_hooks: (
  297. # pyrefly: ignore [not-a-type]
  298. 'OrderedDict[int, Callable[["Optimizer", StateDict], StateDict | None]]'
  299. )
  300. _optimizer_load_state_dict_pre_hooks: (
  301. # pyrefly: ignore [not-a-type]
  302. 'OrderedDict[int, Callable[["Optimizer", StateDict], StateDict | None]]'
  303. )
  304. _optimizer_load_state_dict_post_hooks: (
  305. # pyrefly: ignore [not-a-type]
  306. 'OrderedDict[int, Callable[["Optimizer"], None]]'
  307. )
  308. def __init__(self, params: ParamsT, defaults: dict[str, Any]) -> None: # noqa: D107
  309. torch._C._log_api_usage_once("python.optimizer")
  310. self.defaults = defaults
  311. self._optimizer_step_pre_hooks = OrderedDict()
  312. self._optimizer_step_post_hooks = OrderedDict()
  313. self._optimizer_state_dict_pre_hooks = OrderedDict()
  314. self._optimizer_state_dict_post_hooks = OrderedDict()
  315. self._optimizer_load_state_dict_pre_hooks = OrderedDict()
  316. self._optimizer_load_state_dict_post_hooks = OrderedDict()
  317. self._patch_step_function()
  318. if isinstance(params, torch.Tensor):
  319. raise TypeError(
  320. "params argument given to the optimizer should be "
  321. "an iterable of Tensors or dicts, but got " + torch.typename(params)
  322. )
  323. self.state: defaultdict[torch.Tensor, Any] = defaultdict(dict)
  324. self.param_groups: list[dict[str, Any]] = []
  325. param_groups = list(params)
  326. if len(param_groups) == 0:
  327. raise ValueError("optimizer got an empty parameter list")
  328. if not isinstance(param_groups[0], dict):
  329. param_groups = [{"params": param_groups}]
  330. for param_group in param_groups:
  331. self.add_param_group(cast(dict, param_group))
  332. # Allows _accelerator_graph_capture_health_check to rig a poor man's TORCH_WARN_ONCE in python,
  333. # which I don't think exists
  334. # https://github.com/pytorch/pytorch/issues/72948
  335. self._warned_capturable_if_run_uncaptured = True
  336. def __getstate__(self) -> dict[str, Any]: # noqa: D105
  337. return {
  338. "defaults": self.defaults,
  339. "state": self.state,
  340. "param_groups": self.param_groups,
  341. }
  342. def __setstate__(self, state: dict[str, Any]) -> None: # noqa: D105
  343. self.__dict__.update(state)
  344. if "_optimizer_step_pre_hooks" not in self.__dict__:
  345. self._optimizer_step_pre_hooks = OrderedDict()
  346. if "_optimizer_step_post_hooks" not in self.__dict__:
  347. self._optimizer_step_post_hooks = OrderedDict()
  348. if "_optimizer_state_dict_pre_hooks" not in self.__dict__:
  349. self._optimizer_state_dict_pre_hooks = OrderedDict()
  350. if "_optimizer_state_dict_post_hooks" not in self.__dict__:
  351. self._optimizer_state_dict_post_hooks = OrderedDict()
  352. if "_optimizer_load_state_dict_pre_hooks" not in self.__dict__:
  353. self._optimizer_load_state_dict_pre_hooks = OrderedDict()
  354. if "_optimizer_load_state_dict_post_hooks" not in self.__dict__:
  355. self._optimizer_load_state_dict_post_hooks = OrderedDict()
  356. self._patch_step_function() # To support multiprocessing pickle/unpickle
  357. self.defaults.setdefault("differentiable", False)
  358. def __repr__(self) -> str: # noqa: D105
  359. format_string = self.__class__.__name__ + " ("
  360. for i, group in enumerate(self.param_groups):
  361. format_string += "\n"
  362. format_string += f"Parameter Group {i}\n"
  363. for key in sorted(group.keys()):
  364. if key != "params":
  365. format_string += f" {key}: {group[key]}\n"
  366. format_string += ")"
  367. return format_string
  368. # Currently needed by Adam and AdamW
  369. def _accelerator_graph_capture_health_check(self) -> None:
  370. # Note [torch.compile x capturable]
  371. # If we are compiling, we try to take the capturable path automatically by
  372. # setting the flag to True during tracing. Due to this, we skip all the checks
  373. # normally required for determining whether we can use CUDA/XPU graphs and
  374. # shunt the responsibility to torch.inductor. This saves time during tracing
  375. # since the checks are slow without sacrificing UX since inductor will warn
  376. # later if CUDA/XPU graphs cannot be enabled, e.g.,
  377. # https://github.com/pytorch/pytorch/blob/d3ba8901d8640eb16f88b2bfef9df7fa383d4b47/torch/_inductor/compile_fx.py#L390.
  378. # Thus, when compiling, inductor will determine if cudagraphs
  379. # can be enabled based on whether there is input mutation or CPU tensors.
  380. if torch.compiler.is_compiling():
  381. return
  382. # Determine available accelerator device
  383. accelerator = None
  384. if torch.cuda.is_available():
  385. accelerator = (torch.cuda, "CUDA")
  386. elif torch.xpu.is_available():
  387. accelerator = (torch.xpu, "XPU")
  388. if accelerator:
  389. device_module, device_name = accelerator
  390. capturing = device_module.is_current_stream_capturing()
  391. if capturing and not all(
  392. group["capturable"] for group in self.param_groups
  393. ):
  394. raise RuntimeError(
  395. f"Attempting {device_name} graph capture of step() for an instance of "
  396. + self.__class__.__name__
  397. + " but param_groups' capturable is False."
  398. )
  399. if (
  400. (not getattr(self, "_warned_capturable_if_run_uncaptured", False))
  401. and all(group["capturable"] for group in self.param_groups)
  402. and (not capturing)
  403. ):
  404. warnings.warn(
  405. "This instance was constructed with capturable=True or some of all the param_groups came with capturable=True, "
  406. f"but step() is running without {device_name} graph capture. If you never intend to graph-capture this "
  407. "instance, capturable=True can impair performance, and you should set capturable=False.",
  408. stacklevel=2,
  409. )
  410. self._warned_capturable_if_run_uncaptured = True
  411. def _optimizer_step_code(self) -> None:
  412. """Entry point for `torch.profile.profiler`.
  413. When python tracing is enabled the profiler will hook into this
  414. function at the CPython level to inspect the optimizer's parameters and
  415. param groups. It is called it after `step()` since many optimizers
  416. lazily initialize state.
  417. This is a workaround due to lack of a proper step hook on the optimizer,
  418. and will be removed if it exists.
  419. """
  420. @staticmethod
  421. def profile_hook_step(func: Callable[_P, R]) -> Callable[_P, R]: # noqa: D102
  422. @functools.wraps(func)
  423. def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> R:
  424. self, *_ = args
  425. self = cast(Optimizer, self)
  426. profile_name = f"Optimizer.step#{self.__class__.__name__}.step"
  427. with torch.autograd.profiler.record_function(profile_name):
  428. # call optimizer step pre hooks
  429. for pre_hook in chain(
  430. _global_optimizer_pre_hooks.values(),
  431. self._optimizer_step_pre_hooks.values(),
  432. ):
  433. result = pre_hook(self, args, kwargs)
  434. if result is not None:
  435. if isinstance(result, tuple) and len(result) == 2:
  436. args, kwargs = result # type: ignore[assignment]
  437. else:
  438. raise RuntimeError(
  439. f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}."
  440. )
  441. # pyrefly: ignore [invalid-param-spec]
  442. out = func(*args, **kwargs)
  443. self._optimizer_step_code()
  444. # call optimizer step post hooks
  445. for post_hook in chain(
  446. self._optimizer_step_post_hooks.values(),
  447. _global_optimizer_post_hooks.values(),
  448. ):
  449. post_hook(self, args, kwargs)
  450. return out
  451. return wrapper
  452. @staticmethod
  453. def _group_tensors_by_device_and_dtype(
  454. tensorlistlist: TensorListList,
  455. with_indices: bool = False,
  456. ) -> (
  457. dict[tuple[None, None], tuple[TensorListList, Indices]]
  458. | dict[tuple[torch.device, torch.dtype], tuple[TensorListList, Indices]]
  459. ):
  460. """Group a list of lists of tensors by device and dtype.
  461. Skips this step if we are compiling since this will occur during inductor lowering.
  462. """
  463. if torch.compiler.is_compiling():
  464. return {(None, None): (tensorlistlist, list(range(len(tensorlistlist[0]))))}
  465. else:
  466. return _group_tensors_by_device_and_dtype(tensorlistlist, with_indices) # type: ignore[return-value, arg-type]
  467. def _patch_step_function(self) -> None:
  468. self._zero_grad_profile_name = (
  469. f"Optimizer.zero_grad#{self.__class__.__name__}.zero_grad"
  470. )
  471. hooked = getattr(self.__class__.step, "hooked", None)
  472. if not hooked:
  473. self.__class__.step = self.profile_hook_step(self.__class__.step) # type: ignore[assignment]
  474. self.__class__.step.hooked = True # type: ignore[attr-defined]
  475. def register_step_pre_hook(self, hook: OptimizerPreHook) -> RemovableHandle:
  476. r"""Register an optimizer step pre hook which will be called before optimizer step.
  477. It should have the following signature::
  478. hook(optimizer, args, kwargs) -> None or modified args and kwargs
  479. The ``optimizer`` argument is the optimizer instance being used. If
  480. args and kwargs are modified by the pre-hook, then the transformed
  481. values are returned as a tuple containing the new_args and new_kwargs.
  482. Args:
  483. hook (Callable): The user defined hook to be registered.
  484. Returns:
  485. :class:`torch.utils.hooks.RemovableHandle`:
  486. a handle that can be used to remove the added hook by calling
  487. ``handle.remove()``
  488. """
  489. handle = hooks.RemovableHandle(self._optimizer_step_pre_hooks)
  490. self._optimizer_step_pre_hooks[handle.id] = hook
  491. return handle
  492. def register_step_post_hook(self, hook: OptimizerPostHook) -> RemovableHandle:
  493. r"""Register an optimizer step post hook which will be called after optimizer step.
  494. It should have the following signature::
  495. hook(optimizer, args, kwargs) -> None
  496. The ``optimizer`` argument is the optimizer instance being used.
  497. Args:
  498. hook (Callable): The user defined hook to be registered.
  499. Returns:
  500. :class:`torch.utils.hooks.RemovableHandle`:
  501. a handle that can be used to remove the added hook by calling
  502. ``handle.remove()``
  503. """
  504. handle = hooks.RemovableHandle(self._optimizer_step_post_hooks)
  505. self._optimizer_step_post_hooks[handle.id] = hook
  506. return handle
  507. def register_state_dict_pre_hook(
  508. self, hook: Callable[["Optimizer"], None], prepend: bool = False
  509. ) -> RemovableHandle: # noqa: D101
  510. r"""Register a state dict pre-hook which will be called before :meth:`~torch.optim.Optimizer.state_dict` is called.
  511. It should have the following signature::
  512. hook(optimizer) -> None
  513. The ``optimizer`` argument is the optimizer instance being used.
  514. The hook will be called with argument ``self`` before calling ``state_dict`` on ``self``.
  515. The registered hook can be used to perform pre-processing before the ``state_dict``
  516. call is made.
  517. Args:
  518. hook (Callable): The user defined hook to be registered.
  519. prepend (bool): If True, the provided pre ``hook`` will be fired before
  520. all the already registered pre-hooks on ``state_dict``. Otherwise,
  521. the provided ``hook`` will be fired after all the already registered
  522. pre-hooks. (default: False)
  523. Returns:
  524. :class:`torch.utils.hooks.RemovableHandle`:
  525. a handle that can be used to remove the added hook by calling
  526. ``handle.remove()``
  527. """
  528. handle = hooks.RemovableHandle(self._optimizer_state_dict_pre_hooks)
  529. self._optimizer_state_dict_pre_hooks[handle.id] = hook
  530. if prepend:
  531. self._optimizer_state_dict_pre_hooks.move_to_end(handle.id, last=False)
  532. return handle
  533. def register_state_dict_post_hook(
  534. self,
  535. hook: Callable[["Optimizer", StateDict], StateDict | None],
  536. prepend: bool = False,
  537. ) -> RemovableHandle:
  538. r"""Register a state dict post-hook which will be called after :meth:`~torch.optim.Optimizer.state_dict` is called.
  539. It should have the following signature::
  540. hook(optimizer, state_dict) -> state_dict or None
  541. The hook will be called with arguments ``self`` and ``state_dict`` after generating
  542. a ``state_dict`` on ``self``. The hook may modify the state_dict inplace or optionally
  543. return a new one. The registered hook can be used to perform post-processing
  544. on the ``state_dict`` before it is returned.
  545. Args:
  546. hook (Callable): The user defined hook to be registered.
  547. prepend (bool): If True, the provided post ``hook`` will be fired before
  548. all the already registered post-hooks on ``state_dict``. Otherwise,
  549. the provided ``hook`` will be fired after all the already registered
  550. post-hooks. (default: False)
  551. Returns:
  552. :class:`torch.utils.hooks.RemovableHandle`:
  553. a handle that can be used to remove the added hook by calling
  554. ``handle.remove()``
  555. """
  556. handle = hooks.RemovableHandle(self._optimizer_state_dict_post_hooks)
  557. self._optimizer_state_dict_post_hooks[handle.id] = hook
  558. if prepend:
  559. self._optimizer_state_dict_post_hooks.move_to_end(handle.id, last=False)
  560. return handle
  561. @torch._disable_dynamo
  562. def state_dict(self) -> StateDict:
  563. r"""Return the state of the optimizer as a :class:`dict`.
  564. It contains two entries:
  565. * ``state``: a Dict holding current optimization state. Its content
  566. differs between optimizer classes, but some common characteristics
  567. hold. For example, state is saved per parameter, and the parameter
  568. itself is NOT saved. ``state`` is a Dictionary mapping parameter ids
  569. to a Dict with state corresponding to each parameter.
  570. * ``param_groups``: a List containing all parameter groups where each
  571. parameter group is a Dict. Each parameter group contains metadata
  572. specific to the optimizer, such as learning rate and weight decay,
  573. as well as a List of parameter IDs of the parameters in the group.
  574. If a param group was initialized with ``named_parameters()`` the names
  575. content will also be saved in the state dict.
  576. NOTE: The parameter IDs may look like indices but they are just IDs
  577. associating state with param_group. When loading from a state_dict,
  578. the optimizer will zip the param_group ``params`` (int IDs) and the
  579. optimizer ``param_groups`` (actual ``nn.Parameter`` s) in order to
  580. match state WITHOUT additional verification.
  581. A returned state dict might look something like:
  582. .. code-block:: text
  583. {
  584. 'state': {
  585. 0: {'momentum_buffer': tensor(...), ...},
  586. 1: {'momentum_buffer': tensor(...), ...},
  587. 2: {'momentum_buffer': tensor(...), ...},
  588. 3: {'momentum_buffer': tensor(...), ...}
  589. },
  590. 'param_groups': [
  591. {
  592. 'lr': 0.01,
  593. 'weight_decay': 0,
  594. ...
  595. 'params': [0]
  596. 'param_names' ['param0'] (optional)
  597. },
  598. {
  599. 'lr': 0.001,
  600. 'weight_decay': 0.5,
  601. ...
  602. 'params': [1, 2, 3]
  603. 'param_names': ['param1', 'layer.weight', 'layer.bias'] (optional)
  604. }
  605. ]
  606. }
  607. """
  608. for pre_hook in self._optimizer_state_dict_pre_hooks.values():
  609. pre_hook(self)
  610. # Save order indices instead of Tensors
  611. param_mappings: dict[int, int] = {}
  612. start_index = 0
  613. def pack_group(group: dict[str, Any]) -> dict[str, Any]:
  614. nonlocal start_index
  615. packed = {k: v for k, v in group.items() if k != "params"}
  616. param_mappings.update(
  617. {
  618. id(p): i
  619. for i, p in enumerate(group["params"], start_index)
  620. if id(p) not in param_mappings
  621. }
  622. )
  623. packed["params"] = [param_mappings[id(p)] for p in group["params"]]
  624. start_index += len(packed["params"])
  625. return packed
  626. param_groups = [pack_group(g) for g in self.param_groups]
  627. # Remap state to use order indices as keys
  628. packed_state = {
  629. (param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v
  630. for k, v in self.state.items()
  631. }
  632. state_dict = {
  633. "state": packed_state,
  634. "param_groups": param_groups,
  635. }
  636. for post_hook in self._optimizer_state_dict_post_hooks.values():
  637. hook_result = post_hook(self, state_dict)
  638. if hook_result is not None:
  639. state_dict = hook_result
  640. return state_dict
  641. @staticmethod
  642. def _process_value_according_to_param_policy(
  643. param: torch.Tensor,
  644. value: torch.Tensor,
  645. param_id: int,
  646. param_groups: list[dict[Any, Any]],
  647. key: Hashable = None,
  648. ) -> torch.Tensor:
  649. # Floating-point types are a bit special here. They are the only ones
  650. # that are assumed to always match the type of params.
  651. # Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424
  652. # UNLESS fused or capturable, see note [special device hosting for step]
  653. fused = False
  654. capturable = False
  655. if param_groups is None:
  656. raise AssertionError("Expected param_groups to be set")
  657. for pg in param_groups:
  658. if param_id in pg["params"]:
  659. fused = pg.get("fused", False)
  660. capturable = pg.get("capturable", False)
  661. break
  662. if key == "step":
  663. if capturable or fused:
  664. return value.to(dtype=torch.float32, device=param.device)
  665. else:
  666. return value
  667. else:
  668. if param.is_floating_point():
  669. return value.to(dtype=param.dtype, device=param.device)
  670. else:
  671. return value.to(device=param.device)
  672. def register_load_state_dict_pre_hook(
  673. self,
  674. hook: Callable[["Optimizer", StateDict], StateDict | None],
  675. prepend: bool = False,
  676. ) -> RemovableHandle: # noqa: D205 D400
  677. r"""Register a load_state_dict pre-hook which will be called before
  678. :meth:`~torch.optim.Optimizer.load_state_dict` is called. It should have the
  679. following signature::
  680. hook(optimizer, state_dict) -> state_dict or None
  681. The ``optimizer`` argument is the optimizer instance being used and the
  682. ``state_dict`` argument is a shallow copy of the ``state_dict`` the user
  683. passed in to ``load_state_dict``. The hook may modify the state_dict inplace
  684. or optionally return a new one. If a state_dict is returned, it will be used
  685. to be loaded into the optimizer.
  686. The hook will be called with argument ``self`` and ``state_dict`` before
  687. calling ``load_state_dict`` on ``self``. The registered hook can be used to
  688. perform pre-processing before the ``load_state_dict`` call is made.
  689. Args:
  690. hook (Callable): The user defined hook to be registered.
  691. prepend (bool): If True, the provided pre ``hook`` will be fired before
  692. all the already registered pre-hooks on ``load_state_dict``. Otherwise,
  693. the provided ``hook`` will be fired after all the already registered
  694. pre-hooks. (default: False)
  695. Returns:
  696. :class:`torch.utils.hooks.RemovableHandle`:
  697. a handle that can be used to remove the added hook by calling
  698. ``handle.remove()``
  699. """
  700. handle = hooks.RemovableHandle(self._optimizer_load_state_dict_pre_hooks)
  701. self._optimizer_load_state_dict_pre_hooks[handle.id] = hook
  702. if prepend:
  703. self._optimizer_load_state_dict_pre_hooks.move_to_end(handle.id, last=False)
  704. return handle
  705. def register_load_state_dict_post_hook(
  706. self, hook: Callable[["Optimizer"], None], prepend: bool = False
  707. ) -> RemovableHandle: # noqa: D205 D400
  708. r"""Register a load_state_dict post-hook which will be called after
  709. :meth:`~torch.optim.Optimizer.load_state_dict` is called. It should have the
  710. following signature::
  711. hook(optimizer) -> None
  712. The ``optimizer`` argument is the optimizer instance being used.
  713. The hook will be called with argument ``self`` after calling
  714. ``load_state_dict`` on ``self``. The registered hook can be used to
  715. perform post-processing after ``load_state_dict`` has loaded the
  716. ``state_dict``.
  717. Args:
  718. hook (Callable): The user defined hook to be registered.
  719. prepend (bool): If True, the provided post ``hook`` will be fired before
  720. all the already registered post-hooks on ``load_state_dict``. Otherwise,
  721. the provided ``hook`` will be fired after all the already registered
  722. post-hooks. (default: False)
  723. Returns:
  724. :class:`torch.utils.hooks.RemovableHandle`:
  725. a handle that can be used to remove the added hook by calling
  726. ``handle.remove()``
  727. """
  728. handle = hooks.RemovableHandle(self._optimizer_load_state_dict_post_hooks)
  729. self._optimizer_load_state_dict_post_hooks[handle.id] = hook
  730. if prepend:
  731. self._optimizer_load_state_dict_post_hooks.move_to_end(
  732. handle.id, last=False
  733. ) # type: ignore[attr-defined]
  734. return handle
  735. @torch._disable_dynamo
  736. def load_state_dict(self, state_dict: StateDict) -> None:
  737. r"""Load the optimizer state.
  738. Args:
  739. state_dict (dict): optimizer state. Should be an object returned
  740. from a call to :meth:`state_dict`.
  741. .. warning::
  742. Make sure this method is called after initializing :class:`torch.optim.lr_scheduler.LRScheduler`,
  743. as calling it beforehand will overwrite the loaded learning rates.
  744. .. note::
  745. The names of the parameters (if they exist under the "param_names" key of each param group
  746. in :meth:`state_dict`) will not affect the loading process.
  747. To use the parameters' names for custom cases (such as when the parameters in the loaded state dict
  748. differ from those initialized in the optimizer),
  749. a custom ``register_load_state_dict_pre_hook`` should be implemented to adapt the loaded dict
  750. accordingly.
  751. If ``param_names`` exist in loaded state dict ``param_groups`` they will be saved and override
  752. the current names, if present, in the optimizer state. If they do not exist in loaded state dict,
  753. the optimizer ``param_names`` will remain unchanged.
  754. Example:
  755. >>> # xdoctest: +SKIP
  756. >>> model = torch.nn.Linear(10, 10)
  757. >>> optim = torch.optim.SGD(model.parameters(), lr=3e-4)
  758. >>> scheduler1 = torch.optim.lr_scheduler.LinearLR(
  759. ... optim,
  760. ... start_factor=0.1,
  761. ... end_factor=1,
  762. ... total_iters=20,
  763. ... )
  764. >>> scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(
  765. ... optim,
  766. ... T_max=80,
  767. ... eta_min=3e-5,
  768. ... )
  769. >>> lr = torch.optim.lr_scheduler.SequentialLR(
  770. ... optim,
  771. ... schedulers=[scheduler1, scheduler2],
  772. ... milestones=[20],
  773. ... )
  774. >>> lr.load_state_dict(torch.load("./save_seq.pt"))
  775. >>> # now load the optimizer checkpoint after loading the LRScheduler
  776. >>> optim.load_state_dict(torch.load("./save_optim.pt"))
  777. """
  778. # shallow copy, to be consistent with module API
  779. state_dict = state_dict.copy()
  780. for pre_hook in self._optimizer_load_state_dict_pre_hooks.values():
  781. hook_result = pre_hook(self, state_dict)
  782. if hook_result is not None:
  783. state_dict = hook_result
  784. # Validate the state_dict
  785. groups = self.param_groups
  786. # Deepcopy as we write into saved_groups later to update state
  787. saved_groups = deepcopy(state_dict["param_groups"])
  788. if len(groups) != len(saved_groups):
  789. raise ValueError(
  790. "loaded state dict has a different number of parameter groups"
  791. )
  792. param_lens = (len(g["params"]) for g in groups)
  793. saved_lens = (len(g["params"]) for g in saved_groups)
  794. if any(
  795. p_len != s_len for p_len, s_len in zip(param_lens, saved_lens, strict=True)
  796. ):
  797. raise ValueError(
  798. "loaded state dict contains a parameter group "
  799. "that doesn't match the size of optimizer's group"
  800. )
  801. # Update the state
  802. id_map = dict(
  803. zip(
  804. chain.from_iterable(g["params"] for g in saved_groups),
  805. chain.from_iterable(g["params"] for g in groups),
  806. strict=True,
  807. )
  808. )
  809. def _cast(param, value, param_id=None, param_groups=None, key=None):
  810. r"""Make a deep copy of value, casting all tensors to device of param."""
  811. if isinstance(value, torch.Tensor):
  812. return Optimizer._process_value_according_to_param_policy(
  813. param,
  814. value,
  815. # pyrefly: ignore [bad-argument-type]
  816. param_id,
  817. # pyrefly: ignore [bad-argument-type]
  818. param_groups,
  819. key,
  820. )
  821. elif isinstance(value, dict):
  822. return {
  823. k: _cast(
  824. param, v, param_id=param_id, param_groups=param_groups, key=k
  825. )
  826. for k, v in value.items()
  827. }
  828. elif isinstance(value, Iterable):
  829. # pyrefly: ignore [bad-instantiation]
  830. return type(value)(
  831. # pyrefly: ignore [bad-argument-count]
  832. _cast(param, v, param_id=param_id, param_groups=param_groups)
  833. for v in value
  834. ) # type: ignore[call-arg]
  835. else:
  836. return value
  837. # Copy state assigned to params (and cast tensors to appropriate types).
  838. # State that is not assigned to params is copied as is (needed for
  839. # backward compatibility).
  840. state: defaultdict[torch.Tensor, dict[Any, Any]] = defaultdict(dict)
  841. for k, v in state_dict["state"].items():
  842. if k in id_map:
  843. param = id_map[k]
  844. state[param] = _cast(
  845. param, v, param_id=k, param_groups=state_dict["param_groups"]
  846. )
  847. else:
  848. state[k] = v
  849. # Update parameter groups, setting their 'params' value
  850. def update_group(
  851. group: dict[str, Any], new_group: dict[str, Any]
  852. ) -> dict[str, Any]:
  853. new_group["params"] = group["params"]
  854. if "param_names" in group and "param_names" not in new_group:
  855. new_group["param_names"] = group["param_names"]
  856. return new_group
  857. param_groups = [
  858. update_group(g, ng) for g, ng in zip(groups, saved_groups, strict=True)
  859. ]
  860. self.__setstate__({"state": state, "param_groups": param_groups})
  861. for post_hook in self._optimizer_load_state_dict_post_hooks.values():
  862. post_hook(self)
  863. @torch._disable_dynamo
  864. def zero_grad(self, set_to_none: bool = True) -> None:
  865. r"""Reset the gradients of all optimized :class:`torch.Tensor` s.
  866. Args:
  867. set_to_none (bool, optional): Instead of setting to zero, set the grads to None. Default: ``True``
  868. This will in general have lower memory footprint, and can modestly improve performance.
  869. However, it changes certain behaviors. For example:
  870. 1. When the user tries to access a gradient and perform manual ops on it,
  871. a None attribute or a Tensor full of 0s will behave differently.
  872. 2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s
  873. are guaranteed to be None for params that did not receive a gradient.
  874. 3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None
  875. (in one case it does the step with a gradient of 0 and in the other it skips
  876. the step altogether).
  877. """
  878. foreach = self.defaults.get("foreach", False) or self.defaults.get(
  879. "fused", False
  880. )
  881. if not hasattr(self, "_zero_grad_profile_name"):
  882. self._patch_step_function()
  883. per_device_and_dtype_grads: (
  884. defaultdict[torch.device, defaultdict[torch.dtype, list[torch.Tensor]]]
  885. | None
  886. )
  887. if foreach:
  888. per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list))
  889. else:
  890. per_device_and_dtype_grads = None
  891. with torch.autograd.profiler.record_function(self._zero_grad_profile_name):
  892. for group in self.param_groups:
  893. for p in group["params"]:
  894. if p.grad is not None:
  895. if set_to_none:
  896. p.grad = None
  897. else:
  898. if p.grad.grad_fn is not None:
  899. p.grad.detach_()
  900. else:
  901. p.grad.requires_grad_(False)
  902. if not foreach or p.grad.is_sparse:
  903. p.grad.zero_()
  904. else:
  905. if per_device_and_dtype_grads is None:
  906. raise AssertionError(
  907. "Expected per_device_and_dtype_grads to be set"
  908. )
  909. per_device_and_dtype_grads[p.grad.device][
  910. p.grad.dtype
  911. ].append(p.grad)
  912. if foreach:
  913. if per_device_and_dtype_grads is None:
  914. raise AssertionError(
  915. "Expected per_device_and_dtype_grads to be set"
  916. )
  917. for per_dtype_grads in per_device_and_dtype_grads.values():
  918. for grads in per_dtype_grads.values():
  919. torch._foreach_zero_(grads)
  920. @overload
  921. def step(self, closure: None = None) -> None: ...
  922. @overload
  923. def step(self, closure: Callable[[], float]) -> float: ...
  924. def step(self, closure: Callable[[], float] | None = None) -> float | None:
  925. r"""Perform a single optimization step to update parameter.
  926. Args:
  927. closure (Callable): A closure that reevaluates the model and
  928. returns the loss. Optional for most optimizers.
  929. """
  930. raise NotImplementedError
  931. @torch._disable_dynamo
  932. def add_param_group(self, param_group: dict[str, Any]) -> None:
  933. r"""Add a param group to the :class:`Optimizer` s `param_groups`.
  934. This can be useful when fine tuning a pre-trained network as frozen layers can be made
  935. trainable and added to the :class:`Optimizer` as training progresses.
  936. Args:
  937. param_group (dict): Specifies what Tensors should be optimized along with group
  938. specific optimization options.
  939. """
  940. if not isinstance(param_group, dict):
  941. raise TypeError(f"param_group must be a dict, but got {type(param_group)}")
  942. params = param_group["params"]
  943. if isinstance(params, torch.Tensor):
  944. param_group["params"] = [params]
  945. elif isinstance(params, set):
  946. raise TypeError(
  947. "optimizer parameters need to be organized in ordered collections, but "
  948. "the ordering of tensors in sets will change between runs. Please use a list instead."
  949. )
  950. else:
  951. param_group["params"] = list(params)
  952. extracted_param_tensors = []
  953. extracted_param_names = []
  954. for param in param_group["params"]:
  955. if isinstance(param, tuple):
  956. param_name = param[0]
  957. extracted_param_names.append(param_name)
  958. extracted_param_tensors.append(param[1])
  959. else:
  960. extracted_param_tensors.append(param)
  961. param_group["params"] = extracted_param_tensors
  962. if len(extracted_param_names) != 0:
  963. if len(extracted_param_names) == len(extracted_param_tensors):
  964. param_group["param_names"] = extracted_param_names
  965. else:
  966. raise ValueError(
  967. "all optimizer params should be with/without names. Some param names are missing"
  968. )
  969. for param in param_group["params"]:
  970. if not isinstance(param, torch.Tensor):
  971. raise TypeError(
  972. "optimizer can only optimize Tensors, "
  973. "but one of the params is " + torch.typename(param)
  974. )
  975. if not self.defaults.get("differentiable", None) and not (
  976. param.is_leaf or param.retains_grad
  977. ):
  978. raise ValueError("can't optimize a non-leaf Tensor")
  979. for name, default in self.defaults.items():
  980. if default is required and name not in param_group:
  981. raise ValueError(
  982. f"parameter group didn't specify a value of required optimization parameter {name}"
  983. )
  984. else:
  985. param_group.setdefault(name, default)
  986. params = param_group["params"]
  987. if len(params) != len(set(params)):
  988. warnings.warn(
  989. "optimizer contains a parameter group with duplicate parameters; "
  990. "in future, this will cause an error; "
  991. "see github.com/pytorch/pytorch/issues/40967 for more information",
  992. stacklevel=3,
  993. )
  994. param_set: set[torch.Tensor] = set()
  995. for group in self.param_groups:
  996. param_set.update(set(group["params"]))
  997. if ("param_names" in param_group) != ("param_names" in group):
  998. current_group_txt = (
  999. "with names" if "param_names" in param_group else "without names"
  1000. )
  1001. raise ValueError(
  1002. "all optimizer param groups should be with/without names. "
  1003. f"cannot add param group {current_group_txt} to the optimizer"
  1004. )
  1005. if not param_set.isdisjoint(set(param_group["params"])):
  1006. raise ValueError("some parameters appear in more than one parameter group")
  1007. self.param_groups.append(param_group)