_decomposition_utils.py 510 B

1234567891011121314
  1. # mypy: allow-untyped-defs
  2. import torch
  3. from torch._ops import OpOverload, OpOverloadPacket
  4. def _register_decomposition(op: OpOverload, graph: torch._C.Graph) -> None:
  5. if isinstance(op, OpOverloadPacket):
  6. raise AssertionError(
  7. f"Must pass specific op overload, not overload packet, found {op}"
  8. )
  9. if not isinstance(op, OpOverload):
  10. raise AssertionError(f"Expected OpOverload, got {type(op)}")
  11. torch._C._jit_register_decomposition_for_schema(op._schema, graph)