trainer_optimizer.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627
  1. # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Optimizer utilities for the Trainer class.
  16. """
  17. from __future__ import annotations
  18. import importlib.metadata
  19. import logging
  20. from collections.abc import Callable
  21. from dataclasses import dataclass
  22. from typing import TYPE_CHECKING, Any
  23. import torch
  24. from packaging import version
  25. from torch import nn
  26. from .optimization import Adafactor
  27. from .trainer_pt_utils import LayerWiseDummyOptimizer
  28. from .trainer_utils import check_target_module_exists
  29. from .training_args import OptimizerNames, ParallelMode
  30. from .utils import (
  31. is_apollo_torch_available,
  32. is_bitsandbytes_available,
  33. is_galore_torch_available,
  34. is_grokadamw_available,
  35. is_lomo_available,
  36. is_schedulefree_available,
  37. is_torch_optimi_available,
  38. is_torchao_available,
  39. strtobool,
  40. )
  41. if TYPE_CHECKING:
  42. from .modeling_utils import PreTrainedModel
  43. from .training_args import TrainingArguments
  44. logger = logging.getLogger(__name__)
  45. @dataclass
  46. class OptimizerContext:
  47. """Context object passed to all optimizer handlers."""
  48. args: TrainingArguments
  49. model: PreTrainedModel | None
  50. optimizer_kwargs: dict[str, Any]
  51. adam_kwargs: dict[str, Any]
  52. optim_args: dict[str, str]
  53. def _parse_optim_args(optim_args_str: str | None) -> dict[str, str]:
  54. """Parse optimizer arguments from a comma-separated string."""
  55. if not optim_args_str:
  56. return {}
  57. optim_args = {}
  58. for mapping in optim_args_str.replace(" ", "").split(","):
  59. key, value = mapping.split("=")
  60. optim_args[key] = value
  61. return optim_args
  62. # Type alias for optimizer handler functions
  63. OptimizerHandler = Callable[[OptimizerContext], tuple[Any, dict[str, Any]]]
  64. def is_optimizer_factory(optimizer_cls_or_factory: Any) -> bool:
  65. """
  66. Check if the returned value from a handler is a factory rather than an Optimizer class.
  67. Factory callables are used for complex optimizers like Muon or Dion that need to:
  68. - Split parameters between multiple internal optimizers
  69. - Handle complex sharding logic
  70. - Access the full model structure for parameter grouping
  71. Args:
  72. optimizer_cls_or_factory: The first element returned by an optimizer handler.
  73. Returns:
  74. `bool`: True if it's not an Optimizer class (i.e., likely a factory), False if it's an Optimizer class.
  75. """
  76. # If it's a class that's a subclass of torch.optim.Optimizer, it's not a factory
  77. if isinstance(optimizer_cls_or_factory, type) and issubclass(optimizer_cls_or_factory, torch.optim.Optimizer):
  78. return False
  79. return True
  80. def _setup_low_rank_optimizer(
  81. args: TrainingArguments,
  82. model: PreTrainedModel,
  83. optimizer_name: str,
  84. optimizer_mapping: dict[str, Any],
  85. optim_kwargs: dict[str, Any],
  86. optimizer_kwargs: dict[str, Any],
  87. is_layerwise_supported: bool = True,
  88. ) -> tuple[Any, dict[str, Any]]:
  89. """
  90. Helper function to set up low-rank optimizers like GaLore and Apollo.
  91. These optimizers apply low-rank projections to specific target modules (typically linear layers).
  92. """
  93. is_layerwise = optimizer_name.lower().endswith("layerwise")
  94. if is_layerwise and args.parallel_mode == ParallelMode.DISTRIBUTED and is_layerwise_supported:
  95. raise NotImplementedError(f"Layer-wise {optimizer_name} does not support DDP at this time")
  96. optimizer_cls = optimizer_mapping[optimizer_name]
  97. if args.optim_target_modules is None:
  98. raise ValueError(f"You need to define `optim_target_modules` to use {optimizer_name} optimizers")
  99. if not isinstance(args.optim_target_modules, (list, str)):
  100. raise TypeError(
  101. f"`optim_target_modules` must be a list of strings, a regex string, or 'all-linear'. "
  102. f"Got: {args.optim_target_modules}"
  103. )
  104. if model is None:
  105. raise ValueError(f"You need to pass a model to initialize {optimizer_name} optimizer.")
  106. all_linear = (
  107. isinstance(args.optim_target_modules, str) and args.optim_target_modules.replace("_", "-") == "all-linear"
  108. )
  109. target_params_names = []
  110. for module_name, module in model.named_modules():
  111. target_module_exists, is_regex = check_target_module_exists(
  112. args.optim_target_modules, module_name, return_is_regex=True
  113. )
  114. if not isinstance(module, nn.Linear):
  115. if target_module_exists and not is_regex:
  116. logger.warning(f"{module_name} matched but ignored. {optimizer_name} only supports linear layers.")
  117. continue
  118. if not target_module_exists and not all_linear:
  119. continue
  120. target_params_names.append(module_name + ".weight")
  121. if len(target_params_names) == 0:
  122. raise ValueError(f"No target modules found for {optimizer_name} ({args.optim_target_modules}).")
  123. target_params = [p for n, p in model.named_parameters() if n in target_params_names]
  124. non_target_params = [p for n, p in model.named_parameters() if n not in target_params_names]
  125. param_groups = [
  126. {"params": non_target_params},
  127. {"params": target_params, **optim_kwargs},
  128. ]
  129. if is_layerwise:
  130. if args.gradient_accumulation_steps != 1:
  131. raise ValueError(f"Layerwise {optimizer_name} does not support gradient accumulation!")
  132. optimizer_dict = {}
  133. for param in non_target_params:
  134. optimizer_dict[param] = optimizer_cls([{"params": [param]}], **optimizer_kwargs)
  135. for param in target_params:
  136. optimizer_dict[param] = optimizer_cls([{"params": [param], **optim_kwargs}], **optimizer_kwargs)
  137. def optimizer_hook(param):
  138. if param.grad is not None:
  139. optimizer_dict[param].step()
  140. optimizer_dict[param].zero_grad()
  141. for param in model.parameters():
  142. if param.requires_grad:
  143. param.register_post_accumulate_grad_hook(optimizer_hook)
  144. optimizer_cls = LayerWiseDummyOptimizer
  145. optimizer_kwargs.update({"optimizer_dict": optimizer_dict})
  146. optimizer_kwargs.update({"params": param_groups})
  147. return optimizer_cls, optimizer_kwargs
  148. # =============================================================================
  149. # Individual optimizer handlers
  150. # =============================================================================
  151. def _get_adafactor(ctx: OptimizerContext) -> tuple[Any, dict[str, Any]]:
  152. """Get Adafactor optimizer."""
  153. ctx.optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
  154. return Adafactor, ctx.optimizer_kwargs
  155. def _get_adamw_torch(ctx: OptimizerContext) -> tuple[Any, dict[str, Any]]:
  156. """Get PyTorch AdamW optimizer (regular or fused)."""
  157. from torch.optim import AdamW
  158. ctx.optimizer_kwargs.update(ctx.adam_kwargs)
  159. if ctx.args.optim == OptimizerNames.ADAMW_TORCH_FUSED:
  160. ctx.optimizer_kwargs.update({"fused": True})
  161. return AdamW, ctx.optimizer_kwargs
  162. def _get_adamw_torch_xla(ctx: OptimizerContext) -> tuple[Any, dict[str, Any]]:
  163. """Get Torch XLA syncfree AdamW optimizer."""
  164. try:
  165. from torch_xla.amp.syncfree import AdamW
  166. ctx.optimizer_kwargs.update(ctx.adam_kwargs)
  167. return AdamW, ctx.optimizer_kwargs
  168. except ImportError:
  169. raise ValueError("Trainer failed to import syncfree AdamW from torch_xla.")
  170. def _get_adamw_torch_npu_fused(ctx: OptimizerContext) -> tuple[Any, dict[str, Any]]:
  171. """Get NPU Fused AdamW optimizer."""
  172. try:
  173. from torch_npu.optim import NpuFusedAdamW
  174. ctx.optimizer_kwargs.update(ctx.adam_kwargs)
  175. return NpuFusedAdamW, ctx.optimizer_kwargs
  176. except ImportError:
  177. raise ValueError("Trainer failed to import FusedAdamW from torch_npu.")
  178. def _get_adamw_apex_fused(ctx: OptimizerContext) -> tuple[Any, dict[str, Any]]:
  179. """Get Apex Fused Adam optimizer."""
  180. try:
  181. from apex.optimizers import FusedAdam
  182. ctx.optimizer_kwargs.update(ctx.adam_kwargs)
  183. return FusedAdam, ctx.optimizer_kwargs
  184. except ImportError:
  185. raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!")
  186. def _get_bitsandbytes_optimizer(ctx: OptimizerContext) -> tuple[Any, dict[str, Any]]:
  187. """Get bitsandbytes optimizer (AdamW, Lion, RMSprop variants)."""
  188. if not is_bitsandbytes_available():
  189. raise ImportError(
  190. "You need to install `bitsandbytes` in order to use bitsandbytes optimizers: `pip install -U bitsandbytes`"
  191. )
  192. from bitsandbytes.optim import AdamW, Lion, RMSprop
  193. optim_name = ctx.args.optim
  194. is_paged = "paged" in optim_name
  195. optim_bits = 8 if "8bit" in optim_name else 32
  196. optimizer_cls = None
  197. additional_optim_kwargs = ctx.adam_kwargs
  198. if "adam" in optim_name:
  199. optimizer_cls = AdamW
  200. elif "lion" in optim_name:
  201. optimizer_cls = Lion
  202. additional_optim_kwargs = {"betas": (ctx.args.adam_beta1, ctx.args.adam_beta2)}
  203. elif "rmsprop" in optim_name:
  204. optimizer_cls = RMSprop
  205. additional_optim_kwargs = ctx.optim_args
  206. elif "ademamix" in optim_name:
  207. from bitsandbytes.optim import AdEMAMix
  208. optimizer_cls = AdEMAMix
  209. additional_optim_kwargs = {
  210. "betas": (
  211. float(ctx.optim_args.get("beta1", ctx.args.adam_beta1)),
  212. float(ctx.optim_args.get("beta2", ctx.args.adam_beta2)),
  213. float(ctx.optim_args.get("beta3", 0.9999)),
  214. ),
  215. "alpha": float(ctx.optim_args.get("alpha", 5.0)),
  216. "eps": float(ctx.optim_args.get("eps", ctx.args.adam_epsilon)),
  217. }
  218. if "t_alpha" in ctx.optim_args:
  219. additional_optim_kwargs["t_alpha"] = int(ctx.optim_args["t_alpha"])
  220. if "t_beta3" in ctx.optim_args:
  221. additional_optim_kwargs["t_beta3"] = int(ctx.optim_args["t_beta3"])
  222. bnb_kwargs = {"optim_bits": optim_bits}
  223. if "rmsprop" not in optim_name:
  224. bnb_kwargs["is_paged"] = is_paged
  225. ctx.optimizer_kwargs.update(additional_optim_kwargs)
  226. ctx.optimizer_kwargs.update(bnb_kwargs)
  227. return optimizer_cls, ctx.optimizer_kwargs
  228. def _get_adamw_anyprecision(ctx: OptimizerContext) -> tuple[Any, dict[str, Any]]:
  229. """Get AnyPrecision AdamW optimizer."""
  230. try:
  231. from torchdistx.optimizers import AnyPrecisionAdamW
  232. ctx.optimizer_kwargs.update(ctx.adam_kwargs)
  233. ctx.optimizer_kwargs.update(
  234. {
  235. "use_kahan_summation": strtobool(ctx.optim_args.get("use_kahan_summation", "False")),
  236. "momentum_dtype": getattr(torch, ctx.optim_args.get("momentum_dtype", "float32")),
  237. "variance_dtype": getattr(torch, ctx.optim_args.get("variance_dtype", "float32")),
  238. "compensation_buffer_dtype": getattr(
  239. torch, ctx.optim_args.get("compensation_buffer_dtype", "bfloat16")
  240. ),
  241. }
  242. )
  243. return AnyPrecisionAdamW, ctx.optimizer_kwargs
  244. except ImportError:
  245. raise ValueError("Please install https://github.com/pytorch/torchdistx")
  246. def _get_sgd(ctx: OptimizerContext) -> tuple[Any, dict[str, Any]]:
  247. """Get SGD optimizer."""
  248. kwargs = ctx.optimizer_kwargs.copy()
  249. if ctx.optim_args:
  250. for key in ("momentum", "dampening", "weight_decay"):
  251. if key in ctx.optim_args:
  252. kwargs[key] = float(ctx.optim_args[key])
  253. if "nesterov" in ctx.optim_args:
  254. kwargs["nesterov"] = ctx.optim_args["nesterov"].lower() in ("true", "1", "yes")
  255. return torch.optim.SGD, kwargs
  256. def _get_adagrad(ctx: OptimizerContext) -> tuple[Any, dict[str, Any]]:
  257. """Get Adagrad optimizer."""
  258. kwargs = ctx.optimizer_kwargs.copy()
  259. if ctx.optim_args:
  260. for key in ("lr_decay", "weight_decay", "eps"):
  261. if key in ctx.optim_args:
  262. kwargs[key] = float(ctx.optim_args[key])
  263. return torch.optim.Adagrad, kwargs
  264. def _get_rmsprop(ctx: OptimizerContext) -> tuple[Any, dict[str, Any]]:
  265. """Get RMSprop optimizer."""
  266. kwargs = ctx.optimizer_kwargs.copy()
  267. if ctx.optim_args:
  268. for key in ("momentum", "alpha", "eps", "weight_decay"):
  269. if key in ctx.optim_args:
  270. kwargs[key] = float(ctx.optim_args[key])
  271. if "centered" in ctx.optim_args:
  272. kwargs["centered"] = ctx.optim_args["centered"].lower() in ("true", "1", "yes")
  273. return torch.optim.RMSprop, kwargs
  274. def _get_galore_optimizer(ctx: OptimizerContext) -> tuple[Any, dict[str, Any]]:
  275. """Get GaLore optimizer."""
  276. if not is_galore_torch_available():
  277. raise ImportError(
  278. "You need to install `galore_torch` in order to use GaLore optimizers. "
  279. "Install it with `pip install git+https://github.com/jiaweizzhao/GaLore`"
  280. )
  281. from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit
  282. optimizer_mapping = {
  283. OptimizerNames.GALORE_ADAMW: GaLoreAdamW,
  284. OptimizerNames.GALORE_ADAMW_8BIT: GaLoreAdamW8bit,
  285. OptimizerNames.GALORE_ADAFACTOR: GaLoreAdafactor,
  286. OptimizerNames.GALORE_ADAMW_LAYERWISE: GaLoreAdamW,
  287. OptimizerNames.GALORE_ADAMW_8BIT_LAYERWISE: GaLoreAdamW8bit,
  288. OptimizerNames.GALORE_ADAFACTOR_LAYERWISE: GaLoreAdafactor,
  289. }
  290. galore_optim_kwargs = {
  291. "rank": int(ctx.optim_args.pop("rank", 128)),
  292. "update_proj_gap": int(ctx.optim_args.pop("update_proj_gap", 200)),
  293. "scale": float(ctx.optim_args.pop("scale", 0.25)),
  294. "proj_type": ctx.optim_args.pop("proj_type", "std"),
  295. }
  296. optimizer_cls, optimizer_kwargs = _setup_low_rank_optimizer(
  297. ctx.args, ctx.model, ctx.args.optim, optimizer_mapping, galore_optim_kwargs, ctx.optimizer_kwargs
  298. )
  299. if ctx.args.optim == OptimizerNames.GALORE_ADAFACTOR:
  300. optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
  301. return optimizer_cls, optimizer_kwargs
  302. def _get_apollo_optimizer(ctx: OptimizerContext) -> tuple[Any, dict[str, Any]]:
  303. """Get Apollo optimizer."""
  304. if not is_apollo_torch_available():
  305. raise ImportError(
  306. "You need to install `apollo_torch` in order to use APOLLO optimizers. "
  307. "Install it with `pip install git+https://github.com/zhuhanqing/APOLLO`"
  308. )
  309. from apollo_torch import APOLLOAdamW
  310. optimizer_mapping = {
  311. OptimizerNames.APOLLO_ADAMW: APOLLOAdamW,
  312. OptimizerNames.APOLLO_ADAMW_LAYERWISE: APOLLOAdamW,
  313. }
  314. apollo_optim_kwargs = {
  315. "rank": int(ctx.optim_args.pop("rank", 128)),
  316. "proj": ctx.optim_args.pop("proj", "random"),
  317. "scale_type": ctx.optim_args.pop("scale_type", "channel"),
  318. "update_proj_gap": int(ctx.optim_args.pop("update_proj_gap", 200)),
  319. "scale": float(ctx.optim_args.pop("scale", 1.0)),
  320. "proj_type": ctx.optim_args.pop("proj_type", "std"),
  321. }
  322. apollo_optim_kwargs.update(ctx.adam_kwargs)
  323. return _setup_low_rank_optimizer(
  324. ctx.args, ctx.model, ctx.args.optim, optimizer_mapping, apollo_optim_kwargs, ctx.optimizer_kwargs
  325. )
  326. def _get_lomo_optimizer(ctx: OptimizerContext) -> tuple[Any, dict[str, Any]]:
  327. """Get LOMO optimizer."""
  328. if not is_lomo_available():
  329. raise ImportError(
  330. "You need to install `lomo_optim` in order to use LOMO optimizers. "
  331. "Install it with `pip install lomo-optim`"
  332. )
  333. if ctx.model is None:
  334. raise ValueError("You need to pass a `model` in order to correctly initialize a LOMO optimizer.")
  335. from lomo_optim import AdaLomo, Lomo
  336. optimizer_cls = AdaLomo if "ada" in ctx.args.optim else Lomo
  337. ctx.optimizer_kwargs.update({"model": ctx.model})
  338. return optimizer_cls, ctx.optimizer_kwargs
  339. def _get_grokadamw(ctx: OptimizerContext) -> tuple[Any, dict[str, Any]]:
  340. """Get GrokAdamW optimizer."""
  341. if not is_grokadamw_available():
  342. raise ValueError("Please install grokadamw with `pip install grokadamw`")
  343. from grokadamw import GrokAdamW
  344. ctx.optimizer_kwargs.update(
  345. {
  346. "alpha_init": float(ctx.optim_args.get("alpha_init", 0.98)),
  347. "lamb": float(ctx.optim_args.get("lamb", 2.0)),
  348. "gamma": float(ctx.optim_args.get("gamma", 0.1)),
  349. "grokking_signal_decay_rate": float(ctx.optim_args.get("grokking_signal_decay_rate", 0.1)),
  350. "gradient_clipping": float(ctx.optim_args.get("gradient_clipping", 1.0)),
  351. }
  352. )
  353. return GrokAdamW, ctx.optimizer_kwargs
  354. def _get_torchao_optimizer(ctx: OptimizerContext) -> tuple[Any, dict[str, Any]]:
  355. """Get TorchAO 4-bit or 8-bit optimizer."""
  356. if not is_torchao_available() or version.parse(importlib.metadata.version("torchao")) < version.parse("0.4.0"):
  357. raise ImportError(
  358. "You need to have `torchao>=0.4.0` in order to use torch 4-bit optimizers. "
  359. "Install it with `pip install torchao` or follow the instructions here: "
  360. "https://github.com/pytorch/ao"
  361. )
  362. if version.parse(importlib.metadata.version("torch")) <= version.parse("2.4"):
  363. raise ImportError(
  364. "You need to have `torch>2.4` in order to use torch 4-bit optimizers. "
  365. "Install it with `pip install --upgrade torch` it is available on pipy. "
  366. "Otherwise, you need to install torch nightly."
  367. )
  368. if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.11.0"):
  369. from torchao.optim import AdamW4bit, AdamW8bit
  370. else:
  371. from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit
  372. if ctx.args.optim == OptimizerNames.ADAMW_TORCH_4BIT:
  373. optimizer_cls = AdamW4bit
  374. else:
  375. optimizer_cls = AdamW8bit
  376. ctx.optimizer_kwargs.update(
  377. {
  378. "block_size": ctx.optim_args.get("block_size", 256),
  379. "bf16_stochastic_round": strtobool(ctx.optim_args.get("bf16_stochastic_round", "False")),
  380. }
  381. )
  382. ctx.optimizer_kwargs.update(ctx.adam_kwargs)
  383. return optimizer_cls, ctx.optimizer_kwargs
  384. def _get_schedule_free_optimizer(ctx: OptimizerContext) -> tuple[Any, dict[str, Any]]:
  385. """Get ScheduleFree optimizer."""
  386. if not is_schedulefree_available():
  387. raise ImportError(
  388. "You need to install `schedulefree` in order to use schedulefree optimizers. "
  389. "Install it with `pip install schedulefree.`"
  390. )
  391. from schedulefree import AdamWScheduleFree, SGDScheduleFree
  392. additional_optim_kwargs = {}
  393. require_warmup = True
  394. if ctx.args.optim == OptimizerNames.SCHEDULE_FREE_RADAM:
  395. if not is_schedulefree_available("1.4.0"):
  396. raise ImportError(
  397. "You need to install `schedulefree>=1.4.0` in order to use RAdamScheduleFree optimizer. "
  398. "Install it with `pip install schedulefree.`"
  399. )
  400. from schedulefree import RAdamScheduleFree
  401. optimizer_cls = RAdamScheduleFree
  402. additional_optim_kwargs = ctx.adam_kwargs
  403. require_warmup = False
  404. elif ctx.args.optim == OptimizerNames.SCHEDULE_FREE_ADAMW:
  405. optimizer_cls = AdamWScheduleFree
  406. additional_optim_kwargs = ctx.adam_kwargs
  407. elif ctx.args.optim == OptimizerNames.SCHEDULE_FREE_SGD:
  408. optimizer_cls = SGDScheduleFree
  409. else:
  410. raise ValueError("Invalid schedulefree optimizer")
  411. additional_optim_kwargs["weight_decay"] = ctx.args.weight_decay
  412. if require_warmup:
  413. additional_optim_kwargs["warmup_steps"] = ctx.args.warmup_steps
  414. additional_optim_kwargs.update(
  415. {
  416. "weight_lr_power": float(ctx.optim_args.get("weight_lr_power", 2.0)),
  417. "r": float(ctx.optim_args.get("r", 0.0)),
  418. }
  419. )
  420. ctx.optimizer_kwargs.update(additional_optim_kwargs)
  421. return optimizer_cls, ctx.optimizer_kwargs
  422. def _get_stable_adamw(ctx: OptimizerContext) -> tuple[Any, dict[str, Any]]:
  423. """Get StableAdamW optimizer from torch-optimi."""
  424. if not is_torch_optimi_available():
  425. raise ImportError(
  426. "You need to install `torch-optimi` in order to use stable_adamw optimizers. "
  427. "Install it with `pip install torch-optimi`."
  428. )
  429. from optimi import StableAdamW
  430. max_lr = ctx.optim_args.pop("max_lr", None)
  431. if max_lr is not None:
  432. max_lr = float(max_lr)
  433. kahan_sum = ctx.optim_args.pop("kahan_sum", None)
  434. if kahan_sum is not None:
  435. kahan_sum = bool(kahan_sum)
  436. ctx.adam_kwargs["weight_decay"] = ctx.args.weight_decay
  437. stable_adamw_kwargs = {
  438. "decouple_lr": bool(ctx.optim_args.pop("decouple_lr", False)),
  439. "max_lr": max_lr,
  440. "kahan_sum": kahan_sum,
  441. }
  442. ctx.optimizer_kwargs.update(ctx.adam_kwargs)
  443. ctx.optimizer_kwargs.update(stable_adamw_kwargs)
  444. return StableAdamW, ctx.optimizer_kwargs
  445. # =============================================================================
  446. # Dispatch table
  447. # =============================================================================
  448. _BITSANDBYTES_OPTIMIZERS = [
  449. OptimizerNames.ADAMW_BNB,
  450. OptimizerNames.ADAMW_8BIT,
  451. OptimizerNames.PAGED_ADAMW,
  452. OptimizerNames.PAGED_ADAMW_8BIT,
  453. OptimizerNames.ADEMAMIX,
  454. OptimizerNames.ADEMAMIX_8BIT,
  455. OptimizerNames.PAGED_ADEMAMIX,
  456. OptimizerNames.PAGED_ADEMAMIX_8BIT,
  457. OptimizerNames.LION,
  458. OptimizerNames.LION_8BIT,
  459. OptimizerNames.PAGED_LION,
  460. OptimizerNames.PAGED_LION_8BIT,
  461. OptimizerNames.RMSPROP_BNB,
  462. OptimizerNames.RMSPROP_8BIT,
  463. OptimizerNames.RMSPROP_32BIT,
  464. ]
  465. _GALORE_OPTIMIZERS = [
  466. OptimizerNames.GALORE_ADAMW,
  467. OptimizerNames.GALORE_ADAMW_8BIT,
  468. OptimizerNames.GALORE_ADAFACTOR,
  469. OptimizerNames.GALORE_ADAMW_LAYERWISE,
  470. OptimizerNames.GALORE_ADAMW_8BIT_LAYERWISE,
  471. OptimizerNames.GALORE_ADAFACTOR_LAYERWISE,
  472. ]
  473. _APOLLO_OPTIMIZERS = [
  474. OptimizerNames.APOLLO_ADAMW,
  475. OptimizerNames.APOLLO_ADAMW_LAYERWISE,
  476. ]
  477. _TORCHAO_OPTIMIZERS = [
  478. OptimizerNames.ADAMW_TORCH_4BIT,
  479. OptimizerNames.ADAMW_TORCH_8BIT,
  480. ]
  481. _SCHEDULE_FREE_OPTIMIZERS = [
  482. OptimizerNames.SCHEDULE_FREE_RADAM,
  483. OptimizerNames.SCHEDULE_FREE_ADAMW,
  484. OptimizerNames.SCHEDULE_FREE_SGD,
  485. ]
  486. # =============================================================================
  487. # Built-in optimizer handlers registry
  488. # =============================================================================
  489. _OPTIMIZER_HANDLERS: dict[str, OptimizerHandler] = {
  490. OptimizerNames.ADAFACTOR: _get_adafactor,
  491. OptimizerNames.ADAMW_TORCH: _get_adamw_torch,
  492. OptimizerNames.ADAMW_TORCH_FUSED: _get_adamw_torch,
  493. OptimizerNames.ADAMW_TORCH_XLA: _get_adamw_torch_xla,
  494. OptimizerNames.ADAMW_TORCH_NPU_FUSED: _get_adamw_torch_npu_fused,
  495. OptimizerNames.ADAMW_APEX_FUSED: _get_adamw_apex_fused,
  496. OptimizerNames.ADAMW_ANYPRECISION: _get_adamw_anyprecision,
  497. OptimizerNames.SGD: _get_sgd,
  498. OptimizerNames.ADAGRAD: _get_adagrad,
  499. OptimizerNames.RMSPROP: _get_rmsprop,
  500. OptimizerNames.GROKADAMW: _get_grokadamw,
  501. OptimizerNames.STABLE_ADAMW: _get_stable_adamw,
  502. OptimizerNames.LOMO: _get_lomo_optimizer,
  503. OptimizerNames.ADALOMO: _get_lomo_optimizer,
  504. **dict.fromkeys(_BITSANDBYTES_OPTIMIZERS, _get_bitsandbytes_optimizer),
  505. **dict.fromkeys(_GALORE_OPTIMIZERS, _get_galore_optimizer),
  506. **dict.fromkeys(_APOLLO_OPTIMIZERS, _get_apollo_optimizer),
  507. **dict.fromkeys(_TORCHAO_OPTIMIZERS, _get_torchao_optimizer),
  508. **dict.fromkeys(_SCHEDULE_FREE_OPTIMIZERS, _get_schedule_free_optimizer),
  509. }