shape_prop.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. # mypy: ignore-errors
  2. import traceback
  3. from typing import Any, NamedTuple, Optional
  4. import torch
  5. import torch.fx
  6. from torch._dispatch.python import enable_python_dispatcher
  7. from torch._guards import detect_fake_mode
  8. from torch._prims_common import is_contiguous_for_memory_format_or_false
  9. from torch._subclasses.meta_utils import is_sparse_any
  10. from torch.fx._compatibility import compatibility
  11. from torch.fx.node import map_aggregate, Node
  12. __all__ = ["TensorMetadata", "ShapeProp"]
  13. @compatibility(is_backward_compatible=True)
  14. class TensorMetadata(NamedTuple):
  15. # TensorMetadata is a structure containing pertinent information
  16. # about a tensor within a PyTorch program.
  17. # General Tensor metadata
  18. shape: torch.Size
  19. dtype: torch.dtype
  20. requires_grad: bool
  21. stride: tuple[int, ...]
  22. memory_format: Optional[torch.memory_format]
  23. # Quantization metadata
  24. is_quantized: bool
  25. qparams: dict[str, Any]
  26. # When include_contiguity is True, we will set contiguity when its always true for the tensor.
  27. # Some tensors can represent both contiguous and non-contiguous tensors. e.g: (u0, u1) with (u2, u3).
  28. # In such situation contiguity is not set. We could also make it a tri-state i.e: (def_contiguous,
  29. # def_not_contiguous and unknown).
  30. def _extract_tensor_metadata(
  31. result: torch.Tensor, include_contiguity=True
  32. ) -> TensorMetadata:
  33. """
  34. Extract a TensorMetadata NamedTuple describing `result`.
  35. """
  36. shape = result.shape
  37. dtype = result.dtype
  38. requires_grad = result.requires_grad
  39. stride = result.stride() if not is_sparse_any(result) else ()
  40. memory_format = None
  41. if include_contiguity and not is_sparse_any(result):
  42. memory_formats = (
  43. torch.contiguous_format,
  44. torch.channels_last,
  45. torch.channels_last_3d,
  46. )
  47. for query_format in memory_formats:
  48. if is_contiguous_for_memory_format_or_false(
  49. result, memory_format=query_format
  50. ):
  51. memory_format = query_format
  52. break
  53. is_quantized = result.is_quantized
  54. qparams: dict[str, Any] = {}
  55. if is_quantized:
  56. qscheme = result.qscheme()
  57. qparams["qscheme"] = qscheme
  58. if qscheme in (torch.per_tensor_affine, torch.per_tensor_symmetric):
  59. qparams["scale"] = result.q_scale() # type: ignore[assignment]
  60. qparams["zero_point"] = result.q_zero_point() # type: ignore[assignment]
  61. elif qscheme in (
  62. torch.per_channel_affine,
  63. torch.per_channel_affine_float_qparams,
  64. torch.per_channel_symmetric,
  65. ):
  66. # In this branch, scale and zero_point are expected to be tensors,
  67. # we store the values as immutable_list in TensorMetadata for
  68. # easier serialization downstream
  69. qparams["scale"] = result.q_per_channel_scales().tolist() # type: ignore[assignment]
  70. qparams["zero_point"] = result.q_per_channel_zero_points().tolist() # type: ignore[assignment]
  71. qparams["axis"] = result.q_per_channel_axis() # type: ignore[assignment]
  72. return TensorMetadata(
  73. shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams
  74. )
  75. @compatibility(is_backward_compatible=True)
  76. class ShapeProp(torch.fx.Interpreter):
  77. """
  78. Execute an FX graph Node-by-Node and
  79. record the shape and type of the result
  80. into the corresponding node.
  81. Example:
  82. In this example, we record the shape
  83. and data type of a module given
  84. an example input ``torch.randn(50, D_in)``.
  85. We print the name, shape and dtype of each node.
  86. class TwoLayerNet(torch.nn.Module):
  87. def __init__(self, D_in, H, D_out):
  88. super().__init__()
  89. self.linear1 = torch.nn.Linear(D_in, H)
  90. self.linear2 = torch.nn.Linear(H, D_out)
  91. def forward(self, x):
  92. h_relu = self.linear1(x).clamp(min=0)
  93. y_pred = self.linear2(h_relu)
  94. return y_pred
  95. N, D_in, H, D_out = 64, 1000, 100, 10
  96. x = torch.randn(N, D_in)
  97. y = torch.randn(N, D_out)
  98. model = TwoLayerNet(D_in, H, D_out)
  99. gm = torch.fx.symbolic_trace(model)
  100. sample_input = torch.randn(50, D_in)
  101. ShapeProp(gm).propagate(sample_input)
  102. for node in gm.graph.nodes:
  103. print(node.name, node.meta['tensor_meta'].dtype,
  104. node.meta['tensor_meta'].shape)
  105. The output of this code is:
  106. x torch.float32 torch.Size([50, 1000])
  107. linear1 torch.float32 torch.Size([50, 100])
  108. clamp_1 torch.float32 torch.Size([50, 100])
  109. linear2 torch.float32 torch.Size([50, 10])
  110. output torch.float32 torch.Size([50, 10])
  111. Args:
  112. module (GraphModule): The module to be executed
  113. fake_mode (FakeTensorMode): A fake mode for copying the gm
  114. """
  115. def __init__(self, gm, fake_mode=None):
  116. super().__init__(gm)
  117. if fake_mode is None:
  118. fake_mode = detect_fake_mode()
  119. if fake_mode is not None:
  120. from torch._dynamo.utils import deepcopy_to_fake_tensor
  121. # Note:
  122. # We need fake execution cause the inputs are fake, however, we cannot fakify the module
  123. # - because we need to write to the tensor_meta of the real module. So we fakify to
  124. # produce a result (L131 below), to extract tensor meta, and then keep going.
  125. #
  126. # If we were to fakify, we would write to the wrong node, and then downstream fusion
  127. # would be missing the tensor_meta.
  128. #
  129. # See torch/_inductor/overrides.py for where this is called upstream of fusion.
  130. self.fake_module = deepcopy_to_fake_tensor(self.module, fake_mode)
  131. self.fake_mode = fake_mode
  132. else:
  133. self.fake_module = None
  134. self.fake_mode = None
  135. self.real_module = self.module
  136. def run_node(self, n: Node) -> Any:
  137. from torch.fx.experimental.symbolic_shapes import (
  138. compute_unbacked_bindings,
  139. rebind_unbacked,
  140. )
  141. try:
  142. if self.fake_module is not None:
  143. # Hacky swap. Alternatively, we could do this with overriding
  144. # call_module and get_attr.
  145. self.module = self.fake_module
  146. try:
  147. if self.fake_mode is not None:
  148. with self.fake_mode, enable_python_dispatcher():
  149. result = super().run_node(n)
  150. rebind_unbacked(self.fake_mode.shape_env, n, result)
  151. else:
  152. result = super().run_node(n)
  153. finally:
  154. self.module = self.real_module
  155. except Exception as e:
  156. traceback.print_exc()
  157. raise RuntimeError(
  158. f"ShapeProp error for: node={n.format_node()} with meta={n.meta}"
  159. ) from e
  160. found_tensor = False
  161. def extract_tensor_meta(obj):
  162. if isinstance(obj, torch.Tensor):
  163. nonlocal found_tensor
  164. found_tensor = True
  165. return _extract_tensor_metadata(obj)
  166. else:
  167. return obj
  168. meta = map_aggregate(result, extract_tensor_meta)
  169. if found_tensor:
  170. n.meta["tensor_meta"] = meta
  171. if self.fake_mode:
  172. if (shape_env := self.fake_mode.shape_env) and (
  173. symbol_to_path := compute_unbacked_bindings(shape_env, result)
  174. ):
  175. n.meta["unbacked_bindings"] = symbol_to_path
  176. n.meta["type"] = type(result)
  177. return result
  178. def propagate(self, *args):
  179. """
  180. Run `module` via interpretation and return the result and
  181. record the shape and type of each node.
  182. Args:
  183. *args (Tensor): the sample input.
  184. Returns:
  185. Any: The value returned from executing the Module
  186. """
  187. if self.fake_mode is not None:
  188. fake_args = [
  189. self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
  190. for t in args
  191. ]
  192. else:
  193. fake_args = args
  194. return super().run(*fake_args)