ufunc.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. from __future__ import annotations
  2. from dataclasses import dataclass
  3. import torchgen.api.types as api_types
  4. from torchgen.api import cpp, structured
  5. from torchgen.api.types import (
  6. ArgName,
  7. BaseCppType,
  8. BaseCType,
  9. Binding,
  10. ConstRefCType,
  11. CType,
  12. NamedCType,
  13. scalarT,
  14. )
  15. from torchgen.model import (
  16. Argument,
  17. BaseTy,
  18. BaseType,
  19. DispatchKey,
  20. FunctionSchema,
  21. NativeFunctionsGroup,
  22. Type,
  23. )
  24. def schema_kernel_name(func: FunctionSchema, dispatch_key: DispatchKey) -> str:
  25. if not func.is_out_fn():
  26. raise AssertionError("ufunc.kernel_name should only be invoked on out schemas")
  27. return f"ufunc_{func.name.name}_{dispatch_key}"
  28. def kernel_name(g: NativeFunctionsGroup, dispatch_key: DispatchKey) -> str:
  29. return schema_kernel_name(g.out.func, dispatch_key)
  30. # Tensors are omitted (as they are stored in TensorIterator), everything else is
  31. # passed along (technically, we can pass tensors along too, it just wastes
  32. # argument registers)
  33. #
  34. # NB: used for CPU only
  35. def dispatchstub_type(t: Type, *, binds: ArgName) -> NamedCType | None:
  36. # Dispatch stubs are always plain ints
  37. r = cpp.valuetype_type(t, binds=binds, symint=False)
  38. if r is not None:
  39. return r
  40. if t == BaseType(BaseTy.Scalar):
  41. return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
  42. elif t == BaseType(BaseTy.Tensor):
  43. return None
  44. else:
  45. raise AssertionError(f"unrecognized type {repr(t)}")
  46. def opmath_type(scalar_t: BaseCppType) -> BaseCppType:
  47. if scalar_t == api_types.scalar_t:
  48. return api_types.opmath_t
  49. raise NotImplementedError
  50. # NB: Tensors in constructor are stored in opmath_t, not scalar_t
  51. # because Tensor in constructor = its a scalar tensor partially applied =
  52. # it can be higher precision and we want to compute in that higher precision
  53. #
  54. # NB: CUDA only
  55. def ufunctor_ctor_type(t: Type, *, binds: ArgName, scalar_t: BaseCppType) -> NamedCType:
  56. r = cpp.valuetype_type(t, binds=binds, symint=False)
  57. if r is not None:
  58. return r
  59. if t == BaseType(BaseTy.Scalar):
  60. return NamedCType(binds, BaseCType(opmath_type(scalar_t)))
  61. elif t == BaseType(BaseTy.Tensor):
  62. return NamedCType(binds, BaseCType(opmath_type(scalar_t)))
  63. else:
  64. raise AssertionError(f"unrecognized type {repr(t)}")
  65. # Only Tensors ever get passed directly to operator()
  66. #
  67. # NB: CUDA only
  68. # (Actually, this works for CPU too)
  69. def ufunctor_apply_type(
  70. t: Type, *, binds: ArgName, scalar_t: BaseCppType
  71. ) -> NamedCType:
  72. if t == BaseType(BaseTy.Tensor):
  73. return NamedCType(binds, BaseCType(scalar_t))
  74. else:
  75. raise AssertionError(f"unrecognized type {repr(t)}")
  76. # The actual ufunc template function the user writes. Everything here
  77. # is done in the computation type. compute_t is opmath_t in CUDA and scalar_t
  78. # in CPU
  79. def ufunc_type(t: Type, *, binds: ArgName, compute_t: CType) -> NamedCType:
  80. r = cpp.valuetype_type(t, binds=binds, symint=False)
  81. if r is not None:
  82. return r
  83. if t == BaseType(BaseTy.Scalar):
  84. return NamedCType(binds, compute_t)
  85. elif t == BaseType(BaseTy.Tensor):
  86. return NamedCType(binds, compute_t)
  87. else:
  88. raise AssertionError(f"unrecognized type {repr(t)}")
  89. def ufunctor_ctor_argument(a: Argument, scalar_t: BaseCppType) -> Binding:
  90. return Binding(
  91. nctype=ufunctor_ctor_type(a.type, binds=a.name, scalar_t=scalar_t),
  92. name=a.name,
  93. default=None,
  94. argument=a,
  95. )
  96. def ufunctor_apply_argument(a: Argument, scalar_t: BaseCppType) -> Binding:
  97. return Binding(
  98. nctype=ufunctor_apply_type(a.type, binds=a.name, scalar_t=scalar_t),
  99. name=a.name,
  100. default=None,
  101. argument=a,
  102. )
  103. def ufunc_argument(a: Argument, compute_t: CType) -> Binding:
  104. return Binding(
  105. nctype=ufunc_type(a.type, binds=a.name, compute_t=compute_t),
  106. name=a.name,
  107. default=None,
  108. argument=a,
  109. )
  110. @dataclass(frozen=True)
  111. class UfunctorBindings:
  112. ctor: list[Binding]
  113. apply: list[Binding]
  114. # ufunctors are a CUDA-only concept representing functors that take some of
  115. # their arguments on a host-side constructor, and the rest in the device-side
  116. # apply. E.g.,
  117. #
  118. # template <typename scalar_t>
  119. # struct CUDAFunctorOnSelf_add {
  120. # using opmath_t = at::opmath_type<scalar_t>;
  121. # opmath_t other_;
  122. # opmath_t alpha_;
  123. # CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha) : other_(other), alpha_(alpha) {}
  124. # __device__ scalar_t operator()(scalar_t self) {
  125. # return ufunc::add(static_cast<opmath_t>(self), other_, alpha_);
  126. # }
  127. # };
  128. #
  129. # The ctor refers to the constructor CUDAFunctorOnSelf_add, while apply refers
  130. # to the operator() definition
  131. def ufunctor_arguments(
  132. g: NativeFunctionsGroup, *, scalar_tensor_idx: int | None, scalar_t: BaseCppType
  133. ) -> UfunctorBindings:
  134. ctor = []
  135. apply = []
  136. for a in g.functional.func.arguments.flat_non_out:
  137. if a.type.is_tensor_like():
  138. if scalar_tensor_idx == 0:
  139. # put it in the ctor anyway
  140. ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t))
  141. scalar_tensor_idx = None
  142. else:
  143. if scalar_tensor_idx is not None:
  144. scalar_tensor_idx -= 1
  145. apply.append(ufunctor_apply_argument(a, scalar_t=scalar_t))
  146. else:
  147. ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t))
  148. if scalar_tensor_idx is not None:
  149. raise AssertionError("scalar_tensor_idx should be None at end of processing")
  150. return UfunctorBindings(ctor=ctor, apply=apply)
  151. # ufuncs are the inner loop template functions that you wrote in ufunc/add.h
  152. # which do the actual computation in question. E.g.,
  153. #
  154. # template <typename T>
  155. # C10_HOST_DEVICE T add(T self, T other, T alpha) __ubsan_ignore_undefined__ {
  156. # return self + alpha * other;
  157. # }
  158. #
  159. # In this file, we refer to T as compute_t which is bound by caller
  160. def ufunc_arguments(g: NativeFunctionsGroup, *, compute_t: CType) -> list[Binding]:
  161. return [
  162. ufunc_argument(a, compute_t=compute_t)
  163. for a in g.functional.func.arguments.flat_non_out
  164. ]
  165. # Stubs are the DispatchStub trampolines that CPU kernels use to get to their
  166. # vectorized versions. E.g.,
  167. #
  168. # using structured_binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha);
  169. # DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub);
  170. def stub_arguments(g: NativeFunctionsGroup) -> list[Binding]:
  171. # stubs drop all tensor arguments (they are implicit in the TensorIterator
  172. # argument and keep everything else)
  173. return [
  174. r
  175. for a in g.out.func.arguments.flat_non_out
  176. if not a.type.is_tensor_like()
  177. for r in structured.argument(a)
  178. ]