interp.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. # mypy: allow-untyped-defs
  2. """
  3. This is a simple interpreter for Sympy expressions that dispatches to
  4. classes following the torch._inductor.virtualized calling convention.
  5. For directness, the interpreter takes the handler directly rather than
  6. consulting the TLS. It does not use most of the methods on the full
  7. handler; only those with corresponding Sympy expressions. To see an example
  8. of a full handler, see torch.utils._sympy.value_ranges.ValueRangeAnalysis.
  9. """
  10. import functools
  11. import logging
  12. from typing import Any
  13. import sympy
  14. from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom
  15. import torch
  16. from .functions import (
  17. BitwiseFn_bitwise_and,
  18. BitwiseFn_bitwise_or,
  19. BitwiseFn_bitwise_xor,
  20. CeilToInt,
  21. CleanDiv,
  22. FloatPow,
  23. FloatTrueDiv,
  24. FloorDiv,
  25. FloorToInt,
  26. Identity,
  27. IntTrueDiv,
  28. IsNonOverlappingAndDenseIndicator,
  29. Max,
  30. Min,
  31. Mod,
  32. ModularIndexing,
  33. OpaqueUnaryFn_log2,
  34. PowByNatural,
  35. PythonMod,
  36. RoundDecimal,
  37. RoundToInt,
  38. ToFloat,
  39. TruncToFloat,
  40. TruncToInt,
  41. Where,
  42. )
  43. log = logging.getLogger(__name__)
  44. # TODO: Dedupe this with SYMPY_INTERP
  45. @functools.cache
  46. def handlers():
  47. # TODO add CeilDiv (it doesn't appear in the index_expr)
  48. # TODO default to some decompositions if the interpreter doesn't have them
  49. # like decomposing ModularIndexing or implementing Le(a,b) as Ge(b, a)
  50. HANDLERS = {
  51. sympy.Or: "or_",
  52. sympy.And: "and_",
  53. sympy.Eq: "eq",
  54. sympy.Ne: "ne",
  55. sympy.Lt: "lt",
  56. sympy.Gt: "gt",
  57. sympy.Le: "le",
  58. sympy.Ge: "ge",
  59. sympy.Not: "not_",
  60. IntTrueDiv: "int_truediv",
  61. FloatTrueDiv: "truediv",
  62. FloorDiv: "floordiv",
  63. CleanDiv: "floordiv", # TODO: hmm?
  64. TruncToFloat: "trunc",
  65. Where: "where",
  66. sympy.Add: "add",
  67. sympy.Mul: "mul",
  68. FloatPow: "pow",
  69. PowByNatural: "pow_by_natural",
  70. # sympy simplifies x * x into Pow(x, 2), so we need to handle this.
  71. # Do NOT use builtin Pow for floats
  72. # TODO: There is a hazard here, if we have float * float it will
  73. # also get turned into Pow(float, 2) but we don't want this because
  74. # pow_by_natural is assumed to only be integers. Probably the fix is
  75. # to add a FloatMul to impede this optimization
  76. sympy.Pow: "pow_by_natural",
  77. Mod: "mod",
  78. PythonMod: "python_mod",
  79. # TODO: Inductor can generate these, but it's ill-specified which
  80. # semantics were intended here. Needs to be cleaned up along with
  81. # FloorDiv in a bigger cleanup
  82. sympy.Mod: "mod",
  83. sympy.Abs: "abs",
  84. sympy.log: "log",
  85. sympy.exp: "exp",
  86. sympy.Min: "minimum",
  87. sympy.Max: "maximum",
  88. Min: "minimum",
  89. Max: "maximum",
  90. ModularIndexing: "modular_indexing",
  91. sympy.functions.elementary.piecewise.ExprCondPair: "expr_cond_pair",
  92. sympy.Piecewise: "piecewise",
  93. Identity: "identity",
  94. IsNonOverlappingAndDenseIndicator: "is_non_overlapping_and_dense_indicator",
  95. RoundDecimal: "round_decimal",
  96. # TODO: do the rest of the opaque unary functions...
  97. OpaqueUnaryFn_log2: "log2",
  98. BitwiseFn_bitwise_and: "bitwise_and",
  99. BitwiseFn_bitwise_or: "bitwise_or",
  100. BitwiseFn_bitwise_xor: "bitwise_xor",
  101. }
  102. # TODO: This is kind of pointless, we shouldn't be generating sympy.sin
  103. # for these functions, they should be Opaque instead
  104. for name in ["cos", "sin", "tan", "sinh", "cosh", "tanh", "asin", "acos", "atan"]:
  105. HANDLERS[getattr(sympy, name)] = name
  106. return HANDLERS
  107. ASSOCIATIVE_OPS = {"minimum", "maximum", "mul", "add", "and_", "or_"}
  108. def _run_sympy_handler(analysis, args, expr, index_dtype=torch.int64):
  109. # Special cases
  110. if isinstance(expr, sympy.Pow) and isinstance(
  111. expr.args[1], sympy.core.numbers.Half
  112. ):
  113. return analysis.sqrt(args[0])
  114. if isinstance(expr, ToFloat):
  115. return analysis.to_dtype(args[0], torch.float64)
  116. # These handlers are special because they take an extra dtype argument
  117. # specifying what they should convert to, and we need to appropriately set
  118. # this up when we convert from Sympy. A reasonable default when you
  119. # are translating is to conservatively do int64, and then narrow these
  120. # arguments later when you discover you can narrow the index range. But
  121. # if you already know that 32-bit indexing is OK, you can directly do the
  122. # sympy translation with index_dtype=torch.int32
  123. INDEX_DTYPE_HANDLERS = {
  124. TruncToInt: "trunc_to_int",
  125. sympy.floor: "floor_to_int",
  126. sympy.ceiling: "ceil_to_int",
  127. FloorToInt: "floor_to_int",
  128. CeilToInt: "ceil_to_int",
  129. RoundToInt: "round_to_int",
  130. }
  131. if (handler_name := INDEX_DTYPE_HANDLERS.get(expr.func)) is not None:
  132. return getattr(analysis, handler_name)(*args, index_dtype)
  133. # Fastpath for n-ary integral addition
  134. if expr.func is sympy.Add and expr.is_integer and hasattr(analysis, "sym_sum"):
  135. r = analysis.sym_sum(args)
  136. log.debug("sym_sum(%s) -> %s", args, r)
  137. return r
  138. if hasattr(expr.func, "_torch_handler_name"):
  139. handler_name = expr.func._torch_handler_name
  140. else:
  141. handler_name = handlers()[expr.func]
  142. handler = getattr(analysis, handler_name)
  143. try:
  144. if handler_name in ASSOCIATIVE_OPS:
  145. if len(args) <= 1:
  146. raise AssertionError("associative op needs >1 args")
  147. acc = handler(args[0], args[1])
  148. for i in range(2, len(args)):
  149. acc = handler(acc, args[i])
  150. log.debug("%s(%s) -> %s", handler_name, args, acc)
  151. return acc
  152. else:
  153. r = handler(*args)
  154. log.debug("%s(%s) -> %s", handler_name, args, r)
  155. return r
  156. except NotImplementedError:
  157. raise
  158. except Exception:
  159. log.warning("failed while executing %s(%s)", handler_name, args)
  160. raise
  161. _nil = object()
  162. def sympy_interp(
  163. analysis,
  164. env: dict[sympy.Symbol, Any],
  165. expr: sympy.Expr | SympyBoolean,
  166. *,
  167. index_dtype=torch.int64,
  168. missing_handler=None,
  169. ):
  170. # Handle base cases
  171. dtype = None
  172. if isinstance(expr, BooleanAtom):
  173. dtype = torch.bool
  174. elif isinstance(expr, sympy.Integer):
  175. dtype = torch.int64
  176. elif isinstance(expr, sympy.Number):
  177. dtype = torch.double
  178. if dtype is not None:
  179. return analysis.constant(expr, dtype)
  180. elif isinstance(expr, sympy.Symbol):
  181. if (r := env.get(expr, _nil)) is not _nil:
  182. return r
  183. elif missing_handler:
  184. return missing_handler(expr)
  185. else:
  186. raise KeyError(expr)
  187. # Recursive case
  188. return _run_sympy_handler(
  189. analysis,
  190. [
  191. sympy_interp(
  192. analysis,
  193. env,
  194. arg,
  195. index_dtype=index_dtype,
  196. missing_handler=missing_handler,
  197. )
  198. for arg in expr.args
  199. ],
  200. expr,
  201. index_dtype=index_dtype,
  202. )