| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356 |
- # mypy: allow-untyped-defs
- import collections
- import inspect
- import typing
- from types import GenericAlias
- from typing import Optional, Union
- import torch
- from torch import device, dtype, Tensor, types
- from torch.utils._exposed_in import exposed_in
- from .opaque_object import _OPAQUE_TYPES, is_opaque_type
- # This is used as a negative test for
- # test_custom_ops.py::TestTypeConversion::test_type_eval.
- _TestTensor = torch.Tensor
- @exposed_in("torch.library")
- def infer_schema(
- prototype_function: typing.Callable,
- /,
- *,
- mutates_args,
- op_name: Optional[str] = None,
- ) -> str:
- r"""Parses the schema of a given function with type hints. The schema is inferred from the
- function's type hints, and can be used to define a new operator.
- We make the following assumptions:
- * None of the outputs alias any of the inputs or each other.
- * | String type annotations "device, dtype, Tensor, types" without library specification are
- | assumed to be torch.*. Similarly, string type annotations "Optional, List, Sequence, Union"
- | without library specification are assumed to be typing.*.
- * | Only the args listed in ``mutates_args`` are being mutated. If ``mutates_args`` is "unknown",
- | it assumes that all inputs to the operator are being mutates.
- Callers (e.g. the custom ops API) are responsible for checking these assumptions.
- Args:
- prototype_function: The function from which to infer a schema for from its type annotations.
- op_name (Optional[str]): The name of the operator in the schema. If ``name`` is None, then the
- name is not included in the inferred schema. Note that the input schema to
- ``torch.library.Library.define`` requires a operator name.
- mutates_args ("unknown" | Iterable[str]): The arguments that are mutated in the function.
- Returns:
- The inferred schema.
- Example:
- >>> def foo_impl(x: torch.Tensor) -> torch.Tensor:
- >>> return x.sin()
- >>>
- >>> infer_schema(foo_impl, op_name="foo", mutates_args={})
- foo(Tensor x) -> Tensor
- >>>
- >>> infer_schema(foo_impl, mutates_args={})
- (Tensor x) -> Tensor
- """
- UNKNOWN_MUTATES = "unknown"
- pf_globals = prototype_function.__globals__
- pf_locals = None
- # TODO: Once our minimum version is py3.10+ pass `eval_str=True` to
- # inspect.signature() and we no longer need to deal with stringified
- # annotations below.
- sig = inspect.signature(prototype_function)
- def error_fn(what):
- raise ValueError(f"infer_schema(func): {what} Got func with signature {sig})")
- def convert_type_string(annotation_type: str):
- try:
- return eval(annotation_type, pf_globals, pf_locals)
- except Exception:
- error_fn(
- f"Unsupported type annotation {annotation_type}. It is not a type."
- )
- def unstringify_types(
- tys: tuple[Union[type[object], str], ...],
- ) -> tuple[tuple[typing.Any, ...], bool]:
- res = []
- changed = False
- for ty in tys:
- ty, ty_changed = unstringify_type(ty)
- res.append(ty)
- changed |= ty_changed
- if changed:
- return tuple(res), True
- else:
- return tys, False # type: ignore[return-value]
- def unstringify_type(ty: Union[type[object], str]) -> tuple[typing.Any, bool]:
- # Dig through a generic type and if it contains a stringified type
- # convert that to a real type. The second return value indicates if the
- # type contained a string or not.
- if isinstance(ty, str):
- return convert_type_string(ty), True
- elif origin := typing.get_origin(ty):
- args, args_changed = unstringify_types(typing.get_args(ty))
- if args_changed:
- return GenericAlias(origin, args), True
- return ty, False
- params = []
- seen_args = set()
- saw_kwarg_only_arg = False
- for idx, (name, param) in enumerate(sig.parameters.items()):
- if not supported_param(param):
- error_fn("We do not support positional-only args, varargs, or varkwargs.")
- if param.kind == inspect.Parameter.KEYWORD_ONLY:
- # The first time we see a kwarg-only arg, add "*" to the schema.
- if not saw_kwarg_only_arg:
- params.append("*")
- saw_kwarg_only_arg = True
- if param.annotation is inspect.Parameter.empty:
- error_fn(f"Parameter {name} must have a type annotation.")
- # The annotation might be converted to a string by annotation,
- # we convert it to the actual type.
- annotation_type, _ = unstringify_type(param.annotation)
- schema_type = None
- if annotation_type not in SUPPORTED_PARAM_TYPES:
- if is_opaque_type(annotation_type):
- schema_type = _OPAQUE_TYPES[annotation_type].class_name
- elif annotation_type == torch._C.ScriptObject:
- error_fn(
- f"Parameter {name}'s type cannot be inferred from the schema "
- "as it is a ScriptObject. Please manually specify the schema "
- "using the `schema=` kwarg with the actual type of the ScriptObject."
- )
- elif (
- hasattr(annotation_type, "__origin__")
- and annotation_type.__origin__ is tuple
- ):
- list_type = tuple_to_list(annotation_type)
- example_type_str = "\n\n"
- # Only suggest the list type if this type is supported.
- if list_type in SUPPORTED_PARAM_TYPES:
- example_type_str = f"For example, {list_type}.\n\n"
- error_fn(
- f"Parameter {name} has unsupported type {param.annotation}. "
- f"We do not support Tuple inputs in schema. As a workaround, please try to use List instead. "
- f"{example_type_str}"
- f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}."
- )
- else:
- error_fn(
- f"Parameter {name} has unsupported type {param.annotation}. "
- f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}."
- )
- else:
- schema_type = SUPPORTED_PARAM_TYPES[annotation_type]
- if schema_type is None:
- raise AssertionError(f"schema_type is None for param {name}")
- if type(mutates_args) is str:
- if mutates_args != UNKNOWN_MUTATES:
- raise ValueError(
- "mutates_args must either be a sequence of the names of "
- "the arguments that are mutated or the string 'unknown'. "
- )
- if schema_type.startswith("Tensor"):
- schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor') :]}"
- elif name in mutates_args:
- if not schema_type.startswith("Tensor"):
- error_fn(
- f"Parameter {name} is in mutable_args but only Tensors or collections of Tensors can be mutated"
- )
- schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor') :]}"
- seen_args.add(name)
- if param.default is inspect.Parameter.empty:
- # pyrefly: ignore [bad-argument-type]
- params.append(f"{schema_type} {name}")
- else:
- default_repr = None
- if param.default is None or isinstance(param.default, (int, float, bool)):
- default_repr = str(param.default)
- elif isinstance(param.default, (str, torch.device)):
- default_repr = f'"{param.default}"'
- elif isinstance(param.default, torch.dtype):
- dtype_repr = str(param.default)
- torch_dot = "torch."
- if not dtype_repr.startswith(torch_dot):
- raise AssertionError(
- f"dtype repr {dtype_repr!r} must start with 'torch.'"
- )
- default_repr = dtype_repr[len(torch_dot) :]
- else:
- error_fn(
- f"Parameter {name} has an unsupported default value type {type(param.default)}. "
- f"Please file an issue on GitHub so we can prioritize this."
- )
- # pyrefly: ignore [bad-argument-type]
- params.append(f"{schema_type} {name}={default_repr}")
- if mutates_args != UNKNOWN_MUTATES:
- mutates_args_not_seen = set(mutates_args) - seen_args
- if len(mutates_args_not_seen) > 0:
- error_fn(
- f"{mutates_args_not_seen} in mutates_args were not found in "
- f"the custom op's signature. "
- f"mutates_args should contain the names of all args that the "
- f"custom op mutates, or just the string 'unknown' if you don't know."
- )
- return_annotation, _ = unstringify_type(sig.return_annotation)
- ret = parse_return(return_annotation, error_fn)
- if op_name is not None:
- return f"{op_name}({', '.join(params)}) -> {ret}"
- return f"({', '.join(params)}) -> {ret}"
- def derived_types(
- base_type: Union[type, typing._SpecialForm],
- cpp_type: str,
- list_base: bool,
- optional_base_list: bool,
- optional_list_base: bool,
- ):
- result: list[tuple[Union[type, typing._SpecialForm, GenericAlias], str]] = [
- (base_type, cpp_type),
- # pyrefly: ignore [not-a-type]
- (typing.Optional[base_type], f"{cpp_type}?"),
- ]
- def derived_seq_types(typ: Union[type, typing._SpecialForm]):
- return (
- typing.Sequence[typ], # type: ignore[valid-type] # noqa: UP006
- typing.List[typ], # type: ignore[valid-type] # noqa: UP006
- GenericAlias(collections.abc.Sequence, (typ,)),
- GenericAlias(list, (typ,)),
- )
- if list_base:
- result.extend(
- (seq_typ, f"{cpp_type}[]") for seq_typ in derived_seq_types(base_type)
- )
- if optional_base_list:
- result.extend(
- (seq_typ, f"{cpp_type}?[]")
- # pyrefly: ignore [not-a-type]
- for seq_typ in derived_seq_types(typing.Optional[base_type])
- )
- if optional_list_base:
- result.extend(
- (typing.Optional[seq_typ], f"{cpp_type}[]?")
- for seq_typ in derived_seq_types(base_type)
- )
- return result
- def get_supported_param_types():
- data: list[tuple[Union[type, typing._SpecialForm], str, bool, bool, bool]] = [
- # (python type, schema type, type[] variant, type?[] variant, type[]? variant
- (Tensor, "Tensor", True, True, False),
- (int, "SymInt", True, False, True),
- (float, "float", True, False, True),
- (bool, "bool", True, False, True),
- (str, "str", False, False, False),
- (types.Number, "Scalar", True, False, False),
- (dtype, "ScalarType", False, False, False),
- (device, "Device", False, False, False),
- ]
- if torch.distributed.is_available():
- from torch.distributed.distributed_c10d import GroupName
- data.append((typing.cast(type, GroupName), "str", False, False, False))
- result = []
- for line in data:
- result.extend(derived_types(*line))
- return dict(result)
- SUPPORTED_RETURN_TYPES = {
- Tensor: "Tensor",
- typing.List[Tensor]: "Tensor[]", # noqa: UP006
- list[Tensor]: "Tensor[]",
- int: "SymInt",
- float: "float",
- bool: "bool",
- types.Number: "Scalar",
- }
- def parse_return(annotation, error_fn):
- if annotation is None:
- return "()"
- if annotation is inspect.Parameter.empty:
- error_fn("No return type annotation was provided. Please add one.")
- origin = typing.get_origin(annotation)
- if origin is not tuple:
- if annotation not in SUPPORTED_RETURN_TYPES:
- error_fn(
- f"Return has unsupported type {annotation}. "
- f"The valid types are: {SUPPORTED_RETURN_TYPES}."
- )
- return SUPPORTED_RETURN_TYPES[annotation]
- args = typing.get_args(annotation)
- for arg in args:
- if arg not in SUPPORTED_RETURN_TYPES:
- error_fn(
- f"Return has unsupported type {annotation}. "
- f"The valid types are: {SUPPORTED_RETURN_TYPES}."
- )
- output_ty = ", ".join([SUPPORTED_RETURN_TYPES[arg] for arg in args])
- # use (()) to represent tuple with single element
- if len(args) == 1:
- output_ty = "(" + output_ty + ")"
- return "(" + output_ty + ")"
- SUPPORTED_PARAM_TYPES = get_supported_param_types()
- def supported_param(param: inspect.Parameter) -> bool:
- return param.kind in (
- inspect.Parameter.POSITIONAL_OR_KEYWORD,
- inspect.Parameter.KEYWORD_ONLY,
- )
- def tuple_to_list(tuple_type: type[tuple]) -> type[list]:
- """
- Convert `tuple_type` into a list type with the same type arguments. Assumes that `tuple_type` is typing.Tuple type.
- """
- type_args = getattr(tuple_type, "__args__", None)
- # Account for different python versions, e.g. python 3.8 would give ()
- # but python 3.12 would give None.
- if (
- tuple_type is typing.Tuple # noqa: UP006
- or tuple_type is tuple
- or type_args == ()
- or type_args is None
- ):
- # Handle the case of an empty tuple type
- return list
- elif len(type_args) == 1:
- # General case: create a List with the same type arguments
- return list[type_args[0]] # type: ignore[valid-type]
- elif len(type_args) == 2 and type_args[1] is Ellipsis:
- return list[type_args[0]] # type: ignore[valid-type]
- else:
- return list[typing.Union[tuple(type_args)]] # type: ignore[misc, return-value]
|