_backward.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. # mypy: allow-untyped-defs
  2. # Copyright (c) Meta Platforms, Inc. and affiliates
  3. import collections
  4. import logging
  5. from collections.abc import Iterator
  6. from typing import Any
  7. import torch
  8. from torch.autograd.graph import GradientEdge, Node
  9. from torch.nn import Parameter
  10. from ._debug import map_debug_info
  11. logger = logging.getLogger(__name__)
  12. def _get_grad_fn_or_grad_acc(t: torch.Tensor) -> Node | None:
  13. """
  14. Get the grad function or grad accumulator for a tensor.
  15. Accumulate grad nodes are lazily created, so we need to a
  16. dummy view in order to trigger its creation.
  17. """
  18. if t.requires_grad and t.grad_fn is None:
  19. # if no grad function (leaf tensors) we use view
  20. viewed_t = t.view_as(t)
  21. grad_fn = viewed_t.grad_fn
  22. if grad_fn is not None:
  23. return grad_fn.next_functions[0][0]
  24. else:
  25. raise RuntimeError(
  26. "Attempted to get grad_fn, but got None."
  27. "Is this being created in a no-grad context?"
  28. )
  29. else:
  30. return t.grad_fn
  31. def reverse_closure(
  32. roots: list[Node], target_nodes: set[Node], reverse_edges_dict
  33. ) -> tuple[set[Node], set[Node]]:
  34. """
  35. This function returns the reverse closure of the given roots,
  36. i.e. the set of nodes that can be reached from the roots by following the
  37. reverse edges of the graph. The target_nodes are the nodes that we want to
  38. include in the closure.
  39. """
  40. # Recurse until we reach a target node
  41. closure: set[Node] = set()
  42. visited_target_nodes = set()
  43. q: collections.deque[Node] = collections.deque()
  44. for node in roots:
  45. if node is not None and node not in closure:
  46. closure.add(node)
  47. q.append(node)
  48. while q:
  49. node = q.popleft()
  50. reverse_edges = reverse_edges_dict[node]
  51. for fn in reverse_edges:
  52. if fn in closure or fn is None:
  53. continue
  54. if fn in target_nodes:
  55. visited_target_nodes.add(fn)
  56. continue
  57. closure.add(fn)
  58. q.append(fn)
  59. return closure, visited_target_nodes
  60. def construct_reverse_graph(roots: list[Node]) -> dict[Node, list[Node]]:
  61. q: collections.deque[Node] = collections.deque()
  62. root_seen: set[Node] = set()
  63. reverse_edges_dict: dict[Node, list[Node]] = collections.defaultdict(list)
  64. for node in roots:
  65. if node is not None and node not in root_seen:
  66. q.append(node)
  67. root_seen.add(node)
  68. while q:
  69. node = q.popleft()
  70. for fn, _ in node.next_functions:
  71. if fn is not None:
  72. if len(reverse_edges_dict[fn]) == 0:
  73. q.append(fn)
  74. reverse_edges_dict[fn].append(node)
  75. return reverse_edges_dict
  76. def get_param_groups(
  77. inputs: list[Node], params: list[Node], reverse_edges_dict
  78. ) -> list[dict[str, Any]]:
  79. """
  80. Given a list of inputs and a list of parameters, return a list of parameter
  81. groups, where each group contains the parameters and the intermediates that
  82. are connected to the parameters.
  83. The returned list of parameter groups is a list of dictionaries, where each
  84. dictionary contains the following keys:
  85. - "params": a set of parameters
  86. - "intermediates": a set of intermediates
  87. The returned list of parameter groups is a list of dictionaries,
  88. """
  89. # reverse graph that starts with inputs, and goes up to the dOutput or the loss,
  90. # but omits weights and any subgraphs connecting weights to this closure
  91. inputs_closure, _ = reverse_closure(inputs, set(), reverse_edges_dict)
  92. param_groups: dict[Node, dict[str, set]] = dict() # keyed on intermediates
  93. for param in params:
  94. closure, intersected = reverse_closure(
  95. [param], inputs_closure, reverse_edges_dict
  96. )
  97. param_group: dict[str, set] = {
  98. "params": {param},
  99. "intermediates": intersected,
  100. }
  101. for input_node in intersected:
  102. existing = param_groups.get(input_node)
  103. if existing is not None:
  104. existing["params"] = existing["params"].union(param_group["params"])
  105. existing["intermediates"] = existing["intermediates"].union(
  106. param_group["intermediates"]
  107. )
  108. param_group = existing
  109. else:
  110. param_groups[input_node] = param_group
  111. # Sanity check: union of all param_groups params should be equal to all params
  112. union_params: set[Node] = set()
  113. seen_ids: set[int] = set()
  114. unique_param_groups = []
  115. for param_group in param_groups.values():
  116. if id(param_group) not in seen_ids:
  117. seen_ids.add(id(param_group))
  118. unique_param_groups.append(param_group)
  119. union_params = union_params.union(param_group["params"])
  120. # The assert will only be true if the input tensor requires gradients,
  121. # otherwise the autograd graph will miss the first layer of inputs
  122. # assert union_params == set(params)
  123. return unique_param_groups
  124. def stage_backward_input(
  125. stage_outputs_or_loss: list[torch.Tensor],
  126. output_grads: list[torch.Tensor] | None,
  127. input_values: list[torch.Tensor],
  128. weights: Iterator[Parameter],
  129. ) -> tuple[tuple[torch.Tensor | None, ...], list[dict[str, Any]]]:
  130. """
  131. Compute the gradients for only the stage inputs with
  132. respect to the stage outputs (if non-last stage) or loss (if last stage)
  133. After computing input gradients, we save the intermediate nodes in `param_groups`
  134. for later use in stage_backward_weight. We don't need to save any other intermediate nodes
  135. that aren't needed for dW because when we do dW calculation, we start from saved intermediates.
  136. Detaching the stage_outputs_or_loss at the end of this function is important as
  137. it frees up the memory that the autograd graph is anticipating to be used later (but doesn't actually need).
  138. """
  139. stage_output_grad_fns: list[Node] = list(
  140. filter(None, map(_get_grad_fn_or_grad_acc, stage_outputs_or_loss))
  141. )
  142. stage_input_grad_fns: list[Node] = list(
  143. filter(None, map(_get_grad_fn_or_grad_acc, input_values))
  144. )
  145. weight_grad_fns: list[Node] = list(
  146. filter(None, map(_get_grad_fn_or_grad_acc, weights))
  147. )
  148. reverse_edges_dict = construct_reverse_graph(stage_output_grad_fns)
  149. param_groups = get_param_groups(
  150. stage_input_grad_fns, weight_grad_fns, reverse_edges_dict
  151. )
  152. handles = []
  153. for param_group in param_groups:
  154. for i, intermediate in enumerate(param_group["intermediates"]):
  155. def get_hook(param_group, i):
  156. def hook(grad_inputs):
  157. if param_group.get("grads", None) is None:
  158. param_group["grads"] = [None] * len(
  159. param_group["intermediates"]
  160. )
  161. param_group["grads"][i] = grad_inputs
  162. return hook
  163. # These are always "split" nodes that we need to recompute, so
  164. # save their inputs.
  165. handle = intermediate.register_prehook(get_hook(param_group, i))
  166. handles.append(handle)
  167. if output_grads is None:
  168. # In case this is the loss and there are no output_grads, then we just use 1s
  169. output_grads = [
  170. torch.ones_like(stage_output) for stage_output in stage_outputs_or_loss
  171. ]
  172. # Some inputs may not be used or may not require gradients, so we filter them out
  173. input_values = [inp for inp in input_values if inp.requires_grad]
  174. dinputs = torch.autograd.grad(
  175. stage_outputs_or_loss,
  176. inputs=input_values,
  177. grad_outputs=output_grads,
  178. retain_graph=True,
  179. )
  180. # Update the gradients for inputs
  181. for inp, dinput in zip(input_values, dinputs):
  182. if inp.grad is None:
  183. inp.grad = dinput
  184. else:
  185. inp.grad += dinput
  186. # stage_outputs_or_loss are not used in backwards after this point, so we can safely remove it from the autograd graph
  187. # this allows autograd to clear up the graph dedicated for this tensor and free up significant memory
  188. for t in stage_outputs_or_loss:
  189. t.detach_()
  190. # hooks are no longer necessary, clean up for consistency
  191. for handle in handles:
  192. handle.remove()
  193. return dinputs, param_groups
  194. def stage_backward_weight(
  195. weights: Iterator[Parameter], param_groups: list[dict[str, Any]], retain_graph=False
  196. ) -> tuple[torch.Tensor | None, ...]:
  197. # map weights to param_group_weights
  198. grad_acc_to_weight = {}
  199. weight_grads: list[torch.Tensor | None] = []
  200. for index, weight in enumerate(weights):
  201. grad_acc = _get_grad_fn_or_grad_acc(weight)
  202. # pyrefly: ignore [unsupported-operation]
  203. grad_acc_to_weight[grad_acc] = weight, index
  204. weight_grads.append(weight.grad)
  205. for param_group in param_groups:
  206. valid_edges = []
  207. valid_grad_outputs: list[torch.Tensor] = []
  208. for grads_tuple, intermediate in zip(
  209. param_group["grads"], param_group["intermediates"]
  210. ):
  211. non_none_grads = [g for g in grads_tuple if g is not None]
  212. if non_none_grads:
  213. summed_grad = sum(non_none_grads)
  214. valid_edges.append(GradientEdge(intermediate, 0))
  215. # pyrefly: ignore [bad-argument-type]
  216. valid_grad_outputs.append(summed_grad)
  217. # Break a reference cycle caused inside stage_backward_input->get_hook->hook
  218. # The summarized cycle is:
  219. # `hook` -> cell -> param_group -> intermediates -> `hook`
  220. # because we install the hook function onto each of the intermediate autograd nodes.
  221. # We need to keep intermediates alive up until backward_weight, but we can free it now.
  222. del param_group["intermediates"]
  223. if valid_edges: # Only call autograd.grad if we have valid gradients
  224. # [NEW!] Able to pass a GradientEdge to autograd.grad as output
  225. weights_edges = tuple(GradientEdge(w, 0) for w in param_group["params"])
  226. dweights = torch.autograd.grad(
  227. valid_edges,
  228. weights_edges,
  229. grad_outputs=valid_grad_outputs,
  230. retain_graph=retain_graph,
  231. )
  232. # release grad memory early after use
  233. del param_group["grads"]
  234. for grad_acc, dw in zip(param_group["params"], dweights):
  235. weight, index = grad_acc_to_weight[grad_acc]
  236. if weight.grad is None:
  237. weight.grad = dw
  238. else:
  239. weight.grad += dw
  240. # return grads in the original order weights were provided in
  241. return tuple(weight_grads)
  242. def stage_backward(
  243. stage_output,
  244. output_grads,
  245. input_values,
  246. outputs_with_grads_idxs: list[int] | None = None, # deprecated, not used
  247. ) -> tuple[torch.Tensor | None, ...]:
  248. """
  249. This is a helper function to:
  250. 1. compute the gradients for the stage inputs, and
  251. 2. accumulate gradients for the stage module's parameters.
  252. Given the input value(s) and the corresponding gradient for the output
  253. value(s), compute and accumulate gradients for all parameter values (leaves
  254. in the autograd trace) as well as return a list of the gradients for the
  255. input values
  256. """
  257. if outputs_with_grads_idxs is not None:
  258. # Deprecated, not used in runtime calls, only exists in compiler
  259. stage_output = [stage_output[i] for i in outputs_with_grads_idxs]
  260. output_grads = [output_grads[i] for i in outputs_with_grads_idxs]
  261. try:
  262. # stage_output may be a composite datatype like dict. Extract all individual
  263. # tensor values here
  264. stage_output_tensors: list[torch.Tensor] = []
  265. output_grad_tensors: list[torch.Tensor | None] = []
  266. def extract_tensors_with_grads(
  267. output_val,
  268. grad_val,
  269. # Don't delete me- see [Note: ref cycle]
  270. extract_tensors_with_grads,
  271. ):
  272. if isinstance(output_val, torch.Tensor):
  273. if not output_val.requires_grad and output_val.grad_fn is None:
  274. return
  275. if not isinstance(grad_val, (torch.Tensor, type(None))):
  276. raise AssertionError(
  277. f"Expected Tensor or None gradient but got {type(grad_val)}"
  278. )
  279. stage_output_tensors.append(output_val)
  280. output_grad_tensors.append(grad_val)
  281. elif isinstance(output_val, (tuple, list)):
  282. if grad_val is None:
  283. return
  284. if not isinstance(grad_val, (tuple, list)):
  285. raise AssertionError(
  286. f"grad_value expected to have type {type(output_val)} but got {type(grad_val)}"
  287. )
  288. if not len(output_val) == len(grad_val):
  289. raise AssertionError(
  290. f"Expected len(output_val) == len(grad_val), got {len(output_val)} != {len(grad_val)}"
  291. )
  292. for ov, gv in zip(output_val, grad_val):
  293. extract_tensors_with_grads(
  294. ov,
  295. gv,
  296. extract_tensors_with_grads,
  297. )
  298. elif isinstance(output_val, dict):
  299. if grad_val is None:
  300. return
  301. if not isinstance(grad_val, dict):
  302. raise AssertionError(f"Expected dict, got {type(grad_val)}")
  303. if not set(output_val.keys()) == set(grad_val.keys()):
  304. raise AssertionError(
  305. f"Expected keys {set(output_val.keys())}, got {set(grad_val.keys())}"
  306. )
  307. for k in output_val:
  308. extract_tensors_with_grads(
  309. output_val[k], grad_val[k], extract_tensors_with_grads
  310. )
  311. else:
  312. # Output is a non-tensor type; just ignore it
  313. pass
  314. # Note: ref cycle
  315. # break a ref cycle that would keep tensors alive until GC runs
  316. # 1. extract_tensors_with_grads refers to a cell that holds refs to any vars defined in stage_backward
  317. # and used in extract_tensors_with_grads
  318. # 2. extract_tensors_with_grads referred to both stage_output_tensors, output_grad_tensors,
  319. # and to itself (extract_tensors_with_grads) since it makes a recursive call
  320. # 3. stage_output_tensors was kept alive by the above refcycle, and it holds activation tensors, which is bad
  321. # fix -> explicitly pass in the ref to the fn, so there is no gc cycle anymore
  322. extract_tensors_with_grads(
  323. stage_output, output_grads, extract_tensors_with_grads
  324. )
  325. torch.autograd.backward(
  326. stage_output_tensors,
  327. grad_tensors=output_grad_tensors, # type: ignore[arg-type]
  328. )
  329. # Extract gradients wrt the input values
  330. grad_inputs: list[torch.Tensor | None] = []
  331. for val in input_values:
  332. if isinstance(val, torch.Tensor):
  333. grad_inputs.append(val.grad)
  334. # Since gradients that will pass back to previous stages do not require gradient accumulation,
  335. # by decrementing the gradients' reference count at this point, the memory of gradients will be
  336. # returned to the allocator as soon as the next micro batch's get_bwd_send_ops comes and current
  337. # asynchronous send completes.
  338. # This prevents the gradients from persisting in GPU memory for the entire duration of step_microbatches
  339. # until clear_runtime_states() is called.
  340. val.grad = None
  341. else:
  342. grad_inputs.append(None)
  343. # Alternative impl: `torch.autograd.grad`.
  344. # Note that `torch.autograd.grad` will not accumulate gradients into the
  345. # model's parameters.
  346. """
  347. inputs_with_grad = []
  348. for val in input_values:
  349. if isinstance(val, torch.Tensor) and val.requires_grad:
  350. inputs_with_grad.append(val)
  351. grad_inputs = torch.autograd.grad(
  352. stage_output_tensors, inputs_with_grad, output_grad_tensors, # type: ignore[arg-type]
  353. )
  354. """
  355. except Exception as e:
  356. exc_msg = f"""
  357. Failed to run stage backward:
  358. Stage output: {map_debug_info(stage_output)}
  359. Output gradient: {map_debug_info(output_grads)}
  360. Input: {map_debug_info(input_values)}
  361. """
  362. raise RuntimeError(exc_msg) from e
  363. return tuple(grad_inputs)
  364. # TODO: handling requires_grad=False dynamically. Can we analyze this during initial
  365. # IR emission?
  366. def _null_coalesce_accumulate(lhs, rhs):
  367. """
  368. Coalesce two values, even if one of them is null, returning the non-null
  369. value.
  370. """
  371. if lhs is None:
  372. return rhs
  373. elif rhs is None:
  374. return lhs
  375. else:
  376. return torch.add(lhs, rhs)