bounds.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. import logging
  2. import operator
  3. from collections.abc import Callable
  4. from functools import partial
  5. from typing import Any, Optional, Union
  6. import sympy
  7. from sympy import Expr
  8. import torch
  9. from torch.utils._sympy.value_ranges import (
  10. bound_sympy,
  11. SymPyValueRangeAnalysis,
  12. ValueRanges,
  13. )
  14. from ..utils._sympy.functions import PowByNatural
  15. from ..utils._sympy.numbers import int_oo
  16. from .loop_body import InterpreterShim, LoopBody, LoopBodyBlock
  17. from .ops_handler import DefaultHandler, ReductionType, StoreMode
  18. from .utils import cache_on_self, dominated_nodes
  19. from .virtualized import V
  20. log = logging.getLogger(__name__)
  21. class BoundVars:
  22. """
  23. Performs Value Range Analysis on LoopBody's fx graph by calling BoundVars.run()
  24. It exposes the ranges of the nodes in the `bounds` variable
  25. Note. A current limitation of this analysis is that it just works on a per-loop basis.
  26. We should be able to propagate the bounds between across the whole graph. This may benefit
  27. the case a bounded variable is returned by a kernel and fed into another.
  28. """
  29. def __init__(self, loop_body: LoopBody) -> None:
  30. def upper_bound(v: Union[Expr, int]) -> int:
  31. return bound_sympy(v).upper if isinstance(v, Expr) else v
  32. self.loop_body = loop_body
  33. self.replacement_vals = {
  34. k: ValueRanges[Expr](0, upper_bound(v) - 1)
  35. for k, v in loop_body.var_ranges.items()
  36. }
  37. # avoid computing these values, pessimistically assume that they are unbounded
  38. self.unbounded_vars = dominated_nodes(
  39. node
  40. for node in self.loop_body.get_nodes()
  41. if node.target in ["load", "reduction", operator.getitem]
  42. or "masked_subblock" in node.target
  43. )
  44. # To access this variable call `get_bounds()`
  45. self._bounds: dict[torch.fx.Node, ValueRanges[Expr]] = {}
  46. def __repr__(self) -> str:
  47. return (
  48. f"{self.__class__.__name__}("
  49. f"loop_body={self.loop_body},\n "
  50. f"replacement_vals={self.replacement_vals}, \n"
  51. f"unbounded_vars={self.unbounded_vars}, \n"
  52. f"_bounds={self._bounds})"
  53. )
  54. @cache_on_self
  55. def get_bounds(self) -> dict[torch.fx.Node, ValueRanges[Expr]]:
  56. submodules = self.swap_submodules(self.loop_body.submodules)
  57. # Initialize the environment with the unbounded variables
  58. for node in self.unbounded_vars:
  59. # we need to evaluate masked_subblock to recurse, and we need to set indirect values
  60. if not isinstance(node.target, str) or (
  61. "masked_subblock" not in node.target
  62. and "set_indirect" not in node.target
  63. ):
  64. self._bounds[node] = ValueRanges[Expr].unknown()
  65. with V.set_ops_handler(ValueRangeAnalysis()):
  66. interpreter = InterpreterShim(self.loop_body.root_block.graph, submodules)
  67. log.debug("get_bounds:\n%s", self.loop_body.root_block.graph)
  68. interpreter.run(V.get_ops_handler(), initial_env=self._bounds)
  69. return self._bounds
  70. def swap_submodules(
  71. self, submodules: dict[str, Callable[..., Any]]
  72. ) -> dict[str, Callable[..., ValueRanges[Expr]]]:
  73. result: dict[str, Callable[..., ValueRanges[Expr]]] = {}
  74. for key in submodules:
  75. if key == "get_index":
  76. result[key] = self.get_index
  77. elif "masked_subblock" in key:
  78. subblock = self.loop_body.subblocks[key]
  79. # The result within the lambda will reference to the final
  80. # set of modules at the end of the for-loop as it stores a reference to it
  81. # bind subblock in a function because python lambdas close over by reference
  82. # moving the lambda out of make_fn would close over the reference to subblock,
  83. # so all lambdas would have the same subblock reference that is the final
  84. # subblock in the loop
  85. def make_fn(
  86. subblock: LoopBodyBlock,
  87. ) -> Callable[[Any, Any], ValueRanges[Expr]]:
  88. return lambda mask, value: self.masked_subblock(
  89. subblock, self._bounds, mask, value, result
  90. )
  91. result[key] = make_fn(subblock)
  92. elif "set_indirect" in key:
  93. idx = int(key[len("set_indirect") :])
  94. var = self.loop_body.indirect_vars[idx]
  95. indirect = partial(self.set_indirect, var)
  96. result[key] = indirect
  97. else:
  98. assert "scan" in key
  99. result[key] = submodules[key]
  100. return result
  101. def masked_subblock(
  102. self,
  103. subblock: LoopBodyBlock,
  104. env: dict[torch.fx.Node, ValueRanges[Expr]],
  105. mask: Any,
  106. value: Any,
  107. submodules: dict[str, Callable[..., Any]],
  108. ) -> ValueRanges[Expr]:
  109. interp = InterpreterShim(subblock.graph, submodules)
  110. interp.run(V.get_ops_handler(), initial_env=env)
  111. output = [node for node in subblock.graph.nodes if node.target == "output"]
  112. assert len(output) == 1
  113. # dont bother unioning with value since the load from buffer will be
  114. # pessimistically assumed to be inf anyway
  115. return interp.env[output[0]]
  116. def set_indirect(self, old: Expr, new: ValueRanges[Expr]) -> ValueRanges[Expr]:
  117. assert isinstance(new, ValueRanges)
  118. self.replacement_vals[old] = new
  119. return new
  120. def get_index(self, name: str) -> ValueRanges[Expr]:
  121. expr = self.loop_body.indexing_exprs[name]
  122. bound = self.replacement_vals.get(expr)
  123. if bound is None:
  124. bound = bound_sympy(expr, self.replacement_vals)
  125. # The following assertion is true at the time of this writing
  126. # We don't assert is as to not execute bound_sympy when bound is not None
  127. # assert bound is None or bound == bound_sympy(expr, self.replacement_vals)
  128. self.replacement_vals[name] = bound
  129. return bound
  130. class ValueRangeAnalysis(SymPyValueRangeAnalysis, DefaultHandler):
  131. def __init__(self) -> None:
  132. self.name = "ValueRangeAnalysis"
  133. boolean_operators = (
  134. "xor",
  135. "logical_and",
  136. "logical_or",
  137. "logical_not",
  138. )
  139. for op in boolean_operators:
  140. setattr(self, op, self.bool_handler)
  141. @staticmethod
  142. def bool_handler(*args: Any, **kwargs: Any) -> ValueRanges[Any]:
  143. # just assuming bools can have both values
  144. return ValueRanges(sympy.false, sympy.true) # type: ignore[arg-type]
  145. def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
  146. # many ops are unlikely to show up in optimizable indexing compute,
  147. # so we dont have full coverage
  148. return ValueRanges.unknown()
  149. def load(self, name: str, index: sympy.Expr) -> ValueRanges[Any]:
  150. return ValueRanges.unknown()
  151. def store(
  152. self, name: str, index: sympy.Expr, value: Any, mode: StoreMode = None
  153. ) -> None:
  154. return
  155. def reduction(
  156. self,
  157. dtype: torch.dtype,
  158. src_dtype: torch.dtype,
  159. reduction_type: ReductionType,
  160. value: Any,
  161. ) -> ValueRanges[Any]:
  162. return ValueRanges.unknown()
  163. @classmethod
  164. def index_expr(cls, index: Any, dtype: torch.dtype) -> ValueRanges[Any]:
  165. assert isinstance(index, ValueRanges)
  166. return cls.to_dtype(index, dtype)
  167. @staticmethod
  168. def to_dtype(
  169. x: Any,
  170. dtype: torch.dtype,
  171. src_dtype: Optional[torch.dtype] = None,
  172. use_compute_types: bool = True,
  173. ) -> ValueRanges[Any]:
  174. x = ValueRanges.wrap(x)
  175. if dtype == torch.bool:
  176. if x.is_singleton():
  177. return ValueRanges.wrap(x.lower != 0)
  178. elif x.is_bool:
  179. return x
  180. elif 0 not in x:
  181. return ValueRanges.wrap(sympy.true)
  182. else:
  183. return ValueRanges(sympy.false, sympy.true)
  184. def cast(x: Any, dtype: torch.dtype) -> sympy.Expr:
  185. # dtype is int or float
  186. if dtype.is_floating_point:
  187. return sympy.Float(x)
  188. else:
  189. if x in (int_oo, -int_oo):
  190. return x
  191. try:
  192. return sympy.Integer(x)
  193. except TypeError:
  194. # inf cannot be cast to Integer
  195. return x
  196. if x.is_bool:
  197. if x.is_singleton():
  198. val = 1 if x.lower else 0
  199. return ValueRanges.wrap(cast(val, dtype))
  200. else:
  201. return ValueRanges(cast(0, dtype), cast(1, dtype))
  202. else:
  203. # int to float or float to int
  204. return ValueRanges(cast(x.lower, dtype), cast(x.upper, dtype))
  205. @staticmethod
  206. def square(x: Any) -> ValueRanges[Any]:
  207. return ValueRanges.convex_min_zero_map(x, lambda y: PowByNatural(y, 2))
  208. @staticmethod
  209. def neg(x: Any) -> ValueRanges[Any]:
  210. return ValueRanges.decreasing_map(x, operator.neg)
  211. # TODO: this is slightly inaccurate because truncdiv operates at integer
  212. # precision, but we're going through float truediv which means we can
  213. # potentially lose precision on the bounds
  214. @classmethod
  215. def truncdiv(cls, a: Any, b: Any) -> ValueRanges[Any]:
  216. x = cls.truediv(a, b)
  217. if x == ValueRanges.unknown():
  218. return x
  219. return cls.trunc(x)
  220. @classmethod
  221. def sub(cls, a: Any, b: Any) -> ValueRanges[Any]:
  222. return cls.add(a, cls.neg(b))