dependencies.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890
  1. import abc
  2. import dataclasses
  3. import itertools
  4. import logging
  5. import re
  6. from collections.abc import Callable, Iterable, Sequence
  7. from typing import Any, Optional, TypeVar, Union
  8. from typing_extensions import Self
  9. from unittest.mock import patch
  10. import sympy
  11. import torch
  12. from torch._inductor.utils import get_free_symbols
  13. from torch.fx.experimental.symbolic_shapes import free_symbols, free_unbacked_symbols
  14. from torch.utils._ordered_set import OrderedSet
  15. from ..utils._sympy.symbol import make_symbol, SymT
  16. from .codegen.common import index_prevent_reordering
  17. from .ops_handler import DefaultHandler
  18. from .utils import (
  19. get_dtype_size,
  20. reduction_num_outputs,
  21. sympy_index_symbol,
  22. sympy_subs,
  23. VarRanges,
  24. )
  25. from .virtualized import ReductionType, V
  26. T = TypeVar("T")
  27. log = logging.getLogger(__name__)
  28. is_indirect = re.compile(r"indirect|tmp").search
  29. class Dep(abc.ABC):
  30. name: str
  31. index: sympy.Expr
  32. @abc.abstractmethod
  33. def get_free_symbol_uses(
  34. self, unbacked_only: bool = False
  35. ) -> OrderedSet[sympy.Symbol]:
  36. pass
  37. @abc.abstractmethod
  38. def rename(self, renames: dict[str, str]) -> Self:
  39. pass
  40. @abc.abstractmethod
  41. def get_numel(self) -> sympy.Expr:
  42. pass
  43. @abc.abstractmethod
  44. def numbytes_hint(self) -> int:
  45. pass
  46. @abc.abstractmethod
  47. def numel_hint(self) -> int:
  48. pass
  49. @abc.abstractmethod
  50. def has_unbacked_symbols(self) -> bool:
  51. pass
  52. @abc.abstractmethod
  53. def is_contiguous(self) -> bool:
  54. pass
  55. def normalize_with_stride_order(self, prefix: str = "t") -> Self:
  56. return self
  57. @dataclasses.dataclass(frozen=True)
  58. class MemoryDep(Dep):
  59. # pyrefly: ignore [bad-override]
  60. name: str
  61. # pyrefly: ignore [bad-override]
  62. index: sympy.Expr
  63. var_names: tuple[sympy.Symbol, ...]
  64. size: tuple[sympy.Expr, ...]
  65. mode: Optional[str] = None
  66. def get_free_symbol_uses(
  67. self, unbacked_only: bool = False
  68. ) -> OrderedSet[sympy.Symbol]:
  69. return (
  70. get_free_symbols(self.index, unbacked_only)
  71. | get_free_symbols(self.size, unbacked_only)
  72. | get_free_symbols(self.var_names, unbacked_only)
  73. )
  74. def __repr__(self) -> str:
  75. maybe_mode = ""
  76. if self.mode is not None:
  77. maybe_mode = f", {self.mode}"
  78. return f"MemoryDep({self.name!r}, {self.index}, {self.ranges}{maybe_mode})"
  79. @property
  80. def num_vars(self) -> int:
  81. return len(self.var_names)
  82. def decide_loop_order_to_match(self, other: "MemoryDep") -> Optional[list[int]]:
  83. """
  84. Can return None if not able to decide loop orders.
  85. """
  86. assert self.num_vars == other.num_vars
  87. # ignore broadcast for now since broadcast causes extra 0 strides
  88. # which makes it hard to decide the correct loop orders.
  89. if self.num_vars != len(self.index.free_symbols):
  90. return None
  91. if other.num_vars != len(other.index.free_symbols):
  92. return None
  93. # bail out if any size is 0 or 1
  94. # For size == 0, it's an empty tensor, any strides for that dimension
  95. # are equivalent. Skip for simplicity and it may not matter that much.
  96. #
  97. # For size == 1, it cause cause tie for strides of different dimensions.
  98. # Also when we first time create LoopBody in ComputedBuffer.simplify_and_reorder
  99. # we can dependencies.index_vars_squeeze which should already sqeeuze
  100. # the size == 1 dimensions.
  101. if any(s == 0 or s == 1 for s in itertools.chain(self.size, other.size)):
  102. return None
  103. # Extract strides for both expression
  104. self_strides = V.graph.sizevars.stride_hints(self.index, self.var_names)
  105. other_strides = V.graph.sizevars.stride_hints(other.index, other.var_names)
  106. # Even if the shape contains no 0/1, some complex index expression may
  107. # still have duplicate stride values. Here is an example:
  108. # https://gist.github.com/shunting314/511a7e1ec88aa2e1a8ec85d8445ab129
  109. # We don't reorder the loop for these cases for now, but in theory
  110. # we could improve the algorithm to detect the correct loop orders.
  111. if len(OrderedSet(self_strides)) != len(self_strides) or len(
  112. OrderedSet(other_strides)
  113. ) != len(other_strides):
  114. log.debug(
  115. "unable to decide loop order. self_dep=%s v.s. other_dep=%s, self_strides=%s v.s. other_strides=%s",
  116. self,
  117. other,
  118. self_strides,
  119. other_strides,
  120. )
  121. return None
  122. # May happen if self and other are as follows
  123. # MemoryDep('addmm_6', 393216*d0 + 768*d1 + d2, {d0: 16, d1: 512, d2: 768}, None)
  124. # MemoryDep('addmm_6', 98304*d0 + d1 + 768*d2, {d0: 64, d1: 768, d2: 128}, None)
  125. if OrderedSet(self_strides) != OrderedSet(other_strides):
  126. return None
  127. stride_to_index = {s: i for i, s in enumerate(self_strides)}
  128. order = [stride_to_index[s] for s in other_strides]
  129. assert OrderedSet(order) == OrderedSet(range(self.num_vars))
  130. return order
  131. def get_offset(self) -> sympy.Expr:
  132. """
  133. Return the offset by setting every variable to be 0.
  134. """
  135. return sympy_subs(self.index, dict.fromkeys(self.var_names, 0))
  136. def normalize(self) -> "MemoryDep":
  137. """
  138. Normalize by merging loops. The different to normalize_with_stride_order is,
  139. this method does not reorder loops while normalize_with_stride_order reorder
  140. loops based on stride order.
  141. """
  142. return MemoryDep(
  143. self.name,
  144. *_RecordLoadStoreInner._normalize(self.index, self.ranges), # type: ignore[arg-type]
  145. self.mode,
  146. )
  147. def normalize_with_stride_order(self, prefix: str = "t") -> "MemoryDep":
  148. r"""
  149. Used to decide if two MemoryDep does not equal due to different loop orders.
  150. More specifically, when dep1 and dep2 are not equal, we can normalize
  151. both and check if they are equal after that. If yes, then the mismatch is
  152. caused by different loop orders.
  153. """
  154. # import here to avoid circular import
  155. from torch._inductor import ir
  156. strides = V.graph.sizevars.stride_hints(self.index, self.var_names)
  157. # pick a loop order with stride ordered decreasingly
  158. order = sorted(range(len(strides)), key=strides.__getitem__, reverse=True)
  159. stride_reorder = ir.same_reorder(order)
  160. sizes = self.size
  161. var_names = self.var_names
  162. new_reordered_sizes = stride_reorder(sizes)
  163. new_reordered_var_names = stride_reorder(var_names)
  164. new_simplified_sizes, reindex, _prune = V.graph.sizevars._simplify_loops(
  165. new_reordered_var_names,
  166. new_reordered_sizes,
  167. index_prevent_reordering(
  168. [self.index], new_reordered_var_names, new_reordered_sizes
  169. ),
  170. )
  171. # now let's create new symbols with the passed in prefix
  172. var_ranges, add_var = var_builder(prefix)
  173. replacement = dict(
  174. zip(
  175. new_reordered_var_names,
  176. reindex([add_var(x) for x in new_simplified_sizes]),
  177. )
  178. )
  179. new_index = sympy_subs(sympy.expand(self.index), replacement) # type: ignore[arg-type] # next PR
  180. out = MemoryDep(
  181. self.name, new_index, tuple(var_ranges.keys()), tuple(var_ranges.values())
  182. ) # type: ignore[arg-type]
  183. return out
  184. @property
  185. def ranges(self) -> dict[sympy.Symbol, sympy.Expr]:
  186. """{c0: 128, c1: 512, ...}"""
  187. return dict(zip(self.var_names, self.size))
  188. def simplify_with_ranges(self) -> "MemoryDep":
  189. return MemoryDep(
  190. name=self.name,
  191. index=V.graph.sizevars.simplify_with_ranges(self.index, self.ranges),
  192. var_names=self.var_names,
  193. size=self.size,
  194. mode=self.mode,
  195. )
  196. def get_numel(self) -> sympy.Expr:
  197. if self.is_indirect():
  198. numel = V.graph.get_numel(self.name)
  199. else:
  200. vars: OrderedSet[sympy.Basic] = OrderedSet(self.index.free_symbols)
  201. numel = sympy.S.One
  202. for var, size in zip(self.var_names, self.size):
  203. if var in vars:
  204. numel = numel * size
  205. return numel # type: ignore[return-value]
  206. def rename(self, renames: dict[str, str]) -> "MemoryDep":
  207. if self.name in renames:
  208. return MemoryDep(
  209. renames[self.name],
  210. self.index,
  211. var_names=self.var_names,
  212. size=self.size,
  213. mode=self.mode,
  214. )
  215. return self
  216. def numbytes_hint(self) -> int:
  217. try:
  218. return V.graph.sizevars.optimization_hint(
  219. self.get_numel(), fallback=0
  220. ) * get_dtype_size(V.graph.get_dtype(self.name))
  221. except NotImplementedError: # NoneLayout
  222. return 0
  223. def numel_hint(self) -> int:
  224. try:
  225. return V.graph.sizevars.optimization_hint(self.get_numel(), fallback=0)
  226. except NotImplementedError: # NoneLayout
  227. return 0
  228. def has_unbacked_symbols(self) -> bool:
  229. return len(free_unbacked_symbols(self.get_numel())) > 0
  230. def is_contiguous(self) -> bool:
  231. if isinstance(self.index, sympy.Integer):
  232. return True
  233. return isinstance(self.index, sympy.Symbol) and self.index in self.var_names
  234. def stride1_for_last_dim(self, result_for_complex_expression: bool = True) -> bool:
  235. """
  236. Whether the stride for the last dimension is 1.
  237. """
  238. # python test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_masked_scatter_cuda_float16
  239. # will exercise thru this corner case.
  240. if len(self.var_names) == 0:
  241. return True
  242. terms = self.index.args if isinstance(self.index, sympy.Add) else [self.index]
  243. last_sym = self.var_names[-1]
  244. for term in terms:
  245. if term == last_sym:
  246. return True
  247. # Having a >1 stride for the last dimension is bad for perf
  248. # return False.
  249. if (
  250. isinstance(term, sympy.Mul)
  251. and len(term.args) == 2
  252. and term.args[1] == last_sym
  253. and isinstance(term.args[0], (int, sympy.Integer))
  254. and term.args[0] > 1
  255. ):
  256. return False
  257. return result_for_complex_expression
  258. def is_scalar(self) -> bool:
  259. if isinstance(self.index, sympy.Symbol):
  260. return self.index not in self.var_names and not self.is_indirect()
  261. return isinstance(self.index, (int, sympy.Integer))
  262. def is_indirect(self) -> bool:
  263. return any(is_indirect(v.name) for v in self.index.free_symbols) # type: ignore[attr-defined]
  264. @dataclasses.dataclass(frozen=True)
  265. class StarDep(Dep):
  266. # pyrefly: ignore [bad-override]
  267. name: str
  268. mode: Optional[str] = None
  269. # depends on the entire buffer
  270. @property
  271. # pyrefly: ignore [bad-override]
  272. def index(self) -> sympy.Expr:
  273. raise NotImplementedError("StarDep does not have an index")
  274. def get_numel(self) -> sympy.Expr:
  275. return V.graph.get_numel(self.name) # type: ignore[return-value]
  276. def rename(self, renames: dict[str, str]) -> "StarDep":
  277. if self.name in renames:
  278. return StarDep(renames[self.name], self.mode)
  279. return self
  280. def get_free_symbol_uses(
  281. self, unbacked_only: bool = False
  282. ) -> OrderedSet[sympy.Symbol]:
  283. return OrderedSet()
  284. def numbytes_hint(self) -> int:
  285. try:
  286. return V.graph.sizevars.optimization_hint(
  287. self.get_numel(), fallback=0
  288. ) * get_dtype_size(V.graph.get_dtype(self.name))
  289. except NotImplementedError:
  290. return 0 # NoneLayout, MultiOutputLayout, etc
  291. def numel_hint(self) -> int:
  292. try:
  293. return V.graph.sizevars.optimization_hint(self.get_numel(), fallback=0)
  294. except NotImplementedError:
  295. return 0 # NoneLayout, MultiOutputLayout, etc
  296. def has_unbacked_symbols(self) -> bool:
  297. return len(free_unbacked_symbols(self.get_numel())) > 0
  298. def is_contiguous(self) -> bool:
  299. return False
  300. def is_scalar(self) -> bool:
  301. return False
  302. def is_indirect(self) -> bool:
  303. return False
  304. # Used for tracking mutation ordering
  305. # if A reads a buffer and B mutates it
  306. # B must be ordered after A
  307. #
  308. # This is useful for a variety of reasons.
  309. # For example, if A's read is never actually used, we can eliminate it.
  310. # Another case is if A's buffer ends up being fused away, we never need to
  311. # materialize that buffer
  312. @dataclasses.dataclass(frozen=True)
  313. class WeakDep(Dep):
  314. # Fake dependency on unused buffer
  315. # pyrefly: ignore [bad-override]
  316. name: str
  317. # Buffer that is doing the mutation
  318. mutating_buf: str
  319. # WeakDep's are also used to add dependencies to prevent some specific reordering,
  320. # E.g. collectives global ordering.
  321. # But if other pass guarantees proper ordering by its logic,
  322. # This additional "fake" deps will be holding optimizations.
  323. # This flag is used to identify those additional deps.
  324. is_fake: bool = False
  325. def get_free_symbol_uses(
  326. self, unbacked_only: bool = False
  327. ) -> OrderedSet[sympy.Symbol]:
  328. return OrderedSet()
  329. @property
  330. # pyrefly: ignore [bad-override]
  331. def index(self) -> sympy.Expr:
  332. raise NotImplementedError("WeakDep does not have an index")
  333. def get_numel(self) -> sympy.Expr:
  334. return sympy.S.One
  335. def rename(self, renames: dict[str, str]) -> "WeakDep":
  336. if self.name in renames:
  337. return WeakDep(renames[self.name], self.mutating_buf, self.is_fake)
  338. return self
  339. def numbytes_hint(self) -> int:
  340. return 1 # Purely inserted for ordering, not an actual dep
  341. def numel_hint(self) -> int:
  342. return 1 # Purely inserted for ordering, not an actual dep
  343. def has_unbacked_symbols(self) -> bool:
  344. return False
  345. def is_contiguous(self) -> bool:
  346. return False
  347. @dataclasses.dataclass(frozen=True)
  348. class IndexExprDep:
  349. index: sympy.Expr # type: ignore[assignment]
  350. var_names: tuple[sympy.Symbol, ...]
  351. size: tuple[sympy.Expr, ...]
  352. @dataclasses.dataclass
  353. class ReadWrites:
  354. reads: OrderedSet[Dep]
  355. writes: OrderedSet[Dep]
  356. index_exprs: OrderedSet[IndexExprDep]
  357. range_vars: Optional[list[sympy.Expr]] = None
  358. var_ranges: Optional[VarRanges] = None
  359. def rename(self, renames: dict[str, str]) -> "ReadWrites":
  360. return ReadWrites(
  361. OrderedSet(dep.rename(renames) for dep in self.reads),
  362. OrderedSet(dep.rename(renames) for dep in self.writes),
  363. self.index_exprs,
  364. self.range_vars,
  365. self.var_ranges,
  366. )
  367. def with_read(self, dep: Union[Dep, OrderedSet[Dep]]) -> "ReadWrites":
  368. assert isinstance(dep, (WeakDep, StarDep, OrderedSet))
  369. if not isinstance(dep, OrderedSet):
  370. dep = OrderedSet([dep])
  371. return ReadWrites(
  372. OrderedSet.union(self.reads, dep),
  373. self.writes,
  374. self.index_exprs,
  375. self.range_vars,
  376. self.var_ranges,
  377. )
  378. def merge(self, other: "ReadWrites") -> "ReadWrites":
  379. reads = OrderedSet.union(self.reads, other.reads)
  380. writes = OrderedSet.union(self.writes, other.writes)
  381. index_exprs = OrderedSet.union(self.index_exprs, other.index_exprs)
  382. return ReadWrites(reads - writes, writes, index_exprs)
  383. @staticmethod
  384. def merge_list(read_writes: list["ReadWrites"]) -> "ReadWrites":
  385. all_writes = OrderedSet.union(*[rw.writes for rw in read_writes])
  386. all_reads = OrderedSet.union(*[rw.reads for rw in read_writes]) - all_writes
  387. all_index_exprs = OrderedSet.union(*[rw.index_exprs for rw in read_writes])
  388. return ReadWrites(all_reads, all_writes, all_index_exprs)
  389. def remove_reads(self, rem_reads: OrderedSet[Dep]) -> "ReadWrites":
  390. return ReadWrites(
  391. self.reads - rem_reads,
  392. self.writes,
  393. self.index_exprs,
  394. self.range_vars,
  395. self.var_ranges,
  396. )
  397. def reads_and_writes(self) -> Iterable[Dep]:
  398. return itertools.chain(self.reads, self.writes)
  399. def buffer_names(self, ignore_integer_index: bool = True) -> OrderedSet[str]:
  400. """
  401. Integer index is used for load_seed.
  402. """
  403. names: OrderedSet[str] = OrderedSet()
  404. for dep in self.reads_and_writes():
  405. if not isinstance(dep, MemoryDep):
  406. continue
  407. if not ignore_integer_index or not isinstance(
  408. dep.index, (int, sympy.Integer)
  409. ):
  410. names.add(dep.name)
  411. return names
  412. def get_free_symbol_uses(
  413. self, unbacked_only: bool = False
  414. ) -> OrderedSet[sympy.Symbol]:
  415. result: OrderedSet[sympy.Symbol] = OrderedSet()
  416. for dep in self.reads_and_writes():
  417. result |= dep.get_free_symbol_uses(unbacked_only)
  418. return result
  419. class _RecordLoadStoreInner(V.MockHandler): # type: ignore[name-defined]
  420. def __init__(self, var_ranges: VarRanges, normalize: bool) -> None:
  421. super().__init__()
  422. self._reads: OrderedSet[Dep] = OrderedSet()
  423. self._writes: OrderedSet[MemoryDep] = OrderedSet()
  424. self._index_exprs: OrderedSet[IndexExprDep] = OrderedSet()
  425. self._var_ranges: VarRanges = var_ranges
  426. self._should_normalize: bool = normalize
  427. @staticmethod
  428. def drop_unused_symbols(
  429. index: Union[int, sympy.Expr],
  430. var_names: list[sympy.Expr],
  431. sizes: list[sympy.Expr],
  432. ) -> None:
  433. """
  434. Reduction has last (reduced) dim in its sizes, but
  435. downstream users won't. Normalize this away.
  436. """
  437. if not isinstance(index, sympy.Expr):
  438. # index can be an int
  439. return
  440. free_symbols = index.free_symbols
  441. while var_names and var_names[-1] not in free_symbols:
  442. var_names.pop()
  443. sizes.pop()
  444. @classmethod
  445. def _normalize(
  446. cls, index: sympy.Expr, var_ranges: VarRanges
  447. ) -> tuple[sympy.Expr, tuple[sympy.Symbol, ...], tuple[sympy.Expr, ...]]:
  448. # Try to further simplify the indexes even if simplify_loops didn't
  449. # convert it to the simplest form because of the interference from
  450. # different indexing formulas.
  451. index_vars = [*var_ranges.keys()]
  452. sizes = tuple(var_ranges.values()) # type: ignore[assignment]
  453. new_sizes, reindex, _prune = V.graph.sizevars._simplify_loops(
  454. index_vars,
  455. sizes,
  456. index_prevent_reordering([index], index_vars, sizes),
  457. )
  458. # assign new variables each dimension to deal with numbering mismatches
  459. # d0, d1, d2 could become d0, d2 -- which won't match d0, d1
  460. new_vars, add_var = var_builder(canonicalization_prefix())
  461. replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes])))
  462. index = sympy_subs(sympy.expand(index), replacement)
  463. new_vars = [*new_vars.keys()]
  464. new_sizes = [*new_sizes]
  465. cls.drop_unused_symbols(index, new_vars, new_sizes)
  466. return index, tuple(new_vars), tuple(new_sizes) # type: ignore[arg-type]
  467. def canonicalize(
  468. self, index: sympy.Expr
  469. ) -> tuple[sympy.Expr, tuple[sympy.Symbol, ...], tuple[sympy.Expr, ...]]:
  470. if not self._should_normalize:
  471. sizes = [V.graph.sizevars.simplify(x) for x in self._var_ranges.values()]
  472. var_names = [k for k, v in zip(self._var_ranges.keys(), sizes) if v != 1]
  473. sizes = [v for v in sizes if v != 1]
  474. self.drop_unused_symbols(index, var_names, sizes)
  475. return index, tuple(var_names), tuple(sizes) # type: ignore[return-value, arg-type]
  476. var_ranges = {
  477. k: V.graph.sizevars.simplify(v)
  478. for k, v in self._var_ranges.items()
  479. # TODO(jansel): explore this further normalization
  480. # if k in free_symbols
  481. }
  482. return self._normalize(index, var_ranges)
  483. def load(self, name: str, index: sympy.Expr) -> None:
  484. self._reads.add(MemoryDep(name, *self.canonicalize(index)))
  485. def load_seed(self, name: str, index: int) -> None:
  486. assert isinstance(index, int)
  487. self.load(name, sympy.Integer(index))
  488. def store(
  489. self, name: str, index: sympy.Expr, value: str, mode: Optional[str] = None
  490. ) -> None:
  491. self._writes.add(MemoryDep(name, *self.canonicalize(index), mode=mode))
  492. def store_reduction(self, name: str, index: sympy.Expr, value: str) -> None:
  493. self.store(name, index, f"store_reduction({value})")
  494. def index_expr(self, index: sympy.Expr, dtype: Optional[torch.dtype]) -> None:
  495. self._index_exprs.add(IndexExprDep(*self.canonicalize(index)))
  496. def bucketize(
  497. self,
  498. values: T,
  499. boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr],
  500. boundary_indices: T,
  501. indexing_dtype: torch.dtype,
  502. right: bool,
  503. sorter: Optional[tuple[str, sympy.Expr]] = None,
  504. sorter_indices: Optional[T] = None,
  505. ) -> None:
  506. """Records the names of the buffers that bucketize will read from."""
  507. self._reads.add(StarDep(boundaries[0]))
  508. if sorter is not None:
  509. self._reads.add(StarDep(sorter[0]))
  510. class RecordLoadStore(V.KernelFormatterHandler): # type: ignore[name-defined]
  511. def __init__(self, var_ranges: VarRanges, normalize: bool) -> None:
  512. parent_handler = _RecordLoadStoreInner(
  513. var_ranges=var_ranges, normalize=normalize
  514. )
  515. super().__init__(parent_handler=parent_handler)
  516. # TODO: check call sites
  517. def var_builder(prefix: str) -> tuple[VarRanges, Callable[[sympy.Expr], sympy.Symbol]]:
  518. cnt = itertools.count()
  519. var_ranges: VarRanges = {}
  520. def add_var(length: sympy.Expr) -> sympy.Symbol:
  521. v = sympy_index_symbol(f"{prefix}{next(cnt)}")
  522. var_ranges[v] = length
  523. return v
  524. return var_ranges, add_var
  525. def index_vars_no_squeeze(
  526. *argsizes: Sequence[sympy.Expr], prefix: str
  527. ) -> tuple[list[list[sympy.Symbol]], VarRanges]:
  528. var_ranges, add_var = var_builder(prefix)
  529. args: list[list[sympy.Symbol]] = [list(map(add_var, size)) for size in argsizes]
  530. return args, var_ranges
  531. def index_vars_squeeze(
  532. *argsizes: Sequence[sympy.Expr], prefix: str = "d"
  533. ) -> tuple[list[Sequence[sympy.Expr]], VarRanges]:
  534. from .ir import SqueezeView
  535. var_ranges, add_var = var_builder(prefix)
  536. args: list[Sequence[sympy.Expr]] = []
  537. new_sizes: list[Sequence[sympy.Expr]] = []
  538. for size in argsizes:
  539. new_size, reindex = SqueezeView.squeezer(size)
  540. new_sizes.append(new_size)
  541. args.append(reindex(list(map(add_var, new_size))))
  542. return args, var_ranges
  543. def extract_read_writes(
  544. fn: Callable[..., Any],
  545. *argsizes: Sequence[sympy.Expr],
  546. normalize: bool = False,
  547. prefix: str = "d",
  548. hidden_args: Sequence[list[sympy.Expr]] = (),
  549. ) -> ReadWrites:
  550. args, var_ranges = index_vars_squeeze(*argsizes, prefix=prefix)
  551. from .loop_body import LoopBody
  552. if isinstance(fn, LoopBody):
  553. inner = extract_loop_body_with_args(
  554. fn,
  555. [*args, *hidden_args], # type: ignore[list-item]
  556. var_ranges,
  557. normalize,
  558. )
  559. else:
  560. # Slow path tracing the function
  561. rw = RecordLoadStore(var_ranges, normalize=normalize)
  562. with V.set_ops_handler(rw):
  563. fn(*args, *hidden_args)
  564. inner = rw.parent_handler
  565. if normalize:
  566. range_vars = [] # Number of vars could differ due to normalization
  567. else:
  568. range_vars = [*itertools.chain.from_iterable(args)]
  569. return ReadWrites(
  570. # pyrefly: ignore [missing-attribute]
  571. OrderedSet(inner._reads),
  572. # pyrefly: ignore [missing-attribute]
  573. OrderedSet(inner._writes),
  574. # pyrefly: ignore [missing-attribute]
  575. inner._index_exprs,
  576. range_vars,
  577. var_ranges,
  578. )
  579. def extract_loop_body_with_args(
  580. fn: Any,
  581. args: list[list[sympy.Expr]],
  582. var_ranges: VarRanges,
  583. normalize: bool = False,
  584. ) -> _RecordLoadStoreInner:
  585. from .loop_body import MemoryUsageType
  586. # Fast path to avoid tracing when we already have a LoopBody
  587. inner = _RecordLoadStoreInner(var_ranges=var_ranges, normalize=normalize)
  588. name_to_index = fn.indexing_from_args(args)
  589. if fn.indirect_vars:
  590. # mimic the `tmpX` naming tracing gives us
  591. repl = {v: make_symbol(SymT.TMP, i) for i, v in enumerate(fn.indirect_vars)}
  592. name_to_index = {k: sympy_subs(v, repl) for k, v in name_to_index.items()} # type: ignore[arg-type]
  593. for entry in fn.memory_usage[MemoryUsageType.LOAD]:
  594. inner.load(entry.buffer_name, name_to_index[entry.index_name]) # type: ignore[arg-type]
  595. for entry in fn.memory_usage[MemoryUsageType.LOAD_SEED]:
  596. inner.load_seed(entry.buffer_name, int(name_to_index[entry.index_name])) # type: ignore[arg-type]
  597. for entry in fn.memory_usage[MemoryUsageType.STORE]:
  598. inner.store(
  599. entry.buffer_name,
  600. name_to_index[entry.index_name],
  601. None, # type: ignore[arg-type]
  602. entry.mode,
  603. )
  604. for entry in fn.memory_usage[MemoryUsageType.STORE_REDUCTION]:
  605. inner.store_reduction(
  606. entry.buffer_name,
  607. name_to_index[entry.index_name],
  608. None, # type: ignore[arg-type]
  609. )
  610. for entry in fn.memory_usage[MemoryUsageType.INDEX_EXPR]:
  611. inner.index_expr(name_to_index[entry.index_name], None)
  612. for entry in fn.memory_usage[MemoryUsageType.BUCKETIZE]:
  613. # All that matters is that we record the buffer name, so place it in the
  614. # "boundaries" name position to ensure that it's recorded.
  615. inner.bucketize(
  616. None,
  617. (entry.buffer_name, None, None, None),
  618. None,
  619. None, # type: ignore[arg-type]
  620. None, # type: ignore[arg-type]
  621. )
  622. # fn.memory_usage[MemoryUsageType.CHECK_BOUNDS] intentionally skipped
  623. return inner
  624. def extract_input_node_reduction_ranges(
  625. input_node: "torch._inductor.ir.IRNode",
  626. ) -> tuple[Optional[list[sympy.Expr]], Optional[list[sympy.Expr]]]:
  627. """
  628. Returns the size and reduction size of all inputs, if the sizes and reduction_sizes (if exist) are all the same.
  629. It's possible that a node has multiple inputs, some are Reduction nodes and others are Pointwise nodes.
  630. In this case, reduction_sizes of the Reduction nodes need to be the same.
  631. Otherwise returns (None, None).
  632. """
  633. from .ir import ComputedBuffer, ExternKernel, Loops
  634. size: Optional[list[sympy.Expr]]
  635. reduction_size: Optional[list[sympy.Expr]]
  636. if isinstance(input_node.get_defining_op(), ComputedBuffer):
  637. # Input node has already been realized. Return its size and reduction_size.
  638. size = [*input_node.get_size()]
  639. reduction_size = [*input_node.get_reduction_size()]
  640. if len(reduction_size) > 0:
  641. return (size, reduction_size)
  642. else:
  643. return (None, None)
  644. if not isinstance(input_node.data.data, Loops): # type: ignore[attr-defined]
  645. # Other IRNodes do not have reduction_ranges.
  646. return (None, None)
  647. # There is one issue: what if there are views / permutations between the input node and its dependent realized nodes?
  648. # The current method still uses reduction ranges from the dependent realized node, which is not ideal.
  649. # Is there a way to check whether there are permutations in between?
  650. reads = input_node.get_reads()
  651. reduction_size: Optional[list[sympy.Expr]] = None
  652. size: Optional[list[sympy.Expr]] = None
  653. while reduction_size is None and len(reads) > 0:
  654. seen: OrderedSet[str] = OrderedSet()
  655. new_reads: list[Dep] = []
  656. for read in reads:
  657. if not isinstance(read, MemoryDep):
  658. continue
  659. if read.name in seen:
  660. continue
  661. seen.add(read.name)
  662. buffer = V.graph.try_get_buffer(read.name)
  663. if buffer is None:
  664. continue
  665. op = buffer.get_defining_op()
  666. if op is None or isinstance(op, ExternKernel):
  667. continue
  668. if isinstance(op, ComputedBuffer) and len(op.get_reduction_size()) > 0:
  669. if reduction_size is None:
  670. reduction_size = [*op.get_reduction_size()]
  671. size = [*op.get_size()]
  672. elif reduction_size != [*op.get_reduction_size()] or size != [
  673. *op.get_size()
  674. ]:
  675. return (None, None)
  676. else:
  677. new_reads.extend(op.get_reads())
  678. if reads == new_reads:
  679. return (size, reduction_size)
  680. else:
  681. reads = OrderedSet(new_reads)
  682. return (size, reduction_size)
  683. def canonicalization_prefix() -> str:
  684. return "c"
  685. # ops handler which computes all the free symbols for an IR
  686. class FreeSymbolsOpsHandler(DefaultHandler):
  687. symbols: OrderedSet[sympy.Symbol]
  688. def __init__(self, unbacked_only: bool = True) -> None:
  689. self.symbols = OrderedSet()
  690. self.get_symbols = free_unbacked_symbols if unbacked_only else free_symbols
  691. def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
  692. for a in itertools.chain(args, kwargs.values()):
  693. if isinstance(a, (sympy.Expr, sympy.logic.boolalg.Boolean)):
  694. self.symbols |= self.get_symbols(a)
  695. def indirect_indexing(
  696. self,
  697. index_var: Any,
  698. size: Union[int, sympy.Expr],
  699. check: bool = True,
  700. wrap_neg: bool = True,
  701. ) -> sympy.Symbol:
  702. assert not isinstance(index_var, (sympy.Expr, sympy.logic.boolalg.Boolean))
  703. self.symbols |= self.get_symbols(size)
  704. return sympy_index_symbol(f"({str(index_var)})")
  705. def frexp(self, x: Any) -> tuple[None, ...]:
  706. return (None,) * 2
  707. def scan(
  708. self, dtypes: Any, combine_fn: Any, values: Sequence[Any]
  709. ) -> tuple[None, ...]:
  710. return (None,) * len(values)
  711. def sort(
  712. self, dtypes: Any, values: Sequence[Any], stable: Any, descending: Any
  713. ) -> tuple[None, ...]:
  714. return (None,) * len(values)
  715. def reduction(
  716. self,
  717. dtype: torch.dtype,
  718. src_dtype: torch.dtype,
  719. reduction_type: ReductionType,
  720. value: Union[None, tuple[None, ...]],
  721. ) -> Union[None, tuple[None, ...]]:
  722. num_values = reduction_num_outputs(reduction_type)
  723. return (None,) * num_values if num_values > 1 else None
  724. def masked(self, mask: Any, body: Callable[..., Any], other: Any) -> None:
  725. assert callable(body), "masked body must always be callable."
  726. # The body can make additional calls, for e.g. ops.indirect_indexing
  727. body()
  728. def extract_free_symbols(
  729. fn: Callable[..., Any],
  730. index: Sequence[sympy.Expr],
  731. rindex: Optional[Sequence[sympy.Expr]] = None,
  732. unbacked_only: bool = True,
  733. ) -> OrderedSet[sympy.Symbol]:
  734. from .ir import FlexibleLayout
  735. args = [index, rindex] if rindex is not None else [index]
  736. handler = FreeSymbolsOpsHandler(unbacked_only)
  737. # NB: I cargo culted the allow_indexing patch here, I don't understand why
  738. # people do this all over
  739. with (
  740. V.set_ops_handler(handler),
  741. patch.object(FlexibleLayout, "allow_indexing", True),
  742. ):
  743. fn(*args)
  744. return handler.symbols