sdpa.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. from collections.abc import Sequence
  2. from inspect import getattr_static
  3. from typing import Any, TYPE_CHECKING, TypeGuard
  4. from torch._guards import Source
  5. from torch.backends.cuda import SDPAParams
  6. from torch.fx.proxy import Proxy
  7. from ..bytecode_transformation import create_call_function
  8. from ..exc import unimplemented
  9. from ..source import AttrSource
  10. from .base import VariableTracker
  11. if TYPE_CHECKING:
  12. from torch._dynamo.codegen import PyCodegen
  13. from torch._dynamo.symbolic_convert import InstructionTranslator
  14. PARAM_NAMES = [
  15. "query",
  16. "key",
  17. "value",
  18. "attn_mask",
  19. "dropout",
  20. "is_causal",
  21. "enable_gqa",
  22. ]
  23. class SDPAParamsVariable(VariableTracker):
  24. """Represents the c++ params struct for scaled dot product attention.
  25. This is a read-only container."""
  26. @staticmethod
  27. def create(
  28. tx: "InstructionTranslator", value: Any, source: Source
  29. ) -> VariableTracker:
  30. from .torch import TorchInGraphFunctionVariable
  31. params = [
  32. VariableTracker.build(tx, getattr(value, p), AttrSource(source, p))
  33. for p in PARAM_NAMES
  34. ]
  35. return TorchInGraphFunctionVariable(SDPAParams).call_function(tx, params, {})
  36. def __init__(
  37. self, proxy: Proxy, param_vars: Sequence[VariableTracker], **kwargs: Any
  38. ) -> None:
  39. self.proxy = proxy
  40. self.param_vars = param_vars
  41. super().__init__(**kwargs)
  42. def reconstruct(self, codegen: "PyCodegen") -> None:
  43. assert self.source is None
  44. assert self.param_vars is not None
  45. codegen.add_push_null(
  46. lambda: codegen.load_import_from("torch._C", "_SDPAParams")
  47. )
  48. codegen.foreach(self.param_vars)
  49. codegen.extend_output(create_call_function(len(self.param_vars), False))
  50. def as_proxy(self) -> Proxy:
  51. return self.proxy
  52. def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
  53. import torch._C
  54. from .builder import wrap_fx_proxy
  55. from .misc import GetAttrVariable
  56. try:
  57. getattr_static(torch._C._SDPAParams, name)
  58. except AttributeError:
  59. import torch._dynamo.graph_break_hints as graph_break_hints
  60. unimplemented(
  61. gb_type="unsupported torch._C._SDPAParams attribute",
  62. context=f"name: {name}",
  63. explanation=f"Unable to fetch attribute {name} from torch._C._SDPAParams.",
  64. hints=[
  65. *graph_break_hints.USER_ERROR,
  66. ],
  67. )
  68. proxy = GetAttrVariable.create_getattr_proxy(self.as_proxy(), name)
  69. if self.source is not None:
  70. return wrap_fx_proxy(
  71. tx=tx, proxy=proxy, source=AttrSource(self.source, name)
  72. )
  73. else:
  74. return wrap_fx_proxy(tx=tx, proxy=proxy)
  75. @staticmethod
  76. def is_sdpa_params(value: Any) -> TypeGuard["SDPAParams"]:
  77. return value is SDPAParams