symbol.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. # mypy: allow-untyped-defs
  2. """
  3. This file contains canonical definitions for our symbol naming conventions,
  4. across torch.fx.experimental.symbolic_shapes and torch._inductor. The
  5. intention is:
  6. 1. To make it easily greppable where all the sites we use a prefix are
  7. 2. Make it possible to easily tell if we can introduce a new prefix without
  8. introducing a conflict
  9. You can occasionally test if prefixes have been hardcoded by renaming prefixes
  10. in this file and seeing what breaks.
  11. """
  12. from collections.abc import Iterable
  13. from enum import auto, Enum
  14. import sympy
  15. class SymT(Enum):
  16. SIZE = auto()
  17. FLOAT = auto()
  18. UNBACKED_INT = auto()
  19. UNBACKED_FLOAT = auto()
  20. # Inductor: The intermediates in inner_fn tmp0, one generated per ops call.
  21. # If one of these shows up in an indexing expression, that means an
  22. # indirect load is happening.
  23. TMP = auto()
  24. # Inductor: Placeholder variable that is later replaced with TMP
  25. INDIRECT = auto()
  26. # Inductor: Some size expressions are replaced with a precomputed size ps0
  27. # which is computed host side, and then directly reused in the kernel, so
  28. # we don't repeatedly recompute it on device.
  29. PRECOMPUTED_SIZE = auto()
  30. # Inductor: An indexing variable i0 in loops IR which ranges over non-reduced
  31. # dim in the loop
  32. INDEX = auto()
  33. # Inductor: A reduction indexing (r0, r1) variables in loops IR which ranges over
  34. # reduced dim(s) in the loop
  35. R0_INDEX = auto()
  36. R1_INDEX = auto()
  37. # Inductor: In templated kernels torch._inductor.kernel, we have a hook to
  38. # store the final output and append epilogue fusions. To do this, we must
  39. # know what the indexes the outputs range over. NB: These will also
  40. # advertise as INDEX, this is... probably OK?
  41. TEMPLATE_INDEX = auto()
  42. # Inductor: iteration domain for blockIdx.x/blockIdx.y
  43. XBLOCK = auto()
  44. YBLOCK = auto()
  45. ZBLOCK = auto()
  46. # Inductor: this is used solely for dynamic_reshape_indexer
  47. VIEW = auto()
  48. # Alternate (non-modular) indexing used in halide kernels
  49. HALIDE = auto()
  50. # Invariant: there must not be a prefix which is a prefix of another string,
  51. # as this introduces ambiguity
  52. prefix_str = {
  53. SymT.SIZE: "s", # integer
  54. SymT.UNBACKED_INT: "u", # integer
  55. # Prefix z here is chosen to avoid false aliasing in symbol_is_type test
  56. # DO NOT add a "z" type. You also need to avoid conflicts on these
  57. # prefixes but this is somewhat easier to manage
  58. SymT.FLOAT: "zf",
  59. SymT.UNBACKED_FLOAT: "zuf",
  60. SymT.TMP: "tmp",
  61. SymT.PRECOMPUTED_SIZE: "ps",
  62. SymT.INDEX: "i",
  63. SymT.R0_INDEX: "r0_",
  64. SymT.R1_INDEX: "r1_",
  65. SymT.TEMPLATE_INDEX: "idx",
  66. SymT.XBLOCK: "x",
  67. SymT.YBLOCK: "y",
  68. SymT.ZBLOCK: "z",
  69. SymT.INDIRECT: "indirect", # false aliasing?
  70. SymT.VIEW: "view",
  71. SymT.HALIDE: "h",
  72. }
  73. def make_symbol(prefix: SymT, idx: int, **kwargs) -> sympy.Symbol:
  74. # TODO: maybe put the assumptions here directly
  75. return sympy.Symbol(f"{prefix_str[prefix]}{idx}", **kwargs)
  76. # This type is a little wider than it should be, because free_symbols says
  77. # that it contains Basic, rather than Symbol
  78. def symbol_is_type(sym: sympy.Basic, prefix: SymT | Iterable[SymT]) -> bool:
  79. if not isinstance(sym, sympy.Symbol):
  80. raise AssertionError("expected sympy.Symbol")
  81. name_str = sym.name.lower() # Match capitalized names like XBLOCK, RBLOCK
  82. if isinstance(prefix, SymT):
  83. return name_str.startswith(prefix_str[prefix])
  84. else:
  85. return name_str.startswith(tuple(prefix_str[p] for p in prefix))
  86. def free_symbol_is_type(e: sympy.Expr, prefix: SymT | Iterable[SymT]) -> bool:
  87. return any(symbol_is_type(v, prefix) for v in e.free_symbols)