optimizer.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421
  1. """
  2. This module implements variable tracking for PyTorch optimizers during Dynamo tracing.
  3. The OptimizerVariable class provides specialized handling for optimizer instances by:
  4. - Optimizing the tracing of expensive optimizer initialization
  5. - Managing optimizer state and parameter group tracking
  6. - Handling tensor sources and guards for optimizer state tensors
  7. - Supporting CUDA graph execution through static tensor address management
  8. - Providing special handling for parameter gradients and optimizer state tensors
  9. Key features include:
  10. - Efficient initialization tracing via _init_group optimization
  11. - Automatic marking of optimizer state tensors as static for CUDA graphs
  12. - Proper source tracking for parameter groups, gradients, and state tensors
  13. - Guard installation for optimizer state structure
  14. - Support for both CPU and GPU tensor handling
  15. - Cleanup of static tensor references via finalizers
  16. The module integrates with Dynamo's broader tracing system while providing
  17. optimizer-specific optimizations and safety guarantees.
  18. """
  19. import logging
  20. import weakref
  21. from collections.abc import Iterable
  22. from typing import Any, Optional, TYPE_CHECKING
  23. import torch
  24. from torch._dynamo.variables.tensor import TensorVariable
  25. from torch._guards import Source
  26. from torch._logging import getArtifactLogger
  27. from torch.utils._pytree import tree_map_only
  28. from ..guards import GuardBuilder, install_guard
  29. from ..source import (
  30. AttrSource,
  31. ConstDictKeySource,
  32. DictGetItemSource,
  33. GetItemSource,
  34. GlobalWeakRefSource,
  35. GradSource,
  36. )
  37. from ..utils import GLOBAL_KEY_PREFIX
  38. from .base import VariableTracker
  39. from .constant import ConstantVariable
  40. from .dicts import ConstDictVariable
  41. from .lists import ListVariable
  42. from .misc import GetAttrVariable
  43. from .user_defined import UserDefinedObjectVariable
  44. if TYPE_CHECKING:
  45. from torch._dynamo.symbolic_convert import InstructionTranslator
  46. class ArgMappingException(Exception):
  47. pass
  48. class GuardInstallException(Exception):
  49. pass
  50. perf_hint_log = getArtifactLogger(__name__, "perf_hints")
  51. def _is_static_for_cudagraphs(x: torch.Tensor) -> bool:
  52. from torch._inductor.cudagraph_trees import get_manager
  53. if x.is_cuda:
  54. manager = get_manager(x.device.index, False)
  55. is_static_address = torch._dynamo.utils.get_static_address_type(x) is not None
  56. if manager:
  57. assert manager.current_node is not None
  58. return (
  59. is_static_address
  60. or manager.current_node._is_cuda_graph_recorded_tensor(x)
  61. )
  62. else:
  63. return is_static_address
  64. else:
  65. # Don't print a warning for non-cuda tensors
  66. return True
  67. class OptimizerVariable(UserDefinedObjectVariable):
  68. _nonvar_fields = {
  69. "grad_to_source",
  70. "tensor_to_source",
  71. "static_tensor_names",
  72. *UserDefinedObjectVariable._nonvar_fields,
  73. }
  74. def __init__(
  75. self,
  76. value: torch.optim.Optimizer,
  77. grad_to_source: Optional[dict[Any, GradSource]] = None,
  78. static_tensor_names: Optional[set[str]] = None,
  79. tensor_to_source: Optional[dict[torch.Tensor, Source]] = None,
  80. **kwargs: Any,
  81. ) -> None:
  82. super().__init__(value, **kwargs)
  83. # pyrefly: ignore [bad-override]
  84. self.value: torch.optim.Optimizer = value
  85. self.grad_to_source = grad_to_source or {}
  86. self.tensor_to_source = tensor_to_source or {}
  87. self.static_tensor_names = static_tensor_names or set()
  88. def call_method(
  89. self,
  90. tx: "InstructionTranslator",
  91. name: str,
  92. args: list[VariableTracker],
  93. kwargs: dict[str, VariableTracker],
  94. ) -> "VariableTracker":
  95. """This is an optimization to avoid tracing the very slow initialization of the optimizer"""
  96. if name == "_init_group":
  97. if not hasattr(self.value, "_init_group"):
  98. # Fallback: if the optimizer does not have _init_group, trace normally
  99. return super().call_method(tx, name, args, kwargs)
  100. try:
  101. self.graph_break_if_pending_mutation(tx)
  102. self.move_step_if_cpu()
  103. py_args, py_kwargs = self.get_python_args(*args, **kwargs)
  104. ret_val = self.value._init_group(*py_args, **py_kwargs)
  105. self.map_sources_and_install_guards(tx)
  106. self.update_list_args(tx, args, kwargs, py_args, py_kwargs)
  107. # stash a weak_ptr to optimizer to invalidate code
  108. # if the optimizer object dies
  109. mangled_name = f"__optimizer_{id(self.value)}"
  110. tx.store_global_weakref_by_id(mangled_name, self.value)
  111. self.create_finalizer(tx)
  112. # This is currently safe only because the only actual `ret_val`s returned
  113. # by the `_init_group` of existing optimizers are properties that are invariant
  114. # to the input tensors (e.g. dtype, layout). Changing these would trigger a
  115. # recompilation and hence never result in the wrong specialization of `ret_val`.
  116. return ConstantVariable.create(ret_val)
  117. except (ArgMappingException, GuardInstallException) as _:
  118. # trace normally if we can't map args or install guards correctly
  119. pass
  120. return super().call_method(tx, name, args, kwargs)
  121. def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
  122. # Note: this allows us to intercept the call in call_method
  123. # in the typical case, we return a UserMethodVariable
  124. # which will directly inline
  125. if name in ("_init_group"):
  126. assert self.source
  127. return GetAttrVariable(self, name, source=AttrSource(self.source, name))
  128. if name == "param_groups":
  129. from ..decorators import mark_static_address
  130. for group in self.value.param_groups:
  131. for p in group["params"]:
  132. mark_static_address(p, guard=True)
  133. self._set_capturable(tx)
  134. return super().var_getattr(tx, name)
  135. def graph_break_if_pending_mutation(self, tx: "InstructionTranslator") -> None:
  136. # If there are pending mutations on a parameter (due to using closure)
  137. # then we need to graph break to allow the python version of the parameter
  138. # to update, so that running _init_group will initialize the states with
  139. # the correct values
  140. for g in self.value.param_groups:
  141. for p in g["params"]:
  142. side_effects = tx.output.side_effects
  143. variable = side_effects.id_to_variable.get(id(p), None)
  144. if variable and side_effects.has_pending_mutation(variable):
  145. from ..exc import unimplemented
  146. unimplemented(
  147. gb_type="optimizer: pending mutation on parameter",
  148. context=f"variable: {variable}, parameter: {p}",
  149. explanation="Pending mutations on a parameter (e.g. due to using closure) require a graph break.",
  150. hints=[],
  151. )
  152. def _set_capturable(self, tx: "InstructionTranslator") -> None:
  153. from . import LazyVariableTracker
  154. # We only set capturable if params are on cuda
  155. # and the state is not initialized
  156. def safe_to_set_capturable(group: dict[str, Any]) -> bool:
  157. all_uninitialized = True
  158. all_gpu = True
  159. for p in group.get("params", []):
  160. all_gpu &= p.is_cuda or p.is_xpu
  161. all_uninitialized &= p not in self.value.state
  162. return "capturable" in group and all_uninitialized and all_gpu
  163. # track indices to not set so we don't need to
  164. # in the variable tracker realize the whole state
  165. # we handle guarding the state specially
  166. for group in self.value.param_groups:
  167. if safe_to_set_capturable(group):
  168. group["capturable"] = True
  169. source = self.source and AttrSource(self.source, "param_groups")
  170. param_groups_vt = LazyVariableTracker.realize_all(
  171. VariableTracker.build(tx, self.value.param_groups, source)
  172. )
  173. for param_group_vt in param_groups_vt.items:
  174. key = ConstDictVariable._HashableTracker(
  175. ConstantVariable.create("capturable")
  176. )
  177. param_group_vt.items[key] = ConstantVariable.create(True)
  178. def get_python_args(
  179. self, *args: Any, **kwargs: Any
  180. ) -> tuple[list[Any], dict[str, Any]]:
  181. """Get python values equivalent to the variable tracker args"""
  182. def map_arg(arg: Any) -> Any:
  183. if isinstance(arg, VariableTracker) and arg.is_python_constant():
  184. return arg.as_python_constant()
  185. elif isinstance(arg, ListVariable) and not arg.items:
  186. # pyrefly: ignore [implicit-any]
  187. return []
  188. elif (
  189. isinstance(arg, ConstDictVariable)
  190. and isinstance(arg.source, GetItemSource)
  191. and isinstance(arg.source.base, AttrSource)
  192. and arg.source.base.member == "param_groups"
  193. ):
  194. return self.value.param_groups[arg.source.index]
  195. raise ArgMappingException
  196. new_args = [map_arg(arg) for arg in args]
  197. new_kwargs = {k: map_arg(v) for k, v in kwargs.items()}
  198. return new_args, new_kwargs
  199. # If users load an old state dictionary,
  200. # it's possible that step could be on the cpu
  201. # if this is the case, move it to the GPU
  202. # corresponding to the parameter
  203. # in most cases this is a no-op because the state is empty
  204. def move_step_if_cpu(self) -> None:
  205. for p, state in self.value.state.items():
  206. if "step" in state and state["step"].is_cpu:
  207. state["step"] = state["step"].to(p.device)
  208. def map_sources_and_install_guards(self, tx: "InstructionTranslator") -> None:
  209. from ..decorators import mark_static_address
  210. from .lazy import LazyVariableTracker
  211. self.grad_to_source = {}
  212. self.tensor_to_source = {}
  213. def mark_static(x: Any) -> None:
  214. mark_static_address(x, guard=True)
  215. tree_map_only(torch.Tensor, mark_static, self.value.state)
  216. # Recursively realize the variable trackers for optim.state and
  217. # optim.param_groups, which recursively install the necessary guards.
  218. params_groups_source = self.source and AttrSource(self.source, "param_groups")
  219. param_groups_vt = LazyVariableTracker.realize_all(
  220. VariableTracker.build(tx, self.value.param_groups, params_groups_source)
  221. )
  222. state_source = self.source and AttrSource(self.source, "state")
  223. state_vt = VariableTracker.build(tx, self.value.state, state_source)
  224. # We need to realize the top level state dict to populate
  225. # the guard locals
  226. state_vt.realize()
  227. assert state_source is not None
  228. tx.output.guard_on_key_order.add(state_source)
  229. # Populate self.grad_to_source and self.tensor_to_source so that we can
  230. # manually update_list_args
  231. for group, group_vt in zip(self.value.param_groups, param_groups_vt.items):
  232. # we assume here that all params within a param group
  233. # are initialized similarly
  234. if len(group["params"]) > 0:
  235. for param in group["params"]:
  236. if param.grad is not None:
  237. key_index = None
  238. for i, k in enumerate(self.value.state.keys()):
  239. if k is param:
  240. key_index = i
  241. break
  242. if key_index:
  243. LazyVariableTracker.realize_all(
  244. VariableTracker.build(
  245. tx,
  246. self.value.state[param],
  247. DictGetItemSource(
  248. state_source,
  249. ConstDictKeySource(state_source, key_index),
  250. ),
  251. )
  252. )
  253. break
  254. params_vt = group_vt.getitem_const(tx, ConstantVariable.create("params"))
  255. all_static = True
  256. non_static_grads = []
  257. for p, p_vt in zip(group["params"], params_vt.unpack_var_sequence(tx)):
  258. param_source = p_vt.source
  259. self.tensor_to_source[p] = param_source
  260. grad_source = GradSource(
  261. param_source,
  262. "grad",
  263. )
  264. if p.grad is not None:
  265. self.grad_to_source[p.grad] = grad_source
  266. if not _is_static_for_cudagraphs(p.grad):
  267. all_static = False
  268. non_static_grads.append(grad_source)
  269. else:
  270. install_guard(grad_source.make_guard(GuardBuilder.CONSTANT_MATCH))
  271. # Note: to avoid spam logs only warn if perf hint artifact is enabled
  272. # (NB: artifacts are only enabled at the debug or warning level)
  273. if not all_static and perf_hint_log.isEnabledFor(logging.DEBUG):
  274. non_static_grad_names = [src.name for src in non_static_grads]
  275. perf_hint_log.warning(
  276. (
  277. "Grad tensors %s will be copied during cudagraphs execution."
  278. "If using cudagraphs and the grad tensor addresses will be the same across runs,"
  279. " use torch._dynamo.decorators.mark_static_address to elide this copy.",
  280. ),
  281. non_static_grad_names,
  282. )
  283. # We have to again iterate over the state dict to collect the
  284. # tensor_to_source dict. This is used for the finalizer.
  285. for idx, value in enumerate(self.value.state.values()):
  286. p_state_source = DictGetItemSource(
  287. state_source, ConstDictKeySource(state_source, idx)
  288. )
  289. tx.output.guard_on_key_order.add(p_state_source)
  290. for inner_idx, v in enumerate(value.values()):
  291. if (
  292. isinstance(v, torch.Tensor)
  293. and v not in self.grad_to_source
  294. and v not in self.tensor_to_source
  295. ):
  296. self.tensor_to_source[v] = DictGetItemSource(
  297. p_state_source, ConstDictKeySource(p_state_source, inner_idx)
  298. )
  299. def wrap_tensor(
  300. self, tx: "InstructionTranslator", tensor_value: torch.Tensor
  301. ) -> TensorVariable:
  302. """Wrap state tensor in a TensorVariable"""
  303. from ..decorators import mark_static_address
  304. # If we have a source for a tensor already use it,
  305. # if we have not seen a tensor before, stash and use a
  306. # global weak ref source, since it must be an optimizer tensor
  307. # that we have missed
  308. if tensor_value in self.tensor_to_source:
  309. # mark these tensors as static for cudagraphs
  310. mark_static_address(tensor_value, guard=True)
  311. source = self.tensor_to_source[tensor_value]
  312. self.static_tensor_names.add(tx.output.module_key_name(source.name))
  313. elif tensor_value in self.grad_to_source:
  314. source = self.grad_to_source[tensor_value]
  315. else:
  316. # mark these tensors as static for cudagraphs
  317. mark_static_address(tensor_value, guard=True)
  318. global_name = tx.store_global_weakref_by_id(GLOBAL_KEY_PREFIX, tensor_value)
  319. source = GlobalWeakRefSource(global_name)
  320. self.static_tensor_names.add(tx.output.module_key_name(source.name))
  321. return VariableTracker.build(tx, tensor_value, source)
  322. def update_list_args(
  323. self,
  324. tx: "InstructionTranslator",
  325. args: Iterable[VariableTracker],
  326. kwargs: Any,
  327. py_args: Iterable[Any],
  328. py_kwargs: Any,
  329. ) -> None:
  330. """Update the args and kwargs to the traced optimizer call"""
  331. for arg, py_arg in zip(args, py_args):
  332. if isinstance(arg, ListVariable):
  333. assert isinstance(py_arg, list), (
  334. "py_arg should be a list in optimizer variable"
  335. )
  336. for i, val in enumerate(py_arg):
  337. tx.output.side_effects.mutation(arg)
  338. if isinstance(val, torch.Tensor):
  339. arg.items.append(self.wrap_tensor(tx, val))
  340. else:
  341. source = arg.source and GetItemSource(arg.source, i)
  342. arg.items.append(VariableTracker.build(tx, val, source))
  343. def create_finalizer(self, tx: "InstructionTranslator") -> None:
  344. names_to_delete = self.static_tensor_names
  345. value = self.value
  346. tc = tx.output.tracing_context
  347. def init_finalizer(gm: torch.fx.GraphModule) -> None:
  348. def clear_static_tensor_refs() -> None:
  349. for name in names_to_delete:
  350. gm._buffers.pop(name, None)
  351. gm._parameters.pop(name, None)
  352. if tc.params_flat:
  353. tc.params_flat.clear()
  354. if tc.params_flat_unwrap_subclasses:
  355. tc.params_flat_unwrap_subclasses.clear()
  356. weakref.finalize(value, clear_static_tensor_refs)
  357. tx.output.add_graph_finalizer(init_finalizer)