decomposition.py 42 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320
  1. # mypy: allow-untyped-decorators
  2. import functools
  3. import logging
  4. import math
  5. import operator
  6. import sys
  7. import typing
  8. from collections.abc import Callable
  9. from typing import Any, Optional, TypeAlias, TypeVar, Union
  10. from typing_extensions import ParamSpec
  11. import torch
  12. import torch._decomp as decomp
  13. import torch._prims_common as utils
  14. import torch.ao.quantization.fx._decomposed
  15. from torch._decomp import (
  16. core_aten_decompositions,
  17. get_decompositions,
  18. remove_decompositions,
  19. )
  20. from torch._decomp.decompositions import (
  21. _grid_sampler_2d as decomp_grid_sampler_2d,
  22. _index_add,
  23. embedding_dense_backward as decomp_embedding_dense_backward,
  24. pw_cast_for_opmath,
  25. pw_cast_for_opmath_non_tensor_args,
  26. )
  27. from torch._decomp.decompositions_for_rng import extra_random_decomps
  28. from torch._dynamo.utils import counters
  29. from torch._environment import is_fbcode
  30. from torch._higher_order_ops.out_dtype import out_dtype
  31. from torch._inductor.utils import pad_listlike
  32. from torch._prims_common import (
  33. elementwise_dtypes,
  34. ELEMENTWISE_TYPE_PROMOTION_KIND,
  35. type_to_dtype,
  36. )
  37. from torch._refs import native_layer_norm as decomp_native_layer_norm
  38. from torch.fx.experimental.symbolic_shapes import guard_or_false, statically_known_true
  39. from torch.utils._ordered_set import OrderedSet
  40. from . import config, inductor_prims
  41. from .utils import (
  42. is_gpu,
  43. needs_fallback_due_to_atomic_add_limitations,
  44. use_scatter_fallback,
  45. )
  46. _T = TypeVar("_T")
  47. _P = ParamSpec("_P")
  48. _GenericOperator: TypeAlias = Union[
  49. torch._ops.OperatorBase, torch._ops.OpOverloadPacket
  50. ]
  51. log = logging.getLogger(__name__)
  52. aten = torch.ops.aten
  53. prims = torch.ops.prims
  54. quantized = torch.ops.quantized
  55. _quantized = torch.ops._quantized
  56. quantized_decomposed = torch.ops.quantized_decomposed
  57. inductor_decompositions = get_decompositions(
  58. [
  59. aten._adaptive_avg_pool2d_backward,
  60. aten.index_select,
  61. aten.addmv,
  62. aten.arange,
  63. aten.bitwise_and_,
  64. aten.bitwise_or_,
  65. aten.clamp_min_,
  66. aten.dist,
  67. aten.elu,
  68. aten.empty_like,
  69. aten.flip,
  70. aten.gelu,
  71. aten.hardtanh,
  72. aten.lcm,
  73. aten.leaky_relu,
  74. aten.linalg_vector_norm,
  75. aten._log_softmax,
  76. aten.max_pool2d_with_indices_backward,
  77. aten._native_batch_norm_legit,
  78. aten._native_batch_norm_legit_functional,
  79. aten._native_batch_norm_legit_no_training,
  80. aten._batch_norm_with_update,
  81. aten._batch_norm_with_update_functional,
  82. aten._batch_norm_no_update,
  83. aten.batch_norm_backward,
  84. aten.native_batch_norm,
  85. aten.native_group_norm,
  86. aten.native_layer_norm,
  87. aten.nll_loss2d_backward,
  88. aten.permute_copy,
  89. aten.rrelu_with_noise_backward,
  90. aten._softmax,
  91. aten.sin_,
  92. aten.sqrt_,
  93. out_dtype,
  94. aten._to_copy,
  95. aten.tril_indices,
  96. aten.triu_indices,
  97. aten.unbind_copy.int,
  98. aten.upsample_bilinear2d.vec,
  99. quantized.linear_dynamic_fp16_unpacked_weight,
  100. _quantized.wrapped_quantized_linear,
  101. ]
  102. )
  103. decompositions = {**core_aten_decompositions(), **inductor_decompositions}
  104. # Remove unwanted decompositions included via the core ATen decompositions from
  105. # the Inductor decomp table.
  106. decomps_to_exclude: list[Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket]] = [
  107. aten._unsafe_index,
  108. aten._unsafe_masked_index,
  109. aten._unsafe_masked_index_put_accumulate,
  110. aten._scaled_dot_product_flash_attention_for_cpu.default, # See comments in torch/_decomp/decompositions.py
  111. aten._softmax_backward_data,
  112. aten.clamp_max,
  113. aten.clamp_min,
  114. aten.embedding_dense_backward, # we fall back on xpu
  115. aten.native_layer_norm, # we fall back on mtia
  116. aten.index_add, # we conditionally call this decomp
  117. aten.glu, # inductor lowers this directly
  118. aten.select_scatter, # need to be in the ATen graph in order for it to work with the re-inplacing pass
  119. aten.slice_scatter, # need to be in the ATen graph in order for it to work with the re-inplacing pass
  120. aten.silu, # inductor uses exact eager decomposition
  121. aten.split.Tensor, # inductor lowers this directly
  122. aten.squeeze, # inductor lowers this directly
  123. aten.sum, # inductor lowers this directly
  124. aten.unbind, # inductor lowers this directly
  125. aten.baddbmm, # upcasts to fp32, perf issue
  126. ]
  127. remove_decompositions(decompositions, decomps_to_exclude)
  128. def register_decomposition(
  129. ops: Union[_GenericOperator, list[_GenericOperator]],
  130. ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
  131. for op in ops if isinstance(ops, list) else [ops]:
  132. if op in decompositions:
  133. log.warning("duplicate decomp: %s", ops)
  134. return decomp.register_decomposition(ops, decompositions)
  135. @register_decomposition([aten.embedding_dense_backward])
  136. def _embedding_dense_backward(
  137. grad_output: torch.Tensor,
  138. indices: torch.Tensor,
  139. num_weights: int,
  140. padding_idx: int,
  141. scale_grad_by_freq: bool,
  142. ) -> torch.Tensor:
  143. # TODO: check if XE4 still need this fallback
  144. # check torch.xpu.get_device_properties(grad_output.device).architecture
  145. if grad_output.is_xpu:
  146. return NotImplemented
  147. # We can write a util function to update decomp table if we have more ops to fallback.
  148. return decomp_embedding_dense_backward(
  149. grad_output, indices, num_weights, padding_idx, scale_grad_by_freq
  150. )
  151. @register_decomposition(aten.native_layer_norm)
  152. def _native_layer_norm(
  153. input: torch.Tensor,
  154. normalized_shape: utils.ShapeType,
  155. weight: Optional[torch.Tensor],
  156. bias: Optional[torch.Tensor],
  157. eps: float,
  158. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  159. if input.is_mtia:
  160. return NotImplemented
  161. # We can write a util function to update decomp table if we have more ops to fallback.
  162. return decomp_native_layer_norm(input, normalized_shape, weight, bias, eps)
  163. @register_decomposition([aten.sym_constrain_range_for_size.default])
  164. def sym_constrain_range_for_size(
  165. symbol: torch.SymInt,
  166. *,
  167. min: Optional[torch.types.Number] = None,
  168. max: Optional[torch.types.Number] = None,
  169. ) -> None:
  170. return
  171. @register_decomposition([aten.clamp])
  172. @pw_cast_for_opmath_non_tensor_args
  173. def clamp(
  174. x: torch.Tensor,
  175. min: Optional[torch.types.Number] = None,
  176. max: Optional[torch.types.Number] = None,
  177. ) -> torch.Tensor:
  178. if min is not None:
  179. x = x.clamp_min(min)
  180. if max is not None:
  181. x = x.clamp_max(max)
  182. return x
  183. # Inductor-specific SiLU decomposition for exact eager matching.
  184. # The core decomposition uses x * sigmoid(x), but this form
  185. # x / (1 + exp(-x)) matches eager execution more precisely.
  186. @register_decomposition([aten.silu])
  187. @pw_cast_for_opmath
  188. def silu(x: torch.Tensor) -> torch.Tensor:
  189. return x / (1 + x.neg().exp())
  190. @register_decomposition([aten.full])
  191. def full(
  192. size: list[Union[int, torch.SymInt]],
  193. fill_value: torch.types.Number,
  194. **kwargs: Any,
  195. ) -> torch.Tensor:
  196. dtype = kwargs.get("dtype")
  197. if dtype is None:
  198. kwargs["dtype"] = type_to_dtype(type(fill_value))
  199. return torch.full(size, fill_value, **kwargs)
  200. return NotImplemented
  201. @register_decomposition([aten.index_add])
  202. def index_add(
  203. x: torch.Tensor,
  204. dim: int,
  205. index: torch.Tensor,
  206. tensor: torch.Tensor,
  207. *,
  208. alpha: torch.types.Number = 1,
  209. ) -> torch.Tensor:
  210. # If we are not in fbcode and dtype is bfloat16
  211. # fallback to index_add kernel
  212. # see https://github.com/pytorch/pytorch/issues/137425 for details
  213. if not is_fbcode() and x.dtype == torch.bfloat16:
  214. return NotImplemented
  215. else:
  216. return _index_add(x, dim, index, tensor, inplace=False, alpha=alpha)
  217. # Not really sure how to put this into the main library. PrimTorch wants
  218. # empty_permuted to go to the prim, and typically users don't really want
  219. # to decompose to empty_strided (but inductor is OK with it, because we are
  220. # cool with strides and everything goes to empty_strided)
  221. @register_decomposition([aten.empty_permuted.default])
  222. def empty_permuted(
  223. size: list[Union[int, torch.SymInt]],
  224. physical_layout: list[int],
  225. **kwargs: Any,
  226. ) -> torch.Tensor:
  227. is_identity = list(physical_layout) == list(range(len(physical_layout)))
  228. if is_identity:
  229. return torch.empty(size, **kwargs)
  230. else:
  231. perm = [0] * len(size)
  232. for p, l in enumerate(physical_layout):
  233. perm[l] = p
  234. return torch.empty([size[l] for l in physical_layout], **kwargs).permute(perm)
  235. @register_decomposition([aten.convolution_backward])
  236. def convolution_backward(
  237. grad_output: torch.Tensor,
  238. input: torch.Tensor,
  239. weight: torch.Tensor,
  240. bias_sizes: list[int],
  241. stride: Union[int, list[int]],
  242. padding: Union[int, list[int]],
  243. dilation: Union[int, list[int]],
  244. transposed: bool,
  245. output_padding: list[int],
  246. groups: int,
  247. output_mask: list[bool],
  248. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  249. if not output_mask[2] or not is_gpu(grad_output.device.type):
  250. return NotImplemented
  251. grad_bias = aten.sum(grad_output, [0] + list(range(2, grad_output.dim())))
  252. grad_inp, grad_weight, _ = aten.convolution_backward(
  253. grad_output,
  254. input,
  255. weight,
  256. bias_sizes,
  257. stride,
  258. padding,
  259. dilation,
  260. transposed,
  261. output_padding,
  262. groups,
  263. [output_mask[0], output_mask[1], False],
  264. )
  265. return (grad_inp, grad_weight, grad_bias)
  266. @register_decomposition([aten.round.decimals])
  267. def round_dec(x: torch.Tensor, decimals: int = 0) -> torch.Tensor:
  268. ten_pow_decimals = 10.0**decimals
  269. return aten.round(x * ten_pow_decimals) * (1.0 / ten_pow_decimals)
  270. @register_decomposition([aten.bmm])
  271. @pw_cast_for_opmath
  272. def bmm(
  273. self: torch.Tensor,
  274. batch2: torch.Tensor,
  275. out_dtype: Optional[torch.dtype] = None,
  276. ) -> torch.Tensor:
  277. # TODO: Re-enable for mps once our reductions are performant enough
  278. # (https://github.com/pytorch/pytorch/issues/150121)
  279. if config.coordinate_descent_tuning and self.device.type not in ["cpu", "mps"]:
  280. if statically_known_true(self.shape[1] == 1) or statically_known_true(
  281. batch2.shape[2] == 1
  282. ):
  283. out = (self.unsqueeze(-1) * batch2.unsqueeze(1)).sum(dim=2)
  284. return out
  285. if self.device.type == "cpu":
  286. if statically_known_true(self.size(1) == 1) and statically_known_true(
  287. batch2.size(-1) == 1
  288. ):
  289. counters["inductor"]["decompose_bmm"] += 1
  290. return torch.sum(
  291. self.squeeze(1) * batch2.squeeze(-1), dim=1, keepdim=True
  292. ).unsqueeze(1)
  293. return NotImplemented
  294. @register_decomposition([aten.addmm])
  295. @pw_cast_for_opmath
  296. def addmm(
  297. self: torch.Tensor,
  298. mat1: torch.Tensor,
  299. mat2: torch.Tensor,
  300. out_dtype: Optional[torch.dtype] = None,
  301. beta: torch.types.Number = 1,
  302. alpha: torch.types.Number = 1,
  303. ) -> torch.Tensor:
  304. if self.device.type == "cpu":
  305. if statically_known_true(mat1.size(0) == 1) and statically_known_true(
  306. mat2.size(-1) == 1
  307. ):
  308. counters["inductor"]["decompose_addmm"] += 1
  309. out = torch.sum(
  310. mat1.squeeze(0) * mat2.squeeze(-1), dim=0, keepdim=True
  311. ).unsqueeze(0)
  312. return alpha * out + beta * self
  313. if (
  314. statically_known_true(mat1.size(0) == 1)
  315. and guard_or_false(mat2.size(0) <= 16)
  316. and guard_or_false(mat2.size(1) <= 16)
  317. ):
  318. counters["inductor"]["decompose_addmm"] += 1
  319. out = (mat1.T * mat2).sum(dim=0, keepdim=True)
  320. return alpha * out + beta * self
  321. return NotImplemented
  322. @register_decomposition([aten.mm])
  323. @pw_cast_for_opmath
  324. def mm(
  325. self: torch.Tensor,
  326. input2: torch.Tensor,
  327. out_dtype: Optional[torch.dtype] = None,
  328. ) -> torch.Tensor:
  329. # Our matrix vector multiplies only achieve peak bandwidth with coordinate descent tuning.
  330. # todo: Look into why and fix it (hopefully)
  331. # TODO: Re-enable for mps once our reductions are performant enough
  332. # (https://github.com/pytorch/pytorch/issues/150121)
  333. if config.coordinate_descent_tuning and self.device.type not in ["cpu", "mps"]:
  334. if statically_known_true(self.shape[0] == 1) or statically_known_true(
  335. input2.shape[1] == 1
  336. ):
  337. return (self.unsqueeze(2) * input2.unsqueeze(0)).sum(dim=1)
  338. if self.device.type == "cpu":
  339. if (
  340. statically_known_true(self.size(-1) == 1)
  341. and statically_known_true(self.size(0) > 0)
  342. and statically_known_true(input2.size(0) == 1)
  343. and (self.dtype == input2.dtype)
  344. and guard_or_false((torch.numel(self) + torch.numel(input2)) <= 32)
  345. ):
  346. counters["inductor"]["decompose_mm"] += 1
  347. return self * input2
  348. if statically_known_true(self.size(0) == 1) and statically_known_true(
  349. input2.size(-1) == 1
  350. ):
  351. counters["inductor"]["decompose_mm"] += 1
  352. return torch.sum(
  353. self.squeeze(0) * input2.squeeze(-1), dim=0, keepdim=True
  354. ).unsqueeze(0)
  355. return NotImplemented
  356. # This pass does two things:
  357. # - Eliminate cat when there is only one tensor input
  358. # - Normalize cat calls, so that legacy empty 1-D tensors are removed (NB: we
  359. # don't remove ALL empty tensors, only the naughty ones)
  360. @register_decomposition([aten.cat.default])
  361. def cat(
  362. tensors: list[torch.Tensor],
  363. dim: int = 0,
  364. ) -> torch.Tensor:
  365. def non_empty_tensor(x: torch.Tensor) -> bool:
  366. # For better or worse, this is a valid cat:
  367. #
  368. # torch.cat([torch.randn(2, 2, 4), torch.randn(0), torch.randn(3, 2, 4)])
  369. #
  370. # We'd like to eliminate naughtiness like this for downstream passes
  371. # like split_cat. The easiest way is to just drop such inputs
  372. # (guarding that they are non-zero).
  373. #
  374. # Is it permissible for this filtering to be size-oblivious? A case
  375. # where this could matter is cat([(2, 2), (u0,)], dim=0); if u0
  376. # happened to be zero, we would have liked to have filtered it out.
  377. # But actually, the ONLY way this could have passed is if u0 == 0,
  378. # so by the time we get here we have already installed a deferred
  379. # runtime assert forcing u0 to be zero. So if this hasn't happened,
  380. # we know that the unbacked SymInt has appropriate size and there are
  381. # no problems.
  382. if len(x.shape) == 1 and guard_or_false(x.shape[0] == 0):
  383. return False
  384. if dim < len(x.shape) and guard_or_false(x.shape[dim] == 0):
  385. return False
  386. return True
  387. filtered_tensors = list(filter(non_empty_tensor, tensors))
  388. if len(filtered_tensors) == 1:
  389. # check dtype promotion
  390. promoted_dtype = elementwise_dtypes(
  391. *tensors,
  392. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  393. )[1]
  394. filtered_t = filtered_tensors[0]
  395. return (
  396. filtered_t.clone()
  397. if promoted_dtype == filtered_t.dtype
  398. else filtered_t.to(dtype=promoted_dtype)
  399. )
  400. elif 1 < len(filtered_tensors) < len(tensors):
  401. # on the first call, when we remove empty tensors, we redispatch recursively
  402. return aten.cat.default(filtered_tensors, dim)
  403. # optimization, avoid concat for single, repeated input
  404. if len(filtered_tensors) > 1 and all(
  405. t is filtered_tensors[0] for t in filtered_tensors
  406. ):
  407. inp = filtered_tensors[0]
  408. shape = list(inp.shape)
  409. dim = dim + len(inp.shape) if dim < 0 else dim
  410. shape.insert(dim, len(filtered_tensors))
  411. return inp.unsqueeze(dim).expand(*shape).flatten(dim, dim + 1).clone()
  412. # when no 'filtering' has occurred, we raise to prevent infinite recursion (no more decomposition needed)
  413. return NotImplemented
  414. @register_decomposition([aten.angle])
  415. def angle(x: torch.Tensor) -> torch.Tensor:
  416. if x.is_complex():
  417. return torch.where(
  418. torch.isnan(x.real), float("nan"), torch.atan2(x.imag, x.real)
  419. )
  420. # when x is real number
  421. # if x >= 0, return 0
  422. # if x < 0, return pi
  423. # if x is nan, return nan
  424. _, dtype = elementwise_dtypes(
  425. x,
  426. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  427. )
  428. pi = torch.scalar_tensor(math.pi, dtype=dtype, device=x.device)
  429. ret = torch.where(x < 0, pi, 0.0)
  430. return torch.where(torch.isnan(x), float("nan"), ret)
  431. @register_decomposition([aten.add])
  432. def add(
  433. x: torch.Tensor,
  434. y: torch.Tensor,
  435. *,
  436. alpha: Optional[torch.types.Number] = None,
  437. ) -> torch.Tensor:
  438. # Require both x and y to be complex tensors.
  439. x_is_complex_tensor = torch.is_tensor(x) and x.is_complex()
  440. y_is_complex_tensor = torch.is_tensor(y) and y.is_complex()
  441. if not x_is_complex_tensor or not y_is_complex_tensor:
  442. return NotImplemented
  443. def _requires_fallback(tensor: torch.Tensor) -> bool:
  444. if tensor.ndim == 0:
  445. return False
  446. # Viewing complex tensors as their real dtype requires the last stride to be 1.
  447. return tensor.stride()[-1] != 1
  448. output_size_zero = False
  449. if x.ndim == 0 and y.ndim == 0:
  450. output_size_zero = True
  451. if x.ndim == 0:
  452. x = x.reshape(1)
  453. if y.ndim == 0:
  454. y = y.reshape(1)
  455. z = y
  456. if alpha is not None:
  457. z = alpha * y
  458. complex_type = torch.promote_types(x.dtype, y.dtype)
  459. if _requires_fallback(x) or _requires_fallback(z):
  460. return NotImplemented
  461. # For complex typed `x`, `x.view(x.real.dtype)` doubles the last dimension and can cause problem
  462. # when broadcasting the add.
  463. def reshape_tensor_complex(tensor: torch.Tensor) -> torch.Tensor:
  464. """Reshape tensor from [*initial_dims, last_dim] to *initial_dims, last_dim/2, 2]"""
  465. # Get the current shape of the tensor
  466. *initial_dims, last_dim = tensor.shape
  467. # Check if the last dimension is even. We should never reach here since `x.view(x.real.dtype)`
  468. # doubles the last dimension for complex numbers.
  469. if last_dim % 2 != 0:
  470. raise AssertionError(
  471. "The size of the last dimension must be even to reshape it to [..., last_dim/2, 2]"
  472. )
  473. # Reshape the tensor
  474. new_shape = (*initial_dims, last_dim // 2, 2)
  475. reshaped_tensor = tensor.view(new_shape)
  476. return reshaped_tensor
  477. # Manually resolve complex tensors, as .is_conj() is unreliable after cloning during compilation.
  478. x = x + 0
  479. z = z + 0
  480. x_reshaped = reshape_tensor_complex(x.view(x.real.dtype))
  481. z_reshaped = reshape_tensor_complex(z.view(y.real.dtype))
  482. result = torch.flatten(x_reshaped + z_reshaped, start_dim=-2).view(complex_type)
  483. if output_size_zero:
  484. return result[0]
  485. return result
  486. @register_decomposition([aten.conj_physical])
  487. def conj_physical(self: torch.Tensor) -> torch.Tensor:
  488. if self.is_complex():
  489. return NotImplemented
  490. return self
  491. @register_decomposition([aten.lift, aten.detach_])
  492. def lift(self: torch.Tensor) -> torch.Tensor:
  493. return self
  494. @register_decomposition([aten.fmin, prims.fmin])
  495. def fmin(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
  496. return torch.where(torch.isnan(other) | (other > self), self, other)
  497. @register_decomposition([aten.fmax, prims.fmax])
  498. def fmax(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
  499. return torch.where(torch.isnan(other) | (other < self), self, other)
  500. @register_decomposition(aten.amax)
  501. def amax(
  502. self: torch.Tensor,
  503. dim: Optional[int] = None,
  504. keepdim: bool = False,
  505. ) -> torch.Tensor:
  506. if self.dtype == torch.bool:
  507. return torch.any(self, dim=dim, keepdim=keepdim)
  508. return NotImplemented
  509. @register_decomposition(aten.amin)
  510. def amin(
  511. self: torch.Tensor,
  512. dim: Optional[int] = None,
  513. keepdim: bool = False,
  514. ) -> torch.Tensor:
  515. if self.dtype == torch.bool:
  516. return torch.all(self, dim=dim, keepdim=keepdim)
  517. return NotImplemented
  518. @register_decomposition([aten.narrow_copy])
  519. def narrow_copy(
  520. self: torch.Tensor,
  521. dim: int,
  522. start: int,
  523. length: int,
  524. ) -> torch.Tensor:
  525. # Use memory_format=torch.contiguous_format to ensure correct strides.
  526. # For empty tensors, a plain clone() preserves the input view's strides.
  527. return torch.narrow(self, dim, start, length).clone(
  528. memory_format=torch.contiguous_format
  529. )
  530. @register_decomposition([aten.view_copy.default])
  531. def view_copy_default(
  532. self: torch.Tensor,
  533. size: list[Union[int, torch.SymInt]],
  534. ) -> torch.Tensor:
  535. return aten.view(self, size).clone()
  536. @register_decomposition([aten.view_copy.dtype])
  537. def view_copy_dtype(
  538. self: torch.Tensor,
  539. dtype: torch.dtype,
  540. ) -> torch.Tensor:
  541. return self.clone().view(dtype)
  542. def _get_shape_permutation_like(
  543. self: torch.Tensor,
  544. ) -> tuple[utils.ShapeType, utils.StrideType]:
  545. physical_layout, _ = utils.compute_elementwise_output_logical_to_physical_perm(self)
  546. shape = [self.shape[l] for l in physical_layout]
  547. permutation = [0] * len(shape)
  548. for p, l in enumerate(physical_layout):
  549. permutation[l] = p
  550. return (shape, permutation)
  551. @register_decomposition(aten.full_like)
  552. def full_like(
  553. self: torch.Tensor,
  554. fill_value: Union[int, float],
  555. *,
  556. dtype: Optional[torch.dtype] = None,
  557. layout: Optional[torch.layout] = None,
  558. device: Optional[torch.device] = None,
  559. pin_memory: bool = False,
  560. requires_grad: bool = False,
  561. memory_format: torch.memory_format = torch.preserve_format,
  562. ) -> torch.Tensor:
  563. dtype = self.dtype if dtype is None else dtype
  564. layout = self.layout if layout is None else layout
  565. device = self.device if device is None else device
  566. if memory_format != torch.preserve_format:
  567. result = torch.full(
  568. self.shape,
  569. fill_value,
  570. dtype=dtype,
  571. layout=layout,
  572. device=device,
  573. pin_memory=pin_memory,
  574. requires_grad=requires_grad,
  575. )
  576. return result.to(memory_format=memory_format)
  577. else:
  578. assert layout == torch.strided
  579. shape, permutation = _get_shape_permutation_like(self)
  580. result = torch.full(
  581. shape,
  582. fill_value,
  583. dtype=dtype,
  584. layout=layout,
  585. device=device,
  586. pin_memory=pin_memory,
  587. requires_grad=requires_grad,
  588. )
  589. if permutation == list(range(len(permutation))):
  590. return result
  591. return result.permute(permutation).clone()
  592. def _rand_like(
  593. rand_fn: Callable[..., torch.Tensor],
  594. self: torch.Tensor,
  595. *,
  596. dtype: Optional[torch.dtype] = None,
  597. device: Optional[torch.device] = None,
  598. memory_format: torch.memory_format = torch.preserve_format,
  599. **kwargs: Any,
  600. ) -> torch.Tensor:
  601. dtype = self.dtype if dtype is None else dtype
  602. device = self.device if device is None else device
  603. if memory_format != torch.preserve_format:
  604. return rand_fn(
  605. self.shape,
  606. dtype=dtype,
  607. device=device,
  608. **kwargs,
  609. ).to(memory_format=memory_format)
  610. shape, permutation = _get_shape_permutation_like(self)
  611. result = rand_fn(
  612. shape,
  613. dtype=dtype,
  614. device=device,
  615. **kwargs,
  616. )
  617. if permutation == list(range(len(permutation))):
  618. return result
  619. return result.permute(permutation).clone()
  620. @register_decomposition(aten.rand_like)
  621. def rand_like(self: torch.Tensor, **kwargs: Any) -> torch.Tensor:
  622. return _rand_like(torch.rand, self, **kwargs)
  623. @register_decomposition(aten.randn_like)
  624. def randn_like(self: torch.Tensor, **kwargs: Any) -> torch.Tensor:
  625. return _rand_like(torch.randn, self, **kwargs)
  626. @register_decomposition(aten.randint_like.default)
  627. def randint_like(self: torch.Tensor, high: int, **kwargs: Any) -> torch.Tensor:
  628. return _rand_like(functools.partial(aten.randint.low, 0, high), self, **kwargs)
  629. @register_decomposition(aten.randint_like.low_dtype)
  630. def randint_like_low(
  631. self: torch.Tensor, low: int, high: int, **kwargs: Any
  632. ) -> torch.Tensor:
  633. return _rand_like(functools.partial(aten.randint.low, low, high), self, **kwargs)
  634. @register_decomposition(aten.randint.default)
  635. def randint(
  636. high: int,
  637. size: list[Union[int, torch.SymInt]],
  638. **kwargs: Any,
  639. ) -> torch.Tensor:
  640. return aten.randint.low(0, high, size, **kwargs)
  641. @register_decomposition(quantized.linear_dynamic_fp16_unpacked_weight.default)
  642. def linear_dynamic_fp16_unpacked_weight(
  643. input: torch.Tensor,
  644. weight: torch.Tensor,
  645. bias: Optional[torch.Tensor] = None,
  646. ) -> torch.Tensor:
  647. packed_weight = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16(weight)
  648. return torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight(
  649. input, packed_weight, bias, weight.size()[0]
  650. )
  651. @register_decomposition(_quantized.wrapped_quantized_linear.default)
  652. def wrapped_quantized_linear(
  653. input: torch.Tensor,
  654. input_scale: torch.Tensor,
  655. input_zero_point: torch.Tensor,
  656. weight: torch.Tensor,
  657. weight_scale: torch.Tensor,
  658. weight_zero_point: torch.Tensor,
  659. bias: torch.Tensor,
  660. out_scale: torch.Tensor,
  661. out_zero_point: torch.Tensor,
  662. out_channel: int,
  663. ) -> torch.Tensor:
  664. packed_weight = torch.ops._quantized._wrapped_linear_prepack(
  665. weight, weight_scale, weight_zero_point, bias
  666. )
  667. return torch.ops._quantized._wrapped_quantized_linear_prepacked(
  668. input,
  669. input_scale,
  670. input_zero_point,
  671. packed_weight,
  672. out_scale,
  673. out_zero_point,
  674. out_channel,
  675. )
  676. @register_decomposition(torch.ops.quantized.embedding_bag_byte_unpack)
  677. def q_embedding_bag_byte_unpack_decomp(packed: torch.Tensor) -> torch.Tensor:
  678. def bitcast_u8_to_f32(u8: torch.Tensor) -> torch.Tensor:
  679. x, y, z, w = (u8[..., n].to(torch.int32) for n in (0, 1, 2, 3))
  680. if sys.byteorder == "little":
  681. return (x + (y << 8) + (z << 16) + (w << 24)).view(torch.float32)[..., None]
  682. else:
  683. return ((x << 24) + (y << 16) + (z << 8) + w).view(torch.float32)[..., None]
  684. scales = bitcast_u8_to_f32(packed[..., -8:-4])
  685. offsets = bitcast_u8_to_f32(packed[..., -4:])
  686. return packed[..., :-8].to(torch.float32) * scales + offsets
  687. @register_decomposition([aten.grid_sampler_2d])
  688. @pw_cast_for_opmath
  689. def grid_sampler_2d(
  690. a: torch.Tensor,
  691. grid: torch.Tensor,
  692. interpolation_mode: int = 0,
  693. padding_mode: int = 0,
  694. align_corners: bool = False,
  695. ) -> torch.Tensor:
  696. # We do not expand the grid (_expand_grid=False) on cpu for performance reasons
  697. # Experimenting locally it was found that compiled CUDA code is accelerated by ~5x
  698. # and CPU code by ~2x on bicubic mode, if we expand the grid from (N, H, W, 2) into (N, C, H, W, 2)
  699. # However, this leads to a slowdown around ~0.8x on CPU bilinear mode, channels first.
  700. # Thus we apply this hack to not expand the grid for this case.
  701. _expand_grid = not (
  702. a.device == torch.device("cpu")
  703. and interpolation_mode == 0
  704. and a.is_contiguous(memory_format=torch.contiguous_format)
  705. )
  706. output = decomp_grid_sampler_2d(
  707. a,
  708. grid=grid,
  709. interpolation_mode=interpolation_mode,
  710. padding_mode=padding_mode,
  711. align_corners=align_corners,
  712. _expand_grid=_expand_grid,
  713. )
  714. return output
  715. # _foreach_addcmul.Scalar decomposition - uses mul+add instead of FMA
  716. # When emulate_precision_casts is enabled, we skip this decomposition
  717. # and use the inductor lowering which preserves FMA semantics
  718. @register_decomposition(aten._foreach_addcmul.Scalar)
  719. def _foreach_addcmul_scalar(
  720. self: list[torch.Tensor],
  721. left_tensors: list[torch.Tensor],
  722. right_tensors: list[torch.Tensor],
  723. scalar: float = 1,
  724. ) -> list[torch.Tensor]:
  725. return aten._foreach_add.List(
  726. self, aten._foreach_mul.List(left_tensors, right_tensors), alpha=scalar
  727. )
  728. @register_decomposition(aten._foreach_addcdiv.Scalar)
  729. def _foreach_addcdiv_scalar(
  730. self: list[torch.Tensor],
  731. left_tensors: list[torch.Tensor],
  732. right_tensors: list[torch.Tensor],
  733. scalar: float = 1,
  734. ) -> list[torch.Tensor]:
  735. return aten._foreach_add.List(
  736. self, aten._foreach_div.List(left_tensors, right_tensors), alpha=scalar
  737. )
  738. @register_decomposition(aten._foreach_lerp.Scalar)
  739. def _foreach_lerp_scalar(
  740. start_tensors: list[torch.Tensor],
  741. end_tensors: list[torch.Tensor],
  742. weight: torch.types.Number,
  743. ) -> list[torch.Tensor]:
  744. return aten._foreach_add.List(
  745. start_tensors,
  746. aten._foreach_mul.Scalar(
  747. aten._foreach_sub.List(end_tensors, start_tensors), weight
  748. ),
  749. )
  750. @register_decomposition(aten._foreach_lerp.ScalarList)
  751. def _foreach_lerp_scalarlist(
  752. start_tensors: list[torch.Tensor],
  753. end_tensors: list[torch.Tensor],
  754. scalars: list[torch.types.Number],
  755. ) -> list[torch.Tensor]:
  756. return aten._foreach_add.List(
  757. start_tensors,
  758. aten._foreach_mul.ScalarList(
  759. aten._foreach_sub.List(end_tensors, start_tensors), scalars
  760. ),
  761. )
  762. @aten.miopen_batch_norm.default.py_impl(torch._C.DispatchKey.Autograd)
  763. @register_decomposition(aten.miopen_batch_norm)
  764. def miopen_batch_norm(
  765. input: torch.Tensor,
  766. weight: torch.Tensor,
  767. bias: typing.Optional[torch.Tensor],
  768. running_mean: typing.Optional[torch.Tensor],
  769. running_var: typing.Optional[torch.Tensor],
  770. training: bool,
  771. exponential_average_factor: float,
  772. epsilon: float,
  773. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  774. a, b, c = aten.native_batch_norm(
  775. input,
  776. weight,
  777. bias,
  778. running_mean,
  779. running_var,
  780. training,
  781. exponential_average_factor,
  782. epsilon,
  783. )
  784. if training:
  785. return (a, b, c)
  786. return (
  787. a,
  788. weight.new_zeros((0,)),
  789. weight.new_zeros((0,)),
  790. )
  791. @functools.cache
  792. def fast_random_decomps() -> dict[Any, Callable[..., Any]]:
  793. return {**decompositions, **extra_random_decomps}
  794. # TODO(aakhundov): replace this (and the above) Any by more
  795. # specific type and fix all the cascading mypy errors
  796. def select_decomp_table() -> dict[Any, Callable[..., Any]]:
  797. """decomps can change based on config"""
  798. if config.fallback_random:
  799. return decompositions
  800. if config.fallback_embedding_bag_byte_unpack:
  801. # remove q_embedding_bag_byte_unpack_decomp from decompositions
  802. decompositions.pop(torch.ops.quantized.embedding_bag_byte_unpack.default, None)
  803. return decompositions
  804. result = fast_random_decomps()
  805. if config.emulate_precision_casts:
  806. # When emulating precision casts, skip decomposition of addcmul ops
  807. # so that we use the inductor lowering which preserves FMA semantics.
  808. # For _foreach_addcdiv, we use the native CUDA kernel.
  809. # The decomposed version uses separate mul+add/div+add ops which don't match
  810. # eager's FMA rounding behavior.
  811. # Note: We check against OpOverloadPacket to match all overloads (default, out, etc.)
  812. ops_to_skip = OrderedSet(
  813. [
  814. aten.addcmul,
  815. aten._foreach_addcmul.Scalar,
  816. aten._foreach_addcdiv.Scalar,
  817. ]
  818. )
  819. def should_skip(op: Any) -> bool:
  820. # Check if op is directly in the skip set
  821. if op in ops_to_skip:
  822. return True
  823. # For OpOverload, also check if its OpOverloadPacket is in the skip set
  824. if hasattr(op, "overloadpacket"):
  825. return op.overloadpacket in ops_to_skip
  826. return False
  827. result = {k: v for k, v in result.items() if not should_skip(k)}
  828. return result
  829. @register_decomposition(aten.masked_scatter)
  830. def masked_scatter(
  831. self: torch.Tensor,
  832. mask: torch.Tensor,
  833. source: torch.Tensor,
  834. ) -> torch.Tensor:
  835. from .codegen.common import BackendFeature, has_backend_feature
  836. if has_backend_feature(self.device, BackendFeature.MASKED_SCATTER_WITH_INDEX):
  837. # This two-step algorithm is the same as eager CUDA, for eager CPU we
  838. # use a 1-shot serial iteration.
  839. self, mask = aten.broadcast_tensors([self, mask])
  840. source_idx = mask.reshape(-1).cumsum(0) - 1
  841. self_flat, mask_flat, source_flat = (x.flatten() for x in (self, mask, source))
  842. result = aten._unsafe_masked_index(source_flat, mask_flat, [source_idx], 0)
  843. return torch.where(mask_flat, result, self_flat).view(self.shape)
  844. return NotImplemented
  845. @register_decomposition(quantized_decomposed.choose_qparams.tensor)
  846. def choose_qparams_tensor(
  847. input: torch.Tensor,
  848. quant_min: int,
  849. quant_max: int,
  850. eps: float,
  851. dtype: torch.dtype,
  852. ) -> tuple[torch.Tensor, torch.Tensor]:
  853. min_val, max_val = torch.aminmax(input)
  854. scale = (max_val - min_val) / float(quant_max - quant_min)
  855. scale = torch.max(scale, torch.Tensor([eps]))
  856. zero_point = quant_min - torch.round(min_val / scale).to(torch.int)
  857. zero_point = torch.clamp(zero_point, quant_min, quant_max)
  858. return scale.to(torch.float64), zero_point.to(torch.int64)
  859. @register_decomposition(aten.put)
  860. def put(
  861. self: torch.Tensor,
  862. index: torch.Tensor,
  863. source: torch.Tensor,
  864. accumulate: bool = False,
  865. ) -> torch.Tensor:
  866. flattened = self.flatten()
  867. flattened = torch.index_put(
  868. flattened, [index], source.reshape(index.shape), accumulate
  869. )
  870. return flattened.reshape(self.shape)
  871. @register_decomposition(aten.put_)
  872. def put_(
  873. self: torch.Tensor,
  874. index: torch.Tensor,
  875. source: torch.Tensor,
  876. accumulate: bool = False,
  877. ) -> torch.Tensor:
  878. out = aten.put(self, index, source, accumulate=accumulate)
  879. return self.copy_(out)
  880. @register_decomposition(aten._softmax_backward_data.default)
  881. @pw_cast_for_opmath
  882. def _softmax_backward_data(
  883. grad_output: torch.Tensor,
  884. output: torch.Tensor,
  885. dim: int,
  886. input_dtype: torch.dtype,
  887. ) -> torch.Tensor:
  888. new_grad_output = grad_output * output
  889. sum_new_grad = torch.sum(new_grad_output, dim=dim, keepdim=True)
  890. # grad_input = new_grad_output - output * sum_new_grad
  891. grad_input = inductor_prims.fma(-output, sum_new_grad, new_grad_output)
  892. # CPU kernel doesn't respect input_dtype, but following check doesn't work for meta tensor
  893. # if grad_output.device == torch.device("cpu"):
  894. # return grad_input.contiguous()
  895. if grad_output.dtype != input_dtype:
  896. grad_input = grad_input.to(input_dtype)
  897. return grad_input.contiguous()
  898. @register_decomposition(aten.index_reduce)
  899. def index_reduce(
  900. self: torch.Tensor,
  901. dim: int,
  902. index: torch.Tensor,
  903. src: torch.Tensor,
  904. reduction_type: str,
  905. *,
  906. include_self: bool = True,
  907. ) -> torch.Tensor:
  908. if reduction_type == "mean" and not needs_fallback_due_to_atomic_add_limitations(
  909. self.dtype
  910. ):
  911. true_division = self.dtype.is_floating_point or self.dtype.is_complex
  912. ones = torch.ones_like(src)
  913. if include_self:
  914. out = self
  915. counts = torch.ones_like(self).index_add(dim, index, ones)
  916. else:
  917. out = self.index_fill(dim, index, 0)
  918. counts = torch.zeros_like(self).index_add(dim, index, ones)
  919. counts = counts.masked_fill(counts < 1, 1)
  920. out = out.index_add(dim, index, src)
  921. return out / counts if true_division else out // counts
  922. if use_scatter_fallback(
  923. aten.scatter_reduce_.two,
  924. reduction_type,
  925. self.dtype,
  926. src.dtype,
  927. src.device.type,
  928. True,
  929. ):
  930. return NotImplemented
  931. # pyrefly: ignore [missing-attribute]
  932. repeats = self.shape[dim + 1 :].numel() * self.shape[:dim].numel()
  933. index_shape = (index.numel(), *self.shape[dim + 1 :], *self.shape[:dim])
  934. perm = (*range(self.ndim - dim, self.ndim), 0, *range(1, self.ndim - dim))
  935. scatter_index = (
  936. index.to(torch.int64)
  937. .repeat_interleave(repeats)
  938. .reshape(index_shape)
  939. .permute(perm)
  940. )
  941. return self.scatter_reduce(
  942. dim,
  943. scatter_index,
  944. src,
  945. reduction_type,
  946. include_self=include_self,
  947. )
  948. def _max_pool_with_indices(
  949. x: torch.Tensor,
  950. kernel_size: list[int],
  951. stride: Optional[Union[int, list[int]]],
  952. padding: Union[int, list[int]],
  953. dilation: Union[int, list[int]],
  954. ceil_mode: bool,
  955. dim: int,
  956. ) -> tuple[torch.Tensor, torch.Tensor]:
  957. if dilation == 1:
  958. dilation = [1] * dim
  959. if padding == 0:
  960. padding = [0] * dim
  961. if not stride:
  962. stride = kernel_size
  963. # pyrefly: ignore [bad-assignment]
  964. kernel_size = pad_listlike(kernel_size, dim)
  965. # pyrefly: ignore [bad-assignment]
  966. dilation = pad_listlike(dilation, dim)
  967. # pyrefly: ignore [bad-assignment]
  968. padding = pad_listlike(padding, dim)
  969. # pyrefly: ignore [bad-assignment]
  970. stride = pad_listlike(stride, dim)
  971. window_size = functools.reduce(operator.mul, kernel_size)
  972. # We fallback when using non-default dilation or when the window size is too large
  973. if (
  974. torch._inductor.lowering.should_fallback_max_pool_with_indices(
  975. kernel_size, n_dim=dim
  976. )
  977. or window_size > torch.iinfo(torch.int8).max
  978. ):
  979. return NotImplemented
  980. vals, offsets = prims._low_memory_max_pool_with_offsets(
  981. x,
  982. kernel_size,
  983. stride,
  984. padding,
  985. dilation,
  986. ceil_mode,
  987. )
  988. indices = prims._low_memory_max_pool_offsets_to_indices(
  989. offsets,
  990. kernel_size,
  991. x.shape[-dim:],
  992. stride,
  993. padding,
  994. dilation,
  995. )
  996. return vals, indices
  997. @register_decomposition(aten.max_pool2d_with_indices)
  998. def max_pool2d_with_indices(
  999. x: torch.Tensor,
  1000. kernel_size: list[int],
  1001. stride: Optional[Union[int, list[int]]] = None,
  1002. padding: Union[int, list[int]] = 0,
  1003. dilation: Union[int, list[int]] = 1,
  1004. ceil_mode: bool = False,
  1005. ) -> tuple[torch.Tensor, torch.Tensor]:
  1006. return _max_pool_with_indices(
  1007. x, kernel_size, stride, padding, dilation, ceil_mode, dim=2
  1008. )
  1009. @register_decomposition(aten.max_pool3d_with_indices)
  1010. def max_pool3d_with_indices(
  1011. x: torch.Tensor,
  1012. kernel_size: list[int],
  1013. stride: Optional[Union[int, list[int]]] = None,
  1014. padding: Union[int, list[int]] = 0,
  1015. dilation: Union[int, list[int]] = 1,
  1016. ceil_mode: bool = False,
  1017. ) -> tuple[torch.Tensor, torch.Tensor]:
  1018. return _max_pool_with_indices(
  1019. x, kernel_size, stride, padding, dilation, ceil_mode, dim=3
  1020. )
  1021. @register_decomposition(aten.adaptive_max_pool2d)
  1022. def adaptive_max_pool2d(
  1023. x: torch.Tensor, output_size: list[int]
  1024. ) -> tuple[torch.Tensor, torch.Tensor]:
  1025. *batch, h_in, w_in = x.shape
  1026. h_out, w_out = output_size
  1027. if h_out == 0 or w_out == 0:
  1028. o_size = [*batch, h_out, w_out]
  1029. return x.new_empty(o_size), x.new_empty(o_size, dtype=torch.int64)
  1030. if h_in % h_out == 0 and w_in % w_out == 0:
  1031. kernel_size = [h_in // h_out, w_in // w_out]
  1032. return aten.max_pool2d_with_indices(x, kernel_size)
  1033. return NotImplemented
  1034. @register_decomposition(aten.searchsorted.Scalar)
  1035. def searchsorted_scalar(
  1036. sorted_sequence: torch.Tensor,
  1037. self: torch.types.Number,
  1038. *,
  1039. out_int32: bool = False,
  1040. right: bool = False,
  1041. side: Optional[str] = None,
  1042. sorter: Optional[torch.Tensor] = None,
  1043. ) -> torch.Tensor:
  1044. return aten.searchsorted(
  1045. sorted_sequence,
  1046. torch.tensor([self], device=sorted_sequence.device),
  1047. out_int32=out_int32,
  1048. right=right,
  1049. side=side,
  1050. sorter=sorter,
  1051. )[0]
  1052. @register_decomposition(aten.bucketize.Scalar)
  1053. def bucketize_scalar(
  1054. self: torch.types.Number,
  1055. boundaries: torch.Tensor,
  1056. *,
  1057. out_int32: bool = False,
  1058. right: bool = False,
  1059. ) -> torch.Tensor:
  1060. return aten.bucketize(
  1061. torch.tensor([self], device=boundaries.device),
  1062. boundaries,
  1063. out_int32=out_int32,
  1064. right=right,
  1065. ).squeeze(0)
  1066. @register_decomposition(aten.rrelu_with_noise_functional)
  1067. def rrelu_with_noise_functional(
  1068. self: torch.Tensor,
  1069. noise: torch.Tensor,
  1070. lower: float = 0.125,
  1071. upper: float = 0.3333333333333333,
  1072. training: bool = False,
  1073. generator: Optional[torch.Generator] = None,
  1074. ) -> tuple[torch.Tensor, torch.Tensor]:
  1075. if training:
  1076. not_positive = self <= 0
  1077. r = aten.uniform(self, lower, upper, generator=generator)
  1078. output = torch.where(not_positive, self * r, self)
  1079. noise_out = torch.where(not_positive, r, 1)
  1080. return output, noise_out
  1081. else:
  1082. negative_slope = (lower + upper) / 2
  1083. return aten.leaky_relu(self, negative_slope), torch.Tensor()
  1084. @register_decomposition(aten.repeat_interleave.Tensor)
  1085. def repeat_interleave_Tensor(
  1086. repeat: torch.Tensor,
  1087. output_size: Optional[int] = None,
  1088. ) -> torch.Tensor:
  1089. if config.triton.autotune_at_compile_time:
  1090. # We can't compile-time auto-tune this because
  1091. # it expects specific data in `repeat`
  1092. return NotImplemented
  1093. if output_size is None or type(output_size) is not int:
  1094. return NotImplemented
  1095. if repeat.device.type == "mps":
  1096. return NotImplemented
  1097. assert repeat.dtype in [torch.int32, torch.int64]
  1098. assert repeat.ndim == 1
  1099. cumsum = repeat.cumsum(0)
  1100. pos = torch.arange(output_size, device=repeat.device)
  1101. indices = torch.searchsorted(
  1102. cumsum, pos, out_int32=(repeat.dtype == torch.int32), right=True
  1103. )
  1104. return torch.clamp(indices, max=repeat.size(0) - 1)
  1105. # intentionally not regiestered
  1106. def conv1d_to_conv2d(
  1107. input: torch.Tensor,
  1108. weight: torch.Tensor,
  1109. bias: Optional[torch.Tensor] = None,
  1110. stride: tuple[int] = (1,),
  1111. padding: tuple[int] = (0,),
  1112. dilation: tuple[int] = (1,),
  1113. groups: int = 1,
  1114. ) -> torch.Tensor:
  1115. # Shapes:
  1116. # input: (N, C_in, L_in)
  1117. # weight: (C_out, C_in // groups, K)
  1118. # bias: (C_out,)
  1119. assert input.dim() == 3 and weight.dim() == 3, (
  1120. "Expect (N,C_in,L) and (C_out,C_in//groups,K)"
  1121. )
  1122. # pyrefly: ignore [bad-assignment]
  1123. stride = stride[0]
  1124. # pyrefly: ignore [bad-assignment]
  1125. padding = padding[0]
  1126. # pyrefly: ignore [bad-assignment]
  1127. dilation = dilation[0]
  1128. # Unsqueeze to make input 2D: (N,C,L) -> (N,C,L,1)
  1129. input_2d = input.unsqueeze(-1)
  1130. # Unsqueeze kernel: (C_out,C_in/groups,K) -> (C_out,C_in/groups,K,1)
  1131. weight_2d = weight.unsqueeze(-1)
  1132. # Call conv2d with adjusted args
  1133. out_2d = aten.conv2d.default(
  1134. input_2d,
  1135. weight_2d,
  1136. bias,
  1137. stride=(stride, 1),
  1138. padding=(padding, 0),
  1139. dilation=(dilation, 1),
  1140. groups=groups,
  1141. )
  1142. # Squeeze dummy dimension back out: (N,C_out,L_out,1) -> (N,C_out,L_out)
  1143. return out_2d.squeeze(-1)