choices.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658
  1. from __future__ import annotations
  2. import dataclasses
  3. import typing
  4. from typing import Any, Optional, TYPE_CHECKING, Union
  5. import sympy
  6. import torch
  7. from torch._inductor.runtime.runtime_utils import next_power_of_2
  8. from torch._inductor.scheduler import MixOrderReduction
  9. from torch.utils._sympy.value_ranges import bound_sympy
  10. from . import config
  11. from .codecache import write_text
  12. from .kernel_inputs import KernelInputs # noqa: TC001
  13. from .kernel_template_choice import make_ktc_generator
  14. from .metrics import get_metric_table, is_metric_table_enabled
  15. from .runtime.hints import DeviceProperties, ReductionHint
  16. from .scheduler import BaseSchedulerNode, Scheduler, WhyNoFuse
  17. from .select_algorithm import ExternKernelChoice
  18. from .template_heuristics import get_template_heuristic
  19. from .template_heuristics.triton import (
  20. BaseConfigHeuristic,
  21. CPUConfigHeuristic,
  22. CUDAConfigHeuristic,
  23. MTIAConfigHeuristic,
  24. ROCmConfigHeuristic,
  25. XPUConfigHeuristic,
  26. )
  27. from .utils import _use_autotune_backend
  28. from .virtualized import V
  29. if TYPE_CHECKING:
  30. from collections.abc import Generator
  31. from functools import partial
  32. from triton import Config as TritonConfig
  33. from .codegen.common import KernelTemplate
  34. from .codegen.simd_kernel_features import SIMDKernelFeatures
  35. from .codegen.triton import TritonKernel
  36. from .ir import ChoiceCaller
  37. from .kernel_template_choice import KernelTemplateChoice
  38. from torch.utils._ordered_set import OrderedSet # isort: skip
  39. class Sortable(typing.Protocol):
  40. """Anything that can be used as a list.sort() key (int/tuple/etc)"""
  41. def __lt__(self, other: typing.Self) -> bool: ...
  42. @dataclasses.dataclass
  43. class FusionScore:
  44. template_score: int
  45. node_type_score: bool
  46. memory_score: int
  47. proximity_score: int
  48. def __lt__(self, other):
  49. """
  50. node_type_score has higher priority than memory_score unless
  51. the memory_score differs too much
  52. """
  53. threshold = 16
  54. if self.template_score != other.template_score:
  55. return self.template_score < other.template_score
  56. if (
  57. max(self.memory_score, other.memory_score)
  58. > min(self.memory_score, other.memory_score) * threshold
  59. ):
  60. return self.memory_score < other.memory_score
  61. return (self.node_type_score, self.memory_score, self.proximity_score) < (
  62. other.node_type_score,
  63. other.memory_score,
  64. other.proximity_score,
  65. )
  66. class InductorChoices:
  67. """
  68. This class contains a collection of default heuristics that effect performance of our generated
  69. code. We try to not put correctness requirements in this file.
  70. You can override the choices made here by doing:
  71. class MyHeuristics(InductorChoices):
  72. ...
  73. torch._inductor.virtualized.V.set_choices_handler(MyHeuristics())
  74. """
  75. def get_config_heuristics(
  76. self, device_type: Optional[str] = "cuda"
  77. ) -> BaseConfigHeuristic:
  78. if device_type == "cuda":
  79. if torch.version.hip is None:
  80. return CUDAConfigHeuristic()
  81. else:
  82. return ROCmConfigHeuristic()
  83. elif device_type == "xpu":
  84. return XPUConfigHeuristic()
  85. elif device_type == "cpu":
  86. return CPUConfigHeuristic()
  87. elif device_type == "mtia":
  88. return MTIAConfigHeuristic()
  89. else:
  90. return BaseConfigHeuristic()
  91. # Conv configs
  92. def get_conv_configs(
  93. self, device_type: Optional[str] = "cuda"
  94. ) -> partial[Generator[TritonConfig, None, None]]:
  95. conv_heuristics = self.get_config_heuristics(device_type)
  96. return conv_heuristics.get_conv_configs()
  97. # Flex attention configs
  98. # TODO(coconutruben): break out flexattention/decode configs into the new retrieval mechanism
  99. def get_flex_attention_fwd_configs(
  100. self, head_dim: int, dtype: torch.dtype, device_type: Optional[str] = "cuda"
  101. ) -> list[Any]:
  102. flex_heuristics = self.get_config_heuristics(device_type)
  103. return flex_heuristics.get_flex_attn_fwd_configs(head_dim, dtype)
  104. def get_flex_attention_bwd_configs(
  105. self, head_dim: int, dtype: torch.dtype, device_type: Optional[str] = "cuda"
  106. ) -> list[Any]:
  107. flex_heuristics = self.get_config_heuristics(device_type)
  108. return flex_heuristics.get_flex_attn_bwd_configs(head_dim, dtype)
  109. def get_flex_decode_configs(
  110. self, head_dim: int, dtype: torch.dtype, device_type: Optional[str] = "cuda"
  111. ) -> list[Any]:
  112. flex_heuristics = self.get_config_heuristics(device_type)
  113. return flex_heuristics.get_flex_decode_configs(head_dim, dtype)
  114. def _finalize_template_configs(
  115. self,
  116. template_choices: dict[str, Generator[KernelTemplateChoice, None, None]],
  117. kernel_inputs: KernelInputs,
  118. templates: list[Union[KernelTemplate, ExternKernelChoice]],
  119. op_name: str,
  120. kwarg_overrides: Optional[dict[str, dict[str, Any]]] = None,
  121. ) -> list[KernelTemplateChoice]:
  122. """
  123. This method can be subclassed to perform any override/modification of the choices.
  124. The incoming parameters are cheap (generators), so you can do any overrides without
  125. incurring too much cost. Override this method to customize the kernel template choices
  126. before they are converted to ChoiceCaller objects, which is expensive on template codegen.
  127. The full list of arguments are here to facilitate any overrides you may want to do,
  128. as they can be used to start from scratch for each template if so desired.
  129. Args:
  130. template_choices: Dictionary mapping template UIDs to generators of KernelTemplateChoice objects
  131. kernel_inputs: MMKernelInputs containing input tensor nodes and matrix indices
  132. templates: List of template objects (KernelTemplate or ExternKernelChoice) in use
  133. op_name: Operation name (e.g., "bmm", "baddbmm", "addmm")
  134. kwarg_overrides: Optional dict of kwargs to override for each template heuristic
  135. Returns:
  136. Flattened list of KernelTemplateChoice objects across all templates
  137. """
  138. choices: list[KernelTemplateChoice] = []
  139. for choice_gen in template_choices.values():
  140. choices.extend(choice_gen)
  141. return choices
  142. def get_ktc(
  143. self,
  144. kernel_inputs: KernelInputs,
  145. template: Union[KernelTemplate, ExternKernelChoice],
  146. op_name: str,
  147. kwarg_overrides: Optional[dict[str, Any]] = None,
  148. ) -> Generator[KernelTemplateChoice, None, None]:
  149. """
  150. Utility to get the KernelTemplateChoice generator for a specific input.
  151. This is a per template/op call, whereas get_template_configs is an op wide call (all templates).
  152. Consider when overriding/using at which level you need to make decisions
  153. """
  154. # Extract device_type from kernel_inputs
  155. device_type = kernel_inputs.device_type
  156. assert device_type is not None, "get_ktc requires a valid device type"
  157. # Extract template_name from the template object
  158. template_name = template.uid
  159. # Get the appropriate template-specific heuristic
  160. heuristic = get_template_heuristic(template_name, device_type, op_name)
  161. cs = heuristic.get_template_configs(
  162. kernel_inputs,
  163. op_name,
  164. )
  165. # adjust the kernel inputs to the template-specific heuristic, if needed
  166. # default here is to just return the kernel_inputs as is
  167. inputs_val = heuristic.adjust_kernel_inputs(kernel_inputs, op_name)
  168. extra_kwargs = heuristic.get_extra_kwargs(kernel_inputs, op_name)
  169. # Create KernelTemplateChoice generator using the moved function
  170. overrides = kwarg_overrides or {}
  171. return make_ktc_generator(
  172. template=template,
  173. cs=cs,
  174. extra_kwargs=extra_kwargs,
  175. overrides=overrides,
  176. layout=kernel_inputs.output_layout(),
  177. inputs=inputs_val,
  178. )
  179. def _need_to_fix_layout(
  180. self,
  181. adjusted_choices: list[KernelTemplateChoice],
  182. op_name: str,
  183. ) -> bool:
  184. """
  185. Check if we need to fix the layout instead of keeping it flexible
  186. Args:
  187. ktc: KernelTemplateChoice object
  188. Returns:
  189. True if we need to fix the layout, False otherwise
  190. """
  191. # TODO: debug and fix
  192. # NOTE: on mps, we see issues with flexible layouts on baddmm. This check just makes sure
  193. # that for mps, everything stays as it was before this optimization
  194. if len(adjusted_choices) > 0:
  195. if adjusted_choices[0].inputs.device_type == "mps" and op_name not in [
  196. "mm",
  197. "addmm",
  198. ]:
  199. return True
  200. # Since the following backends are not using get_mm_configs yet through the singular call,
  201. if not (config.max_autotune or config.max_autotune_gemm):
  202. # no danger of using other backends than ATEN
  203. if not config.max_autotune_allow_flexible_layouts and op_name not in [
  204. # The historical implementation for mm and addmm allowed had flexible layouts in the
  205. # not max-autotune world
  206. "mm",
  207. "addmm",
  208. ]:
  209. # TODO: deprecate this by migrating users to the new behavior
  210. return True
  211. return False
  212. if not config.max_autotune_allow_flexible_layouts:
  213. # we always need to fix the layout
  214. return True
  215. # Since the following backends are not using get_template_configs yet through the singular call,
  216. # we don't know if they are a valid choice or not. Instead, just skip the optimization
  217. # defensively.
  218. # TODO(coconutruben): remove this once CPP,CK,CUTLASS are supported
  219. if _use_autotune_backend("CUTLASS"):
  220. return True
  221. if _use_autotune_backend("CK") or _use_autotune_backend("CKTILE"):
  222. return True
  223. if _use_autotune_backend("CPP"):
  224. return True
  225. return any(
  226. not isinstance(ktc.template, ExternKernelChoice) for ktc in adjusted_choices
  227. )
  228. def get_template_configs(
  229. self,
  230. kernel_inputs: KernelInputs,
  231. templates: list[Union[KernelTemplate, ExternKernelChoice]],
  232. op_name: str,
  233. kwarg_overrides: Optional[dict[str, dict[str, Any]]] = None,
  234. ) -> list[ChoiceCaller]:
  235. """
  236. Get list of ChoiceCallers for MM templates using template-specific heuristics.
  237. Args:
  238. kernel_inputs: MMKernelInputs containing input tensor nodes and matrix indices
  239. layout: Output layout
  240. templates: List of template objects (KernelTemplate or ExternKernelChoice)
  241. op_name: Operation name (e.g., "bmm", "baddbmm", "addmm", "mm_plus_mm")
  242. kwarg_overrides: Optional dict of kwargs to override for each template heuristic,
  243. indexed by template.uid. These only override the per config kwargs, not the extra kwargs
  244. Returns:
  245. List of ChoiceCaller objects from the templates
  246. """
  247. if kwarg_overrides is None:
  248. kwarg_overrides = {}
  249. input_tensors = kernel_inputs.nodes()
  250. if len(input_tensors) < 2:
  251. raise ValueError(f"Need at least 2 input tensors, got {len(input_tensors)}")
  252. layout = kernel_inputs.output_layout()
  253. # First pass: Create dict of template.uid to generator of KernelTemplateChoice objects
  254. template_choices = {}
  255. for template in templates:
  256. template_choices[template.uid] = self.get_ktc(
  257. kernel_inputs,
  258. template,
  259. op_name,
  260. kwarg_overrides.get(template.uid, {}),
  261. )
  262. # Second pass: Adjust the template choices
  263. adjusted_choices = self._finalize_template_configs(
  264. template_choices,
  265. kernel_inputs,
  266. templates,
  267. op_name,
  268. kwarg_overrides,
  269. )
  270. # Layout optimization: if all choices are ExternKernelChoice and layout is FixedLayout, convert to FlexibleLayout
  271. if self._need_to_fix_layout(adjusted_choices, op_name):
  272. layout = kernel_inputs.output_layout(flexible=False)
  273. for ktc in adjusted_choices:
  274. ktc.layout = layout
  275. # for good measure, delete the cached ChoiceCaller from the ktc if it existed.
  276. # ExternKernelChoice are cheap to generate
  277. if hasattr(ktc, "_choice"):
  278. del ktc._choice
  279. # Third pass: Convert to ChoiceCaller objects
  280. return [ktc.choice for ktc in adjusted_choices if ktc.choice is not None]
  281. def triton_kernel_kwargs(
  282. self,
  283. kernel_cls: type[TritonKernel],
  284. features: SIMDKernelFeatures,
  285. groups: list[sympy.Expr],
  286. kernel_kwargs: dict[str, Any],
  287. ) -> dict[str, Any]:
  288. """Hook to change the kwargs passed to TritonKernel, used to apply fixed configurations"""
  289. return kernel_kwargs
  290. @staticmethod
  291. def should_use_cooperative_reduction(features: SIMDKernelFeatures) -> bool:
  292. """Heuristic to decide if a cooperative reduction should be used."""
  293. if config.triton.force_cooperative_reductions:
  294. return True
  295. if (
  296. not config.triton.cooperative_reductions
  297. or V.graph.get_current_device_or_throw().type == "cpu"
  298. ):
  299. return False
  300. xhint = V.graph.sizevars.optimization_hint(features.numel, fallback=2)
  301. if xhint <= 8:
  302. threshold = 32768 * xhint
  303. elif xhint <= 16:
  304. threshold = 2097152
  305. else:
  306. return False
  307. # TODO(jansel): should this default on for dynamic shapes?
  308. # TODO(laith) What if hint(features.reduction_numel) >= threshold ?
  309. # shall we compare hints instead
  310. return V.graph.sizevars.statically_known_geq(
  311. features.reduction_numel, threshold
  312. )
  313. @staticmethod
  314. def should_use_persistent_reduction(
  315. features: SIMDKernelFeatures, cooperative_reduction: bool
  316. ) -> bool:
  317. """
  318. Heuristic to decide if a persistent reduction should be used.
  319. """
  320. if not config.triton.persistent_reductions:
  321. return False
  322. threshold = {
  323. ReductionHint.INNER: 1024,
  324. }.get(features.get_reduction_hint(), 64)
  325. if features.get_reduction_hint() not in (
  326. ReductionHint.INNER,
  327. ReductionHint.OUTER_TINY,
  328. ):
  329. bounds = bound_sympy(features.reduction_numel)
  330. lower = bounds.lower
  331. upper = bounds.upper
  332. if not all(
  333. (
  334. (isinstance(bound, int) or bound.is_constant())
  335. and not torch.utils._sympy.numbers.is_infinite(bound)
  336. )
  337. for bound in (lower, upper)
  338. ):
  339. return False
  340. lower = next_power_of_2(int(lower))
  341. upper = next_power_of_2(int(upper))
  342. # If we are are coalescing on xblock (not ReductionHint.INNER) and this is not a tiny kernel
  343. # (not ReductionHint.OUTER_TINY), do not use persistent reduction if it induces tile
  344. # quantization. Persistent reduction forces rblock == rnumel, if the bounds between lower
  345. # and upper are large, for the lower values we will be masking off large % of read/writes,
  346. # when we could expand the coalescing xblock instead.
  347. if lower != upper:
  348. return False
  349. if cooperative_reduction:
  350. # The RSPLIT of cooperative reductions means each thread block is operating on fewer elements
  351. # The default fallback will be used if optimizations hint is not provided. The default fallback
  352. # is >> 32.
  353. threshold *= 32 // min(
  354. V.graph.sizevars.optimization_hint(features.numel), 32
  355. )
  356. # If multi_kernel is enabled, we do more aggressive persistent reduction.
  357. # This may result in some persistent reductions slower than the
  358. # corresponding non-persistent reductions. MultiKernel will do benchmarking
  359. # to pick the faster one.
  360. if config.triton.multi_kernel:
  361. threshold *= 16
  362. return V.graph.sizevars.statically_known_leq(
  363. features.reduction_numel, threshold
  364. ) # type: ignore[arg-types]
  365. @staticmethod
  366. def reduction_split_factor(
  367. device: torch.device,
  368. reduction_numel_hint: int,
  369. numel_hint: int,
  370. inner_reduction: bool,
  371. ) -> int:
  372. """Heuristic to decide the RSPLIT used for split reductions.
  373. When a reduction has a small number of outputs there is not enough parallelism,
  374. so we will do the reduction in two phases."""
  375. props = DeviceProperties.create(device)
  376. num_sm = props.multi_processor_count
  377. warp_size = props.warp_size if props.warp_size is not None else 32
  378. max_threads_per_sm = (
  379. props.max_threads_per_multi_processor
  380. if props.max_threads_per_multi_processor is not None
  381. else 2048
  382. )
  383. min_elements_per_thread = warp_size
  384. max_elements_per_thread = 512
  385. threads_per_sm = max_threads_per_sm
  386. min_elements_per_device = min_elements_per_thread * num_sm * threads_per_sm
  387. max_elements_per_device = max_elements_per_thread * num_sm * threads_per_sm
  388. num_warps = 8
  389. num_threads = warp_size * num_warps
  390. if inner_reduction:
  391. # do heuristics that's close to eager mode for split inner reduction
  392. # we leak reduction autotune configs here, and will need to refactor to avoid this later
  393. if numel_hint >= 2 * num_sm: # don't split if there are enough outputs
  394. return 1
  395. if reduction_numel_hint <= 8192:
  396. return 1
  397. if reduction_numel_hint * numel_hint <= min_elements_per_device:
  398. split_size = min_elements_per_thread
  399. elif reduction_numel_hint * numel_hint < max_elements_per_device:
  400. target_blocks = num_sm * threads_per_sm // (2 * num_threads)
  401. blocks_per_output = (target_blocks + numel_hint - 1) // numel_hint
  402. tmp_split_size = (
  403. reduction_numel_hint + num_threads * blocks_per_output - 1
  404. ) // (num_threads * blocks_per_output)
  405. divisors = sympy.divisors(reduction_numel_hint)
  406. closest = min(divisors, key=lambda x: abs(x - tmp_split_size))
  407. if abs(closest - tmp_split_size) < 30:
  408. # prefer even splits, but never smalle than min_elements_per_thread
  409. split_size = max(closest, min_elements_per_thread)
  410. else:
  411. split_size = tmp_split_size
  412. else:
  413. divisors = sympy.divisors(reduction_numel_hint)
  414. closest = min(divisors, key=lambda x: abs(x - max_elements_per_thread))
  415. if abs(closest - max_elements_per_thread) < 50:
  416. # prefer even splits
  417. split_size = closest
  418. else:
  419. split_size = max_elements_per_thread
  420. return (reduction_numel_hint + split_size * num_threads - 1) // (
  421. split_size * num_threads
  422. )
  423. else:
  424. # TODO the best heuristic currently has XBLOCK (corresponding to numel_hint) 128
  425. # extend to even smaller number of outputs
  426. rvals_per_thread = 4 # comes from heuristics, refactor to not leak here
  427. xvals_per_block = 128
  428. xblocks = (numel_hint + xvals_per_block - 1) // xvals_per_block
  429. if reduction_numel_hint * numel_hint < min_elements_per_device:
  430. split_size = min_elements_per_thread
  431. elif reduction_numel_hint * numel_hint < max_elements_per_device:
  432. target_blocks = num_sm * threads_per_sm // (num_threads)
  433. target_blocks = (target_blocks + xblocks - 1) // xblocks
  434. tmp_split_size = (
  435. reduction_numel_hint + rvals_per_thread * target_blocks - 1
  436. ) // (rvals_per_thread * target_blocks)
  437. divisors = sympy.divisors(reduction_numel_hint)
  438. closest = min(divisors, key=lambda x: abs(x - tmp_split_size))
  439. if abs(tmp_split_size - closest) < 20:
  440. split_size = max(closest, min_elements_per_thread)
  441. else:
  442. split_size = tmp_split_size
  443. else:
  444. divisors = sympy.divisors(reduction_numel_hint)
  445. closest = min(divisors, key=lambda x: abs(x - max_elements_per_thread))
  446. if abs(closest - max_elements_per_thread) < 50:
  447. # prefer even splits
  448. split_size = closest
  449. else:
  450. split_size = max_elements_per_thread
  451. return (reduction_numel_hint + rvals_per_thread * split_size - 1) // (
  452. rvals_per_thread * split_size
  453. )
  454. @staticmethod
  455. def can_fuse(
  456. scheduler: Scheduler,
  457. node1: BaseSchedulerNode,
  458. node2: BaseSchedulerNode,
  459. shared_data_score: int,
  460. ) -> bool:
  461. """
  462. Heuristics to prevent fusion applied to both horizontal and vertical fusions. Heuristics here should not
  463. be needed for correctness and tweaking them may yield additional performance.
  464. See also some related heuristics that can be changed via config:
  465. - config.triton.tiling_prevents_pointwise_fusion
  466. - config.triton.tiling_prevents_reduction_fusion
  467. - config.aggressive_fusion (will cause this function to be called more times)
  468. """
  469. if shared_data_score == 0 and (
  470. not config.aggressive_fusion or node1.is_reduction() or node2.is_reduction()
  471. ):
  472. if is_metric_table_enabled("fusion_failure_due_to_indexing_mismatch"):
  473. common_buf_names: OrderedSet[str] = (
  474. node1.read_writes.buffer_names() & node2.read_writes.buffer_names()
  475. )
  476. if len(common_buf_names) > 0:
  477. get_metric_table("fusion_failure_due_to_indexing_mismatch").add_row(
  478. # pyrefly: ignore [bad-argument-type]
  479. lambda: {
  480. "pre_grad_graph_id": V.graph.graph_id,
  481. "post_grad_graph_id": V.graph.post_grad_graph_id,
  482. "node1_name": node1.get_name(),
  483. "node2_name": node2.get_name(),
  484. "node1_debug_str": write_text(node1.debug_str()),
  485. "node2_debug_str": write_text(node2.debug_str()),
  486. "common_buffer_names": list(common_buf_names), # type: ignore[dict-item]
  487. "failure_reason": scheduler.decide_fusion_fail_reason(
  488. node1, node2, common_buf_names
  489. ),
  490. }
  491. )
  492. WhyNoFuse(node1, node2)("no shared data due to indexing mismatch")
  493. return False
  494. WhyNoFuse(node1, node2)("no shared data")
  495. return False # heuristic not needed for correctness
  496. if (
  497. not node1.is_foreach()
  498. and not node2.is_foreach()
  499. and len(node1.get_nodes()) + len(node2.get_nodes()) > config.max_fusion_size
  500. ):
  501. WhyNoFuse(node1, node2)("exceeds max fusion")
  502. return False # heuristic not needed for correctness
  503. if scheduler.can_fusion_increase_peak_memory(node1, node2):
  504. WhyNoFuse(node1, node2)("Fusion will increase peak memory")
  505. return False
  506. if (
  507. config.max_fusion_unique_io_buffers is not None
  508. and scheduler.fusion_prevent_too_many_reads_and_writes(
  509. node1,
  510. node2,
  511. config.max_fusion_unique_io_buffers,
  512. )
  513. ):
  514. WhyNoFuse(node1, node2)("fusion_prevent_too_many_reads_and_writes")
  515. return False
  516. return True
  517. @staticmethod
  518. def can_fuse_vertical(
  519. scheduler: Scheduler,
  520. node1: BaseSchedulerNode,
  521. node2: BaseSchedulerNode,
  522. shared_data_score: int,
  523. ) -> bool:
  524. """Hook for heuristics to prevent vertical (producer/consumer) fusions"""
  525. return True
  526. @staticmethod
  527. def can_fuse_horizontal(
  528. scheduler: Scheduler,
  529. node1: BaseSchedulerNode,
  530. node2: BaseSchedulerNode,
  531. shared_data_score: int,
  532. ) -> bool:
  533. """Hook for heuristics to prevent horizontal (consumer/consumer) fusions"""
  534. if MixOrderReduction.can_fuse(node1, node2):
  535. # For mix order reduction, we disregard shared data or
  536. # distance.
  537. return True
  538. if shared_data_score < config.score_fusion_memory_threshold:
  539. WhyNoFuse(node1, node2)("score_fusion_memory_threshold")
  540. return False
  541. if scheduler.are_long_distant_nodes(node1, node2):
  542. WhyNoFuse(node1, node2)(
  543. "Nodes are too far away. Fusing them may increase peak memory."
  544. )
  545. return False
  546. return True
  547. @staticmethod
  548. def score_fusion(
  549. scheduler: Scheduler,
  550. node1: BaseSchedulerNode,
  551. node2: BaseSchedulerNode,
  552. ) -> Sortable:
  553. """
  554. Assign a score (higher comes first) to the fusion of node1 and node2.
  555. When different fusions conflict with each other, this is the way we
  556. decide what order to run them in.
  557. Our current score is based on:
  558. - The type of fusion (template/reduction/etc)
  559. - Estimate of the saved memory operations
  560. - Fusions closer together in original graph order
  561. """
  562. memory_score, is_mix_order_reduction = typing.cast(
  563. tuple[int, bool],
  564. scheduler.score_fusion_memory(
  565. node1, node2, return_is_mix_order_reduction=True
  566. ),
  567. )
  568. proximity_score = -max(
  569. abs(node1.min_order - node2.max_order),
  570. abs(node2.min_order - node1.max_order),
  571. )
  572. # prologue fusion always last
  573. if node2.is_template():
  574. template_score = 0
  575. else:
  576. template_score = 1 + (
  577. (node1.is_template() == config.epilogue_fusion_first)
  578. and memory_score > 0
  579. )
  580. type_score = node1.is_reduction() == node2.is_reduction() and memory_score > 0
  581. return FusionScore(
  582. template_score,
  583. type_score,
  584. memory_score,
  585. proximity_score,
  586. )