native_functions.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. from __future__ import annotations
  2. import torchgen.api.meta as meta
  3. import torchgen.api.structured as structured
  4. from torchgen.api.types import kernel_signature
  5. from torchgen.context import with_native_function_and_index
  6. from torchgen.model import BackendIndex, NativeFunction, NativeFunctionsGroup
  7. from torchgen.utils import mapMaybe
  8. def torch_api_key_word_prefix(bankend_index: BackendIndex) -> str:
  9. if bankend_index.external:
  10. return ""
  11. # Although Intel GPU ATen library is out-of-tree, it still utilizes torchgen to produce structured
  12. # kernels. Regarding these produced structured kernels, they should be visible for the Intel GPU ATen
  13. # library. Therefore, we need to add "TORCH_XPU_API" prefix to these structured kernels,
  14. # rather than "TORCH_API". Because the semantic of "TORCH_API" is "hidden" for out-of-tree backends.
  15. # For other in-tree backends like cpu and cuda, they still use "TORCH_API" prefix with "visible" semantic.
  16. device_torch_api_key_word_mapping = {
  17. "XPU": "TORCH_XPU_API",
  18. }
  19. return (
  20. device_torch_api_key_word_mapping.get(
  21. bankend_index.dispatch_key.name, "TORCH_API"
  22. )
  23. + " "
  24. )
  25. @with_native_function_and_index
  26. def gen_unstructured(f: NativeFunction, backend_index: BackendIndex) -> str | None:
  27. sig = kernel_signature(f, backend_index)
  28. metadata = backend_index.get_kernel(f)
  29. if metadata is None:
  30. return None
  31. if "legacy::" in metadata.kernel:
  32. return None
  33. else:
  34. prefix = "static" if backend_index.external else "TORCH_API"
  35. return f"{prefix} {sig.decl(name=metadata.kernel)};"
  36. @with_native_function_and_index
  37. def gen_structured(g: NativeFunctionsGroup, backend_index: BackendIndex) -> list[str]:
  38. meta_name = meta.name(g)
  39. out_args = structured.impl_arguments(g)
  40. metadata = backend_index.get_kernel(g)
  41. if metadata is None:
  42. return []
  43. prefix = torch_api_key_word_prefix(backend_index)
  44. return [
  45. f"""\
  46. struct {prefix}structured_{metadata.kernel} : public at::meta::structured_{meta_name} {{
  47. void impl({", ".join(a.decl() for a in out_args)});
  48. }};
  49. """
  50. ]
  51. # Generates NativeFunctions.h, a list of forward declarations of all
  52. # actual kernel definitions we keep in aten/src/ATen/native/
  53. @with_native_function_and_index
  54. def compute_native_function_declaration(
  55. g: NativeFunctionsGroup | NativeFunction, backend_index: BackendIndex
  56. ) -> list[str]:
  57. metadata = backend_index.get_kernel(g)
  58. if isinstance(g, NativeFunctionsGroup):
  59. if metadata is not None and metadata.structured:
  60. if backend_index.external:
  61. # Structured hasn't been tested with external backends yet.
  62. raise AssertionError(
  63. "Structured external backend functions are not implemented yet."
  64. )
  65. else:
  66. return gen_structured(g, backend_index)
  67. else:
  68. return list(
  69. mapMaybe(lambda f: gen_unstructured(f, backend_index), g.functions())
  70. )
  71. else:
  72. x = gen_unstructured(g, backend_index)
  73. return [] if x is None else [x]