_decompositions.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. # mypy: allow-untyped-defs
  2. import torch
  3. from torch import Tensor
  4. aten = torch.ops.aten
  5. import inspect
  6. import warnings
  7. from collections.abc import Callable
  8. from typing import Optional, TypeVar
  9. from typing_extensions import ParamSpec
  10. from torch.types import Number
  11. decomposition_table: dict[str, torch.jit.ScriptFunction] = {}
  12. function_name_set: set[str] = set()
  13. _T = TypeVar("_T")
  14. _P = ParamSpec("_P")
  15. def check_decomposition_has_type_annotations(f) -> None:
  16. inspect_empty = inspect._empty # type: ignore[attr-defined]
  17. sig = inspect.signature(f)
  18. for param in sig.parameters.values():
  19. if param.annotation == inspect_empty:
  20. raise AssertionError(
  21. f"No signature on param {param.name} for function {f.name}"
  22. )
  23. if sig.return_annotation == inspect_empty:
  24. raise AssertionError(f"No return annotation for function {f.name}")
  25. def signatures_match(decomposition_sig, torch_op_sig):
  26. decomp_params = decomposition_sig.parameters
  27. op_params = torch_op_sig.parameters
  28. if len(decomp_params) != len(op_params):
  29. return False
  30. for decomp_param, op_param in zip(decomp_params.values(), op_params.values()):
  31. # can't check full equality yet because not all fields are correctly deduced
  32. # in the torch_op_sig - like default value
  33. # can't check 'kind' bc
  34. # kwarg-only values with defaults not yet supported in TS
  35. inspect_empty = inspect._empty # type: ignore[attr-defined]
  36. for field in ["name", "annotation"]:
  37. if field == "name" and decomp_param.name == "self":
  38. warnings.warn(
  39. "PyTorch uses 'input' instead of 'self' on public api", stacklevel=2
  40. )
  41. if getattr(decomp_param, field) != getattr(op_param, field):
  42. return False
  43. decomp_default = decomp_param.default
  44. op_default = op_param.default
  45. # default value not always correctly inferred as being present on torch schema,
  46. # but if specified on both they should be equal
  47. if decomp_default != inspect_empty and op_default != inspect_empty:
  48. if decomp_default != op_default:
  49. return False
  50. return decomposition_sig.return_annotation == torch_op_sig.return_annotation
  51. def register_decomposition(
  52. aten_op: torch._ops.OpOverload,
  53. registry: Optional[dict[str, torch.jit.ScriptFunction]] = None,
  54. ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
  55. def decomposition_decorator(f: Callable[_P, _T]) -> Callable[_P, _T]:
  56. nonlocal registry
  57. if registry is None:
  58. registry = decomposition_table
  59. if not isinstance(aten_op, torch._ops.OpOverload):
  60. raise AssertionError(
  61. f"Expected aten_op to be OpOverload, got {type(aten_op)}"
  62. )
  63. # Need unique name for jit function serialization
  64. if f.__name__ in function_name_set:
  65. raise AssertionError(f"Duplicated function name {f.__name__}")
  66. function_name_set.add(f.__name__)
  67. scripted_func = torch.jit.script(f)
  68. torch._C._jit_pass_inline(scripted_func.graph)
  69. for _ in range(2):
  70. torch._C._jit_pass_peephole(scripted_func.graph)
  71. torch._C._jit_pass_constant_propagation(scripted_func.graph)
  72. registry[str(aten_op._schema)] = scripted_func
  73. return f
  74. return decomposition_decorator
  75. # TODO: replace torch.sigmoid -> aten.sigmoid
  76. @register_decomposition(aten.var.correction)
  77. def var_decomposition(
  78. input: Tensor,
  79. dim: Optional[list[int]] = None,
  80. correction: Optional[Number] = None,
  81. keepdim: bool = False,
  82. ) -> Tensor:
  83. if dim is None:
  84. dim_i: list[int] = []
  85. dim = dim_i
  86. if isinstance(dim, (tuple, list)) and len(dim) == 0:
  87. n = input.numel()
  88. else:
  89. n = 1
  90. for dim_i in dim: # type: ignore[assignment]
  91. n *= input.shape[dim_i] # type: ignore[call-overload]
  92. mean = aten.mean(input, dim, True)
  93. sub = input - mean
  94. sq = sub * sub
  95. sum = aten.sum(sq, dim, keepdim)
  96. if correction is None:
  97. denom = float(n - 1)
  98. else:
  99. if isinstance(correction, int):
  100. denom = float(n - correction)
  101. elif isinstance(correction, float):
  102. denom = float(n) - correction
  103. else:
  104. raise RuntimeError("correction must be int or float")
  105. # pyrefly: ignore [no-matching-overload]
  106. return sum / max(0, denom)
  107. @register_decomposition(aten.var.default)
  108. def var(input: Tensor, unbiased: bool = True) -> Tensor:
  109. return var_decomposition(input, correction=(1 if unbiased else 0))