native.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. from __future__ import annotations
  2. from typing import TYPE_CHECKING
  3. from typing_extensions import assert_never
  4. from torchgen import local
  5. from torchgen.api import cpp
  6. from torchgen.api.types import (
  7. ArgName,
  8. BaseCType,
  9. Binding,
  10. boolT,
  11. ConstRefCType,
  12. CType,
  13. deviceT,
  14. layoutT,
  15. ListCType,
  16. MutRefCType,
  17. NamedCType,
  18. OptionalCType,
  19. scalarT,
  20. scalarTypeT,
  21. tensorT,
  22. )
  23. from torchgen.model import (
  24. Argument,
  25. FunctionSchema,
  26. Return,
  27. SelfArgument,
  28. TensorOptionsArguments,
  29. Type,
  30. )
  31. if TYPE_CHECKING:
  32. from collections.abc import Sequence
  33. # This file describes the translation of JIT schema to the native functions API.
  34. # This looks a lot like the C++ API (which makes historical sense, because the
  35. # idea was you wrote native functions to implement functions in the C++ API),
  36. # but over time we have evolved the C++ API without actually changing our
  37. # native:: kernels. The intention is to make native API and dispatcher API
  38. # line up as closely as possible, since this results in the least overhead
  39. # (no translation is needed from dispatcher API to native API).
  40. #
  41. # NB: this is symint aware, you will get the non-SymInt variant for some
  42. # dispatch entries and SymInt for others.
  43. def name(func: FunctionSchema) -> str:
  44. name = str(func.name.name)
  45. # TODO: delete this!
  46. if func.is_out_fn():
  47. name += "_out"
  48. if func.name.overload_name:
  49. name += f"_{func.name.overload_name}"
  50. return name
  51. def argumenttype_type(
  52. t: Type, *, mutable: bool, binds: ArgName, symint: bool
  53. ) -> NamedCType:
  54. if str(t) == "Tensor?":
  55. tensor_type: OptionalCType = OptionalCType(BaseCType(tensorT))
  56. if mutable and not local.use_const_ref_for_mutable_tensors():
  57. return NamedCType(binds, MutRefCType(tensor_type))
  58. else:
  59. return NamedCType(binds, ConstRefCType(tensor_type))
  60. elif str(t) == "Tensor?[]":
  61. return NamedCType(
  62. binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT))))
  63. )
  64. elif str(t) == "Scalar":
  65. return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
  66. elif str(t) == "Scalar?":
  67. return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
  68. return cpp.argumenttype_type(t, mutable=mutable, binds=binds, symint=symint)
  69. def returns_type(rs: Sequence[Return], *, symint: bool) -> CType:
  70. return cpp.returns_type(rs, symint=symint)
  71. def argument_type(a: Argument, *, binds: ArgName, symint: bool) -> NamedCType:
  72. return argumenttype_type(a.type, mutable=a.is_write, binds=binds, symint=symint)
  73. def argument(
  74. a: Argument | SelfArgument | TensorOptionsArguments,
  75. *,
  76. is_out: bool,
  77. symint: bool,
  78. ) -> list[Binding]:
  79. # Ideally, we NEVER default native functions. However, there are a number
  80. # of functions that call native:: directly and rely on the defaulting
  81. # existing. So for BC, we generate defaults for non-out variants (but not
  82. # for out variants, where it is impossible to generate an appropriate
  83. # default)
  84. should_default = not is_out
  85. if isinstance(a, Argument):
  86. default: str | None = None
  87. if should_default and a.default is not None:
  88. default = cpp.default_expr(a.default, a.type, symint=symint)
  89. return [
  90. Binding(
  91. nctype=argument_type(a, binds=a.name, symint=symint),
  92. name=a.name,
  93. default=default,
  94. argument=a,
  95. )
  96. ]
  97. elif isinstance(a, SelfArgument):
  98. # Erase SelfArgument from the distinction
  99. return argument(a.argument, is_out=is_out, symint=symint)
  100. elif isinstance(a, TensorOptionsArguments):
  101. default = None
  102. if should_default:
  103. default = "{}"
  104. # TODO: Not sure why the arguments assigned here are for
  105. # TensorOptionsArguments and not the constituent pieces. It seems
  106. # to matter
  107. return [
  108. Binding(
  109. nctype=NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))),
  110. name="dtype",
  111. default=default,
  112. argument=a,
  113. ),
  114. Binding(
  115. nctype=NamedCType("layout", OptionalCType(BaseCType(layoutT))),
  116. name="layout",
  117. default=default,
  118. argument=a,
  119. ),
  120. Binding(
  121. nctype=NamedCType("device", OptionalCType(BaseCType(deviceT))),
  122. name="device",
  123. default=default,
  124. argument=a,
  125. ),
  126. Binding(
  127. nctype=NamedCType("pin_memory", OptionalCType(BaseCType(boolT))),
  128. name="pin_memory",
  129. default=default,
  130. argument=a,
  131. ),
  132. ]
  133. else:
  134. assert_never(a)
  135. def arguments(func: FunctionSchema, *, symint: bool) -> list[Binding]:
  136. args: list[Argument | TensorOptionsArguments | SelfArgument] = []
  137. args.extend(func.arguments.non_out)
  138. args.extend(func.arguments.out)
  139. return [
  140. r for arg in args for r in argument(arg, symint=symint, is_out=func.is_out_fn())
  141. ]