freezing.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. # mypy: allow-untyped-defs
  2. from __future__ import annotations
  3. import itertools
  4. import logging
  5. import weakref
  6. from typing import Any, Optional
  7. import torch
  8. import torch.utils._pytree as pytree
  9. from torch._dynamo.utils import dynamo_timed, lazy_format_graph_code
  10. from torch._functorch.aot_autograd import MutationType
  11. from torch._functorch.compile_utils import fx_graph_cse
  12. from torch._inductor.constant_folding import constant_fold, replace_node_with_constant
  13. from torch._inductor.freezing_utils import enter_freezing, record_has_frozen_params
  14. from torch._inductor.fx_passes.freezing_patterns import freezing_passes
  15. from torch._inductor.fx_passes.post_grad import view_to_reshape
  16. from . import config
  17. aten = torch.ops.aten
  18. prims = torch.ops.prims
  19. log = logging.getLogger(__name__)
  20. def replace_params_with_constants(
  21. gm: torch.fx.GraphModule,
  22. flat_params: list[Any],
  23. fw_metadata: torch._functorch.aot_autograd.ViewAndMutationMeta,
  24. ) -> list[int]:
  25. """
  26. Replaces the parameters of a PyTorch GraphModule with constants wherever possible.
  27. Returns a list of indices representing the input parameters that were not converted to constants.
  28. """
  29. params = gm.graph.find_nodes(op="placeholder")
  30. fake_inp_nodes = params[: len(params)]
  31. preserved_arg_indices = []
  32. aliased_input_args = [
  33. out_info.base_idx
  34. for out_info in fw_metadata.output_info
  35. if out_info.base_idx is not None
  36. ]
  37. # TODO (tmanlaibaatar) figure out why this is different
  38. # from mutated_inp_runtime_indices
  39. mutated_inps = [
  40. i
  41. for i, m in enumerate(fw_metadata.input_info)
  42. if m.mutation_type
  43. in (MutationType.MUTATED_IN_GRAPH, MutationType.MUTATED_OUT_GRAPH)
  44. ]
  45. static_indices_new = []
  46. static_indices_offset = 0
  47. for i, (real_input, node) in enumerate(zip(flat_params, fake_inp_nodes)):
  48. if i in mutated_inps or i in aliased_input_args:
  49. preserved_arg_indices.append(i)
  50. if i in fw_metadata.static_input_indices:
  51. new_static_index = i - static_indices_offset
  52. static_indices_new.append(new_static_index)
  53. else:
  54. replace_node_with_constant(gm, node, real_input)
  55. static_indices_offset += 1
  56. # add on non param inputs
  57. preserved_arg_indices.extend(range(len(flat_params), len(params)))
  58. # is this necessary ?
  59. fw_metadata.static_input_indices = static_indices_new
  60. gm.recompile()
  61. return preserved_arg_indices
  62. def freeze(
  63. dynamo_gm: torch.fx.GraphModule,
  64. aot_autograd_gm: torch.fx.GraphModule,
  65. example_inputs: list[torch._subclasses.FakeTensor],
  66. ) -> tuple[torch.fx.GraphModule, list[int]]:
  67. """
  68. Inlines parameters that are not mutated into constants and optimizes the graph through constant propagation
  69. and other techniques. If enabled, the function also discards the original parameters of the module for memory efficiency.
  70. Assumes that this function is run in dynamo tracing post aot_autograd.
  71. Args:
  72. dynamo_gm (torch.fx.GraphModule): The Dynamo constructed GraphModule.
  73. aot_autograd_gm (torch.fx.GraphModule): The aot_autograd constructed GraphModule to be frozen.
  74. example_inputs (List[torch.Tensor]): A list of example input tensors to be used in the freezing process.
  75. Returns:
  76. Tuple[torch.fx.GraphModule, List[int]]: A tuple containing the frozen GraphModule and a list of indices
  77. of the inputs that were preserved (not turned into constants).
  78. """
  79. with enter_freezing():
  80. return _freeze(dynamo_gm, aot_autograd_gm, example_inputs)
  81. def _freeze(
  82. dynamo_gm: torch.fx.GraphModule,
  83. aot_autograd_gm: torch.fx.GraphModule,
  84. example_inputs: list[torch._subclasses.FakeTensor],
  85. ) -> tuple[torch.fx.GraphModule, list[int]]:
  86. # We have convert conv's weight to channels last which may meet error for .view
  87. # when doing fake_tensor_prop. So we need to convert view to reshape first.
  88. # See the details in fx_codegen_and_compile of compile_fx.py.
  89. view_to_reshape(aot_autograd_gm)
  90. if tracing_context := torch._guards.TracingContext.try_get():
  91. fw_metadata = tracing_context.fw_metadata
  92. assert tracing_context.params_flat_unwrap_subclasses is not None
  93. params_flat = tracing_context.params_flat_unwrap_subclasses
  94. assert fw_metadata is not None and params_flat is not None
  95. preserved_arg_indices = replace_params_with_constants(
  96. aot_autograd_gm, params_flat, fw_metadata
  97. )
  98. else:
  99. inputs = aot_autograd_gm.graph.find_nodes(op="placeholder")
  100. preserved_arg_indices = list(range(len(inputs)))
  101. # TODO - further restrict cse ? right now needed to dedup aliasing ops
  102. cse_graph = fx_graph_cse(aot_autograd_gm.graph)
  103. aot_autograd_gm.graph = cse_graph
  104. aot_autograd_gm.recompile()
  105. aot_example_inputs = [example_inputs[ind] for ind in preserved_arg_indices]
  106. freezing_passes(aot_autograd_gm, aot_example_inputs)
  107. constant_fold(aot_autograd_gm)
  108. # invalidate nn Modules
  109. if config.freezing_discard_parameters:
  110. invalidate_eager_modules()
  111. discard_traced_gm_params(dynamo_gm)
  112. log.debug(
  113. "%s", lazy_format_graph_code("FROZEN GRAPH", aot_autograd_gm, colored=True)
  114. )
  115. record_has_frozen_params(aot_autograd_gm)
  116. return aot_autograd_gm, preserved_arg_indices
  117. class ErasedTensor(torch.Tensor):
  118. @staticmethod
  119. def __new__(cls, elem, name, owning_mod):
  120. return super().__new__(cls, elem.to(device="meta"))
  121. def __init__(self, elem, name: Optional[str], mod) -> None:
  122. self.erased_name = name
  123. self.owning_mod_ref = weakref.ref(mod)
  124. @classmethod
  125. def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore[override]
  126. erased_tensors = [
  127. e
  128. # pyrefly: ignore [bad-unpacking]
  129. for e in pytree.arg_tree_leaves(*args, **kwargs)
  130. if isinstance(e, ErasedTensor)
  131. ]
  132. assert len(erased_tensors) > 0
  133. e = erased_tensors[0]
  134. raise RuntimeError(
  135. f"Trying to run Pytorch Eager Module after Dynamo Freezing. "
  136. "The original parameters have been discarded for memory efficiency. "
  137. f"Found in op {func} for erased parameter {e.erased_name} of {e.owning_mod_ref()}"
  138. )
  139. def invalidate_eager_modules():
  140. with torch.utils._python_dispatch._disable_current_modes():
  141. for (
  142. mod
  143. ) in torch._guards.TracingContext.get().module_context.nn_modules.values():
  144. if not isinstance(mod, torch.nn.Module):
  145. continue
  146. for attr_name, tensor in list(
  147. itertools.chain(
  148. mod.named_parameters(recurse=False),
  149. # pyrefly: ignore [bad-argument-type]
  150. mod.named_buffers(recurse=False),
  151. )
  152. ):
  153. with torch._dispatch.python.no_python_dispatcher():
  154. e_t = ErasedTensor(tensor, attr_name, mod)
  155. if isinstance(tensor, torch.nn.Parameter):
  156. e_t.requires_grad_(True)
  157. e_t._is_param = True
  158. setattr(mod, attr_name, e_t)
  159. def discard_traced_gm_params(mod: torch.fx.GraphModule):
  160. with torch.utils._python_dispatch._disable_current_modes():
  161. for attr_name, tensor in list(
  162. itertools.chain(
  163. mod.named_parameters(recurse=False),
  164. # pyrefly: ignore [bad-argument-type]
  165. mod.named_buffers(recurse=False),
  166. )
  167. ):
  168. with torch._dispatch.python.no_python_dispatcher():
  169. e_t = ErasedTensor(tensor, attr_name, mod)
  170. if isinstance(tensor, torch.nn.Parameter):
  171. e_t.requires_grad_(True)
  172. e_t._is_param = True
  173. setattr(mod, attr_name, e_t)
  174. def enforce_output_layout(gm: torch.fx.GraphModule):
  175. """
  176. Make sure the output node's layout does not change due to compiler optimizations
  177. by adding aten.as_strided nodes with the expected strides.
  178. Only used for inference so we can assume all graph outputs are model outputs.
  179. """
  180. *_, output_node = gm.graph.nodes
  181. out_list = output_node.args[0]
  182. with gm.graph.inserting_before(output_node):
  183. for n in out_list:
  184. if not isinstance(
  185. n.meta["val"], torch.Tensor
  186. ) or not torch._prims_common.is_non_overlapping_and_dense_or_false(
  187. n.meta["val"]
  188. ):
  189. continue
  190. # add a node to enforce eager layout
  191. ft = n.meta["val"]
  192. new_node = gm.graph.call_function(
  193. prims.inductor_force_stride_order.default, (n, ft.stride())
  194. )
  195. # can not call
  196. # n.replace_all_uses_with(new_node)
  197. # since it will replace the usage of n in new_node itself.
  198. output_node.replace_input_with(n, new_node)
  199. gm.graph.lint()
  200. gm.recompile()
  201. def enforce_as_strided_input_layout(gm: torch.fx.GraphModule):
  202. """
  203. Make sure the as_strided node's input's layout does not change due to compiler
  204. optimizations, because the as_strided strides info depends on input tensor stride info.
  205. """
  206. as_strided_ops = [
  207. torch.ops.aten.as_strided.default,
  208. torch.ops.aten.as_strided_.default,
  209. torch.ops.aten.as_strided_scatter.default,
  210. ]
  211. strided_nodes = [n for n in gm.graph.nodes if n.target in as_strided_ops]
  212. for n in strided_nodes:
  213. with gm.graph.inserting_before(n):
  214. # add a node to enforce eager layout
  215. ft = n.args[0].meta["val"]
  216. new_node = gm.graph.call_function(
  217. prims.inductor_force_stride_order.default, (n.args[0], ft.stride())
  218. )
  219. n.replace_input_with(n.args[0], new_node)
  220. gm.graph.lint()
  221. gm.recompile()
  222. def convert_conv_weights_to_channels_last(gm: torch.fx.GraphModule):
  223. """
  224. Convert 4d convolution weight tensor to channels last format.
  225. This pass is performed before freezing so the added nodes can be constant
  226. folded by freezing.
  227. """
  228. with dynamo_timed("convert_conv_weights_to_channels_last"):
  229. convs = [n for n in gm.graph.nodes if n.target is aten.convolution.default]
  230. for conv in convs:
  231. weight_node = conv.args[1]
  232. if len(weight_node.meta["val"].size()) != 4 or weight_node.meta[
  233. "val"
  234. ].is_contiguous(memory_format=torch.channels_last):
  235. # not a 4d tensor or already channels last, skip
  236. continue
  237. with gm.graph.inserting_before(conv):
  238. new_node = gm.graph.call_function(
  239. aten.clone.default,
  240. (weight_node,),
  241. {"memory_format": torch.channels_last},
  242. )
  243. conv.replace_input_with(weight_node, new_node)
  244. enforce_as_strided_input_layout(gm)
  245. enforce_output_layout(gm)