meta_tracer.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. # mypy: allow-untyped-defs
  2. import builtins
  3. import functools
  4. import warnings
  5. from collections.abc import Callable
  6. from typing import Any, Optional, Union
  7. import torch
  8. import torch.fx
  9. def embedding_override(self, input):
  10. return torch.empty(*input.shape, self.weight.shape[-1], device="meta")
  11. def nn_layernorm_override(self, input):
  12. return input
  13. def torch_relu_override(x):
  14. return x
  15. def torch_nn_relu_override(self, x):
  16. return x
  17. def functional_relu_override(x, inplace=False):
  18. if inplace:
  19. raise AssertionError(
  20. "dont support inplace functional.relu for metatensor analysis"
  21. )
  22. return x
  23. def torch_where_override(condition, x, y):
  24. # torch.where returns the broadcasted tensor of condition, x, and y,
  25. # so hack it by using addition
  26. return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta")
  27. def torch_abs_override(input, *, out=None):
  28. if out is not None:
  29. raise AssertionError("Dont support in-place abs for MetaTensor analysis")
  30. return input
  31. manual_meta_overrides: dict[Callable, Callable] = {
  32. torch.nn.Embedding: embedding_override,
  33. torch.nn.LayerNorm: nn_layernorm_override,
  34. torch.relu: torch_relu_override,
  35. torch.nn.functional.relu: functional_relu_override,
  36. torch.nn.ReLU: torch_nn_relu_override,
  37. torch.where: torch_where_override,
  38. torch.abs: torch_abs_override,
  39. }
  40. def gen_constructor_wrapper(target):
  41. @functools.wraps(target)
  42. def wrapper(*args, **kwargs):
  43. proxy = None
  44. def check_has_proxy(v):
  45. if isinstance(v, torch.fx.Proxy):
  46. nonlocal proxy
  47. proxy = v
  48. torch.fx.node.map_aggregate(args, check_has_proxy)
  49. torch.fx.node.map_aggregate(kwargs, check_has_proxy)
  50. if proxy is not None:
  51. return proxy.tracer.create_proxy("call_function", target, args, kwargs)
  52. else:
  53. return target(*args, **kwargs)
  54. return wrapper, target
  55. class MetaProxy(torch.fx.Proxy):
  56. def install_tensor_meta(self, tensor_meta):
  57. self._tensor_meta = tensor_meta
  58. def size(self, dim=None):
  59. if hasattr(self, "_tensor_meta") and self._tensor_meta is not None:
  60. return self._tensor_meta.size(*[dim] if dim else [])
  61. return self.tracer.create_proxy(
  62. "call_method", "size", (self, dim) if dim else (self,), {}
  63. )
  64. def dim(self):
  65. if hasattr(self, "_tensor_meta") and self._tensor_meta is not None:
  66. return self._tensor_meta.dim()
  67. return self.tracer.create_proxy("call_method", "dim", (self,), {})
  68. @property
  69. def shape(self):
  70. if hasattr(self, "_tensor_meta") and self._tensor_meta is not None:
  71. return self._tensor_meta.shape
  72. return self.tracer.create_proxy(
  73. "call_function", builtins.getattr, (self, "shape"), {}
  74. )
  75. @property
  76. def dtype(self):
  77. if hasattr(self, "_tensor_meta") and self._tensor_meta is not None:
  78. return self._tensor_meta.dtype
  79. return self.tracer.create_proxy(
  80. "call_function", builtins.getattr, (self, "dtype"), {}
  81. )
  82. @property
  83. def device(self):
  84. # Hack so we can track when devices are used. During meta-tensor propagation,
  85. # replace these values with a constant 'meta'
  86. return MetaDeviceAttribute(self, "device")
  87. def __getattr__(self, k):
  88. if k == "_tensor_meta":
  89. return self.__getattribute__(k)
  90. # note: not added to the graph yet, if this is a method call
  91. # we peephole optimize to the method invocation
  92. return MetaAttribute(self, k)
  93. class MetaAttribute(MetaProxy):
  94. def __init__(self, root, attr: str):
  95. self.root = root
  96. self.attr = attr
  97. self.tracer = root.tracer
  98. self._node = None
  99. @property
  100. def node(self): # type: ignore[override]
  101. # the node for attributes is added lazily, since most will just be method calls
  102. # which do not rely on the getitem call
  103. if self._node is None:
  104. self._node = self.tracer.create_proxy(
  105. "call_function", getattr, (self.root, self.attr), {}
  106. ).node
  107. return self._node
  108. def __call__(self, *args, **kwargs):
  109. return self.tracer.create_proxy(
  110. "call_method", self.attr, (self.root,) + args, kwargs
  111. )
  112. class MetaDeviceAttribute(MetaAttribute):
  113. pass
  114. def proxys_to_metas(v):
  115. if isinstance(v, MetaDeviceAttribute):
  116. return "meta"
  117. if isinstance(v, torch.fx.Proxy):
  118. if not isinstance(v, MetaProxy):
  119. raise AssertionError(f"Expected MetaProxy but got {type(v)}")
  120. if not hasattr(v, "_tensor_meta"):
  121. raise AssertionError("MetaProxy does not have an associated meta")
  122. return v._tensor_meta
  123. return v
  124. class MetaTracer(torch.fx.Tracer):
  125. allow_insert_stateless_mods: bool = True
  126. _TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full_like", "eye"]
  127. def create_proxy(
  128. self,
  129. kind,
  130. target,
  131. args,
  132. kwargs,
  133. name=None,
  134. type_expr=None,
  135. proxy_factory_fn=None,
  136. ):
  137. rv = super().create_proxy(
  138. kind,
  139. target,
  140. args,
  141. kwargs,
  142. name,
  143. type_expr,
  144. # pyrefly: ignore [bad-argument-type]
  145. proxy_factory_fn,
  146. )
  147. if kind == "placeholder" and target in self.meta_args:
  148. rv.install_tensor_meta(self.meta_args[target])
  149. return rv
  150. if target in self.orig_fns:
  151. # NOTE: tensor constructors in PyTorch define the `device` argument as
  152. # *kwargs-only*. That is why this works. If you add methods to
  153. # _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only,
  154. # this will break and you will likely see issues where we cannot infer
  155. # the size of the output.
  156. if "device" in kwargs:
  157. kwargs["device"] = "meta"
  158. try:
  159. args_metas = torch.fx.node.map_aggregate(args, proxys_to_metas)
  160. kwargs_metas = torch.fx.node.map_aggregate(kwargs, proxys_to_metas)
  161. if kind == "call_function":
  162. meta_target = manual_meta_overrides.get(target, target)
  163. meta_out = meta_target(*args_metas, **kwargs_metas)
  164. elif kind == "call_method":
  165. meta_target = getattr(args_metas[0], target) # type: ignore[index]
  166. meta_out = meta_target(*args_metas[1:], **kwargs_metas) # type: ignore[index]
  167. elif kind == "call_module":
  168. if not hasattr(self, "orig_forward"):
  169. raise AssertionError("orig_forward not set for call_module")
  170. self._disable_module_getattr = True
  171. try:
  172. mod = self.root.get_submodule(target)
  173. mod_type = type(mod)
  174. if mod_type in manual_meta_overrides:
  175. meta_out = manual_meta_overrides[mod_type](
  176. mod, *args_metas, **kwargs_metas
  177. ) # type: ignore[misc, arg-type]
  178. else:
  179. meta_out = self.orig_forward(*args_metas, **kwargs_metas)
  180. finally:
  181. self._disable_module_getattr = False
  182. elif kind == "get_attr":
  183. self._disable_module_getattr = True
  184. try:
  185. attr_itr = self.root
  186. atoms = target.split(".")
  187. for atom in atoms:
  188. attr_itr = getattr(attr_itr, atom)
  189. if not isinstance(attr_itr, torch.Tensor):
  190. raise AssertionError(f"Expected Tensor, got {type(attr_itr)}")
  191. meta_out = attr_itr.to(device="meta")
  192. finally:
  193. self._disable_module_getattr = False
  194. else:
  195. return rv
  196. # TODO
  197. if not isinstance(rv, torch.fx.Proxy):
  198. raise AssertionError("Dont support composite output yet")
  199. rv.install_tensor_meta(meta_out)
  200. except Exception as e:
  201. warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}")
  202. return rv
  203. def getattr(self, attr, attr_val, parameter_proxy_cache):
  204. if getattr(self, "_disable_module_getattr", False):
  205. return attr_val
  206. else:
  207. return super().getattr(attr, attr_val, parameter_proxy_cache)
  208. def call_module(self, m, forward, args, kwargs):
  209. self.orig_forward = forward
  210. return super().call_module(m, forward, args, kwargs)
  211. def _insert_module_as_submodule(self, mod: torch.nn.Module) -> str:
  212. """
  213. Helper method which tries to insert a module that was not declared as submodule.
  214. """
  215. idx = 0
  216. mod_name = mod.__class__.__name__.lower()
  217. path = f"{mod_name}_{idx}"
  218. while hasattr(self.root, path):
  219. path = f"{mod_name}_{idx}"
  220. idx += 1
  221. self.root.add_module(path, mod)
  222. return path
  223. def path_of_module(self, mod: torch.nn.Module) -> str:
  224. try:
  225. return super().path_of_module(mod)
  226. except NameError:
  227. if (
  228. self.allow_insert_stateless_mods
  229. and len(list(mod.parameters())) == 0
  230. and len(list(mod.buffers())) == 0
  231. ):
  232. path = self._insert_module_as_submodule(mod)
  233. self.prev_module = path
  234. return path
  235. raise
  236. def proxy(self, node):
  237. return MetaProxy(node, self)
  238. def trace(self, root, meta_args: dict[str, torch.Tensor], concrete_args=None): # type: ignore[override]
  239. if not isinstance(meta_args, dict):
  240. raise AssertionError(f"Expected dict for meta_args, got {type(meta_args)}")
  241. self.meta_args = meta_args
  242. self.patched_torch_methods = {
  243. target: gen_constructor_wrapper(getattr(torch, target))
  244. for target in self._TORCH_METHODS_TO_PATCH
  245. }
  246. self.orig_fns = set()
  247. for name, (wrapper, orig) in self.patched_torch_methods.items():
  248. setattr(torch, name, wrapper)
  249. self.orig_fns.add(orig)
  250. try:
  251. graph = super().trace(root, concrete_args)
  252. graph._tracer_extras = {"meta_args": meta_args}
  253. return graph
  254. finally:
  255. for name, (_, orig) in self.patched_torch_methods.items():
  256. setattr(torch, name, orig)
  257. def symbolic_trace(
  258. root: Union[torch.nn.Module, Callable[..., Any]],
  259. meta_args: Optional[dict[str, torch.Tensor]] = None,
  260. concrete_args: Optional[dict[str, Any]] = None,
  261. ) -> torch.fx.GraphModule:
  262. tracer = MetaTracer()
  263. graph = tracer.trace(root, meta_args, concrete_args) # type: ignore[arg-type]
  264. name = (
  265. root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
  266. )
  267. gm = torch.fx.GraphModule(tracer.root, graph, name)
  268. return gm