_impl.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547
  1. """Implementations of ONNX operators as native Torch ops.
  2. NOTE: Fake implementations:
  3. Refer to https://docs.pytorch.org/docs/stable/library.html#torch.library.register_fake
  4. for more details on how to create fake kernels.
  5. """
  6. # flake8: noqa: B950
  7. import math
  8. from collections.abc import Callable
  9. from typing import Optional, TypeVar
  10. from typing_extensions import ParamSpec
  11. import torch
  12. from torch.onnx.ops import _dtype_mappings
  13. # Use ParamSpec for better type preservation instead of bound Callable TypeVar
  14. _P = ParamSpec("_P")
  15. _R = TypeVar("_R")
  16. # ONNX to ATen decomp table
  17. ONNX_ATEN_DECOMP_TABLE: dict[torch._ops.OpOverload, Callable] = {}
  18. _ATTENTION_23_ALLOWED_INTERMEDIATE_PRECISIONS = frozenset(
  19. {
  20. 1, # FLOAT
  21. 10, # FLOAT16
  22. 11, # DOUBLE
  23. 16, # BFLOAT16
  24. }
  25. )
  26. def _onnx_op(
  27. op_type: str, opset_version: int, fake_impl: Callable[_P, _R]
  28. ) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
  29. """Decorator to register an ONNX operator with a custom implementation."""
  30. def decorator(func: Callable[_P, _R]) -> Callable[_P, _R]:
  31. overload = f"opset{opset_version}"
  32. torch_op = torch.library.custom_op(
  33. f"onnx::{op_type}.{overload}", mutates_args=()
  34. )(func)
  35. ONNX_ATEN_DECOMP_TABLE[getattr(getattr(torch.ops.onnx, op_type), overload)] = (
  36. func # type: ignore[assignment]
  37. )
  38. torch_op.register_fake(fake_impl)
  39. return torch_op # type: ignore[return-value]
  40. return decorator
  41. def _rotary_embedding_23_fake_impl(
  42. x: torch.Tensor,
  43. cos_cache: torch.Tensor,
  44. sin_cache: torch.Tensor,
  45. position_ids: Optional[torch.Tensor] = None,
  46. *,
  47. interleaved: bool = False,
  48. num_heads: int = 0,
  49. rotary_embedding_dim: int = 0,
  50. ) -> torch.Tensor:
  51. """Fake implementation for RotaryEmbedding-23 for torch.compile purposes."""
  52. return x.clone()
  53. @_onnx_op("RotaryEmbedding", 23, _rotary_embedding_23_fake_impl)
  54. def rotary_embedding_23(
  55. x: torch.Tensor,
  56. cos_cache: torch.Tensor,
  57. sin_cache: torch.Tensor,
  58. position_ids: Optional[torch.Tensor] = None,
  59. *,
  60. interleaved: bool = False,
  61. num_heads: int = 0,
  62. rotary_embedding_dim: int = 0,
  63. ) -> torch.Tensor:
  64. """RotaryEmbedding-23 https://onnx.ai/onnx/operators/onnx__RotaryEmbedding.html#rotaryembedding-23"""
  65. # x has shape (batch_size, num_heads, sequence_length, head_size)
  66. # or (batch_size, sequence_length, hidden_size)
  67. input_shape = x.shape
  68. input_rank = len(input_shape)
  69. batch_size = input_shape[0]
  70. sequence_length = input_shape[-2]
  71. # Validate position_ids and caches match x
  72. if position_ids is not None:
  73. torch._check(
  74. position_ids.dim() == 2,
  75. lambda: f"position_ids must be 2D when provided. Received shape {position_ids.shape}",
  76. )
  77. torch._check(
  78. position_ids.shape[0] == batch_size,
  79. lambda: f"position_ids first dim (batch) must match x.shape[0] ({batch_size}). Received {position_ids.shape[0]}",
  80. )
  81. torch._check(
  82. position_ids.shape[1] == sequence_length,
  83. lambda: f"position_ids second dim (sequence) must match x.shape[-2] ({sequence_length}). Received {position_ids.shape[1]}",
  84. )
  85. torch._check(
  86. cos_cache.dim() == 2 and sin_cache.dim() == 2,
  87. lambda: "cos_cache/sin_cache must be 2D when position_ids is provided. "
  88. f"Received cos_cache shape {cos_cache.shape}, sin_cache shape {sin_cache.shape}",
  89. )
  90. else:
  91. torch._check(
  92. cos_cache.dim() == 3 and sin_cache.dim() == 3,
  93. lambda: "cos_cache/sin_cache must be 3D when position_ids is not provided. "
  94. f"Received cos_cache shape {cos_cache.shape}, sin_cache shape {sin_cache.shape}",
  95. )
  96. # First ensure x has shape [batch_size, num_heads, seq_len, head_size]
  97. # So that the rotation logic can be shared with reshaped 3D inputs
  98. if input_rank == 4:
  99. # Reshape from (batch_size, num_heads, seq_len, head_size)
  100. # to [batch_size, seq_len, num_heads, head_size]
  101. x = torch.permute(x, (0, 2, 1, 3))
  102. elif input_rank == 3:
  103. torch._check(
  104. num_heads != 0,
  105. lambda: f"num_heads must be provided for 3D inputs. Received input tensor with shape {input_shape}",
  106. )
  107. hidden_size = input_shape[2]
  108. head_size = hidden_size // num_heads
  109. new_shape = [batch_size, sequence_length, num_heads, head_size]
  110. x = torch.reshape(x, new_shape)
  111. torch._check(len(x.shape) == 4, lambda: "x should be a 4D tensor by now")
  112. head_size = x.shape[3]
  113. # Fully or partially perform rotation on x based on rotary_embedding_dim attribute
  114. if rotary_embedding_dim == 0:
  115. # If rotary_embedding_dim not provided, perform full rotation by using head_size
  116. rotary_embedding_dim = head_size
  117. x_rotate = x[:, :, :, :rotary_embedding_dim]
  118. x_not_rotate = x[:, :, :, rotary_embedding_dim:]
  119. rotary_embedding_dim_half = rotary_embedding_dim // 2
  120. # Retrieve sin and cos caches using position ids
  121. if position_ids is not None:
  122. cos = cos_cache[
  123. position_ids
  124. ] # Shape: [batch_size, sequence_length, head_size/2]
  125. sin = sin_cache[
  126. position_ids
  127. ] # Shape: [batch_size, sequence_length, head_size/2]
  128. else:
  129. cos = cos_cache # Shape: [batch_size, sequence_length, rotary_embedding_dim/2]
  130. sin = sin_cache # Shape: [batch_size, sequence_length, rotary_embedding_dim/2]
  131. torch._check(
  132. cos.shape[0] == batch_size and cos.shape[1] == sequence_length,
  133. lambda: f"cos has shape {cos.shape} but expected (batch={batch_size}, seq={sequence_length}, ...)",
  134. )
  135. torch._check(
  136. sin.shape[0] == batch_size and sin.shape[1] == sequence_length,
  137. lambda: f"sin has shape {sin.shape} but expected (batch={batch_size}, seq={sequence_length}, ...)",
  138. )
  139. torch._check(
  140. cos.shape[-1] == rotary_embedding_dim_half,
  141. lambda: f"Last dimension of cos cache ({cos.shape[-1]}) should match rotary_embedding_dim/2 ({rotary_embedding_dim_half}).",
  142. )
  143. torch._check(
  144. sin.shape[-1] == rotary_embedding_dim_half,
  145. lambda: f"Last dimension of sin cache ({sin.shape[-1]}) should match rotary_embedding_dim/2 ({rotary_embedding_dim_half}).",
  146. )
  147. cos = torch.unsqueeze(
  148. cos, 2
  149. ) # Shape: [batch_size, sequence_length, 1, rotary_embedding_dim/2]
  150. sin = torch.unsqueeze(
  151. sin, 2
  152. ) # Shape: [batch_size, sequence_length, 1, rotary_embedding_dim/2]
  153. # Either divide the x in halves or interleave (based on interleaved attribute)
  154. if interleaved:
  155. x1 = x_rotate[:, :, :, 0::2]
  156. x2 = x_rotate[:, :, :, 1::2]
  157. else:
  158. x1, x2 = torch.chunk(x_rotate, 2, dim=-1)
  159. # Calculate real and imaginary values
  160. real = cos * x1 - sin * x2
  161. imag = sin * x1 + cos * x2
  162. # Inserted rotated embeddings back to the original x
  163. if interleaved:
  164. # x_rotate[:, :, :, 0::2] = real
  165. # x_rotate[:, :, :, 1::2] = imag
  166. real = torch.unsqueeze(real, -1)
  167. imag = torch.unsqueeze(imag, -1)
  168. x_rotate_concat = torch.cat((real, imag), dim=-1)
  169. x_rotate = torch.reshape(x_rotate_concat, x_rotate.shape)
  170. else:
  171. x_rotate = torch.cat((real, imag), dim=-1)
  172. output = torch.cat((x_rotate, x_not_rotate), dim=-1)
  173. if input_rank == 3:
  174. return torch.reshape(output, input_shape)
  175. # Return the dimensions to the original order
  176. return torch.permute(output, (0, 2, 1, 3))
  177. def _get_scale_factor(scale: Optional[float], head_size: int) -> float:
  178. """Get the scale factor for attention computation."""
  179. return scale if scale is not None else (1.0 / math.sqrt(head_size))
  180. def _reshape_3d_to_4d(
  181. tensor: torch.Tensor, batch_size: int, num_heads: int
  182. ) -> torch.Tensor:
  183. """Reshape 3D tensor to 4D for multi-head attention."""
  184. sequence_length, hidden_size = tensor.shape[1], tensor.shape[2]
  185. head_size = hidden_size // num_heads
  186. return (
  187. tensor.view(batch_size, sequence_length, num_heads, head_size)
  188. .transpose(1, 2)
  189. .contiguous()
  190. )
  191. def _get_qk_output_for_aten_spda(
  192. Q: torch.Tensor,
  193. K: torch.Tensor,
  194. current_q_num_heads: int,
  195. current_kv_num_heads: int,
  196. scale: Optional[float],
  197. qk_matmul_output_mode: int,
  198. ) -> torch.Tensor:
  199. """Get QK output tensor based on the specified mode."""
  200. if qk_matmul_output_mode == 0:
  201. return _compute_qk_output_for_mode_0(
  202. Q, K, current_q_num_heads, current_kv_num_heads, scale
  203. )
  204. else:
  205. # For other modes, return a zero tensor with correct shape
  206. return torch.zeros_like(torch.matmul(Q, K.transpose(-2, -1)))
  207. def _validate_gqa_configuration(
  208. current_q_num_heads: int, current_kv_num_heads: int
  209. ) -> None:
  210. """Validate Group Query Attention configuration."""
  211. torch._check(
  212. current_q_num_heads % current_kv_num_heads == 0,
  213. lambda: f"q_num_heads ({current_q_num_heads}) must be divisible by kv_num_heads ({current_kv_num_heads}) for GQA",
  214. )
  215. def _compute_qk_output_for_mode_0(
  216. Q: torch.Tensor,
  217. K: torch.Tensor,
  218. current_q_num_heads: int,
  219. current_kv_num_heads: int,
  220. scale: Optional[float],
  221. ) -> torch.Tensor:
  222. """Helper function to compute QK output for qk_matmul_output_mode == 0."""
  223. # Handle GQA manually for QK output
  224. K_for_qk = K
  225. if current_q_num_heads != current_kv_num_heads:
  226. repeat_factor = current_q_num_heads // current_kv_num_heads
  227. K_for_qk = K.repeat_interleave(repeat_factor, dim=1)
  228. scale_factor = _get_scale_factor(scale, Q.shape[3])
  229. # Scale both Q and K by sqrt(scale_factor) for numerical stability
  230. sqrt_scale = math.sqrt(scale_factor)
  231. Q_scaled = Q * sqrt_scale
  232. K_scaled = K_for_qk * sqrt_scale
  233. return torch.matmul(Q_scaled, K_scaled.transpose(-2, -1))
  234. def _attention_23_fake_impl(
  235. Q: torch.Tensor,
  236. K: torch.Tensor,
  237. V: torch.Tensor,
  238. attn_mask: Optional[torch.Tensor] = None,
  239. past_key: Optional[torch.Tensor] = None,
  240. past_value: Optional[torch.Tensor] = None,
  241. *,
  242. is_causal: bool = False,
  243. kv_num_heads: int = 0,
  244. q_num_heads: int = 0,
  245. qk_matmul_output_mode: int = 0,
  246. scale: Optional[float] = None,
  247. softcap: float = 0.0,
  248. softmax_precision: Optional[int] = None,
  249. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  250. """Fake implementation for Attention-23 for torch.compile purposes."""
  251. batch_size = Q.shape[0]
  252. # Handle 3D vs 4D input shapes
  253. if len(Q.shape) == 3:
  254. # 3D input: (batch_size, sequence_length, hidden_size)
  255. q_sequence_length = Q.shape[1]
  256. output_shape = Q.shape # Same shape as Q for 3D output
  257. # For present_key and present_value, we need 4D shapes
  258. if past_key is not None:
  259. present_key_shape = (
  260. batch_size,
  261. kv_num_heads,
  262. past_key.shape[2] + K.shape[1], # Combined sequence length
  263. K.shape[2] // kv_num_heads, # head_size
  264. )
  265. else:
  266. present_key_shape = (
  267. batch_size,
  268. kv_num_heads,
  269. K.shape[1], # sequence_length
  270. K.shape[2] // kv_num_heads, # head_size
  271. )
  272. present_value_shape = present_key_shape # Same shape as present_key
  273. # QK output shape for 3D input (reshaped to 4D internally)
  274. qk_output_shape = (
  275. batch_size,
  276. q_num_heads,
  277. q_sequence_length,
  278. present_key_shape[2], # kv_sequence_length
  279. )
  280. else:
  281. # 4D input: (batch_size, num_heads, sequence_length, head_size)
  282. q_sequence_length = Q.shape[2]
  283. # Same shape as Q for 4D output
  284. output_shape = Q.shape # type: ignore[assignment]
  285. # Handle past key/value concatenation
  286. if past_key is not None:
  287. present_key_shape = (
  288. K.shape[0], # batch_size
  289. K.shape[1], # num_heads
  290. past_key.shape[2] + K.shape[2], # Combined sequence length
  291. K.shape[3], # head_size
  292. )
  293. else:
  294. present_key_shape = K.shape # type: ignore[assignment]
  295. present_value_shape = present_key_shape # Same shape as present_key
  296. # QK output shape
  297. qk_output_shape = (
  298. Q.shape[0], # batch_size
  299. Q.shape[1], # q_num_heads
  300. Q.shape[2], # q_sequence_length
  301. present_key_shape[2], # kv_sequence_length
  302. )
  303. # Create fake tensors with correct shapes and dtypes
  304. output = torch.empty(output_shape, dtype=Q.dtype, device=Q.device)
  305. present_key = torch.empty(present_key_shape, dtype=K.dtype, device=K.device)
  306. present_value = torch.empty(present_value_shape, dtype=V.dtype, device=V.device)
  307. qk_output = torch.empty(qk_output_shape, dtype=Q.dtype, device=Q.device)
  308. return output, present_key, present_value, qk_output
  309. @_onnx_op("Attention", 23, _attention_23_fake_impl)
  310. def attention_23(
  311. Q: torch.Tensor,
  312. K: torch.Tensor,
  313. V: torch.Tensor,
  314. attn_mask: Optional[torch.Tensor] = None,
  315. past_key: Optional[torch.Tensor] = None,
  316. past_value: Optional[torch.Tensor] = None,
  317. *,
  318. is_causal: bool = False,
  319. kv_num_heads: int = 0,
  320. q_num_heads: int = 0,
  321. qk_matmul_output_mode: int = 0,
  322. scale: Optional[float] = None,
  323. softcap: float = 0.0,
  324. softmax_precision: Optional[int] = None,
  325. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  326. """Attention-23 https://onnx.ai/onnx/operators/onnx__Attention.html#attention-23"""
  327. num_head_dim, sequence_dim, head_dim = 1, 2, 3
  328. # Store original input shape to determine output shape
  329. input_shape_len = len(Q.shape)
  330. batch_size = Q.shape[0]
  331. # Reshape 3D inputs to 4D format
  332. if len(Q.shape) == 3:
  333. torch._check(
  334. q_num_heads != 0 and kv_num_heads != 0,
  335. lambda: "q_num_heads and kv_num_heads must be provided for 3D inputs",
  336. )
  337. q_sequence_length = Q.shape[1]
  338. Q = _reshape_3d_to_4d(Q, batch_size, q_num_heads)
  339. K = _reshape_3d_to_4d(K, batch_size, kv_num_heads)
  340. V = _reshape_3d_to_4d(V, batch_size, kv_num_heads)
  341. torch._check(
  342. len(Q.shape) == 4 and len(K.shape) == 4 and len(V.shape) == 4,
  343. lambda: "Q, K, and V should be 4D tensors by now",
  344. )
  345. # Calculate scale factor if not provided
  346. q_head_size = Q.shape[head_dim]
  347. scale = _get_scale_factor(scale, q_head_size)
  348. # Handle past key/value caches
  349. present_key = (
  350. torch.cat([past_key, K], dim=sequence_dim)
  351. if past_key is not None
  352. else K.clone()
  353. )
  354. present_value = (
  355. torch.cat([past_value, V], dim=sequence_dim)
  356. if past_value is not None
  357. else V.clone()
  358. )
  359. # Update K and V to include past states
  360. K, V = present_key, present_value
  361. # Get current dimensions
  362. current_q_num_heads = Q.shape[num_head_dim]
  363. current_kv_num_heads = K.shape[num_head_dim]
  364. q_sequence_length = Q.shape[sequence_dim]
  365. kv_sequence_length = K.shape[sequence_dim]
  366. # Check if we can use the optimized scaled_dot_product_attention (most optimized)
  367. can_use_sdpa = (
  368. softcap == 0.0 # No softcap
  369. and qk_matmul_output_mode == 0 # Default QK output mode
  370. and softmax_precision is None # No custom softmax precision
  371. and (attn_mask is None or attn_mask.dtype == torch.bool)
  372. )
  373. _validate_gqa_configuration(current_q_num_heads, current_kv_num_heads)
  374. if can_use_sdpa:
  375. # Use PyTorch's optimized scaled_dot_product_attention
  376. output = torch.nn.functional.scaled_dot_product_attention(
  377. Q,
  378. K,
  379. V,
  380. attn_mask=attn_mask,
  381. dropout_p=0.0,
  382. is_causal=is_causal,
  383. scale=scale,
  384. enable_gqa=bool(
  385. current_q_num_heads != current_kv_num_heads
  386. ), # Ensure enable_gqa is not SymBool
  387. )
  388. qk_output = _get_qk_output_for_aten_spda(
  389. Q,
  390. K,
  391. current_q_num_heads,
  392. current_kv_num_heads,
  393. scale,
  394. qk_matmul_output_mode,
  395. )
  396. else:
  397. # Fallback to manual implementation for complex cases
  398. # Handle Group Query Attention (GQA) and Multi-Query Attention (MQA)
  399. if current_q_num_heads != current_kv_num_heads:
  400. repeat_factor = current_q_num_heads // current_kv_num_heads
  401. K = K.repeat_interleave(repeat_factor, dim=num_head_dim)
  402. V = V.repeat_interleave(repeat_factor, dim=num_head_dim)
  403. # Create attention bias
  404. attn_bias = torch.zeros(
  405. q_sequence_length, kv_sequence_length, dtype=Q.dtype, device=Q.device
  406. )
  407. # Apply causal masking
  408. if is_causal:
  409. torch._check(
  410. attn_mask is None, lambda: "Cannot use both is_causal and attn_mask"
  411. )
  412. causal_mask = torch.tril(
  413. torch.ones(
  414. q_sequence_length,
  415. kv_sequence_length,
  416. dtype=torch.bool,
  417. device=Q.device,
  418. )
  419. )
  420. attn_bias = attn_bias.masked_fill(~causal_mask, float("-inf"))
  421. # Apply attention mask
  422. if attn_mask is not None:
  423. if attn_mask.dtype == torch.bool:
  424. # Boolean mask: True means participate in attention
  425. attn_bias = attn_bias.masked_fill(~attn_mask, float("-inf"))
  426. else:
  427. # Float mask: added to attention scores
  428. attn_bias = attn_bias + attn_mask
  429. # Apply scaling factor
  430. scale_factor = _get_scale_factor(scale, Q.shape[3])
  431. # Scale both Q and K by sqrt(scale_factor) for numerical stability
  432. sqrt_scale = math.sqrt(scale_factor)
  433. Q_scaled = Q * sqrt_scale
  434. K_scaled = K * sqrt_scale
  435. # Compute Q @ K^T
  436. qk_matmul_output = torch.matmul(Q_scaled, K_scaled.transpose(-2, -1))
  437. # Initialize QK output based on mode
  438. qk_output = qk_matmul_output # Default case for mode 0
  439. # Add attention bias
  440. qk_with_bias = qk_matmul_output + attn_bias
  441. if qk_matmul_output_mode == 1:
  442. qk_output = qk_with_bias
  443. # Apply softcap if provided
  444. if softcap > 0.0:
  445. qk_with_bias = softcap * torch.tanh(qk_with_bias / softcap)
  446. if qk_matmul_output_mode == 2:
  447. qk_output = qk_with_bias
  448. # Apply softmax with optional precision casting
  449. if softmax_precision is not None:
  450. # Map ONNX data type to torch dtype
  451. if softmax_precision in _ATTENTION_23_ALLOWED_INTERMEDIATE_PRECISIONS:
  452. original_dtype = qk_with_bias.dtype
  453. qk_with_bias = qk_with_bias.to(
  454. _dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE[softmax_precision]
  455. )
  456. qk_softmax = torch.softmax(qk_with_bias, dim=-1)
  457. qk_softmax = qk_softmax.to(original_dtype)
  458. else:
  459. qk_softmax = torch.softmax(qk_with_bias, dim=-1)
  460. else:
  461. qk_softmax = torch.softmax(qk_with_bias, dim=-1)
  462. if qk_matmul_output_mode == 3:
  463. qk_output = qk_softmax
  464. # Compute attention output
  465. output = torch.matmul(qk_softmax, V)
  466. # Reshape output back to 3D if input was 3D
  467. if input_shape_len == 3:
  468. # output: (batch_size, q_num_heads, q_sequence_length, v_head_size) -> (batch_size, q_sequence_length, hidden_size)
  469. output = (
  470. output.transpose(1, 2).contiguous().view(batch_size, q_sequence_length, -1)
  471. )
  472. return output, present_key, present_value, qk_output