fake_profile.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. import contextlib
  2. import io
  3. import logging
  4. import os
  5. from collections.abc import Callable, Generator
  6. from dataclasses import dataclass
  7. from typing import Any, Optional, Union
  8. import torch
  9. from torch._library.custom_ops import _maybe_get_opdef
  10. from torch.types import FileLike
  11. log = logging.getLogger(__name__)
  12. class MissingOpProfile(RuntimeError):
  13. """
  14. This is raised when we don't have an operator profile available for the
  15. given inputs.
  16. """
  17. @dataclass(frozen=True)
  18. class TensorMetadata:
  19. rank: int
  20. dtype: torch.dtype
  21. device: torch.device
  22. layout: torch.layout
  23. @staticmethod
  24. def maybe_from_tensor(t: Any) -> Optional["TensorMetadata"]:
  25. if not isinstance(t, torch.Tensor):
  26. return None
  27. return TensorMetadata(t.dim(), t.dtype, t.device, t.layout)
  28. @dataclass(frozen=True)
  29. class OpProfile:
  30. args_profile: tuple[Optional[TensorMetadata]]
  31. out_profile: Union[TensorMetadata, tuple[TensorMetadata]]
  32. def _generate_fake_kernel(op_name: str, op_profile: set[OpProfile]) -> Callable:
  33. def _match_args(args_profile: tuple[Optional[TensorMetadata]], args: Any) -> bool:
  34. return all(
  35. TensorMetadata.maybe_from_tensor(arg) == args_profile[i]
  36. for i, arg in enumerate(args)
  37. )
  38. def _generate_res(
  39. out_profile: Union[TensorMetadata, tuple[TensorMetadata]],
  40. ) -> Union[torch.Tensor, list[torch.Tensor]]:
  41. ctx = torch.library.get_ctx()
  42. def _generate_tensor_out(t: TensorMetadata) -> torch.Tensor:
  43. fake_shape = [ctx.new_dynamic_size() for _ in range(t.rank)]
  44. fake_strides = [-1] * t.rank
  45. expected = 1
  46. fake_stride = expected
  47. # pyrefly: ignore [bad-assignment]
  48. for i in range(t.rank):
  49. fake_strides[i] = fake_stride # type: ignore[assignment]
  50. fake_stride = fake_stride * fake_shape[i] # type: ignore[assignment]
  51. return torch.empty_strided(
  52. fake_shape,
  53. fake_strides,
  54. device=t.device,
  55. dtype=t.dtype,
  56. layout=t.layout,
  57. )
  58. if isinstance(out_profile, TensorMetadata):
  59. return _generate_tensor_out(out_profile)
  60. else:
  61. return [_generate_tensor_out(t) for t in out_profile]
  62. def _fake_kernel(*args, **kwargs): # type: ignore[no-untyped-def]
  63. for profile in op_profile:
  64. if _match_args(profile.args_profile, (*args, *kwargs.values())):
  65. return _generate_res(profile.out_profile)
  66. raise MissingOpProfile(
  67. f"No fake kernel was found for {op_name}, and although we have "
  68. "previously registered some profiles to generate a fake kernel, "
  69. f"no profiles match the given inputs: {args, kwargs}."
  70. )
  71. return _fake_kernel
  72. @contextlib.contextmanager
  73. def unsafe_generate_fake_kernels(op_profiles: dict[str, set[OpProfile]]) -> Generator:
  74. """
  75. Registers a fake kernel based on the given operator profiles. This fake
  76. kernel registration will override any existing fake kernel registrations.
  77. The input is a dictionary mapping operator names to a set of operator
  78. profiles, which we will use to generate fake kernels. The operator profiles
  79. are a record of the input and output tensor metadata. Based on this
  80. information we will match a given input to the recorded profile, and return
  81. an output with the same metadata as in the recorded profile. If a profile
  82. doesn't exist then an exception will be thrown.
  83. The fake kernel generation is considered unsafe because it relies on the
  84. rigid, pre-defined operator profiles that do not account for potential
  85. variations in output behavior. Specifically, the generated kernels assume a
  86. fixed relationship between input and output ranks. However, in reality, it's
  87. possible that data-dependent operations may produce outputs of different
  88. ranks even when given inputs of the same rank. The generated fake kernels
  89. are inflexible and unable to accommodate these nuances, making them
  90. potentially unsafe.
  91. Args:
  92. op_profiles (dict[str, set[OpProfile]]): A dictionary mapping operator
  93. name to a set of operator profiles from which we will generate fake
  94. kernels.
  95. Examples:
  96. >>> # Example: Registering an op-profile from draft-export
  97. >>> import torch
  98. >>> from torch.export._draft_export import draft_export
  99. >>>
  100. >>> @torch.library.custom_op("mylib::foo", mutates_args=())
  101. >>> def foo(x: Tensor, y: Tensor) -> Tensor:
  102. >>> return x + y
  103. >>>
  104. >>> class M(torch.nn.Module):
  105. >>> def forward(self, a, b):
  106. >>> res = torch.ops.mylib.foo(a, b) # no fake impl
  107. >>> return res
  108. >>>
  109. >>> ep = draft_export(M(), (torch.ones(3, 4), torch.ones(3, 4))
  110. >>>
  111. >>> with torch._library.fake_profile.unsafe_generate_fake_kernels(ep._report.op_profiles):
  112. >>> decomp = ep.run_decompositions()
  113. """
  114. libs: list[torch.library.Library] = []
  115. # Stores old fake impls from custom ops declared through @custom_op
  116. old_fake_impls: dict[str, Callable] = {}
  117. for op_name, profiles in op_profiles.items():
  118. log.warning(
  119. "Registering fake profile for %s. This will override any existing "
  120. "fake kernel registration.",
  121. op_name,
  122. )
  123. op_name_split = op_name.split(".")
  124. namespace, op_name_str = op_name_split[0], op_name_split[1]
  125. op_str = f"{namespace}::{op_name_str}"
  126. fake_kernel = _generate_fake_kernel(op_str, profiles)
  127. if opdef := _maybe_get_opdef(op_str):
  128. # If the op is a CustomOpDef, save the existing abstract_fn so that
  129. # we can restore it after this contextmanager
  130. if opdef._abstract_fn is not None:
  131. old_fake_impls[op_str] = opdef._abstract_fn
  132. opdef.register_fake(fake_kernel)
  133. else:
  134. # Create a new library so that we can register a new fake impl.
  135. # These libraries will then be destroyed after the contextmanager,
  136. # which will automatically restore the previously registered fake
  137. # impls.
  138. newlib = torch.library.Library(namespace, "FRAGMENT") # noqa: TOR901
  139. torch.library.register_fake(
  140. op_str, fake_kernel, lib=newlib, allow_override=True
  141. )
  142. libs.append(newlib)
  143. try:
  144. yield libs
  145. finally:
  146. # Destroying the libraries will automatically restore the previously
  147. # registered fake impls
  148. for lib in libs:
  149. lib._destroy()
  150. # Restore abstract_fns for CustomOpDefs
  151. for op_str, old_fake in old_fake_impls.items():
  152. opdef = _maybe_get_opdef(op_str)
  153. if opdef is None:
  154. raise AssertionError(f"opdef for {op_str} must not be None")
  155. opdef.register_fake(old_fake)
  156. def get_torch_version() -> str:
  157. version = torch.__version__.split(".")
  158. return f"{int(version[0])}.{int(version[1])}"
  159. def generate_yaml_from_profiles(op_profiles: dict[str, set[OpProfile]]) -> str:
  160. """
  161. Generates a yaml string from the given operator profiles which can be saved
  162. to a file. The yaml string can be loaded back into an operator profile
  163. structure using `read_profiles_from_yaml`.
  164. """
  165. import yaml
  166. from torch._export.serde.serialize import (
  167. _TORCH_TO_SERIALIZE_DTYPE,
  168. _TORCH_TO_SERIALIZE_LAYOUT,
  169. )
  170. def serialize_tensor_metadata(t: TensorMetadata) -> dict:
  171. return {
  172. "rank": t.rank,
  173. "dtype": _TORCH_TO_SERIALIZE_DTYPE[t.dtype].value,
  174. "device": str(t.device),
  175. "layout": _TORCH_TO_SERIALIZE_LAYOUT[t.layout].value,
  176. }
  177. def serialize_op_profile(op: OpProfile) -> dict:
  178. return {
  179. "args_profile": [
  180. serialize_tensor_metadata(arg)
  181. for arg in op.args_profile
  182. if arg is not None
  183. ],
  184. "out_profile": (
  185. serialize_tensor_metadata(op.out_profile)
  186. if isinstance(op.out_profile, TensorMetadata)
  187. else [serialize_tensor_metadata(out) for out in op.out_profile]
  188. ),
  189. }
  190. serialized_data = {
  191. operator: [serialize_op_profile(profile) for profile in profiles]
  192. for operator, profiles in op_profiles.items()
  193. }
  194. return yaml.dump(
  195. {"torch_version": get_torch_version(), "operators": serialized_data},
  196. sort_keys=False,
  197. )
  198. def save_op_profiles(op_profiles: dict[str, set[OpProfile]], f: FileLike) -> None:
  199. """
  200. Serializes the given operator profiles into a yaml format and saves it to
  201. the given file. The operator profile can be loaded back using `load_op_profiles`.
  202. """
  203. yaml_str = generate_yaml_from_profiles(op_profiles)
  204. if isinstance(f, (str, os.PathLike)):
  205. f = os.fspath(f)
  206. with open(f, "w") as file:
  207. file.write(yaml_str)
  208. elif isinstance(f, io.BytesIO):
  209. f.write(yaml_str.encode("utf-8"))
  210. else:
  211. raise ValueError(f"Invalid type of file {f}")
  212. def read_profiles_from_yaml(yaml_str: str) -> dict[str, set[OpProfile]]:
  213. """
  214. Reads the yaml saved by `save_op_profiles` and returns the operator profiles.
  215. """
  216. import yaml
  217. from torch._export.serde.serialize import (
  218. _SERIALIZE_TO_TORCH_DTYPE,
  219. _SERIALIZE_TO_TORCH_LAYOUT,
  220. )
  221. def deserialize_tensor_metadata(data: dict) -> TensorMetadata:
  222. return TensorMetadata(
  223. rank=data["rank"],
  224. dtype=_SERIALIZE_TO_TORCH_DTYPE[data["dtype"]],
  225. device=torch.device(data["device"]),
  226. layout=_SERIALIZE_TO_TORCH_LAYOUT[data["layout"]],
  227. )
  228. def deserialize_op_profile(data: dict) -> OpProfile:
  229. args_profile = tuple(
  230. deserialize_tensor_metadata(arg) for arg in data["args_profile"]
  231. )
  232. out_profile_data = data["out_profile"]
  233. out_profile: Union[tuple[TensorMetadata], TensorMetadata] = (
  234. tuple(deserialize_tensor_metadata(out) for out in out_profile_data) # type: ignore[assignment]
  235. if isinstance(out_profile_data, list)
  236. else deserialize_tensor_metadata(out_profile_data)
  237. )
  238. return OpProfile(args_profile=args_profile, out_profile=out_profile) # type: ignore[arg-type]
  239. loaded_data = yaml.safe_load(yaml_str)
  240. loaded_torch_version = loaded_data["torch_version"]
  241. if loaded_torch_version != get_torch_version():
  242. raise RuntimeError(
  243. "Unable to load outdated profile. It was saved with torch version: "
  244. f"{loaded_torch_version} but the current torch version is: {get_torch_version()}"
  245. )
  246. operators_data = loaded_data["operators"]
  247. return {
  248. operator: {deserialize_op_profile(profile) for profile in profiles}
  249. for operator, profiles in operators_data.items()
  250. }
  251. def load_op_profiles(f: FileLike) -> dict[str, set[OpProfile]]:
  252. """
  253. Loads the saved operator profiles from `save_op_profiles`.
  254. """
  255. if isinstance(f, (str, os.PathLike)):
  256. f = os.fspath(f)
  257. with open(f) as file:
  258. yaml_str = file.read()
  259. elif isinstance(f, io.BytesIO):
  260. yaml_str = f.read().decode("utf-8")
  261. else:
  262. raise ValueError(f"Invalid type of file {f}")
  263. return read_profiles_from_yaml(yaml_str)