ufunc.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561
  1. from __future__ import annotations
  2. from dataclasses import dataclass
  3. from typing import TYPE_CHECKING
  4. import torchgen.api.ufunc as ufunc
  5. from torchgen.api.translate import translate
  6. from torchgen.api.types import (
  7. BaseCType,
  8. Binding,
  9. CType,
  10. Expr,
  11. NamedCType,
  12. opmath_t,
  13. scalar_t,
  14. StructuredImplSignature,
  15. VectorizedCType,
  16. )
  17. from torchgen.context import with_native_function
  18. from torchgen.model import (
  19. Argument,
  20. BaseTy,
  21. BaseType,
  22. DispatchKey,
  23. NativeFunctionsGroup,
  24. ScalarType,
  25. UfuncKey,
  26. )
  27. from torchgen.utils import OrderedSet
  28. if TYPE_CHECKING:
  29. from collections.abc import Sequence
  30. from torchgen.api.ufunc import UfunctorBindings
  31. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  32. #
  33. # CUDA STUFF
  34. #
  35. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  36. # NB: not bothering to generate dispatch stub forward declaration in header,
  37. # we can just paste it wherever necessary
  38. # TODO: use BackendIndex
  39. # dispatch_key: DispatchKey # only CPU/CUDA right now
  40. # Represents functors for implementing CUDA ufuncs.
  41. # Functors are templated by scalar_t because when USERS instantiate functors
  42. # they are templated. A functor looks something like this:
  43. #
  44. # template <typename scalar_t>
  45. # struct CUDAFunctorOnSelf_add {
  46. # using opmath_t = at::opmath_type<scalar_t>;
  47. # opmath_t other_;
  48. # opmath_t alpha_;
  49. # CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha)
  50. # : other_(other), alpha_(alpha) {}
  51. # __device__ scalar_t operator()(scalar_t self) {
  52. # return ufunc::add(static_cast<opmath_t>(self), other_, alpha_);
  53. # }
  54. # };
  55. #
  56. @dataclass(frozen=True)
  57. class UfunctorSignature:
  58. g: NativeFunctionsGroup
  59. scalar_tensor_idx: int | None
  60. name: str
  61. def arguments(self) -> UfunctorBindings:
  62. return ufunc.ufunctor_arguments(
  63. self.g, scalar_tensor_idx=self.scalar_tensor_idx, scalar_t=scalar_t
  64. )
  65. def fields(self) -> list[Binding]:
  66. # fields are renamed to have a trailing underscore, as is conventional
  67. return [b.rename(f"{b.name}_") for b in self.arguments().ctor]
  68. def returns_type(self) -> CType:
  69. # TODO: don't hardcode; return type will be inferred based on tags on
  70. # the native function
  71. return BaseCType(scalar_t)
  72. def decl_fields(self) -> str:
  73. return "\n".join(f"{f.type} {f.name};" for f in self.fields())
  74. def inline_defn_ctor(self) -> str:
  75. args_str = ", ".join(a.decl() for a in self.arguments().ctor)
  76. # NB: hypothetically could do this with translate but the
  77. # transition here is very regular
  78. init_str = ", ".join(f"{a.name}_({a.name})" for a in self.arguments().ctor)
  79. return f"{self.name}({args_str}) : {init_str} {{}}"
  80. def decl_apply(self) -> str:
  81. args_str = ", ".join(a.decl() for a in self.arguments().apply)
  82. return f"{self.returns_type().cpp_type()} operator()({args_str}) const"
  83. @dataclass(frozen=True)
  84. class UfuncSignature:
  85. g: NativeFunctionsGroup
  86. name: str
  87. compute_t: CType
  88. def arguments(self) -> list[Binding]:
  89. return ufunc.ufunc_arguments(self.g, compute_t=self.compute_t)
  90. def call(self, ctx: Sequence[Binding | Expr]) -> str:
  91. return f"{self.name}({', '.join(a.expr for a in translate(ctx, self.arguments()))})"
  92. # steps:
  93. # 1. take the functional signature
  94. # 2. use api.ufunc to convert it to template signature. this establishes
  95. # the type of the template function
  96. # 3. use api.ufunc (II) to generate a split struct / operator() signature.
  97. # this establish context in which we call the template signature
  98. #
  99. # StructuredImplSignature context
  100. # ~> functor constructor sig
  101. #
  102. # Functor constructor context
  103. # ~> functor fields sig
  104. #
  105. # Functor apply context (functor fields + functor apply sig)
  106. # ~> template sig
  107. #
  108. def eligible_for_binary_scalar_specialization(g: NativeFunctionsGroup) -> bool:
  109. num_tensors = sum(
  110. 1 for a in g.functional.func.arguments.flat_non_out if a.type.is_tensor_like()
  111. )
  112. return num_tensors == 2
  113. def compute_ufunc_cuda_functors(
  114. g: NativeFunctionsGroup,
  115. ) -> tuple[dict[ScalarType, dict[UfuncKey, UfunctorSignature]], str]:
  116. # First, build the functors.
  117. ufunctor_sigs: dict[ScalarType, dict[UfuncKey, UfunctorSignature]] = {}
  118. ufunctors: list[str] = []
  119. loops = g.out.ufunc_inner_loop
  120. scalar_tensor_idx_lookup = {
  121. UfuncKey.CUDAFunctorOnSelf: 1,
  122. UfuncKey.CUDAFunctorOnOther: 0,
  123. UfuncKey.CUDAFunctor: None,
  124. }
  125. if eligible_for_binary_scalar_specialization(g):
  126. keys = [
  127. UfuncKey.CUDAFunctorOnSelf,
  128. UfuncKey.CUDAFunctorOnOther,
  129. UfuncKey.CUDAFunctor,
  130. ]
  131. else:
  132. keys = [UfuncKey.CUDAFunctor]
  133. for k in [UfuncKey.CUDAFunctorOnSelf, UfuncKey.CUDAFunctorOnOther]:
  134. if k in loops:
  135. raise AssertionError(f"cannot use {k} on non-binary function")
  136. for k in keys:
  137. # If the key was directly defined, skip functor codegen; we assume the
  138. # user already done it for us
  139. if k in loops:
  140. ufunctor_sig = UfunctorSignature(
  141. g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=loops[k].name
  142. )
  143. for dtype in loops[k].supported_dtypes:
  144. ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig
  145. continue
  146. # Note [ScalarOnly and Generic must match names for CUDA]
  147. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  148. # Otherwise, look in ANY of the generic entries. For simplicity of
  149. # codegen, both ScalarOnly and Generic are defined, the ufunc name
  150. # must match (if they didn't match, we'd have to generate distinct
  151. # functors per dtype, which is awful, so we're not going to do it unless
  152. # someone really forces us to)
  153. ufunc_name = None
  154. supported_dtypes: OrderedSet[ScalarType] = OrderedSet()
  155. for lk in [UfuncKey.ScalarOnly, UfuncKey.Generic]:
  156. if lk not in loops:
  157. continue
  158. if ufunc_name is None:
  159. ufunc_name = loops[lk].name
  160. else:
  161. # See Note [ScalarOnly and Generic must match names for CUDA]
  162. if ufunc_name != loops[lk].name:
  163. raise AssertionError(
  164. "ScalarOnly and Generic must have same ufunc name"
  165. )
  166. supported_dtypes |= loops[lk].supported_dtypes
  167. if ufunc_name is None:
  168. raise AssertionError("ufunc_name must be non-None")
  169. name = f"{k}_{ufunc_name}"
  170. ufunctor_sig = UfunctorSignature(
  171. g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=name
  172. )
  173. for dtype in supported_dtypes:
  174. ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig
  175. ufunc_sig = UfuncSignature(
  176. g, name=f"ufunc::{ufunc_name}", compute_t=BaseCType(opmath_t)
  177. )
  178. apply_ctx = ufunctor_sig.fields() + ufunctor_sig.arguments().apply
  179. ufunctors.append(
  180. f"""
  181. template <typename scalar_t>
  182. struct {ufunctor_sig.name} {{
  183. using opmath_t = at::opmath_type<scalar_t>;
  184. {ufunctor_sig.decl_fields()}
  185. {ufunctor_sig.inline_defn_ctor()}
  186. __device__ {ufunctor_sig.decl_apply()} {{
  187. return {ufunc_sig.call(apply_ctx)};
  188. }}
  189. }};
  190. """
  191. )
  192. return ufunctor_sigs, "\n".join(ufunctors)
  193. @dataclass(frozen=True)
  194. class BinaryScalarSpecializationConfig:
  195. scalar_idx: int
  196. ctor_tensor: str
  197. ufunc_key: UfuncKey
  198. BinaryScalarSpecializationConfigs = [
  199. BinaryScalarSpecializationConfig(
  200. scalar_idx=0,
  201. ctor_tensor="self",
  202. ufunc_key=UfuncKey.CUDAFunctorOnOther,
  203. ),
  204. BinaryScalarSpecializationConfig(
  205. scalar_idx=1,
  206. ctor_tensor="other",
  207. ufunc_key=UfuncKey.CUDAFunctorOnSelf,
  208. ),
  209. ]
  210. def compute_ufunc_cuda_dtype_body(
  211. g: NativeFunctionsGroup,
  212. dtype: ScalarType,
  213. inner_loops: dict[UfuncKey, UfunctorSignature],
  214. parent_ctx: Sequence[Binding],
  215. ) -> str:
  216. body = "using opmath_t = at::opmath_type<scalar_t>;"
  217. body += "if (false) {}\n" # for ease of codegen
  218. for config in BinaryScalarSpecializationConfigs:
  219. if config.ufunc_key not in inner_loops:
  220. continue
  221. ufunctor_sig = inner_loops[config.ufunc_key]
  222. scalar_idx = config.scalar_idx + 1
  223. # Make a copy and at the same time widen the type (not permissible
  224. # without copy; we don't want to mutate the input argument anyway)
  225. ctx: list[Expr | Binding] = list(parent_ctx)
  226. ctx.append(
  227. Expr(
  228. expr=f"iter.scalar_value<opmath_t>({scalar_idx})",
  229. type=NamedCType(config.ctor_tensor, BaseCType(opmath_t)),
  230. )
  231. )
  232. ufunctor_ctor_exprs_str = ", ".join(
  233. a.expr for a in translate(ctx, ufunctor_sig.arguments().ctor)
  234. )
  235. # NB: ufunctor must be allocated before iter.remove_operand is called,
  236. # as it relies on iter
  237. body += f"""\
  238. else if (iter.is_cpu_scalar({scalar_idx})) {{
  239. {ufunctor_sig.name}<scalar_t> ufunctor({ufunctor_ctor_exprs_str});
  240. iter.remove_operand({scalar_idx});
  241. gpu_kernel(iter, ufunctor);
  242. }}"""
  243. ufunctor_sig = inner_loops[UfuncKey.CUDAFunctor]
  244. ufunctor_ctor_exprs_str = ", ".join(
  245. a.expr for a in translate(parent_ctx, ufunctor_sig.arguments().ctor)
  246. )
  247. body += f"""
  248. else {{
  249. gpu_kernel(iter, {ufunctor_sig.name}<scalar_t>({ufunctor_ctor_exprs_str}));
  250. }}
  251. """
  252. return body
  253. @with_native_function
  254. def compute_ufunc_cuda(g: NativeFunctionsGroup) -> str:
  255. # First, build the functors, indexing them by dtype
  256. ufunctor_sigs, ufunctors = compute_ufunc_cuda_functors(g)
  257. # Next, build the conditionals
  258. sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CUDA))
  259. dtype_cases = []
  260. for dtype, inner_ufunc_sigs in ufunctor_sigs.items():
  261. dtype_cases.append(
  262. f"""
  263. AT_DISPATCH_CASE(at::ScalarType::{dtype},
  264. [&]() {{
  265. {compute_ufunc_cuda_dtype_body(g, dtype, inner_ufunc_sigs, sig.arguments())}
  266. }}
  267. )
  268. """
  269. )
  270. dtype_cases_str = "\n".join(dtype_cases)
  271. stub_sig = StubSignature(g)
  272. return f"""
  273. {ufunctors}
  274. {stub_sig.type_defn()};
  275. {stub_sig.dispatch_decl()}
  276. {stub_sig.kernel_defn()} {{
  277. AT_DISPATCH_SWITCH(iter.common_dtype(), "{sig.name}",
  278. {dtype_cases_str}
  279. );
  280. }}
  281. REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name})
  282. {sig.defn()} {{
  283. {stub_sig.direct_call(sig.arguments())};
  284. }}
  285. """
  286. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  287. #
  288. # CPU STUFF
  289. #
  290. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  291. @dataclass(frozen=True)
  292. class StubSignature:
  293. g: NativeFunctionsGroup
  294. @property
  295. def name(self) -> str:
  296. return f"{str(self.g.functional.func.name.name)}_stub"
  297. @property
  298. def kernel_name(self) -> str:
  299. return f"{str(self.g.functional.func.name.name)}_kernel"
  300. @property
  301. def type_name(self) -> str:
  302. return f"{str(self.g.functional.func.name.name)}_fn"
  303. def arguments(self) -> list[Binding]:
  304. return ufunc.stub_arguments(self.g)
  305. def type(self) -> str:
  306. cpp_args = self.arguments()
  307. return f"void(*)(TensorIteratorBase&, {', '.join(a.type for a in cpp_args)})"
  308. def dispatch_decl(self) -> str:
  309. return f"DECLARE_DISPATCH({self.type_name}, {self.name})"
  310. def dispatch_defn(self) -> str:
  311. return f"DEFINE_DISPATCH({self.name})"
  312. def kernel_defn(self) -> str:
  313. return f"void {self.kernel_name}(TensorIteratorBase& iter, {', '.join(a.defn() for a in self.arguments())})"
  314. def type_defn(self) -> str:
  315. return f"using {self.type_name} = {self.type()}"
  316. # must be called from context where this is TensorIteratorBase*
  317. def call(self, ctx: Sequence[Binding]) -> str:
  318. return f"{self.name}(device_type(), *this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})"
  319. # used in CUDA to skip the unnecessary dynamic dispatch
  320. def direct_call(self, ctx: Sequence[Binding]) -> str:
  321. return f"{self.kernel_name}(*this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})"
  322. @with_native_function
  323. def compute_ufunc_cpu(g: NativeFunctionsGroup) -> str:
  324. stub_sig = StubSignature(g)
  325. sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CPU))
  326. return f"""
  327. {stub_sig.type_defn()};
  328. {stub_sig.dispatch_decl()}
  329. {stub_sig.dispatch_defn()};
  330. {sig.defn()} {{
  331. {stub_sig.call(sig.arguments())};
  332. }}
  333. """
  334. def compute_ufunc_cpu_dtype_body(
  335. g: NativeFunctionsGroup,
  336. dtype: ScalarType,
  337. inner_loops: dict[UfuncKey, UfuncSignature],
  338. parent_ctx: Sequence[Binding],
  339. ) -> str:
  340. if UfuncKey.CPUScalar not in inner_loops:
  341. raise AssertionError(f"{dtype}, {inner_loops.keys()}")
  342. if not inner_loops.keys() <= {UfuncKey.CPUScalar, UfuncKey.CPUVector}:
  343. raise AssertionError(
  344. f"inner_loops keys must be subset of CPUScalar/CPUVector, got {inner_loops.keys()}"
  345. )
  346. scalar_loop = inner_loops[UfuncKey.CPUScalar]
  347. vec_loop = None
  348. if UfuncKey.CPUVector in inner_loops:
  349. vec_loop = inner_loops[UfuncKey.CPUVector]
  350. # NB: We DON'T use translate here, because translate is
  351. # incapable of CSE'ing the scalar accesses in case it is also
  352. # used by Vectorized; also, the unpacking here is very simple
  353. # and only affects Scalar; everything else is implicitly captured
  354. # by the lambda
  355. # Setup scalar in scope
  356. body = []
  357. ctx = []
  358. for b in parent_ctx:
  359. if isinstance(b.argument, Argument) and b.argument.type != BaseType(
  360. BaseTy.Scalar
  361. ):
  362. continue
  363. body.append(f"auto _s_{b.name} = {b.name}.to<scalar_t>();")
  364. ctx.append(Expr(f"_s_{b.name}", NamedCType(b.nctype.name, BaseCType(scalar_t))))
  365. if vec_loop is not None:
  366. for b in parent_ctx:
  367. if isinstance(b.argument, Argument) and b.argument.type != BaseType(
  368. BaseTy.Scalar
  369. ):
  370. continue
  371. body.append(
  372. f"auto _v_{b.name} = at::vec::Vectorized<scalar_t>(_s_{b.name});"
  373. )
  374. ctx.append(
  375. Expr(
  376. f"_v_{b.name}",
  377. NamedCType(b.nctype.name, VectorizedCType(BaseCType(scalar_t))),
  378. )
  379. )
  380. # Setup lambda signature
  381. # NB: simplified version of ufunctor_arguments
  382. scalar_bindings = []
  383. vec_bindings = []
  384. for a in g.functional.func.arguments.flat_non_out:
  385. if not a.type.is_tensor_like():
  386. continue
  387. if a.type != BaseType(BaseTy.Tensor):
  388. raise AssertionError(f"Expected Tensor type, got {a.type}")
  389. scalar_bindings.append(
  390. Binding(
  391. name=a.name,
  392. nctype=NamedCType(a.name, BaseCType(scalar_t)),
  393. argument=a,
  394. )
  395. )
  396. if vec_loop is not None:
  397. vec_bindings.append(
  398. Binding(
  399. name=a.name,
  400. nctype=NamedCType(a.name, VectorizedCType(BaseCType(scalar_t))),
  401. argument=a,
  402. )
  403. )
  404. def with_ctx(b: Sequence[Binding]) -> list[Expr | Binding]:
  405. r: list[Expr | Binding] = []
  406. r.extend(ctx)
  407. r.extend(b)
  408. return r
  409. body_str = "\n".join(body)
  410. if vec_loop is not None:
  411. return f"""
  412. {body_str}
  413. cpu_kernel_vec(iter,
  414. [=]({", ".join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }},
  415. [=]({", ".join(b.decl() for b in vec_bindings)}) {{ return {vec_loop.call(with_ctx(vec_bindings))}; }}
  416. );
  417. """
  418. else:
  419. return f"""
  420. {body_str}
  421. cpu_kernel(iter,
  422. [=]({", ".join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }}
  423. );
  424. """
  425. @with_native_function
  426. def compute_ufunc_cpu_kernel(g: NativeFunctionsGroup) -> str:
  427. stub_sig = StubSignature(g)
  428. # Reindex the ufunc by dtypes; processing generic/scalaronly as well
  429. loops = g.out.ufunc_inner_loop
  430. ufunc_sigs: dict[ScalarType, dict[UfuncKey, UfuncSignature]] = {}
  431. for k in [UfuncKey.CPUScalar, UfuncKey.CPUVector]:
  432. lks = []
  433. # ORDER MATTERS: this specifies overriding precedence
  434. if k in loops: # should happen rarely
  435. lks.append(k)
  436. if UfuncKey.ScalarOnly in loops and k is UfuncKey.CPUScalar:
  437. lks.append(UfuncKey.ScalarOnly)
  438. if UfuncKey.Generic in loops:
  439. lks.append(UfuncKey.Generic)
  440. # TODO: don't hardcode ufunc:: namespace here, should be centralized smh
  441. for lk in lks:
  442. for dtype in loops[lk].supported_dtypes:
  443. compute_t: CType
  444. if k is UfuncKey.CPUScalar:
  445. compute_t = BaseCType(scalar_t)
  446. elif k is UfuncKey.CPUVector:
  447. compute_t = VectorizedCType(BaseCType(scalar_t))
  448. else:
  449. raise AssertionError
  450. inner_ufunc_sigs = ufunc_sigs.setdefault(dtype, {})
  451. if k not in inner_ufunc_sigs:
  452. inner_ufunc_sigs[k] = UfuncSignature(
  453. g, name=f"ufunc::{loops[lk].name}", compute_t=compute_t
  454. )
  455. # Build the conditionals
  456. dtype_cases = []
  457. for dtype, inner_ufunc_sigs in ufunc_sigs.items():
  458. dtype_cases.append(
  459. f"""
  460. AT_DISPATCH_CASE(at::ScalarType::{dtype},
  461. [&]() {{
  462. {compute_ufunc_cpu_dtype_body(g, dtype, inner_ufunc_sigs, stub_sig.arguments())}
  463. }}
  464. )
  465. """
  466. )
  467. dtype_cases_str = "\n".join(dtype_cases)
  468. return f"""
  469. namespace {{
  470. {stub_sig.kernel_defn()} {{
  471. AT_DISPATCH_SWITCH(iter.common_dtype(), "{stub_sig.name}",
  472. {dtype_cases_str}
  473. );
  474. }}
  475. }} // anonymous namespace
  476. {stub_sig.type_defn()};
  477. {stub_sig.dispatch_decl()}
  478. REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name})
  479. """