| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603 |
- # mypy: allow-untyped-defs
- import enum
- import inspect
- import numbers
- import types
- import typing
- import warnings
- from collections.abc import Callable
- from typing import Any, cast, NamedTuple, Optional, TYPE_CHECKING
- import torch
- from torch._jit_internal import boolean_dispatched
- from torch._ops import OpOverload, OpOverloadPacket
- from ._compatibility import compatibility
- if TYPE_CHECKING:
- from .node import Argument
- __all__ = [
- "ArgsKwargsPair",
- "check_for_mutable_operation",
- "get_signature_for_torch_op",
- "create_type_hint",
- "type_matches",
- "normalize_function",
- "normalize_module",
- ]
- @compatibility(is_backward_compatible=False)
- class ArgsKwargsPair(NamedTuple):
- """
- Simple named tuple for wrapping args/kwargs pairs.
- """
- args: tuple[Any, ...]
- kwargs: dict[str, Any]
- _manual_overrides: dict[Callable, list[inspect.Signature]] = {}
- def _nonzero_schemas():
- signatures = []
- def nonzero(self):
- pass
- signatures.append(inspect.signature(nonzero))
- def nonzero(self, *, as_tuple: bool): # type: ignore[no-redef]
- pass
- signatures.append(inspect.signature(nonzero))
- return signatures
- _manual_overrides[torch.nonzero] = _nonzero_schemas()
- class _FakeGlobalNamespace:
- def __getattr__(self, name):
- if name == "torch":
- return torch
- raise RuntimeError("Expected a torch namespace lookup")
- _type_eval_globals = {
- "Tensor": torch.Tensor,
- "Device": torch.device,
- "Layout": torch.layout,
- "number": numbers.Number,
- "Future": torch.jit.Future,
- "AnyEnumType": enum.Enum,
- "QScheme": torch.qscheme,
- "__torch__": _FakeGlobalNamespace(),
- "NoneType": type(None),
- "Storage": torch.UntypedStorage,
- "t": typing.TypeVar("t"),
- "PyObject": Any,
- }
- for k in dir(typing):
- _type_eval_globals[k] = getattr(typing, k)
- def _torchscript_type_to_python_type(ts_type: "torch._C.JitType") -> Any:
- """
- Convert a TorchScript type to a Python type (including subtypes) via
- eval'ing the annotation_str. _type_eval_globals sets up expressions
- like "List" and "Future" to map to actual types (typing.List and jit.Future)
- """
- return eval(ts_type.annotation_str, _type_eval_globals)
- def _torchscript_schema_to_signature_impl(
- ts_schema: torch._C.FunctionSchema,
- ) -> inspect.Signature:
- from inspect import Parameter
- parameters: list[Parameter] = []
- for arg in ts_schema.arguments:
- arg_type = _torchscript_type_to_python_type(arg.type)
- default = arg.default_value if arg.has_default_value() else Parameter.empty
- # TODO: Figure out if this is safe. It seems like when generating the type signatures for
- # PythonArgParser, we emit signatures with `input` instead of `self` as the first tensor
- # argument name. Downstream, if someone converts that positional argument to a keyword
- # argument, the name mismatch will break things, so here we're going to normalize the
- # name to "input"
- name = arg.name if arg.name != "self" else "input"
- kind = (
- Parameter.KEYWORD_ONLY
- if arg.kwarg_only
- else Parameter.POSITIONAL_OR_KEYWORD
- )
- # "from" is a keyword therefore it must be a POSITIONAL_ONLY argument
- if name == "from":
- if kind != Parameter.POSITIONAL_OR_KEYWORD:
- raise AssertionError(f"Expected POSITIONAL_OR_KEYWORD, got {kind}")
- # ParameterKind type is internal implementation detail to inspec package
- # which makes it hard to do type annotation
- kind = Parameter.POSITIONAL_ONLY # type: ignore[assignment]
- # This renders all previous arguments to positional only
- for idx, p in enumerate(parameters):
- if p.kind != Parameter.POSITIONAL_OR_KEYWORD:
- raise AssertionError(
- f"Expected POSITIONAL_OR_KEYWORD for param {p.name}, got {p.kind}"
- )
- parameters[idx] = Parameter(
- name=p.name,
- kind=Parameter.POSITIONAL_ONLY,
- default=p.default,
- annotation=p.annotation,
- )
- parameters.append(
- Parameter(name=name, kind=kind, default=default, annotation=arg_type)
- )
- return_types = [
- _torchscript_type_to_python_type(ret.type) for ret in ts_schema.returns
- ]
- if len(return_types) == 0:
- return_type = None
- elif len(return_types) == 1:
- return_type = return_types[0]
- else:
- return_type = tuple(return_types)
- return inspect.Signature(parameters, return_annotation=return_type)
- _SCHEMA_TO_SIGNATURE_CACHE: dict[tuple[str, str], inspect.Signature] = {}
- def _torchscript_schema_to_signature(
- ts_schema: torch._C.FunctionSchema,
- ) -> inspect.Signature:
- # Cached as it's called in the hot path of FakeTensor dispatch
- cache_key = ts_schema.name, ts_schema.overload_name
- cache_val = _SCHEMA_TO_SIGNATURE_CACHE.get(cache_key)
- if cache_val is not None:
- return cache_val
- res = _torchscript_schema_to_signature_impl(ts_schema)
- _SCHEMA_TO_SIGNATURE_CACHE[cache_key] = res
- return res
- @compatibility(is_backward_compatible=False)
- def check_for_mutable_operation(
- target: Callable, args: tuple["Argument", ...], kwargs: dict[str, "Argument"]
- ):
- signatures, schemas = get_signature_for_torch_op(target, return_schemas=True)
- if signatures and schemas:
- matched_schemas = []
- # Iterate through all of the schema until we find one that matches
- # If one matches, populate `new_args_and_kwargs` with the new args/kwargs
- # values. If none matches, `new_args_and_kwargs` will be None
- for candidate_signature, schema in zip(signatures, schemas):
- try:
- candidate_signature.bind(*args, **kwargs)
- matched_schemas.append((candidate_signature, schema))
- except TypeError:
- continue
- def throw_if_mutable(schema):
- if schema.is_mutable:
- raise RuntimeError(
- f"Tried to trace mutable operation {schema}. FX only supports functional "
- f"code, so operations that mutate operands in-place (e.g. via `out` arguments) "
- f"are not supported"
- )
- if len(matched_schemas) == 0:
- # Did not match any schema. Cannot check for mutation
- pass
- elif len(matched_schemas) == 1:
- # Matched exactly one schema, unambiguous
- _, schema_to_check = matched_schemas[0]
- throw_if_mutable(schema_to_check)
- else:
- # Ambiguous schema match. Since mutability checking is best effort,
- # do nothing.
- pass
- @compatibility(is_backward_compatible=False)
- def get_signature_for_torch_op(op: Callable, return_schemas: bool = False):
- """
- Given an operator on the `torch` namespace, return a list of `inspect.Signature`
- objects corresponding to the overloads of that op.. May return `None` if a signature
- could not be retrieved.
- Args:
- op (Callable): An operator on the `torch` namespace to look up a signature for
- Returns:
- Optional[List[inspect.Signature]]: A list of signatures for the overloads of this
- operator, or None if the operator signatures could not be retrieved. If
- return_schemas=True, returns a tuple containing the optional Python signatures
- and the optional TorchScript Function signature
- """
- if isinstance(op, OpOverload):
- schemas = [op._schema]
- elif isinstance(op, OpOverloadPacket):
- schemas = [getattr(op, overload)._schema for overload in op.overloads()]
- else:
- override = _manual_overrides.get(op)
- if override:
- return (override, None) if return_schemas else None
- aten_fn = torch.jit._builtins._find_builtin(op)
- if aten_fn is None:
- return (None, None) if return_schemas else None
- schemas = torch._C._jit_get_schemas_for_operator(aten_fn)
- signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
- return (signatures, schemas) if return_schemas else signatures
- @compatibility(is_backward_compatible=False)
- def create_type_hint(x):
- """
- Produces a type hint for the given argument.
- The :func:`create_type_hint` looks for a type hint compatible with the input argument `x`.
- If `x` is a `list` or `tuple`, it looks for an object in the list whose type is a superclass
- of the rest, and uses that as `base_type` for the `List` or `Tuple` to be returned.
- If no such object is found, it defaults to `List[Any]`.
- If `x` is neither a `list` nor a `tuple`, it returns `x`.
- """
- try:
- if isinstance(x, (list, tuple)):
- # todo(chilli): Figure out the right way for mypy to handle this
- if isinstance(x, list):
- def ret_type(x):
- return list[x] # type: ignore[valid-type]
- else:
- def ret_type(x):
- return tuple[x, ...] # type: ignore[valid-type]
- if len(x) == 0:
- return ret_type(Any)
- base_type = x[0]
- for t in x:
- if issubclass(t, base_type):
- continue
- elif issubclass(base_type, t):
- base_type = t
- else:
- return ret_type(Any)
- return ret_type(base_type)
- except Exception:
- # We tried to create a type hint for list but failed.
- warnings.warn(
- f"We were not able to successfully create type hint from the type {x}"
- )
- return x
- @compatibility(is_backward_compatible=False)
- def type_matches(signature_type: Any, argument_type: Any):
- sig_origin_type = getattr(signature_type, "__origin__", signature_type)
- if signature_type is argument_type:
- return True
- # Union types in signature. Given type needs to match one of the
- # contained types in the Union
- if sig_origin_type is typing.Union and signature_type != argument_type:
- sig_contained = signature_type.__args__
- return any(type_matches(c, argument_type) for c in sig_contained)
- if getattr(signature_type, "__origin__", None) is list:
- sig_el_type = signature_type.__args__[0]
- # int can be promoted to list[int]
- if argument_type is int and sig_el_type is int:
- return True
- if not inspect.isclass(sig_el_type):
- warnings.warn(
- f"Does not support nested parametric types, got {signature_type}. Please file a bug."
- )
- return False
- if getattr(argument_type, "__origin__", None) is list:
- return issubclass(argument_type.__args__[0], sig_el_type)
- def is_homogeneous_tuple(t):
- if getattr(t, "__origin__", None) is not tuple:
- return False
- contained = t.__args__
- if t.__args__ == ((),): # Tuple[()].__args__ == ((),) for some reason
- return True
- return all((c is Ellipsis) or issubclass(c, sig_el_type) for c in contained)
- # Tuple[T] is accepted for List[T] parameters
- return is_homogeneous_tuple(argument_type)
- # Dtype is an int in schemas
- if signature_type is int and argument_type is torch.dtype:
- return True
- if signature_type is numbers.Number and argument_type in {int, float}:
- return True
- if inspect.isclass(argument_type) and inspect.isclass(signature_type):
- return issubclass(argument_type, signature_type)
- return False
- @compatibility(is_backward_compatible=False)
- def _normalize_function_or_error(
- target: Callable,
- args: tuple[Any, ...],
- kwargs: Optional[dict[str, Any]] = None,
- arg_types: Optional[tuple[Any]] = None,
- kwarg_types: Optional[dict[str, Any]] = None,
- normalize_to_only_use_kwargs: bool = False,
- ) -> ArgsKwargsPair:
- """
- Wrapper around normalize_function that never returns None, but
- loudly errors instead
- """
- res = normalize_function(
- target, args, kwargs, arg_types, kwarg_types, normalize_to_only_use_kwargs
- )
- if res is None:
- raise RuntimeError(
- f"Failed to normalize function {target} with args {args} and kwargs {kwargs}"
- )
- else:
- return res
- @compatibility(is_backward_compatible=False)
- def normalize_function(
- target: Callable,
- args: tuple[Any, ...],
- kwargs: Optional[dict[str, Any]] = None,
- arg_types: Optional[tuple[Any]] = None,
- kwarg_types: Optional[dict[str, Any]] = None,
- normalize_to_only_use_kwargs: bool = False,
- ) -> Optional[ArgsKwargsPair]:
- """
- Returns normalized arguments to PyTorch functions. This means that
- `args/kwargs` will be matched up to the functional's
- signature and return exclusively kwargs in positional order if
- `normalize_to_only_use_kwargs` is True.
- Also populates default values. Does not support positional-only
- parameters or varargs parameters (*args, **kwargs). Does not support modules.
- May require `arg_types` and `kwarg_types` in order to disambiguate overloads.
- Args:
- target (Callable): Function that we are normalizing
- args (Tuple[Any]): Tuple of args to the function
- kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function
- arg_types (Optional[Tuple[Any]]): Tuple of arg types for the args
- kwarg_types (Optional[Dict[str, Any]]): Dict of arg types for the kwargs
- normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs.
- Returns:
- Returns normalized_args_and_kwargs, or `None` if not successful.
- """
- if kwargs is None:
- kwargs = {}
- new_args_and_kwargs = None
- if (
- not isinstance(target, types.BuiltinFunctionType)
- and not (isinstance(target, (OpOverloadPacket, OpOverload)))
- and hasattr(target, "_op")
- ):
- # ExecuTorch's EdgeOpOverload are a wrapper around PyTorch's OpOverload,
- # so we can unwrap it here to get its schema
- # Can't import EdgeOpOverload directly because of a circular dependency,
- # so checking for "_op" existing is the next best thing.
- target = target._op
- # Repeat the condition after checking for the inner _op field.
- if not isinstance(target, types.BuiltinFunctionType) and not (
- isinstance(target, (OpOverloadPacket, OpOverload))
- ):
- target_for_analysis = target
- if target in boolean_dispatched:
- # HACK: `boolean_dispatch` as used in `torch.nn.functional` makes it so that we have
- # a 2-way dispatch based on a boolean value. Here we check that the `true` and `false`
- # branches of the dispatch have exactly the same signature. If they do, use the `true`
- # branch signature for analysis. Otherwise, leave this un-normalized
- if isinstance(target, str):
- raise AssertionError("target should not be a string here")
- dispatched = boolean_dispatched[target]
- if_true, if_false = dispatched["if_true"], dispatched["if_false"]
- if (
- inspect.signature(if_true).parameters
- != inspect.signature(if_false).parameters
- ):
- return None
- target_for_analysis = if_true
- if not callable(target_for_analysis):
- raise AssertionError(
- f"target_for_analysis must be callable, got {type(target_for_analysis)}"
- )
- sig = inspect.signature(inspect.unwrap(target_for_analysis))
- new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(
- sig, args, kwargs, normalize_to_only_use_kwargs
- )
- else:
- if not callable(target):
- raise AssertionError(f"target must be callable, got {type(target)}")
- torch_op_schemas = get_signature_for_torch_op(target)
- matched_schemas = []
- if torch_op_schemas:
- # Iterate through all of the schema until we find one that matches
- # If one matches, populate `new_args_and_kwargs` with the new args/kwargs
- # values. If none matches, `new_args_and_kwargs` will be None
- for candidate_signature in torch_op_schemas:
- try:
- candidate_signature.bind(*args, **kwargs)
- matched_schemas.append(candidate_signature)
- except TypeError:
- continue
- if len(matched_schemas) == 0:
- # Did not match any schema. Cannot normalize
- pass
- elif len(matched_schemas) == 1:
- # Matched exactly one schema, unambiguous
- new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(
- matched_schemas[0], args, kwargs, normalize_to_only_use_kwargs
- )
- else:
- if arg_types is not None or kwarg_types is not None:
- arg_types = arg_types if arg_types else cast(tuple[Any], ())
- kwarg_types = kwarg_types if kwarg_types else {}
- for candidate_signature in torch_op_schemas:
- sig_matches = True
- try:
- bound_types = candidate_signature.bind(
- *arg_types, **kwarg_types
- )
- for arg_name, arg_type in bound_types.arguments.items():
- param = candidate_signature.parameters[arg_name]
- sig_matches = sig_matches and type_matches(
- param.annotation, arg_type
- )
- except TypeError:
- sig_matches = False
- if sig_matches:
- new_args_and_kwargs = (
- _args_kwargs_to_normalized_args_kwargs(
- candidate_signature,
- args,
- kwargs,
- normalize_to_only_use_kwargs,
- )
- )
- break
- else:
- # Matched more than one schema. In this situation, the caller must provide the types of
- # the arguments of the overload they expect.
- schema_printouts = "\n".join(
- str(schema) for schema in matched_schemas
- )
- raise RuntimeError(
- f"Tried to normalize arguments to {torch.typename(target)} but "
- f"the schema match was ambiguous! Please provide argument types to "
- f"the normalize_arguments() call. Available schemas:\n{schema_printouts}"
- )
- return new_args_and_kwargs
- @compatibility(is_backward_compatible=False)
- def normalize_module(
- root: torch.nn.Module,
- target: str,
- args: tuple[Any],
- kwargs: Optional[dict[str, Any]] = None,
- normalize_to_only_use_kwargs: bool = False,
- ) -> Optional[ArgsKwargsPair]:
- """
- Returns normalized arguments to PyTorch modules. This means that
- `args/kwargs` will be matched up to the functional's
- signature and return exclusively kwargs in positional order if
- `normalize_to_only_use_kwargs` is True.
- Also populates default values. Does not support positional-only
- parameters or varargs parameters (*args, **kwargs).
- Args:
- root (nn.Module): root module upon which we query modules
- target (Callable): Function that we are normalizing
- args (Tuple[Any]): Tuple of args to the function
- kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function
- normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs.
- Returns:
- Returns normalized_args_and_kwargs, or `None` if not successful.
- """
- try:
- submod = root.get_submodule(target)
- except AttributeError as e:
- raise RuntimeError(
- f"Tried to normalize node with target {target} but root did not "
- f"have that target!"
- ) from e
- if hasattr(submod.__class__, "__name__"):
- classname = submod.__class__.__name__
- if getattr(torch.nn, classname, None) == submod.__class__:
- sig = inspect.signature(inspect.unwrap(submod.forward))
- if kwargs is None:
- kwargs = {}
- new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(
- sig, args, kwargs, normalize_to_only_use_kwargs
- )
- return new_args_and_kwargs
- return None
- def _args_kwargs_to_normalized_args_kwargs(
- sig: inspect.Signature,
- args: tuple[Any, ...],
- kwargs: dict[str, Any],
- normalize_to_only_use_kwargs: bool,
- ) -> Optional[ArgsKwargsPair]:
- """
- Given a call target, args, and kwargs, return the arguments normalized into
- an ArgsKwargsPair, or None if the type signature is not supported by
- this normalization.
- Args:
- sig (inspect.Signature): Signature object for the target
- args (Tuple): Arguments that appear at the callsite for `target`
- kwargs (Dict): Keyword arguments that appear at the callsite for `target`
- normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs.
- Returns:
- Optional[ArgsKwargsPair]: Normalized args and kwargs for `target`, or `None` if
- this target is not supported.
- """
- # Don't currently support positional-only
- # or varargs (*args, **kwargs) signatures
- supported_parameter_types = {
- inspect.Parameter.POSITIONAL_OR_KEYWORD,
- inspect.Parameter.KEYWORD_ONLY,
- }
- if any(p.kind not in supported_parameter_types for p in sig.parameters.values()):
- # Add an exception for one signature, which is common for random/uniform, i.e.:
- # Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None
- # `from` is Python keyword and as such functions with that signature should have
- # positional-only args, but at the same time they could be dispatched as kwargs
- if list(sig.parameters.keys()) != ["input", "from", "to", "generator"]:
- return None
- bound_args = sig.bind(*args, **kwargs)
- bound_args.apply_defaults()
- new_kwargs: dict[str, Any] = {}
- new_args: list[Any] = []
- for i, param in enumerate(sig.parameters):
- if not normalize_to_only_use_kwargs and i < len(args):
- new_args.append(bound_args.arguments[param])
- else:
- new_kwargs[param] = bound_args.arguments[param]
- return ArgsKwargsPair(tuple(new_args), new_kwargs)
|