infer_schema.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. # mypy: allow-untyped-defs
  2. import collections
  3. import inspect
  4. import typing
  5. from types import GenericAlias
  6. from typing import Optional, Union
  7. import torch
  8. from torch import device, dtype, Tensor, types
  9. from torch.utils._exposed_in import exposed_in
  10. from .opaque_object import _OPAQUE_TYPES, is_opaque_type
  11. # This is used as a negative test for
  12. # test_custom_ops.py::TestTypeConversion::test_type_eval.
  13. _TestTensor = torch.Tensor
  14. @exposed_in("torch.library")
  15. def infer_schema(
  16. prototype_function: typing.Callable,
  17. /,
  18. *,
  19. mutates_args,
  20. op_name: Optional[str] = None,
  21. ) -> str:
  22. r"""Parses the schema of a given function with type hints. The schema is inferred from the
  23. function's type hints, and can be used to define a new operator.
  24. We make the following assumptions:
  25. * None of the outputs alias any of the inputs or each other.
  26. * | String type annotations "device, dtype, Tensor, types" without library specification are
  27. | assumed to be torch.*. Similarly, string type annotations "Optional, List, Sequence, Union"
  28. | without library specification are assumed to be typing.*.
  29. * | Only the args listed in ``mutates_args`` are being mutated. If ``mutates_args`` is "unknown",
  30. | it assumes that all inputs to the operator are being mutates.
  31. Callers (e.g. the custom ops API) are responsible for checking these assumptions.
  32. Args:
  33. prototype_function: The function from which to infer a schema for from its type annotations.
  34. op_name (Optional[str]): The name of the operator in the schema. If ``name`` is None, then the
  35. name is not included in the inferred schema. Note that the input schema to
  36. ``torch.library.Library.define`` requires a operator name.
  37. mutates_args ("unknown" | Iterable[str]): The arguments that are mutated in the function.
  38. Returns:
  39. The inferred schema.
  40. Example:
  41. >>> def foo_impl(x: torch.Tensor) -> torch.Tensor:
  42. >>> return x.sin()
  43. >>>
  44. >>> infer_schema(foo_impl, op_name="foo", mutates_args={})
  45. foo(Tensor x) -> Tensor
  46. >>>
  47. >>> infer_schema(foo_impl, mutates_args={})
  48. (Tensor x) -> Tensor
  49. """
  50. UNKNOWN_MUTATES = "unknown"
  51. pf_globals = prototype_function.__globals__
  52. pf_locals = None
  53. # TODO: Once our minimum version is py3.10+ pass `eval_str=True` to
  54. # inspect.signature() and we no longer need to deal with stringified
  55. # annotations below.
  56. sig = inspect.signature(prototype_function)
  57. def error_fn(what):
  58. raise ValueError(f"infer_schema(func): {what} Got func with signature {sig})")
  59. def convert_type_string(annotation_type: str):
  60. try:
  61. return eval(annotation_type, pf_globals, pf_locals)
  62. except Exception:
  63. error_fn(
  64. f"Unsupported type annotation {annotation_type}. It is not a type."
  65. )
  66. def unstringify_types(
  67. tys: tuple[Union[type[object], str], ...],
  68. ) -> tuple[tuple[typing.Any, ...], bool]:
  69. res = []
  70. changed = False
  71. for ty in tys:
  72. ty, ty_changed = unstringify_type(ty)
  73. res.append(ty)
  74. changed |= ty_changed
  75. if changed:
  76. return tuple(res), True
  77. else:
  78. return tys, False # type: ignore[return-value]
  79. def unstringify_type(ty: Union[type[object], str]) -> tuple[typing.Any, bool]:
  80. # Dig through a generic type and if it contains a stringified type
  81. # convert that to a real type. The second return value indicates if the
  82. # type contained a string or not.
  83. if isinstance(ty, str):
  84. return convert_type_string(ty), True
  85. elif origin := typing.get_origin(ty):
  86. args, args_changed = unstringify_types(typing.get_args(ty))
  87. if args_changed:
  88. return GenericAlias(origin, args), True
  89. return ty, False
  90. params = []
  91. seen_args = set()
  92. saw_kwarg_only_arg = False
  93. for idx, (name, param) in enumerate(sig.parameters.items()):
  94. if not supported_param(param):
  95. error_fn("We do not support positional-only args, varargs, or varkwargs.")
  96. if param.kind == inspect.Parameter.KEYWORD_ONLY:
  97. # The first time we see a kwarg-only arg, add "*" to the schema.
  98. if not saw_kwarg_only_arg:
  99. params.append("*")
  100. saw_kwarg_only_arg = True
  101. if param.annotation is inspect.Parameter.empty:
  102. error_fn(f"Parameter {name} must have a type annotation.")
  103. # The annotation might be converted to a string by annotation,
  104. # we convert it to the actual type.
  105. annotation_type, _ = unstringify_type(param.annotation)
  106. schema_type = None
  107. if annotation_type not in SUPPORTED_PARAM_TYPES:
  108. if is_opaque_type(annotation_type):
  109. schema_type = _OPAQUE_TYPES[annotation_type].class_name
  110. elif annotation_type == torch._C.ScriptObject:
  111. error_fn(
  112. f"Parameter {name}'s type cannot be inferred from the schema "
  113. "as it is a ScriptObject. Please manually specify the schema "
  114. "using the `schema=` kwarg with the actual type of the ScriptObject."
  115. )
  116. elif (
  117. hasattr(annotation_type, "__origin__")
  118. and annotation_type.__origin__ is tuple
  119. ):
  120. list_type = tuple_to_list(annotation_type)
  121. example_type_str = "\n\n"
  122. # Only suggest the list type if this type is supported.
  123. if list_type in SUPPORTED_PARAM_TYPES:
  124. example_type_str = f"For example, {list_type}.\n\n"
  125. error_fn(
  126. f"Parameter {name} has unsupported type {param.annotation}. "
  127. f"We do not support Tuple inputs in schema. As a workaround, please try to use List instead. "
  128. f"{example_type_str}"
  129. f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}."
  130. )
  131. else:
  132. error_fn(
  133. f"Parameter {name} has unsupported type {param.annotation}. "
  134. f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}."
  135. )
  136. else:
  137. schema_type = SUPPORTED_PARAM_TYPES[annotation_type]
  138. if schema_type is None:
  139. raise AssertionError(f"schema_type is None for param {name}")
  140. if type(mutates_args) is str:
  141. if mutates_args != UNKNOWN_MUTATES:
  142. raise ValueError(
  143. "mutates_args must either be a sequence of the names of "
  144. "the arguments that are mutated or the string 'unknown'. "
  145. )
  146. if schema_type.startswith("Tensor"):
  147. schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor') :]}"
  148. elif name in mutates_args:
  149. if not schema_type.startswith("Tensor"):
  150. error_fn(
  151. f"Parameter {name} is in mutable_args but only Tensors or collections of Tensors can be mutated"
  152. )
  153. schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor') :]}"
  154. seen_args.add(name)
  155. if param.default is inspect.Parameter.empty:
  156. # pyrefly: ignore [bad-argument-type]
  157. params.append(f"{schema_type} {name}")
  158. else:
  159. default_repr = None
  160. if param.default is None or isinstance(param.default, (int, float, bool)):
  161. default_repr = str(param.default)
  162. elif isinstance(param.default, (str, torch.device)):
  163. default_repr = f'"{param.default}"'
  164. elif isinstance(param.default, torch.dtype):
  165. dtype_repr = str(param.default)
  166. torch_dot = "torch."
  167. if not dtype_repr.startswith(torch_dot):
  168. raise AssertionError(
  169. f"dtype repr {dtype_repr!r} must start with 'torch.'"
  170. )
  171. default_repr = dtype_repr[len(torch_dot) :]
  172. else:
  173. error_fn(
  174. f"Parameter {name} has an unsupported default value type {type(param.default)}. "
  175. f"Please file an issue on GitHub so we can prioritize this."
  176. )
  177. # pyrefly: ignore [bad-argument-type]
  178. params.append(f"{schema_type} {name}={default_repr}")
  179. if mutates_args != UNKNOWN_MUTATES:
  180. mutates_args_not_seen = set(mutates_args) - seen_args
  181. if len(mutates_args_not_seen) > 0:
  182. error_fn(
  183. f"{mutates_args_not_seen} in mutates_args were not found in "
  184. f"the custom op's signature. "
  185. f"mutates_args should contain the names of all args that the "
  186. f"custom op mutates, or just the string 'unknown' if you don't know."
  187. )
  188. return_annotation, _ = unstringify_type(sig.return_annotation)
  189. ret = parse_return(return_annotation, error_fn)
  190. if op_name is not None:
  191. return f"{op_name}({', '.join(params)}) -> {ret}"
  192. return f"({', '.join(params)}) -> {ret}"
  193. def derived_types(
  194. base_type: Union[type, typing._SpecialForm],
  195. cpp_type: str,
  196. list_base: bool,
  197. optional_base_list: bool,
  198. optional_list_base: bool,
  199. ):
  200. result: list[tuple[Union[type, typing._SpecialForm, GenericAlias], str]] = [
  201. (base_type, cpp_type),
  202. # pyrefly: ignore [not-a-type]
  203. (typing.Optional[base_type], f"{cpp_type}?"),
  204. ]
  205. def derived_seq_types(typ: Union[type, typing._SpecialForm]):
  206. return (
  207. typing.Sequence[typ], # type: ignore[valid-type] # noqa: UP006
  208. typing.List[typ], # type: ignore[valid-type] # noqa: UP006
  209. GenericAlias(collections.abc.Sequence, (typ,)),
  210. GenericAlias(list, (typ,)),
  211. )
  212. if list_base:
  213. result.extend(
  214. (seq_typ, f"{cpp_type}[]") for seq_typ in derived_seq_types(base_type)
  215. )
  216. if optional_base_list:
  217. result.extend(
  218. (seq_typ, f"{cpp_type}?[]")
  219. # pyrefly: ignore [not-a-type]
  220. for seq_typ in derived_seq_types(typing.Optional[base_type])
  221. )
  222. if optional_list_base:
  223. result.extend(
  224. (typing.Optional[seq_typ], f"{cpp_type}[]?")
  225. for seq_typ in derived_seq_types(base_type)
  226. )
  227. return result
  228. def get_supported_param_types():
  229. data: list[tuple[Union[type, typing._SpecialForm], str, bool, bool, bool]] = [
  230. # (python type, schema type, type[] variant, type?[] variant, type[]? variant
  231. (Tensor, "Tensor", True, True, False),
  232. (int, "SymInt", True, False, True),
  233. (float, "float", True, False, True),
  234. (bool, "bool", True, False, True),
  235. (str, "str", False, False, False),
  236. (types.Number, "Scalar", True, False, False),
  237. (dtype, "ScalarType", False, False, False),
  238. (device, "Device", False, False, False),
  239. ]
  240. if torch.distributed.is_available():
  241. from torch.distributed.distributed_c10d import GroupName
  242. data.append((typing.cast(type, GroupName), "str", False, False, False))
  243. result = []
  244. for line in data:
  245. result.extend(derived_types(*line))
  246. return dict(result)
  247. SUPPORTED_RETURN_TYPES = {
  248. Tensor: "Tensor",
  249. typing.List[Tensor]: "Tensor[]", # noqa: UP006
  250. list[Tensor]: "Tensor[]",
  251. int: "SymInt",
  252. float: "float",
  253. bool: "bool",
  254. types.Number: "Scalar",
  255. }
  256. def parse_return(annotation, error_fn):
  257. if annotation is None:
  258. return "()"
  259. if annotation is inspect.Parameter.empty:
  260. error_fn("No return type annotation was provided. Please add one.")
  261. origin = typing.get_origin(annotation)
  262. if origin is not tuple:
  263. if annotation not in SUPPORTED_RETURN_TYPES:
  264. error_fn(
  265. f"Return has unsupported type {annotation}. "
  266. f"The valid types are: {SUPPORTED_RETURN_TYPES}."
  267. )
  268. return SUPPORTED_RETURN_TYPES[annotation]
  269. args = typing.get_args(annotation)
  270. for arg in args:
  271. if arg not in SUPPORTED_RETURN_TYPES:
  272. error_fn(
  273. f"Return has unsupported type {annotation}. "
  274. f"The valid types are: {SUPPORTED_RETURN_TYPES}."
  275. )
  276. output_ty = ", ".join([SUPPORTED_RETURN_TYPES[arg] for arg in args])
  277. # use (()) to represent tuple with single element
  278. if len(args) == 1:
  279. output_ty = "(" + output_ty + ")"
  280. return "(" + output_ty + ")"
  281. SUPPORTED_PARAM_TYPES = get_supported_param_types()
  282. def supported_param(param: inspect.Parameter) -> bool:
  283. return param.kind in (
  284. inspect.Parameter.POSITIONAL_OR_KEYWORD,
  285. inspect.Parameter.KEYWORD_ONLY,
  286. )
  287. def tuple_to_list(tuple_type: type[tuple]) -> type[list]:
  288. """
  289. Convert `tuple_type` into a list type with the same type arguments. Assumes that `tuple_type` is typing.Tuple type.
  290. """
  291. type_args = getattr(tuple_type, "__args__", None)
  292. # Account for different python versions, e.g. python 3.8 would give ()
  293. # but python 3.12 would give None.
  294. if (
  295. tuple_type is typing.Tuple # noqa: UP006
  296. or tuple_type is tuple
  297. or type_args == ()
  298. or type_args is None
  299. ):
  300. # Handle the case of an empty tuple type
  301. return list
  302. elif len(type_args) == 1:
  303. # General case: create a List with the same type arguments
  304. return list[type_args[0]] # type: ignore[valid-type]
  305. elif len(type_args) == 2 and type_args[1] is Ellipsis:
  306. return list[type_args[0]] # type: ignore[valid-type]
  307. else:
  308. return list[typing.Union[tuple(type_args)]] # type: ignore[misc, return-value]