operator_schemas.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603
  1. # mypy: allow-untyped-defs
  2. import enum
  3. import inspect
  4. import numbers
  5. import types
  6. import typing
  7. import warnings
  8. from collections.abc import Callable
  9. from typing import Any, cast, NamedTuple, Optional, TYPE_CHECKING
  10. import torch
  11. from torch._jit_internal import boolean_dispatched
  12. from torch._ops import OpOverload, OpOverloadPacket
  13. from ._compatibility import compatibility
  14. if TYPE_CHECKING:
  15. from .node import Argument
  16. __all__ = [
  17. "ArgsKwargsPair",
  18. "check_for_mutable_operation",
  19. "get_signature_for_torch_op",
  20. "create_type_hint",
  21. "type_matches",
  22. "normalize_function",
  23. "normalize_module",
  24. ]
  25. @compatibility(is_backward_compatible=False)
  26. class ArgsKwargsPair(NamedTuple):
  27. """
  28. Simple named tuple for wrapping args/kwargs pairs.
  29. """
  30. args: tuple[Any, ...]
  31. kwargs: dict[str, Any]
  32. _manual_overrides: dict[Callable, list[inspect.Signature]] = {}
  33. def _nonzero_schemas():
  34. signatures = []
  35. def nonzero(self):
  36. pass
  37. signatures.append(inspect.signature(nonzero))
  38. def nonzero(self, *, as_tuple: bool): # type: ignore[no-redef]
  39. pass
  40. signatures.append(inspect.signature(nonzero))
  41. return signatures
  42. _manual_overrides[torch.nonzero] = _nonzero_schemas()
  43. class _FakeGlobalNamespace:
  44. def __getattr__(self, name):
  45. if name == "torch":
  46. return torch
  47. raise RuntimeError("Expected a torch namespace lookup")
  48. _type_eval_globals = {
  49. "Tensor": torch.Tensor,
  50. "Device": torch.device,
  51. "Layout": torch.layout,
  52. "number": numbers.Number,
  53. "Future": torch.jit.Future,
  54. "AnyEnumType": enum.Enum,
  55. "QScheme": torch.qscheme,
  56. "__torch__": _FakeGlobalNamespace(),
  57. "NoneType": type(None),
  58. "Storage": torch.UntypedStorage,
  59. "t": typing.TypeVar("t"),
  60. "PyObject": Any,
  61. }
  62. for k in dir(typing):
  63. _type_eval_globals[k] = getattr(typing, k)
  64. def _torchscript_type_to_python_type(ts_type: "torch._C.JitType") -> Any:
  65. """
  66. Convert a TorchScript type to a Python type (including subtypes) via
  67. eval'ing the annotation_str. _type_eval_globals sets up expressions
  68. like "List" and "Future" to map to actual types (typing.List and jit.Future)
  69. """
  70. return eval(ts_type.annotation_str, _type_eval_globals)
  71. def _torchscript_schema_to_signature_impl(
  72. ts_schema: torch._C.FunctionSchema,
  73. ) -> inspect.Signature:
  74. from inspect import Parameter
  75. parameters: list[Parameter] = []
  76. for arg in ts_schema.arguments:
  77. arg_type = _torchscript_type_to_python_type(arg.type)
  78. default = arg.default_value if arg.has_default_value() else Parameter.empty
  79. # TODO: Figure out if this is safe. It seems like when generating the type signatures for
  80. # PythonArgParser, we emit signatures with `input` instead of `self` as the first tensor
  81. # argument name. Downstream, if someone converts that positional argument to a keyword
  82. # argument, the name mismatch will break things, so here we're going to normalize the
  83. # name to "input"
  84. name = arg.name if arg.name != "self" else "input"
  85. kind = (
  86. Parameter.KEYWORD_ONLY
  87. if arg.kwarg_only
  88. else Parameter.POSITIONAL_OR_KEYWORD
  89. )
  90. # "from" is a keyword therefore it must be a POSITIONAL_ONLY argument
  91. if name == "from":
  92. if kind != Parameter.POSITIONAL_OR_KEYWORD:
  93. raise AssertionError(f"Expected POSITIONAL_OR_KEYWORD, got {kind}")
  94. # ParameterKind type is internal implementation detail to inspec package
  95. # which makes it hard to do type annotation
  96. kind = Parameter.POSITIONAL_ONLY # type: ignore[assignment]
  97. # This renders all previous arguments to positional only
  98. for idx, p in enumerate(parameters):
  99. if p.kind != Parameter.POSITIONAL_OR_KEYWORD:
  100. raise AssertionError(
  101. f"Expected POSITIONAL_OR_KEYWORD for param {p.name}, got {p.kind}"
  102. )
  103. parameters[idx] = Parameter(
  104. name=p.name,
  105. kind=Parameter.POSITIONAL_ONLY,
  106. default=p.default,
  107. annotation=p.annotation,
  108. )
  109. parameters.append(
  110. Parameter(name=name, kind=kind, default=default, annotation=arg_type)
  111. )
  112. return_types = [
  113. _torchscript_type_to_python_type(ret.type) for ret in ts_schema.returns
  114. ]
  115. if len(return_types) == 0:
  116. return_type = None
  117. elif len(return_types) == 1:
  118. return_type = return_types[0]
  119. else:
  120. return_type = tuple(return_types)
  121. return inspect.Signature(parameters, return_annotation=return_type)
  122. _SCHEMA_TO_SIGNATURE_CACHE: dict[tuple[str, str], inspect.Signature] = {}
  123. def _torchscript_schema_to_signature(
  124. ts_schema: torch._C.FunctionSchema,
  125. ) -> inspect.Signature:
  126. # Cached as it's called in the hot path of FakeTensor dispatch
  127. cache_key = ts_schema.name, ts_schema.overload_name
  128. cache_val = _SCHEMA_TO_SIGNATURE_CACHE.get(cache_key)
  129. if cache_val is not None:
  130. return cache_val
  131. res = _torchscript_schema_to_signature_impl(ts_schema)
  132. _SCHEMA_TO_SIGNATURE_CACHE[cache_key] = res
  133. return res
  134. @compatibility(is_backward_compatible=False)
  135. def check_for_mutable_operation(
  136. target: Callable, args: tuple["Argument", ...], kwargs: dict[str, "Argument"]
  137. ):
  138. signatures, schemas = get_signature_for_torch_op(target, return_schemas=True)
  139. if signatures and schemas:
  140. matched_schemas = []
  141. # Iterate through all of the schema until we find one that matches
  142. # If one matches, populate `new_args_and_kwargs` with the new args/kwargs
  143. # values. If none matches, `new_args_and_kwargs` will be None
  144. for candidate_signature, schema in zip(signatures, schemas):
  145. try:
  146. candidate_signature.bind(*args, **kwargs)
  147. matched_schemas.append((candidate_signature, schema))
  148. except TypeError:
  149. continue
  150. def throw_if_mutable(schema):
  151. if schema.is_mutable:
  152. raise RuntimeError(
  153. f"Tried to trace mutable operation {schema}. FX only supports functional "
  154. f"code, so operations that mutate operands in-place (e.g. via `out` arguments) "
  155. f"are not supported"
  156. )
  157. if len(matched_schemas) == 0:
  158. # Did not match any schema. Cannot check for mutation
  159. pass
  160. elif len(matched_schemas) == 1:
  161. # Matched exactly one schema, unambiguous
  162. _, schema_to_check = matched_schemas[0]
  163. throw_if_mutable(schema_to_check)
  164. else:
  165. # Ambiguous schema match. Since mutability checking is best effort,
  166. # do nothing.
  167. pass
  168. @compatibility(is_backward_compatible=False)
  169. def get_signature_for_torch_op(op: Callable, return_schemas: bool = False):
  170. """
  171. Given an operator on the `torch` namespace, return a list of `inspect.Signature`
  172. objects corresponding to the overloads of that op.. May return `None` if a signature
  173. could not be retrieved.
  174. Args:
  175. op (Callable): An operator on the `torch` namespace to look up a signature for
  176. Returns:
  177. Optional[List[inspect.Signature]]: A list of signatures for the overloads of this
  178. operator, or None if the operator signatures could not be retrieved. If
  179. return_schemas=True, returns a tuple containing the optional Python signatures
  180. and the optional TorchScript Function signature
  181. """
  182. if isinstance(op, OpOverload):
  183. schemas = [op._schema]
  184. elif isinstance(op, OpOverloadPacket):
  185. schemas = [getattr(op, overload)._schema for overload in op.overloads()]
  186. else:
  187. override = _manual_overrides.get(op)
  188. if override:
  189. return (override, None) if return_schemas else None
  190. aten_fn = torch.jit._builtins._find_builtin(op)
  191. if aten_fn is None:
  192. return (None, None) if return_schemas else None
  193. schemas = torch._C._jit_get_schemas_for_operator(aten_fn)
  194. signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
  195. return (signatures, schemas) if return_schemas else signatures
  196. @compatibility(is_backward_compatible=False)
  197. def create_type_hint(x):
  198. """
  199. Produces a type hint for the given argument.
  200. The :func:`create_type_hint` looks for a type hint compatible with the input argument `x`.
  201. If `x` is a `list` or `tuple`, it looks for an object in the list whose type is a superclass
  202. of the rest, and uses that as `base_type` for the `List` or `Tuple` to be returned.
  203. If no such object is found, it defaults to `List[Any]`.
  204. If `x` is neither a `list` nor a `tuple`, it returns `x`.
  205. """
  206. try:
  207. if isinstance(x, (list, tuple)):
  208. # todo(chilli): Figure out the right way for mypy to handle this
  209. if isinstance(x, list):
  210. def ret_type(x):
  211. return list[x] # type: ignore[valid-type]
  212. else:
  213. def ret_type(x):
  214. return tuple[x, ...] # type: ignore[valid-type]
  215. if len(x) == 0:
  216. return ret_type(Any)
  217. base_type = x[0]
  218. for t in x:
  219. if issubclass(t, base_type):
  220. continue
  221. elif issubclass(base_type, t):
  222. base_type = t
  223. else:
  224. return ret_type(Any)
  225. return ret_type(base_type)
  226. except Exception:
  227. # We tried to create a type hint for list but failed.
  228. warnings.warn(
  229. f"We were not able to successfully create type hint from the type {x}"
  230. )
  231. return x
  232. @compatibility(is_backward_compatible=False)
  233. def type_matches(signature_type: Any, argument_type: Any):
  234. sig_origin_type = getattr(signature_type, "__origin__", signature_type)
  235. if signature_type is argument_type:
  236. return True
  237. # Union types in signature. Given type needs to match one of the
  238. # contained types in the Union
  239. if sig_origin_type is typing.Union and signature_type != argument_type:
  240. sig_contained = signature_type.__args__
  241. return any(type_matches(c, argument_type) for c in sig_contained)
  242. if getattr(signature_type, "__origin__", None) is list:
  243. sig_el_type = signature_type.__args__[0]
  244. # int can be promoted to list[int]
  245. if argument_type is int and sig_el_type is int:
  246. return True
  247. if not inspect.isclass(sig_el_type):
  248. warnings.warn(
  249. f"Does not support nested parametric types, got {signature_type}. Please file a bug."
  250. )
  251. return False
  252. if getattr(argument_type, "__origin__", None) is list:
  253. return issubclass(argument_type.__args__[0], sig_el_type)
  254. def is_homogeneous_tuple(t):
  255. if getattr(t, "__origin__", None) is not tuple:
  256. return False
  257. contained = t.__args__
  258. if t.__args__ == ((),): # Tuple[()].__args__ == ((),) for some reason
  259. return True
  260. return all((c is Ellipsis) or issubclass(c, sig_el_type) for c in contained)
  261. # Tuple[T] is accepted for List[T] parameters
  262. return is_homogeneous_tuple(argument_type)
  263. # Dtype is an int in schemas
  264. if signature_type is int and argument_type is torch.dtype:
  265. return True
  266. if signature_type is numbers.Number and argument_type in {int, float}:
  267. return True
  268. if inspect.isclass(argument_type) and inspect.isclass(signature_type):
  269. return issubclass(argument_type, signature_type)
  270. return False
  271. @compatibility(is_backward_compatible=False)
  272. def _normalize_function_or_error(
  273. target: Callable,
  274. args: tuple[Any, ...],
  275. kwargs: Optional[dict[str, Any]] = None,
  276. arg_types: Optional[tuple[Any]] = None,
  277. kwarg_types: Optional[dict[str, Any]] = None,
  278. normalize_to_only_use_kwargs: bool = False,
  279. ) -> ArgsKwargsPair:
  280. """
  281. Wrapper around normalize_function that never returns None, but
  282. loudly errors instead
  283. """
  284. res = normalize_function(
  285. target, args, kwargs, arg_types, kwarg_types, normalize_to_only_use_kwargs
  286. )
  287. if res is None:
  288. raise RuntimeError(
  289. f"Failed to normalize function {target} with args {args} and kwargs {kwargs}"
  290. )
  291. else:
  292. return res
  293. @compatibility(is_backward_compatible=False)
  294. def normalize_function(
  295. target: Callable,
  296. args: tuple[Any, ...],
  297. kwargs: Optional[dict[str, Any]] = None,
  298. arg_types: Optional[tuple[Any]] = None,
  299. kwarg_types: Optional[dict[str, Any]] = None,
  300. normalize_to_only_use_kwargs: bool = False,
  301. ) -> Optional[ArgsKwargsPair]:
  302. """
  303. Returns normalized arguments to PyTorch functions. This means that
  304. `args/kwargs` will be matched up to the functional's
  305. signature and return exclusively kwargs in positional order if
  306. `normalize_to_only_use_kwargs` is True.
  307. Also populates default values. Does not support positional-only
  308. parameters or varargs parameters (*args, **kwargs). Does not support modules.
  309. May require `arg_types` and `kwarg_types` in order to disambiguate overloads.
  310. Args:
  311. target (Callable): Function that we are normalizing
  312. args (Tuple[Any]): Tuple of args to the function
  313. kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function
  314. arg_types (Optional[Tuple[Any]]): Tuple of arg types for the args
  315. kwarg_types (Optional[Dict[str, Any]]): Dict of arg types for the kwargs
  316. normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs.
  317. Returns:
  318. Returns normalized_args_and_kwargs, or `None` if not successful.
  319. """
  320. if kwargs is None:
  321. kwargs = {}
  322. new_args_and_kwargs = None
  323. if (
  324. not isinstance(target, types.BuiltinFunctionType)
  325. and not (isinstance(target, (OpOverloadPacket, OpOverload)))
  326. and hasattr(target, "_op")
  327. ):
  328. # ExecuTorch's EdgeOpOverload are a wrapper around PyTorch's OpOverload,
  329. # so we can unwrap it here to get its schema
  330. # Can't import EdgeOpOverload directly because of a circular dependency,
  331. # so checking for "_op" existing is the next best thing.
  332. target = target._op
  333. # Repeat the condition after checking for the inner _op field.
  334. if not isinstance(target, types.BuiltinFunctionType) and not (
  335. isinstance(target, (OpOverloadPacket, OpOverload))
  336. ):
  337. target_for_analysis = target
  338. if target in boolean_dispatched:
  339. # HACK: `boolean_dispatch` as used in `torch.nn.functional` makes it so that we have
  340. # a 2-way dispatch based on a boolean value. Here we check that the `true` and `false`
  341. # branches of the dispatch have exactly the same signature. If they do, use the `true`
  342. # branch signature for analysis. Otherwise, leave this un-normalized
  343. if isinstance(target, str):
  344. raise AssertionError("target should not be a string here")
  345. dispatched = boolean_dispatched[target]
  346. if_true, if_false = dispatched["if_true"], dispatched["if_false"]
  347. if (
  348. inspect.signature(if_true).parameters
  349. != inspect.signature(if_false).parameters
  350. ):
  351. return None
  352. target_for_analysis = if_true
  353. if not callable(target_for_analysis):
  354. raise AssertionError(
  355. f"target_for_analysis must be callable, got {type(target_for_analysis)}"
  356. )
  357. sig = inspect.signature(inspect.unwrap(target_for_analysis))
  358. new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(
  359. sig, args, kwargs, normalize_to_only_use_kwargs
  360. )
  361. else:
  362. if not callable(target):
  363. raise AssertionError(f"target must be callable, got {type(target)}")
  364. torch_op_schemas = get_signature_for_torch_op(target)
  365. matched_schemas = []
  366. if torch_op_schemas:
  367. # Iterate through all of the schema until we find one that matches
  368. # If one matches, populate `new_args_and_kwargs` with the new args/kwargs
  369. # values. If none matches, `new_args_and_kwargs` will be None
  370. for candidate_signature in torch_op_schemas:
  371. try:
  372. candidate_signature.bind(*args, **kwargs)
  373. matched_schemas.append(candidate_signature)
  374. except TypeError:
  375. continue
  376. if len(matched_schemas) == 0:
  377. # Did not match any schema. Cannot normalize
  378. pass
  379. elif len(matched_schemas) == 1:
  380. # Matched exactly one schema, unambiguous
  381. new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(
  382. matched_schemas[0], args, kwargs, normalize_to_only_use_kwargs
  383. )
  384. else:
  385. if arg_types is not None or kwarg_types is not None:
  386. arg_types = arg_types if arg_types else cast(tuple[Any], ())
  387. kwarg_types = kwarg_types if kwarg_types else {}
  388. for candidate_signature in torch_op_schemas:
  389. sig_matches = True
  390. try:
  391. bound_types = candidate_signature.bind(
  392. *arg_types, **kwarg_types
  393. )
  394. for arg_name, arg_type in bound_types.arguments.items():
  395. param = candidate_signature.parameters[arg_name]
  396. sig_matches = sig_matches and type_matches(
  397. param.annotation, arg_type
  398. )
  399. except TypeError:
  400. sig_matches = False
  401. if sig_matches:
  402. new_args_and_kwargs = (
  403. _args_kwargs_to_normalized_args_kwargs(
  404. candidate_signature,
  405. args,
  406. kwargs,
  407. normalize_to_only_use_kwargs,
  408. )
  409. )
  410. break
  411. else:
  412. # Matched more than one schema. In this situation, the caller must provide the types of
  413. # the arguments of the overload they expect.
  414. schema_printouts = "\n".join(
  415. str(schema) for schema in matched_schemas
  416. )
  417. raise RuntimeError(
  418. f"Tried to normalize arguments to {torch.typename(target)} but "
  419. f"the schema match was ambiguous! Please provide argument types to "
  420. f"the normalize_arguments() call. Available schemas:\n{schema_printouts}"
  421. )
  422. return new_args_and_kwargs
  423. @compatibility(is_backward_compatible=False)
  424. def normalize_module(
  425. root: torch.nn.Module,
  426. target: str,
  427. args: tuple[Any],
  428. kwargs: Optional[dict[str, Any]] = None,
  429. normalize_to_only_use_kwargs: bool = False,
  430. ) -> Optional[ArgsKwargsPair]:
  431. """
  432. Returns normalized arguments to PyTorch modules. This means that
  433. `args/kwargs` will be matched up to the functional's
  434. signature and return exclusively kwargs in positional order if
  435. `normalize_to_only_use_kwargs` is True.
  436. Also populates default values. Does not support positional-only
  437. parameters or varargs parameters (*args, **kwargs).
  438. Args:
  439. root (nn.Module): root module upon which we query modules
  440. target (Callable): Function that we are normalizing
  441. args (Tuple[Any]): Tuple of args to the function
  442. kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function
  443. normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs.
  444. Returns:
  445. Returns normalized_args_and_kwargs, or `None` if not successful.
  446. """
  447. try:
  448. submod = root.get_submodule(target)
  449. except AttributeError as e:
  450. raise RuntimeError(
  451. f"Tried to normalize node with target {target} but root did not "
  452. f"have that target!"
  453. ) from e
  454. if hasattr(submod.__class__, "__name__"):
  455. classname = submod.__class__.__name__
  456. if getattr(torch.nn, classname, None) == submod.__class__:
  457. sig = inspect.signature(inspect.unwrap(submod.forward))
  458. if kwargs is None:
  459. kwargs = {}
  460. new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(
  461. sig, args, kwargs, normalize_to_only_use_kwargs
  462. )
  463. return new_args_and_kwargs
  464. return None
  465. def _args_kwargs_to_normalized_args_kwargs(
  466. sig: inspect.Signature,
  467. args: tuple[Any, ...],
  468. kwargs: dict[str, Any],
  469. normalize_to_only_use_kwargs: bool,
  470. ) -> Optional[ArgsKwargsPair]:
  471. """
  472. Given a call target, args, and kwargs, return the arguments normalized into
  473. an ArgsKwargsPair, or None if the type signature is not supported by
  474. this normalization.
  475. Args:
  476. sig (inspect.Signature): Signature object for the target
  477. args (Tuple): Arguments that appear at the callsite for `target`
  478. kwargs (Dict): Keyword arguments that appear at the callsite for `target`
  479. normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs.
  480. Returns:
  481. Optional[ArgsKwargsPair]: Normalized args and kwargs for `target`, or `None` if
  482. this target is not supported.
  483. """
  484. # Don't currently support positional-only
  485. # or varargs (*args, **kwargs) signatures
  486. supported_parameter_types = {
  487. inspect.Parameter.POSITIONAL_OR_KEYWORD,
  488. inspect.Parameter.KEYWORD_ONLY,
  489. }
  490. if any(p.kind not in supported_parameter_types for p in sig.parameters.values()):
  491. # Add an exception for one signature, which is common for random/uniform, i.e.:
  492. # Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None
  493. # `from` is Python keyword and as such functions with that signature should have
  494. # positional-only args, but at the same time they could be dispatched as kwargs
  495. if list(sig.parameters.keys()) != ["input", "from", "to", "generator"]:
  496. return None
  497. bound_args = sig.bind(*args, **kwargs)
  498. bound_args.apply_defaults()
  499. new_kwargs: dict[str, Any] = {}
  500. new_args: list[Any] = []
  501. for i, param in enumerate(sig.parameters):
  502. if not normalize_to_only_use_kwargs and i < len(args):
  503. new_args.append(bound_args.arguments[param])
  504. else:
  505. new_kwargs[param] = bound_args.arguments[param]
  506. return ArgsKwargsPair(tuple(new_args), new_kwargs)