utils.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659
  1. # mypy: allow-untyped-defs
  2. import dataclasses
  3. import inspect
  4. import sys
  5. from collections.abc import Callable, Iterable, Iterator
  6. from typing import Any, Literal, Optional, overload, Union
  7. import torch
  8. import torch.utils._pytree as pytree
  9. import torchgen
  10. from torch import _C, _utils_internal
  11. from torch._ops import OpOverload
  12. @dataclasses.dataclass
  13. class Kernel:
  14. """Models a (function, source location)"""
  15. func: Callable
  16. source: str
  17. def __call__(self, *args, **kwargs):
  18. return self.func(*args, **kwargs)
  19. class RegistrationHandle:
  20. """Does something when someone calls .destroy() on it"""
  21. def __init__(self, on_destroy: Callable):
  22. self._on_destroy = on_destroy
  23. def destroy(self) -> None:
  24. self._on_destroy()
  25. def get_source(stacklevel: int) -> str:
  26. """Get a string that represents the caller.
  27. Example: "/path/to/foo.py:42"
  28. Use stacklevel=1 to get the caller's source
  29. Use stacklevel=2 to get the caller's caller's source
  30. etc.
  31. """
  32. frame = inspect.getframeinfo(sys._getframe(stacklevel))
  33. source = f"{frame.filename}:{frame.lineno}"
  34. return source
  35. def parse_namespace(qualname: str) -> tuple[str, str]:
  36. splits = qualname.split("::")
  37. if len(splits) != 2:
  38. raise ValueError(
  39. f"Expected `qualname` to be of the form "
  40. f'"namespace::name", but got {qualname}. '
  41. f"The qualname passed to the torch.library APIs must consist "
  42. f"of a namespace and a name, e.g. aten::sin"
  43. )
  44. return splits[0], splits[1]
  45. def lookup_op(qualname: str) -> OpOverload:
  46. namespace, name = parse_namespace(qualname)
  47. if "." in name:
  48. name, overload = name.split(".")
  49. else:
  50. overload = "default"
  51. ns = getattr(torch.ops, namespace)
  52. packet = getattr(ns, name)
  53. return getattr(packet, overload)
  54. def is_builtin(op: OpOverload) -> bool:
  55. if not isinstance(op, OpOverload):
  56. raise AssertionError(f"op must be OpOverload, got {type(op)}")
  57. return op.namespace in {"aten", "prim", "prims"}
  58. def is_functional_schema(schema: Any, *, allow_valid_view: bool = False) -> bool:
  59. """Check if the schema is functional.
  60. An operator is functional if:
  61. - it does not mutate any of its inputs
  62. - If no view are allowed
  63. - it does not return a view on any of its inputs
  64. - If valid views are allowed
  65. - it is not a view or a view with a single input Tensor and single output Tensor
  66. - it has at least one return
  67. """
  68. def is_functional(schema):
  69. if schema.is_mutable:
  70. return False
  71. rets = schema.returns
  72. is_non_mutating_view = len(rets) > 0 and any(
  73. r.alias_info is not None and not r.alias_info.is_write for r in rets
  74. )
  75. num_tensor_inputs = 0
  76. num_tensor_outputs = 0
  77. if isinstance(schema, torch.FunctionSchema):
  78. for arg in schema.arguments:
  79. if isinstance(arg.type, torch.TensorType):
  80. num_tensor_inputs += 1
  81. for ret in schema.returns:
  82. if isinstance(ret.type, torch.TensorType):
  83. num_tensor_outputs += 1
  84. elif isinstance(schema, torchgen.model.FunctionSchema):
  85. for argument in schema.arguments.flat_non_out:
  86. if argument.type.is_tensor_like():
  87. num_tensor_inputs += 1
  88. for ret_arg in schema.returns:
  89. if ret_arg.type.is_tensor_like():
  90. num_tensor_outputs += 1
  91. if is_non_mutating_view:
  92. return allow_valid_view and (
  93. num_tensor_inputs == 1 and num_tensor_outputs == 1
  94. )
  95. if not schema.returns:
  96. return False
  97. return True
  98. if isinstance(schema, torch._C.FunctionSchema):
  99. return is_functional(schema)
  100. # Lazy import because not all PyTorch builds have torchgen
  101. from torchgen.model import FunctionSchema
  102. if isinstance(schema, str):
  103. schema = FunctionSchema.parse(schema)
  104. if not isinstance(schema, FunctionSchema):
  105. raise AssertionError(f"schema must be FunctionSchema, got {type(schema)}")
  106. return is_functional(schema)
  107. # should be torch._C.JitType but that annotation is busted
  108. def is_tensorlist_like_type(typ: Any) -> bool:
  109. return (
  110. typ == _C.ListType(_C.TensorType.get())
  111. or typ == _C.ListType(_C.OptionalType(_C.TensorType.get()))
  112. or typ == _C.OptionalType(_C.ListType(_C.TensorType.get()))
  113. or typ == _C.OptionalType(_C.ListType(_C.OptionalType(_C.TensorType.get())))
  114. )
  115. # should be torch._C.JitType but that annotation is busted
  116. def is_tensor_like_type(typ: Any) -> bool:
  117. return typ == _C.TensorType.get() or typ == _C.OptionalType(_C.TensorType.get())
  118. def mutates_and_returns_first_arg(op: OpOverload):
  119. """Check if an op is an inplace aten op, i.e. it mutates and returns the first arg.
  120. TODO: torchgen/model.py's FunctionSchema.parse is the source of truth for this,
  121. but not all PyTorch builds have torchgen (due to the yaml dependency being weird).
  122. Figure this out.
  123. Example: add_(Tensor(a!) x, Tensor y) -> Tensor(a)
  124. """
  125. if op.namespace != "aten":
  126. return False
  127. schema = op._schema
  128. if len(schema.returns) != 1:
  129. return False
  130. if schema.returns[0].alias_info is None:
  131. return False
  132. alias_set = schema.returns[0].alias_info.after_set
  133. if len(alias_set) != 1:
  134. return False
  135. loc = next(iter(alias_set))
  136. if len(schema.arguments) < 1:
  137. return False
  138. first_arg = schema.arguments[0]
  139. if first_arg.alias_info is None:
  140. return False
  141. if not first_arg.alias_info.is_write:
  142. return False
  143. alias_set = first_arg.alias_info.after_set
  144. if len(alias_set) != 1:
  145. return False
  146. if loc != next(iter(alias_set)):
  147. return False
  148. for arg in schema.arguments[1:]:
  149. if arg.alias_info is not None:
  150. return False
  151. return True
  152. def fill_defaults(schema, args, kwargs):
  153. new_args = []
  154. new_kwargs = {}
  155. for i in range(len(schema.arguments)):
  156. info = schema.arguments[i]
  157. if info.kwarg_only:
  158. if info.name in kwargs:
  159. new_kwargs[info.name] = kwargs[info.name]
  160. else:
  161. new_kwargs[info.name] = info.default_value
  162. else:
  163. if i < len(args):
  164. new_args.append(args[i])
  165. else:
  166. new_args.append(info.default_value)
  167. return tuple(new_args), new_kwargs
  168. def zip_schema(
  169. schema: _C.FunctionSchema, args: tuple[Any, ...], kwargs: dict[str, Any]
  170. ) -> Iterable[tuple[_C.Argument, Any]]:
  171. """zips schema.arguments and (args, kwargs) together.
  172. Assumes that (args, kwargs) were the inputs to some torch._ops.OpOverload:
  173. that is, (args, kwargs) must be bindable to the schema (args, kwargs).
  174. """
  175. if len(schema.arguments) < len(args) + len(kwargs):
  176. raise AssertionError(
  177. f"schema has {len(schema.arguments)} arguments but got {len(args)} args and {len(kwargs)} kwargs"
  178. )
  179. for i in range(len(schema.arguments)):
  180. info = schema.arguments[i]
  181. if info.kwarg_only:
  182. if info.name in kwargs:
  183. yield info, kwargs[info.name]
  184. continue
  185. if i >= len(args):
  186. if not info.kwarg_only and info.name in kwargs:
  187. yield info, kwargs[info.name]
  188. # args that are equal to their default values are not populated
  189. # if they are followed by args that are equal to their defaults.
  190. # Skip these.
  191. continue
  192. yield info, args[i]
  193. return
  194. def hop_schema_from_fx_node(node):
  195. from torchgen.gen_schema_utils import FunctionSchemaGen
  196. hop = node.target
  197. if not isinstance(hop, torch._ops.HigherOrderOperator):
  198. raise RuntimeError("fx_node's target must be a hop.")
  199. def _collect_example_val(node):
  200. meta_val = node.meta.get("val", None)
  201. if meta_val is None:
  202. if node.op != "get_attr":
  203. raise AssertionError(
  204. f"node.op must be 'get_attr' when val is None, got {node.op!r}"
  205. )
  206. meta_val = getattr(node.graph.owning_module, node.target)
  207. return meta_val
  208. example_inputs = []
  209. for arg in node.args:
  210. if isinstance(arg, (torch.fx.Node, torch.fx.node.Node)):
  211. example_inputs.append(_collect_example_val(arg))
  212. elif isinstance(
  213. arg, (torch.fx.immutable_collections.immutable_list, list, tuple)
  214. ):
  215. example_inputs.append([_collect_example_val(x) for x in arg])
  216. else:
  217. raise RuntimeError(f"Unsupported arg type {type(arg)}")
  218. # Bound the arguments to make sure number of inputs are correct
  219. bound_args: inspect.BoundArguments = inspect.signature(hop.__call__).bind(
  220. *example_inputs
  221. )
  222. # We treat example_output as a single value in return. This is to differentiate 1. return a single val
  223. # vs 2. return a tuple with one element.
  224. example_output = _collect_example_val(node)
  225. return FunctionSchemaGen.from_example(
  226. hop._name, tuple(bound_args.arguments.items()), (list(example_output),)
  227. )
  228. def can_generate_trivial_fake_impl(op: OpOverload) -> bool:
  229. if not isinstance(op, OpOverload):
  230. raise AssertionError(f"op must be OpOverload, got {type(op)}")
  231. if is_builtin(op):
  232. # We control the built-ins. These may (in rare cases)
  233. # do input metadata mutation (which we have banned on custom ops)
  234. return False
  235. schema = op._schema
  236. # It's suspicious if the op is not mutable but returns nothing, so we return False out of an abundance of caution
  237. if not schema.is_mutable:
  238. return False
  239. if len(schema.returns) > 0:
  240. return False
  241. # If the op returns nothing, then it has a trivial fake impl.
  242. return True
  243. def requires_set_python_module() -> bool:
  244. """If an op was defined in C++ and extended from Python using the
  245. torch.library APIs, returns if we require that there have been a
  246. m.set_python_module("mylib.ops") call from C++ that associates
  247. the C++ op with a python module.
  248. """
  249. return getattr(_utils_internal, "REQUIRES_SET_PYTHON_MODULE", True)
  250. def handle_dispatch_mode(curr_mode, op_overload, *args, **kwargs):
  251. if not isinstance(curr_mode, torch.utils._python_dispatch.TorchDispatchMode):
  252. raise AssertionError(
  253. f"curr_mode must be TorchDispatchMode, got {type(curr_mode)}"
  254. )
  255. args_flattened, _ = torch.utils._pytree.tree_flatten((args, kwargs.values()))
  256. # TODO: need to double check the semantics of the "types" argument to torch_dispatch.
  257. # It's generated in PyInterpreter.cpp, but seems to be generated in two places,
  258. # where in one case we only include tensors with the python key, and in another
  259. # we include **all** tensors.
  260. overload_types = [
  261. type(a)
  262. for a in args_flattened
  263. if isinstance(a, torch.Tensor)
  264. and torch._C._dispatch_keys(a).has(torch._C.DispatchKey.Python)
  265. ]
  266. # TODO: check that I got these args correct (in C++, we pass in "0000"??)
  267. return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
  268. def has_kwarg_only_args(schema: _C.FunctionSchema):
  269. return any(a.kwarg_only for a in schema.arguments)
  270. def has_kwarg_only_tensors(schema: _C.FunctionSchema):
  271. for a in schema.arguments:
  272. if not (is_tensor_like_type(a.type) or is_tensorlist_like_type(a.type)):
  273. continue
  274. if not a.kwarg_only:
  275. continue
  276. return True
  277. return False
  278. def has_tensor_arg(schema: _C.FunctionSchema) -> bool:
  279. """
  280. Given a schema, returns True if the schema has a Tensor arg.
  281. A Tensor arg is any arg with a type annotation that might involve Tensor.
  282. """
  283. return any(
  284. (is_tensor_like_type(a.type) or is_tensorlist_like_type(a.type))
  285. for a in schema.arguments
  286. )
  287. def get_device_arg_index(schema: _C.FunctionSchema) -> Union[int, None]:
  288. """
  289. Given a schema, returns the id of the `device: torch.device` argument.
  290. If it does not exist, returns None.
  291. """
  292. for index, arg in enumerate(schema.arguments):
  293. if arg.type is _C.DeviceObjType.get() and arg.name == "device":
  294. return index
  295. return None
  296. def iter_tensors(
  297. args: tuple[Any], kwargs: dict[str, Any], allowed_nesting: int = 1
  298. ) -> Iterator[torch.Tensor]:
  299. def check(arg):
  300. if isinstance(arg, torch.Tensor):
  301. yield arg
  302. elif allowed_nesting > 0 and isinstance(arg, (tuple, list)):
  303. yield from iter_tensors(tuple(arg), {}, allowed_nesting - 1)
  304. for arg in args:
  305. yield from check(arg)
  306. for kwarg in kwargs.values():
  307. yield from check(kwarg)
  308. def check_aliasing_constraint(name, prev, result, get_module=lambda: "???"):
  309. """
  310. custom operators' outputs must not alias any inputs or other outputs.
  311. """
  312. storages = {t.untyped_storage()._cdata for t in prev if isinstance(t, torch.Tensor)}
  313. tuple_result = result
  314. if not isinstance(result, tuple):
  315. tuple_result = (result,)
  316. for tensor in iter_tensors(tuple_result, {}):
  317. key = tensor.untyped_storage()._cdata
  318. if tensor.untyped_storage()._cdata in storages:
  319. raise RuntimeError(
  320. f"{name} (with implementation in {get_module()}): "
  321. f"The output of this custom operator (1) must not "
  322. f"also be an input to this custom operator and "
  323. f"(2) may not alias any inputs to this custom operator "
  324. f"or other returns. "
  325. f"The most common way to trigger this error is if "
  326. f"we have y = custom_op(x) and y and x are the same Tensor. "
  327. f"Please instead return a clone of the offending output "
  328. f"tensor(s) (e.g. return x.clone()) or refactor the custom "
  329. f"operator to not return y."
  330. )
  331. storages.add(key)
  332. def _c_check_aliasing_constraint(name, args, kwargs, result, get_module=lambda: "???"):
  333. """
  334. custom operators' outputs must not have any aliases
  335. This version uses C++ implementation for perf.
  336. Only List container is supported.
  337. Tensors in Lists with not only Tensors are checked.
  338. """
  339. tuple_result = result
  340. if not isinstance(result, tuple):
  341. tuple_result = (result,)
  342. if _C._any_output_is_alias_to_input_or_output(args, kwargs, tuple_result):
  343. raise RuntimeError(
  344. f"{name} (with implementation in {get_module()}): "
  345. f"The output of this custom operator (1) must not "
  346. f"also be an input to this custom operator and "
  347. f"(2) may not alias any inputs to this custom operator "
  348. f"or other returns. "
  349. f"The most common way to trigger this error is if "
  350. f"we have y = custom_op(x) and y and x are the same Tensor. "
  351. f"Please instead return a clone of the offending output "
  352. f"tensor(s) (e.g. return x.clone()) or refactor the custom "
  353. f"operator to not return y."
  354. )
  355. class MutationChecker:
  356. """
  357. Check if an operator mutated its arguments.
  358. Usage:
  359. checker = MutationChecker(op, flat_args, args_spec)
  360. op(*args, **kwargs)
  361. checker.check()
  362. """
  363. def __init__(self, op, flat_args, args_spec):
  364. self.op = op
  365. self.args_spec = args_spec
  366. self.flat_args = flat_args
  367. self.real_pre_hashes = [
  368. hash_tensor(a) if isinstance(a, torch.Tensor) else None for a in flat_args
  369. ]
  370. def check(self):
  371. real_post_hashes = [
  372. hash_tensor(a) if isinstance(a, torch.Tensor) else None
  373. for a in self.flat_args
  374. ]
  375. was_mutated = [
  376. not torch.equal(pre, post)
  377. and not (pre.isnan().all() and post.isnan().all())
  378. if isinstance(pre, torch.Tensor) and isinstance(post, torch.Tensor)
  379. else None
  380. for pre, post in zip(self.real_pre_hashes, real_post_hashes)
  381. ]
  382. was_mutated_args, was_mutated_kwargs = pytree.tree_unflatten(
  383. was_mutated, self.args_spec
  384. )
  385. for info, was_mutated in zip_schema(
  386. self.op._schema, was_mutated_args, was_mutated_kwargs
  387. ):
  388. def check_one(info, was_mutated):
  389. if info.is_write == was_mutated:
  390. return
  391. raise RuntimeError(
  392. f"{self.op._name}: for argument '{info.name}': the operator's schema "
  393. f"{self.op._schema} specified that "
  394. f"the operator {'mutates' if info.is_write else 'does not mutate'} "
  395. f"the argument, but this seems to be empirically wrong. "
  396. f"Please make the schema and operator behavior consistent. "
  397. f"You can specify that an operator mutates a Tensor by "
  398. f"e.g. changing its schema type from 'Tensor name' to 'Tensor(a!) name'"
  399. f"(use different identifiers (a, b, c, ...) for different Tensors)"
  400. )
  401. if is_tensor_like_type(info.type):
  402. check_one(info, was_mutated)
  403. elif is_tensorlist_like_type(info.type):
  404. was_any_mutated = False if was_mutated is None else any(was_mutated)
  405. check_one(info, was_any_mutated)
  406. def hash_tensor(t: torch.Tensor) -> torch.Tensor:
  407. """Some inexpensive hash. Used as a quick and dirty indicator for tensor mutation"""
  408. return t.detach().float().mean()
  409. def has_fake_kernel(op: torch._ops.OpOverload) -> bool:
  410. """If an operator (that stays alive until FakeTensorMode) has a Fake kernel.
  411. Don't use this if the operator decomposes before FakeTensorMode.
  412. """
  413. if can_generate_trivial_fake_impl(op):
  414. return True
  415. name = op._name
  416. if torch._C._dispatch_has_kernel_for_dispatch_key(
  417. name, "CompositeImplicitAutograd"
  418. ):
  419. return True
  420. opdef = torch._library.custom_ops._maybe_get_opdef(name)
  421. if opdef is None:
  422. # the non-torch.library.custom_op path
  423. if torch._C._dispatch_has_kernel_for_dispatch_key(
  424. name, "CompositeExplicitAutograd"
  425. ):
  426. return True
  427. entry = torch._library.simple_registry.singleton.find(name)
  428. if entry.fake_impl.kernel is not None:
  429. return True
  430. if torch._C._dispatch_has_kernel_for_dispatch_key(name, "Meta"):
  431. return True
  432. else:
  433. # the torch.library.custom_op path
  434. if opdef._abstract_fn is not None:
  435. return True
  436. return False
  437. def mutated_args_kwargs(schema: _C.FunctionSchema) -> tuple[list[int], list[str]]:
  438. idxs = []
  439. keys = []
  440. for i, info in enumerate(schema.arguments):
  441. if info.alias_info is not None and info.alias_info.is_write:
  442. if info.kwarg_only:
  443. keys.append(info.name)
  444. else:
  445. idxs.append(i)
  446. return idxs, keys
  447. tags_by_priority = [
  448. _C.Tag.needs_exact_strides,
  449. _C.Tag.needs_contiguous_strides,
  450. _C.Tag.needs_fixed_stride_order,
  451. _C.Tag.flexible_layout,
  452. ]
  453. # Case 1: with_default=True (or omitted). Return type is guaranteed to be a Tag.
  454. @overload
  455. def get_layout_constraint_tag(
  456. fn: Any, *, with_default: Literal[True] = True
  457. ) -> _C.Tag: ...
  458. # Case 2: with_default=False. Return type can be a Tag or None.
  459. @overload
  460. def get_layout_constraint_tag(
  461. fn: Any, *, with_default: Literal[False]
  462. ) -> Optional[_C.Tag]: ...
  463. def get_layout_constraint_tag(fn, *, with_default=True):
  464. for tag in tags_by_priority:
  465. if tag in fn.tags:
  466. return tag
  467. if with_default:
  468. if is_builtin(fn):
  469. return _C.Tag.flexible_layout
  470. import torch._functorch
  471. from torch._functorch import config
  472. return getattr(torch._C.Tag, config.custom_op_default_layout_constraint)
  473. return None
  474. # List of random functions that should be treated as impure
  475. _RANDOM_FUNCTIONS = {
  476. torch.rand,
  477. torch.randn,
  478. torch.randint,
  479. torch.randperm,
  480. torch.rand_like,
  481. torch.randn_like,
  482. torch.randint_like,
  483. torch.normal,
  484. torch.poisson,
  485. torch.bernoulli,
  486. torch.multinomial,
  487. }
  488. def is_impure(
  489. op: Callable,
  490. *,
  491. args: Optional[tuple[Any, ...]] = None,
  492. kwargs: Optional[dict[str, Any]] = None,
  493. impure_random: bool = True,
  494. ) -> bool:
  495. """
  496. An operator is impure if it:
  497. - Mutates its inputs (has a mutable schema)
  498. - Has nondeterministic/random behavior that mutates RNG state
  499. - Is explicitly marked as effectful via torch.library._register_effectful_op
  500. Args:
  501. op: The operator to check (function, OpOverload, HigherOrderOperator, etc.)
  502. args: Optional arguments that would be passed to the callable
  503. kwargs: Optional keyword arguments that would be passed to the callable
  504. impure_random: Whether to treat random operations as impure (default: True)
  505. Returns:
  506. bool: True if the callable has side effects, False otherwise
  507. """
  508. # Import here to avoid circular dependencies
  509. from torch._higher_order_ops.effects import _get_effect
  510. from torch.fx.node import _side_effectful_functions
  511. if isinstance(op, torch._ops.OpOverload):
  512. schema = getattr(op, "_schema", None)
  513. if schema is not None and schema.is_mutable:
  514. return True
  515. if op in _side_effectful_functions:
  516. return True
  517. if _get_effect(op) is not None:
  518. return True
  519. if isinstance(op, torch._ops.HigherOrderOperator):
  520. if op in (
  521. torch.ops.higher_order.auto_functionalized,
  522. torch.ops.higher_order.auto_functionalized_v2,
  523. ):
  524. # Check if the auto-functionalized operator (the first argument) is
  525. # side-effectful
  526. if args and len(args) > 0:
  527. return args[0] in _side_effectful_functions
  528. if _get_effect(op) is not None:
  529. return True
  530. return False
  531. # Impure since it mutates RNG state
  532. if impure_random and getattr(op, "_nondeterministic_seeded", False):
  533. return True
  534. # Handle Python random functions that don't have _nondeterministic_seeded
  535. # but still affect global RNG state (issue #151524)
  536. # These should be impure regardless of impure_random setting to maintain
  537. # consistency between eager and compiled execution
  538. # All random operations are impure to ensure consistent behavior
  539. # between eager and compiled execution, regardless of generator usage
  540. if op in _RANDOM_FUNCTIONS:
  541. return True
  542. schema = getattr(op, "_schema", None)
  543. if schema is not None and schema.is_mutable:
  544. return True
  545. if op in _side_effectful_functions:
  546. return True
  547. return False