indexing.h 54 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. // Metal indexing primitives
  3. #pragma once
  4. #include <c10/metal/common.h>
  5. #include <c10/metal/utils.h>
  6. #include <metal_stdlib>
  7. namespace c10 {
  8. namespace metal {
  9. // Given coordinates and strides, calculates offset from the start of the
  10. // tensors
  11. template <typename T>
  12. inline T offset_from_coord(
  13. thread T idx[max_ndim],
  14. constant long* strides,
  15. uint ndim) {
  16. T rc = 0;
  17. for (uint i = 0; i < ndim; ++i) {
  18. rc += idx[i] * T(strides[i]);
  19. }
  20. return rc;
  21. }
  22. // Given thread index calculates position in the ndim tensor
  23. template <typename T>
  24. inline void pos_from_thread_index(
  25. T idx,
  26. thread T pos[max_ndim],
  27. constant long* sizes,
  28. uint ndim) {
  29. for (uint i = 0; i < ndim; ++i) {
  30. pos[i] = idx % T(sizes[i]);
  31. idx /= T(sizes[i]);
  32. }
  33. }
  34. inline long offset_from_thread_index(
  35. long idx,
  36. constant long* sizes,
  37. constant long* strides,
  38. uint ndim) {
  39. long pos[max_ndim];
  40. pos_from_thread_index(idx, pos, sizes, ndim);
  41. return offset_from_coord(pos, strides, ndim);
  42. }
  43. template <typename T, typename F>
  44. kernel void unary_dense(
  45. device result_of<F, T>* output [[buffer(0)]],
  46. constant T* input [[buffer(1)]],
  47. uint index [[thread_position_in_grid]]) {
  48. F f;
  49. output[index] = f(input[index]);
  50. }
  51. template <typename T, typename F>
  52. kernel void unary_strided(
  53. device result_of<F, T>* output [[buffer(0)]],
  54. constant T* input [[buffer(1)]],
  55. constant long* sizes [[buffer(2)]],
  56. constant long* input_strides [[buffer(3)]],
  57. constant long* output_strides [[buffer(4)]],
  58. constant uint& ndim [[buffer(5)]],
  59. uint index [[thread_position_in_grid]]) {
  60. F f;
  61. int pos[max_ndim];
  62. pos_from_thread_index(int(index), pos, sizes, ndim);
  63. const auto input_offs = offset_from_coord(pos, input_strides, ndim);
  64. const auto output_offs = offset_from_coord(pos, output_strides, ndim);
  65. output[output_offs] = f(input[input_offs]);
  66. }
  67. #define REGISTER_UNARY_OP(NAME, DTYPE0, DTYPE1) \
  68. static_assert( \
  69. ::metal:: \
  70. is_same_v<DTYPE1, ::c10::metal::result_of<NAME##_functor, DTYPE0>>, \
  71. "Output dtype mismatch for unary op " #NAME " and input " #DTYPE0); \
  72. template [[host_name(#NAME "_dense_" #DTYPE1 "_" #DTYPE0)]] kernel void :: \
  73. c10::metal::unary_dense<DTYPE0, NAME##_functor>( \
  74. device ::c10::metal::result_of<NAME##_functor, DTYPE0> * output, \
  75. constant DTYPE0 * input, \
  76. uint index); \
  77. template [[host_name(#NAME "_strided_" #DTYPE1 "_" #DTYPE0)]] kernel void :: \
  78. c10::metal::unary_strided<DTYPE0, NAME##_functor>( \
  79. device ::c10::metal::result_of<NAME##_functor, DTYPE0> * output, \
  80. constant DTYPE0 * input, \
  81. constant long* sizes, \
  82. constant long* input_strides, \
  83. constant long* output_strides, \
  84. constant uint& ndim, \
  85. uint index)
  86. #define DEFINE_UNARY_FLOATING_FUNCTOR(NAME) \
  87. struct NAME##_functor { \
  88. template <typename T> \
  89. inline ::metal::enable_if_t<::metal::is_floating_point_v<T>, T> operator()( \
  90. const T x) { \
  91. return T(NAME(x)); \
  92. } \
  93. template <typename T> \
  94. inline ::metal::enable_if_t<::metal::is_integral_v<T>, float> operator()( \
  95. const T x) { \
  96. return NAME(static_cast<float>(x)); \
  97. } \
  98. }
  99. template <typename T, typename T2, typename F>
  100. kernel void unary_alpha_dense(
  101. device result_of<F, T, T2>* output [[buffer(0)]],
  102. constant T* input [[buffer(1)]],
  103. constant T2& alpha [[buffer(2)]],
  104. uint index [[thread_position_in_grid]]) {
  105. F f;
  106. output[index] = f(input[index], alpha);
  107. }
  108. template <typename T, typename T2, typename F>
  109. kernel void unary_alpha_strided(
  110. device result_of<F, T, T2>* output [[buffer(0)]],
  111. constant T* input [[buffer(1)]],
  112. constant long* sizes [[buffer(2)]],
  113. constant long* input_strides [[buffer(3)]],
  114. constant long* output_strides [[buffer(4)]],
  115. constant uint& ndim [[buffer(5)]],
  116. constant T2& alpha [[buffer(6)]],
  117. uint index [[thread_position_in_grid]]) {
  118. F f;
  119. int pos[max_ndim];
  120. pos_from_thread_index(int(index), pos, sizes, ndim);
  121. const auto input_offs = offset_from_coord(pos, input_strides, ndim);
  122. const auto output_offs = offset_from_coord(pos, output_strides, ndim);
  123. output[output_offs] = f(input[input_offs], alpha);
  124. }
  125. #define REGISTER_UNARY_ALPHA_OP(NAME, DTYPEI, DTYPEA, DTYPEO) \
  126. static_assert( \
  127. ::metal::is_same_v< \
  128. DTYPEO, \
  129. ::c10::metal::result_of<NAME##_functor, DTYPEI, DTYPEA>>, \
  130. "Output dtype mismatch for unary op " #NAME " and input " #DTYPEI); \
  131. template [[host_name(#NAME "_dense_" #DTYPEO "_" #DTYPEI \
  132. "_" #DTYPEA)]] kernel void ::c10::metal:: \
  133. unary_alpha_dense<DTYPEI, DTYPEA, NAME##_functor>( \
  134. device ::c10::metal::result_of<NAME##_functor, DTYPEI, DTYPEA> * \
  135. output, \
  136. constant DTYPEI * input, \
  137. constant DTYPEA & alpha, \
  138. uint index); \
  139. template [[host_name(#NAME "_strided_" #DTYPEO "_" #DTYPEI \
  140. "_" #DTYPEA)]] kernel void ::c10::metal:: \
  141. unary_alpha_strided<DTYPEI, DTYPEA, NAME##_functor>( \
  142. device ::c10::metal::result_of<NAME##_functor, DTYPEI, DTYPEA> * \
  143. output, \
  144. constant DTYPEI * input, \
  145. constant long* sizes, \
  146. constant long* input_strides, \
  147. constant long* output_strides, \
  148. constant uint& ndim, \
  149. constant DTYPEA& alpha, \
  150. uint index)
  151. template <typename T>
  152. inline T val_at_offs(constant void* ptr, long offs) {
  153. return *reinterpret_cast<constant T*>(
  154. static_cast<constant char*>(ptr) + offs);
  155. }
  156. // Value at offset with dynamic cast from provided type
  157. template <typename T>
  158. inline T val_at_offs(device void* ptr, long offs) {
  159. return *reinterpret_cast<device T*>(static_cast<device char*>(ptr) + offs);
  160. }
  161. template <typename T, typename P>
  162. inline T val_at_offs(P ptr, long offs, ScalarType type) {
  163. switch (type) {
  164. case ScalarType::Bool:
  165. return cast_to<T>(val_at_offs<bool>(ptr, offs));
  166. case ScalarType::Byte:
  167. return cast_to<T>(val_at_offs<uchar>(ptr, offs));
  168. case ScalarType::Char:
  169. return cast_to<T>(val_at_offs<char>(ptr, offs));
  170. case ScalarType::Short:
  171. return cast_to<T>(val_at_offs<short>(ptr, offs));
  172. case ScalarType::Int:
  173. return cast_to<T>(val_at_offs<int>(ptr, offs));
  174. case ScalarType::Long:
  175. return cast_to<T>(val_at_offs<long>(ptr, offs));
  176. // Floats
  177. case ScalarType::Float:
  178. return cast_to<T>(val_at_offs<float>(ptr, offs));
  179. case ScalarType::Half:
  180. return cast_to<T>(val_at_offs<half>(ptr, offs));
  181. case ScalarType::BFloat16:
  182. return cast_to<T>(val_at_offs<bfloat>(ptr, offs));
  183. // Complex
  184. case ScalarType::ComplexHalf:
  185. return cast_to<T>(val_at_offs<half2>(ptr, offs));
  186. case ScalarType::ComplexFloat:
  187. return cast_to<T>(val_at_offs<float2>(ptr, offs));
  188. }
  189. }
  190. template <typename T>
  191. inline device T& ref_at_offs(device void* ptr, long offs) {
  192. return *reinterpret_cast<device T*>(static_cast<device char*>(ptr) + offs);
  193. }
  194. // Binary elementwise ops kernels
  195. // Right now there are 4 flavors available:
  196. // - binary_dense where both input, other and output are dense and share the
  197. // same type
  198. // - binary_strided when all inputs are of the same types, but some elements are
  199. // strided
  200. // - binary_dense_cast - inputs are dense, but of different dtypes
  201. // - binary_strided_cast - inputs or output are strided and of different dtypes
  202. // - binary_dense_broadcast - one input is dense, another one is broadcastable
  203. // Note about accuracy (for more info see
  204. // https://github.com/pytorch/pytorch/issues/152736) Sometimes when kernel is
  205. // invoked to produce `half` output, but one of the arguments is float arguments
  206. // should be upcast to float, rather than downcast to half At the moment this is
  207. // expressed with `om_t` optional argument (which stands for opmath_type) which
  208. // is identical to output type but could be something else
  209. template <typename T, typename F, typename om_t = T>
  210. kernel void binary_strided(
  211. device void* output [[buffer(0)]],
  212. constant void* input [[buffer(1)]],
  213. constant void* other [[buffer(2)]],
  214. constant long* sizes [[buffer(3)]],
  215. constant long* output_strides [[buffer(4)]],
  216. constant long* input_strides [[buffer(5)]],
  217. constant long* other_strides [[buffer(6)]],
  218. constant uint3& ndim [[buffer(7)]],
  219. uint index [[thread_position_in_grid]]) {
  220. F f;
  221. using res_t = result_of<F, T, T>;
  222. int pos[max_ndim];
  223. pos_from_thread_index(int(index), pos, sizes, ndim.x);
  224. const auto input_offs = offset_from_coord(pos, input_strides, ndim.x);
  225. const auto other_offs = offset_from_coord(pos, other_strides, ndim.x);
  226. const auto output_offs = offset_from_coord(pos, output_strides, ndim.x);
  227. const auto a = val_at_offs<T>(input, input_offs);
  228. const auto b = val_at_offs<T>(other, other_offs);
  229. ref_at_offs<res_t>(output, output_offs) =
  230. static_cast<res_t>(f(om_t(a), om_t(b)));
  231. }
  232. template <typename T, typename T2, typename F>
  233. kernel void binary_alpha_strided(
  234. device void* output [[buffer(0)]],
  235. constant void* input [[buffer(1)]],
  236. constant void* other [[buffer(2)]],
  237. constant T2& alpha [[buffer(3)]],
  238. constant long* sizes [[buffer(4)]],
  239. constant long* output_strides [[buffer(5)]],
  240. constant long* input_strides [[buffer(6)]],
  241. constant long* other_strides [[buffer(7)]],
  242. constant uint3& ndim [[buffer(8)]],
  243. uint index [[thread_position_in_grid]]) {
  244. F f;
  245. int pos[max_ndim];
  246. pos_from_thread_index(int(index), pos, sizes, ndim.x);
  247. const auto input_offs = offset_from_coord(pos, input_strides, ndim.x);
  248. const auto other_offs = offset_from_coord(pos, other_strides, ndim.x);
  249. const auto output_offs = offset_from_coord(pos, output_strides, ndim.x);
  250. const auto a = val_at_offs<T>(input, input_offs);
  251. const auto b = val_at_offs<T>(other, other_offs);
  252. ref_at_offs<result_of<F, T, T, T2>>(output, output_offs) = f(a, b, alpha);
  253. }
  254. template <typename T, typename F, typename om_t = opmath_t<T>>
  255. kernel void binary_strided_cast(
  256. device void* output [[buffer(0)]],
  257. constant void* input [[buffer(1)]],
  258. constant void* other [[buffer(2)]],
  259. constant long* sizes [[buffer(3)]],
  260. constant long* output_strides [[buffer(4)]],
  261. constant long* input_strides [[buffer(5)]],
  262. constant long* other_strides [[buffer(6)]],
  263. constant uint4& ndim_types [[buffer(7)]],
  264. uint index [[thread_position_in_grid]]) {
  265. F f;
  266. using res_t = result_of<F, T, T>;
  267. int pos[max_ndim];
  268. pos_from_thread_index(int(index), pos, sizes, ndim_types.x);
  269. const auto input_offs = offset_from_coord(pos, input_strides, ndim_types.x);
  270. const auto other_offs = offset_from_coord(pos, other_strides, ndim_types.x);
  271. const auto output_offs = offset_from_coord(pos, output_strides, ndim_types.x);
  272. const auto a = val_at_offs<om_t>(
  273. input, input_offs, static_cast<ScalarType>(ndim_types.y));
  274. const auto b = val_at_offs<om_t>(
  275. other, other_offs, static_cast<ScalarType>(ndim_types.z));
  276. ref_at_offs<res_t>(output, output_offs) = static_cast<res_t>(f(a, b));
  277. }
  278. template <typename T, typename T2, typename F>
  279. kernel void binary_alpha_strided_cast(
  280. device void* output [[buffer(0)]],
  281. constant void* input [[buffer(1)]],
  282. constant void* other [[buffer(2)]],
  283. constant T2& alpha [[buffer(3)]],
  284. constant long* sizes [[buffer(4)]],
  285. constant long* output_strides [[buffer(5)]],
  286. constant long* input_strides [[buffer(6)]],
  287. constant long* other_strides [[buffer(7)]],
  288. constant uint4& ndim_types [[buffer(8)]],
  289. uint index [[thread_position_in_grid]]) {
  290. F f;
  291. int pos[max_ndim];
  292. pos_from_thread_index(int(index), pos, sizes, ndim_types.x);
  293. const auto input_offs = offset_from_coord(pos, input_strides, ndim_types.x);
  294. const auto other_offs = offset_from_coord(pos, other_strides, ndim_types.x);
  295. const auto output_offs = offset_from_coord(pos, output_strides, ndim_types.x);
  296. const auto a =
  297. val_at_offs<T>(input, input_offs, static_cast<ScalarType>(ndim_types.y));
  298. const auto b =
  299. val_at_offs<T>(other, other_offs, static_cast<ScalarType>(ndim_types.z));
  300. ref_at_offs<result_of<F, T, T, T2>>(output, output_offs) = f(a, b, alpha);
  301. }
  302. template <typename T, typename F, typename om_t = opmath_t<T>>
  303. kernel void binary_dense(
  304. device result_of<F, T, T>* out [[buffer(0)]],
  305. constant T* input [[buffer(1)]],
  306. constant T* other [[buffer(2)]],
  307. uint tid [[thread_position_in_grid]]) {
  308. F f;
  309. using res_t = result_of<F, T, T>;
  310. out[tid] = static_cast<res_t>(f(om_t(input[tid]), om_t(other[tid])));
  311. }
  312. template <typename T, typename T2, typename F>
  313. kernel void binary_alpha_dense(
  314. device result_of<F, T, T, T2>* out [[buffer(0)]],
  315. constant T* input [[buffer(1)]],
  316. constant T* other [[buffer(2)]],
  317. constant T2& alpha [[buffer(3)]],
  318. uint tid [[thread_position_in_grid]]) {
  319. F f;
  320. out[tid] = f(input[tid], other[tid], alpha);
  321. }
  322. template <typename T, typename F, typename om_t = T>
  323. kernel void binary_dense_cast(
  324. device result_of<F, T, T>* out [[buffer(0)]],
  325. constant void* input [[buffer(1)]],
  326. constant void* other [[buffer(2)]],
  327. constant uint4& sizes_types [[buffer(3)]],
  328. uint tid [[thread_position_in_grid]]) {
  329. F f;
  330. using res_t = result_of<F, T, T>;
  331. const auto a = val_at_offs<om_t>(
  332. input, tid * sizes_types.x, static_cast<ScalarType>(sizes_types.z));
  333. const auto b = val_at_offs<om_t>(
  334. other, tid * sizes_types.y, static_cast<ScalarType>(sizes_types.w));
  335. out[tid] = static_cast<res_t>(f(a, b));
  336. }
  337. template <typename T, typename T2, typename F>
  338. kernel void binary_alpha_dense_cast(
  339. device result_of<F, T, T, T2>* out [[buffer(0)]],
  340. constant void* input [[buffer(1)]],
  341. constant void* other [[buffer(2)]],
  342. constant T2& alpha [[buffer(3)]],
  343. constant uint4& sizes_types [[buffer(4)]],
  344. uint tid [[thread_position_in_grid]]) {
  345. F f;
  346. const auto a = val_at_offs<T>(
  347. input, tid * sizes_types.x, static_cast<ScalarType>(sizes_types.z));
  348. const auto b = val_at_offs<T>(
  349. other, tid * sizes_types.y, static_cast<ScalarType>(sizes_types.w));
  350. out[tid] = f(a, b, alpha);
  351. }
  352. template <typename T, typename F, typename om_t = opmath_t<T>>
  353. kernel void binary_dense_broadcast(
  354. device result_of<F, T, T>* out [[buffer(0)]],
  355. constant T* input [[buffer(1)]],
  356. constant T* broadcast [[buffer(2)]],
  357. constant long& broadcast_numel [[buffer(3)]],
  358. uint tid [[thread_position_in_grid]]) {
  359. F f;
  360. using res_t = result_of<F, T, T>;
  361. out[tid] = static_cast<res_t>(
  362. f(om_t(input[tid]), om_t(broadcast[tid % broadcast_numel])));
  363. }
  364. template <typename T, typename F, typename om_t = opmath_t<T>>
  365. kernel void binary_dense_broadcast_rhs(
  366. device result_of<F, T, T>* out [[buffer(0)]],
  367. constant T* broadcast [[buffer(1)]],
  368. constant T* input [[buffer(2)]],
  369. constant long& broadcast_numel [[buffer(3)]],
  370. uint tid [[thread_position_in_grid]]) {
  371. F f;
  372. using res_t = result_of<F, T, T>;
  373. out[tid] = static_cast<res_t>(
  374. f(om_t(broadcast[tid % broadcast_numel]), om_t(input[tid])));
  375. }
  376. template <typename T, typename T2, typename F>
  377. kernel void binary_alpha_dense_broadcast(
  378. device result_of<F, T, T, T2>* out [[buffer(0)]],
  379. constant T* input [[buffer(1)]],
  380. constant T* broadcast [[buffer(2)]],
  381. constant long& broadcast_numel [[buffer(3)]],
  382. constant T2& alpha [[buffer(4)]],
  383. uint tid [[thread_position_in_grid]]) {
  384. F f;
  385. out[tid] = f(input[tid], broadcast[tid % broadcast_numel], alpha);
  386. }
  387. template <typename T, typename T2, typename F>
  388. kernel void binary_alpha_dense_broadcast_rhs(
  389. device result_of<F, T, T, T2>* out [[buffer(0)]],
  390. constant T* broadcast [[buffer(1)]],
  391. constant T* input [[buffer(2)]],
  392. constant long& broadcast_numel [[buffer(3)]],
  393. constant T2& alpha [[buffer(4)]],
  394. uint tid [[thread_position_in_grid]]) {
  395. F f;
  396. out[tid] = f(broadcast[tid % broadcast_numel], input[tid], alpha);
  397. }
  398. template <typename T, typename F, typename om_t = T>
  399. kernel void binary_dense_broadcast_cast(
  400. device result_of<F, T, T>* out [[buffer(0)]],
  401. constant void* input [[buffer(1)]],
  402. constant void* broadcast [[buffer(2)]],
  403. constant long& broadcast_numel [[buffer(3)]],
  404. constant uint4& sizes_types [[buffer(4)]],
  405. uint tid [[thread_position_in_grid]]) {
  406. F f;
  407. using res_t = result_of<F, T, T>;
  408. const auto a = val_at_offs<om_t>(
  409. input, tid * sizes_types.x, static_cast<ScalarType>(sizes_types.z));
  410. const auto b = val_at_offs<om_t>(
  411. broadcast,
  412. (tid % broadcast_numel) * sizes_types.y,
  413. static_cast<ScalarType>(sizes_types.w));
  414. out[tid] = static_cast<res_t>(f(a, b));
  415. }
  416. template <typename T, typename F, typename om_t = T>
  417. kernel void binary_dense_broadcast_rhs_cast(
  418. device result_of<F, T, T>* out [[buffer(0)]],
  419. constant void* broadcast [[buffer(1)]],
  420. constant void* input [[buffer(2)]],
  421. constant long& broadcast_numel [[buffer(3)]],
  422. constant uint4& sizes_types [[buffer(4)]],
  423. uint tid [[thread_position_in_grid]]) {
  424. F f;
  425. using res_t = result_of<F, T, T>;
  426. const auto a = val_at_offs<om_t>(
  427. broadcast,
  428. (tid % broadcast_numel) * sizes_types.x,
  429. static_cast<ScalarType>(sizes_types.z));
  430. const auto b = val_at_offs<om_t>(
  431. input, tid * sizes_types.y, static_cast<ScalarType>(sizes_types.w));
  432. out[tid] = static_cast<res_t>(f(a, b));
  433. }
  434. template <typename T, typename T2, typename F>
  435. kernel void binary_alpha_dense_broadcast_cast(
  436. device result_of<F, T, T, T2>* out [[buffer(0)]],
  437. constant void* input [[buffer(1)]],
  438. constant void* broadcast [[buffer(2)]],
  439. constant long& broadcast_numel [[buffer(3)]],
  440. constant T2& alpha [[buffer(4)]],
  441. constant uint4& sizes_types [[buffer(5)]],
  442. uint tid [[thread_position_in_grid]]) {
  443. F f;
  444. const auto a = val_at_offs<T>(
  445. input, tid * sizes_types.x, static_cast<ScalarType>(sizes_types.z));
  446. const auto b = val_at_offs<T>(
  447. broadcast,
  448. (tid % broadcast_numel) * sizes_types.y,
  449. static_cast<ScalarType>(sizes_types.w));
  450. out[tid] = f(a, b, alpha);
  451. }
  452. template <typename T, typename T2, typename F>
  453. kernel void binary_alpha_dense_broadcast_rhs_cast(
  454. device result_of<F, T, T, T2>* out [[buffer(0)]],
  455. constant void* broadcast [[buffer(1)]],
  456. constant void* input [[buffer(2)]],
  457. constant long& broadcast_numel [[buffer(3)]],
  458. constant T2& alpha [[buffer(4)]],
  459. constant uint4& sizes_types [[buffer(5)]],
  460. uint tid [[thread_position_in_grid]]) {
  461. F f;
  462. const auto a = val_at_offs<T>(
  463. broadcast,
  464. (tid % broadcast_numel) * sizes_types.x,
  465. static_cast<ScalarType>(sizes_types.z));
  466. const auto b = val_at_offs<T>(
  467. input, tid * sizes_types.y, static_cast<ScalarType>(sizes_types.w));
  468. out[tid] = f(a, b, alpha);
  469. }
  470. template <typename T, typename F, typename om_t = opmath_t<T>>
  471. kernel void binary_dense_scalar(
  472. device result_of<F, T, T>* out [[buffer(0)]],
  473. constant T* input [[buffer(1)]],
  474. device T* scalar [[buffer(2)]],
  475. uint tid [[thread_position_in_grid]]) {
  476. F f;
  477. using res_t = result_of<F, T, T>;
  478. out[tid] = static_cast<res_t>(f(om_t(input[tid]), om_t(scalar[0])));
  479. }
  480. template <typename T, typename F, typename om_t = opmath_t<T>>
  481. kernel void binary_dense_scalar_lhs(
  482. device result_of<F, T, T>* out [[buffer(0)]],
  483. device T* scalar [[buffer(1)]],
  484. constant T* input [[buffer(2)]],
  485. uint tid [[thread_position_in_grid]]) {
  486. F f;
  487. using res_t = result_of<F, T, T>;
  488. out[tid] = static_cast<res_t>(f(om_t(scalar[0]), om_t(input[tid])));
  489. }
  490. template <typename T, typename F, typename om_t = T>
  491. kernel void binary_dense_scalar_cast(
  492. device result_of<F, T, T>* out [[buffer(0)]],
  493. constant void* input [[buffer(1)]],
  494. device void* scalar [[buffer(2)]],
  495. constant uint4& sizes_types [[buffer(3)]],
  496. uint tid [[thread_position_in_grid]]) {
  497. F f;
  498. using res_t = result_of<F, T, T>;
  499. const auto a = val_at_offs<om_t>(
  500. input, tid * sizes_types.x, static_cast<ScalarType>(sizes_types.z));
  501. const auto b =
  502. val_at_offs<om_t>(scalar, 0, static_cast<ScalarType>(sizes_types.w));
  503. out[tid] = static_cast<res_t>(f(a, b));
  504. }
  505. template <typename T, typename F, typename om_t = T>
  506. kernel void binary_dense_scalar_lhs_cast(
  507. device result_of<F, T, T>* out [[buffer(0)]],
  508. device void* scalar [[buffer(1)]],
  509. constant void* input [[buffer(2)]],
  510. constant uint4& sizes_types [[buffer(3)]],
  511. uint tid [[thread_position_in_grid]]) {
  512. F f;
  513. using res_t = result_of<F, T, T>;
  514. const auto a =
  515. val_at_offs<om_t>(scalar, 0, static_cast<ScalarType>(sizes_types.z));
  516. const auto b = val_at_offs<om_t>(
  517. input, tid * sizes_types.y, static_cast<ScalarType>(sizes_types.w));
  518. out[tid] = static_cast<res_t>(f(a, b));
  519. }
  520. template <typename T, typename T2, typename F>
  521. kernel void binary_alpha_dense_scalar(
  522. device result_of<F, T, T, T2>* out [[buffer(0)]],
  523. constant T* input [[buffer(1)]],
  524. device T* scalar [[buffer(2)]],
  525. constant T2& alpha [[buffer(3)]],
  526. uint tid [[thread_position_in_grid]]) {
  527. F f;
  528. out[tid] = f(input[tid], scalar[0], alpha);
  529. }
  530. template <typename T, typename T2, typename F>
  531. kernel void binary_alpha_dense_scalar_lhs(
  532. device result_of<F, T, T, T2>* out [[buffer(0)]],
  533. device T* scalar [[buffer(1)]],
  534. constant T* input [[buffer(2)]],
  535. constant T2& alpha [[buffer(3)]],
  536. uint tid [[thread_position_in_grid]]) {
  537. F f;
  538. out[tid] = f(scalar[0], input[tid], alpha);
  539. }
  540. template <typename T, typename T2, typename F>
  541. kernel void binary_alpha_dense_scalar_cast(
  542. device result_of<F, T, T, T2>* out [[buffer(0)]],
  543. constant void* input [[buffer(1)]],
  544. device void* scalar [[buffer(2)]],
  545. constant T2& alpha [[buffer(3)]],
  546. constant uint4& sizes_types [[buffer(4)]],
  547. uint tid [[thread_position_in_grid]]) {
  548. F f;
  549. const auto a = val_at_offs<T>(
  550. input, tid * sizes_types.x, static_cast<ScalarType>(sizes_types.z));
  551. const auto b =
  552. val_at_offs<T>(scalar, 0, static_cast<ScalarType>(sizes_types.w));
  553. out[tid] = f(a, b, alpha);
  554. }
  555. template <typename T, typename T2, typename F>
  556. kernel void binary_alpha_dense_scalar_lhs_cast(
  557. device result_of<F, T, T, T2>* out [[buffer(0)]],
  558. device void* scalar [[buffer(1)]],
  559. constant void* input [[buffer(2)]],
  560. constant T2& alpha [[buffer(3)]],
  561. constant uint4& sizes_types [[buffer(4)]],
  562. uint tid [[thread_position_in_grid]]) {
  563. F f;
  564. const auto a =
  565. val_at_offs<T>(scalar, 0, static_cast<ScalarType>(sizes_types.z));
  566. const auto b = val_at_offs<T>(
  567. input, tid * sizes_types.y, static_cast<ScalarType>(sizes_types.w));
  568. out[tid] = f(a, b, alpha);
  569. }
  570. #define REGISTER_BINARY_OP_(NAME, DTYPEI, DTYPEO, OMT) \
  571. static_assert( \
  572. ::metal::is_same_v< \
  573. DTYPEO, \
  574. ::c10::metal::result_of<NAME##_functor, DTYPEI, DTYPEI>>, \
  575. "Output dtype mismatch for binary op " #NAME " and input " #DTYPEI); \
  576. template [[host_name(#NAME "_strided_" #DTYPEO "_" #DTYPEI)]] kernel void :: \
  577. c10::metal::binary_strided<DTYPEI, NAME##_functor, OMT>( \
  578. device void* out, \
  579. constant void* input, \
  580. constant void* other, \
  581. constant long* sizes, \
  582. constant long* output_strides, \
  583. constant long* input_strides, \
  584. constant long* other_strides, \
  585. constant uint3& ndim, \
  586. uint tid); \
  587. template [[host_name(#NAME "_strided_cast_" #DTYPEI)]] kernel void ::c10:: \
  588. metal::binary_strided_cast<DTYPEI, NAME##_functor, OMT>( \
  589. device void* out, \
  590. constant void* input, \
  591. constant void* other, \
  592. constant long* sizes, \
  593. constant long* output_strides, \
  594. constant long* input_strides, \
  595. constant long* other_strides, \
  596. constant uint4& ndim_types, \
  597. uint tid); \
  598. template [[host_name(#NAME "_dense_" #DTYPEO "_" #DTYPEI)]] kernel void :: \
  599. c10::metal::binary_dense<DTYPEI, NAME##_functor, OMT>( \
  600. device ::c10::metal::result_of<NAME##_functor, DTYPEI, DTYPEI> * \
  601. out_, \
  602. constant DTYPEI * input_, \
  603. constant DTYPEI * other_, \
  604. uint tid); \
  605. template [[host_name(#NAME "_dense_cast_" #DTYPEI)]] kernel void ::c10:: \
  606. metal::binary_dense_cast<DTYPEI, NAME##_functor, OMT>( \
  607. device ::c10::metal::result_of<NAME##_functor, DTYPEI, DTYPEI> * \
  608. out_, \
  609. constant void* input, \
  610. constant void* other, \
  611. constant uint4& sizes_types, \
  612. uint tid); \
  613. template [[host_name(#NAME "_dense_broadcast_" #DTYPEO "_" #DTYPEI)]] \
  614. kernel void ::c10::metal:: \
  615. binary_dense_broadcast<DTYPEI, NAME##_functor, OMT>( \
  616. device ::c10::metal::result_of<NAME##_functor, DTYPEI, DTYPEI> * \
  617. out_, \
  618. constant DTYPEI * input_, \
  619. constant DTYPEI * broadcast_, \
  620. constant long& broadcast_numel, \
  621. uint tid); \
  622. template [[host_name(#NAME "_dense_broadcast_rhs_" #DTYPEO "_" #DTYPEI)]] \
  623. kernel void ::c10::metal:: \
  624. binary_dense_broadcast_rhs<DTYPEI, NAME##_functor, OMT>( \
  625. device ::c10::metal::result_of<NAME##_functor, DTYPEI, DTYPEI> * \
  626. out_, \
  627. constant DTYPEI * broadcast_, \
  628. constant DTYPEI * input_, \
  629. constant long& broadcast_numel, \
  630. uint tid); \
  631. template [[host_name(#NAME "_dense_broadcast_cast_" #DTYPEI)]] \
  632. kernel void ::c10::metal:: \
  633. binary_dense_broadcast_cast<DTYPEI, NAME##_functor, OMT>( \
  634. device ::c10::metal::result_of<NAME##_functor, DTYPEI, DTYPEI> * \
  635. out_, \
  636. constant void* input_, \
  637. constant void* broadcast_, \
  638. constant long& broadcast_numel, \
  639. constant uint4& sizes_types, \
  640. uint tid); \
  641. template [[host_name(#NAME "_dense_broadcast_rhs_cast_" #DTYPEI)]] \
  642. kernel void ::c10::metal:: \
  643. binary_dense_broadcast_rhs_cast<DTYPEI, NAME##_functor, OMT>( \
  644. device ::c10::metal::result_of<NAME##_functor, DTYPEI, DTYPEI> * \
  645. out_, \
  646. constant void* broadcast_, \
  647. constant void* input_, \
  648. constant long& broadcast_numel, \
  649. constant uint4& sizes_types, \
  650. uint tid); \
  651. template [[host_name(#NAME "_dense_scalar_" #DTYPEO "_" #DTYPEI)]] \
  652. kernel void ::c10::metal::binary_dense_scalar<DTYPEI, NAME##_functor, OMT>( \
  653. device ::c10::metal::result_of<NAME##_functor, DTYPEI, DTYPEI> * out_, \
  654. constant DTYPEI * input_, \
  655. device DTYPEI * scalar_, \
  656. uint tid); \
  657. template [[host_name(#NAME "_dense_scalar_lhs_" #DTYPEO "_" #DTYPEI)]] \
  658. kernel void ::c10::metal:: \
  659. binary_dense_scalar_lhs<DTYPEI, NAME##_functor, OMT>( \
  660. device ::c10::metal::result_of<NAME##_functor, DTYPEI, DTYPEI> * \
  661. out_, \
  662. device DTYPEI * scalar_, \
  663. constant DTYPEI * input_, \
  664. uint tid); \
  665. template [[host_name(#NAME "_dense_scalar_cast_" #DTYPEI)]] \
  666. kernel void ::c10::metal:: \
  667. binary_dense_scalar_cast<DTYPEI, NAME##_functor, OMT>( \
  668. device ::c10::metal::result_of<NAME##_functor, DTYPEI, DTYPEI> * \
  669. out_, \
  670. constant void* input_, \
  671. device void* scalar_, \
  672. constant uint4& sizes_types, \
  673. uint tid); \
  674. template [[host_name(#NAME "_dense_scalar_lhs_cast_" #DTYPEI)]] \
  675. kernel void ::c10::metal:: \
  676. binary_dense_scalar_lhs_cast<DTYPEI, NAME##_functor, OMT>( \
  677. device ::c10::metal::result_of<NAME##_functor, DTYPEI, DTYPEI> * \
  678. out_, \
  679. device void* scalar_, \
  680. constant void* input_, \
  681. constant uint4& sizes_types, \
  682. uint tid)
  683. // OpMath Binary Op promotes inputs to higher precision type before Functor call
  684. #define REGISTER_OPMATH_BINARY_OP(NAME, DTYPEI, DTYPEO) \
  685. REGISTER_BINARY_OP_(NAME, DTYPEI, DTYPEO, ::c10::metal::opmath_t<DTYPEI>)
  686. #define REGISTER_BINARY_OP(NAME, DTYPEI, DTYPEO) \
  687. REGISTER_BINARY_OP_(NAME, DTYPEI, DTYPEO, DTYPEI)
  688. #define REGISTER_BINARY_ALPHA_OP(NAME, DTYPEI, DTYPEA, DTYPEO) \
  689. static_assert( \
  690. ::metal::is_same_v< \
  691. DTYPEO, \
  692. ::c10::metal::result_of<NAME##_functor, DTYPEI, DTYPEI, DTYPEA>>, \
  693. "Output dtype mismatch for binary op " #NAME " and input " #DTYPEI); \
  694. template [[host_name(#NAME "_strided_" #DTYPEO "_" #DTYPEI \
  695. "_" #DTYPEA)]] kernel void ::c10::metal:: \
  696. binary_alpha_strided<DTYPEI, DTYPEA, NAME##_functor>( \
  697. device void* out, \
  698. constant void* input, \
  699. constant void* other, \
  700. constant DTYPEA& alpha, \
  701. constant long* sizes, \
  702. constant long* output_strides, \
  703. constant long* input_strides, \
  704. constant long* other_strides, \
  705. constant uint3& ndim, \
  706. uint tid); \
  707. template [[host_name(#NAME "_strided_cast_" #DTYPEI \
  708. "_" #DTYPEA)]] kernel void ::c10::metal:: \
  709. binary_alpha_strided_cast<DTYPEI, DTYPEA, NAME##_functor>( \
  710. device void* out, \
  711. constant void* input, \
  712. constant void* other, \
  713. constant DTYPEA& alpha, \
  714. constant long* sizes, \
  715. constant long* output_strides, \
  716. constant long* input_strides, \
  717. constant long* other_strides, \
  718. constant uint4& ndim_types, \
  719. uint tid); \
  720. template [[host_name(#NAME "_dense_" #DTYPEO "_" #DTYPEI \
  721. "_" #DTYPEA)]] kernel void ::c10::metal:: \
  722. binary_alpha_dense<DTYPEI, DTYPEA, NAME##_functor>( \
  723. device ::c10::metal:: \
  724. result_of<NAME##_functor, DTYPEI, DTYPEI, DTYPEA> * \
  725. out_, \
  726. constant DTYPEI * input_, \
  727. constant DTYPEI * other_, \
  728. constant DTYPEA & alpha, \
  729. uint tid); \
  730. template \
  731. [[host_name(#NAME "_dense_cast_" #DTYPEI "_" #DTYPEA)]] kernel void :: \
  732. c10::metal::binary_alpha_dense_cast<DTYPEI, DTYPEA, NAME##_functor>( \
  733. device ::c10::metal:: \
  734. result_of<NAME##_functor, DTYPEI, DTYPEI, DTYPEA> * \
  735. out_, \
  736. constant void* input, \
  737. constant void* other, \
  738. constant DTYPEA& alpha, \
  739. constant uint4& sizes_types, \
  740. uint tid); \
  741. template [[host_name(#NAME "_dense_broadcast_" #DTYPEO "_" #DTYPEI \
  742. "_" #DTYPEA)]] kernel void ::c10::metal:: \
  743. binary_alpha_dense_broadcast<DTYPEI, DTYPEA, NAME##_functor>( \
  744. device ::c10::metal:: \
  745. result_of<NAME##_functor, DTYPEI, DTYPEI, DTYPEA> * \
  746. out_, \
  747. constant DTYPEI * input_, \
  748. constant DTYPEI * broadcast_, \
  749. constant long& broadcast_numel, \
  750. constant DTYPEA& alpha, \
  751. uint tid); \
  752. template [[host_name(#NAME "_dense_broadcast_rhs_" #DTYPEO "_" #DTYPEI \
  753. "_" #DTYPEA)]] kernel void ::c10::metal:: \
  754. binary_alpha_dense_broadcast_rhs<DTYPEI, DTYPEA, NAME##_functor>( \
  755. device ::c10::metal:: \
  756. result_of<NAME##_functor, DTYPEI, DTYPEI, DTYPEA> * \
  757. out_, \
  758. constant DTYPEI * broadcast_, \
  759. constant DTYPEI * input_, \
  760. constant long& broadcast_numel, \
  761. constant DTYPEA& alpha, \
  762. uint tid); \
  763. template [[host_name(#NAME "_dense_broadcast_cast_" #DTYPEI \
  764. "_" #DTYPEA)]] kernel void ::c10::metal:: \
  765. binary_alpha_dense_broadcast_cast<DTYPEI, DTYPEA, NAME##_functor>( \
  766. device ::c10::metal:: \
  767. result_of<NAME##_functor, DTYPEI, DTYPEI, DTYPEA> * \
  768. out_, \
  769. constant void* input_, \
  770. constant void* broadcast_, \
  771. constant long& broadcast_numel, \
  772. constant DTYPEA& alpha, \
  773. constant uint4& sizes_types, \
  774. uint tid); \
  775. template [[host_name(#NAME "_dense_broadcast_rhs_cast_" #DTYPEI \
  776. "_" #DTYPEA)]] kernel void ::c10::metal:: \
  777. binary_alpha_dense_broadcast_rhs_cast<DTYPEI, DTYPEA, NAME##_functor>( \
  778. device ::c10::metal:: \
  779. result_of<NAME##_functor, DTYPEI, DTYPEI, DTYPEA> * \
  780. out_, \
  781. constant void* broadcast_, \
  782. constant void* input_, \
  783. constant long& broadcast_numel, \
  784. constant DTYPEA& alpha, \
  785. constant uint4& sizes_types, \
  786. uint tid); \
  787. template [[host_name(#NAME "_dense_scalar_" #DTYPEO "_" #DTYPEI \
  788. "_" #DTYPEA)]] kernel void ::c10::metal:: \
  789. binary_alpha_dense_scalar<DTYPEI, DTYPEA, NAME##_functor>( \
  790. device ::c10::metal:: \
  791. result_of<NAME##_functor, DTYPEI, DTYPEI, DTYPEA> * \
  792. out_, \
  793. constant DTYPEI * input_, \
  794. device DTYPEI * scalar_, \
  795. constant DTYPEA & alpha, \
  796. uint tid); \
  797. template [[host_name(#NAME "_dense_scalar_lhs_" #DTYPEO "_" #DTYPEI \
  798. "_" #DTYPEA)]] kernel void ::c10::metal:: \
  799. binary_alpha_dense_scalar_lhs<DTYPEI, DTYPEA, NAME##_functor>( \
  800. device ::c10::metal:: \
  801. result_of<NAME##_functor, DTYPEI, DTYPEI, DTYPEA> * \
  802. out_, \
  803. device DTYPEI * scalar_, \
  804. constant DTYPEI * input_, \
  805. constant DTYPEA & alpha, \
  806. uint tid); \
  807. template [[host_name(#NAME "_dense_scalar_cast_" #DTYPEI \
  808. "_" #DTYPEA)]] kernel void ::c10::metal:: \
  809. binary_alpha_dense_scalar_cast<DTYPEI, DTYPEA, NAME##_functor>( \
  810. device ::c10::metal:: \
  811. result_of<NAME##_functor, DTYPEI, DTYPEI, DTYPEA> * \
  812. out_, \
  813. constant void* input_, \
  814. device void* scalar_, \
  815. constant DTYPEA& alpha, \
  816. constant uint4& sizes_types, \
  817. uint tid); \
  818. template [[host_name(#NAME "_dense_scalar_lhs_cast_" #DTYPEI \
  819. "_" #DTYPEA)]] kernel void ::c10::metal:: \
  820. binary_alpha_dense_scalar_lhs_cast<DTYPEI, DTYPEA, NAME##_functor>( \
  821. device ::c10::metal:: \
  822. result_of<NAME##_functor, DTYPEI, DTYPEI, DTYPEA> * \
  823. out_, \
  824. device void* scalar_, \
  825. constant void* input_, \
  826. constant DTYPEA& alpha, \
  827. constant uint4& sizes_types, \
  828. uint tid)
  829. // Ternary elementwise ops kernels
  830. // Right now there are 4 flavors available:
  831. // - ternary_dense where both input, other1, other2, and output are dense and
  832. // share the same type
  833. // - ternary_strided when all inputs are of the same types, but some elements
  834. // are strided
  835. // - ternary_dense_cast - inputs are dense, but of different dtypes
  836. // - ternary_strided_cast - inputs or output are strided and of different dtypes
  837. // Note about accuracy (for more info see
  838. // https://github.com/pytorch/pytorch/issues/152736) Sometimes when kernel is
  839. // invoked to produce `half` output, but one of the arguments is float arguments
  840. // should be upcast to float, rather than downcast to half At the moment this is
  841. // expressed with `om_t` optional argument (which stands for opmath_type) which
  842. // is identical to output type but could be something else
  843. template <typename T, typename F, typename om_t = T>
  844. kernel void ternary_strided(
  845. device void* output [[buffer(0)]],
  846. constant void* input [[buffer(1)]],
  847. constant void* other1 [[buffer(2)]],
  848. constant void* other2 [[buffer(3)]],
  849. constant long* sizes [[buffer(4)]],
  850. constant long* output_strides [[buffer(5)]],
  851. constant long* input_strides [[buffer(6)]],
  852. constant long* other1_strides [[buffer(7)]],
  853. constant long* other2_strides [[buffer(8)]],
  854. constant uint& ndim [[buffer(9)]],
  855. uint index [[thread_position_in_grid]]) {
  856. F f;
  857. using res_t = result_of<F, T, T, T>;
  858. int pos[max_ndim];
  859. pos_from_thread_index(int(index), pos, sizes, ndim);
  860. const auto input_offs = offset_from_coord(pos, input_strides, ndim);
  861. const auto other1_offs = offset_from_coord(pos, other1_strides, ndim);
  862. const auto other2_offs = offset_from_coord(pos, other2_strides, ndim);
  863. const auto output_offs = offset_from_coord(pos, output_strides, ndim);
  864. const auto a = val_at_offs<T>(input, input_offs);
  865. const auto b = val_at_offs<T>(other1, other1_offs);
  866. const auto c = val_at_offs<T>(other2, other2_offs);
  867. ref_at_offs<res_t>(output, output_offs) =
  868. static_cast<res_t>(f(om_t(a), om_t(b), om_t(c)));
  869. }
  870. template <typename T, typename F, typename om_t = opmath_t<T>>
  871. kernel void ternary_strided_cast(
  872. device void* output [[buffer(0)]],
  873. constant void* input [[buffer(1)]],
  874. constant void* other1 [[buffer(2)]],
  875. constant void* other2 [[buffer(3)]],
  876. constant long* sizes [[buffer(4)]],
  877. constant long* output_strides [[buffer(5)]],
  878. constant long* input_strides [[buffer(6)]],
  879. constant long* other1_strides [[buffer(7)]],
  880. constant long* other2_strides [[buffer(8)]],
  881. constant uint& ndim [[buffer(9)]],
  882. constant uint4& types [[buffer(10)]],
  883. uint index [[thread_position_in_grid]]) {
  884. F f;
  885. using res_t = result_of<F, T, T, T>;
  886. int pos[max_ndim];
  887. pos_from_thread_index(int(index), pos, sizes, ndim);
  888. const auto input_offs = offset_from_coord(pos, input_strides, ndim);
  889. const auto other1_offs = offset_from_coord(pos, other1_strides, ndim);
  890. const auto other2_offs = offset_from_coord(pos, other2_strides, ndim);
  891. const auto output_offs = offset_from_coord(pos, output_strides, ndim);
  892. const auto a =
  893. val_at_offs<om_t>(input, input_offs, static_cast<ScalarType>(types.x));
  894. const auto b =
  895. val_at_offs<om_t>(other1, other1_offs, static_cast<ScalarType>(types.y));
  896. const auto c =
  897. val_at_offs<om_t>(other2, other2_offs, static_cast<ScalarType>(types.z));
  898. ref_at_offs<res_t>(output, output_offs) = static_cast<res_t>(f(a, b, c));
  899. }
  900. template <typename T, typename F, typename om_t = opmath_t<T>>
  901. kernel void ternary_dense(
  902. device result_of<F, T, T, T>* out [[buffer(0)]],
  903. constant T* input [[buffer(1)]],
  904. constant T* other1 [[buffer(2)]],
  905. constant T* other2 [[buffer(3)]],
  906. uint tid [[thread_position_in_grid]]) {
  907. F f;
  908. using res_t = result_of<F, T, T, T>;
  909. out[tid] = static_cast<res_t>(
  910. f(om_t(input[tid]), om_t(other1[tid]), om_t(other2[tid])));
  911. }
  912. template <typename T, typename F, typename om_t = T>
  913. kernel void ternary_dense_cast(
  914. device result_of<F, T, T, T>* out [[buffer(0)]],
  915. constant void* input [[buffer(1)]],
  916. constant void* other1 [[buffer(2)]],
  917. constant void* other2 [[buffer(3)]],
  918. constant uint3& sizes [[buffer(4)]],
  919. constant uint3& types [[buffer(5)]],
  920. uint tid [[thread_position_in_grid]]) {
  921. F f;
  922. using res_t = result_of<F, T, T, T>;
  923. const auto a =
  924. val_at_offs<om_t>(input, tid * sizes.x, static_cast<ScalarType>(types.x));
  925. const auto b = val_at_offs<om_t>(
  926. other1, tid * sizes.y, static_cast<ScalarType>(types.y));
  927. const auto c = val_at_offs<om_t>(
  928. other2, tid * sizes.z, static_cast<ScalarType>(types.z));
  929. out[tid] = static_cast<res_t>(f(a, b, c));
  930. }
  931. #define REGISTER_TERNARY_OP_(NAME, DTYPEI, DTYPEO, OMT) \
  932. static_assert( \
  933. ::metal::is_same_v< \
  934. DTYPEO, \
  935. ::c10::metal::result_of<NAME##_functor, DTYPEI, DTYPEI, DTYPEI>>, \
  936. "Output dtype mismatch for ternary op " #NAME " and input " #DTYPEI); \
  937. template [[host_name(#NAME "_strided_" #DTYPEO "_" #DTYPEI)]] kernel void :: \
  938. c10::metal::ternary_strided<DTYPEI, NAME##_functor, OMT>( \
  939. device void* out, \
  940. constant void* input, \
  941. constant void* other1, \
  942. constant void* other2, \
  943. constant long* sizes, \
  944. constant long* output_strides, \
  945. constant long* input_strides, \
  946. constant long* other1_strides, \
  947. constant long* other2_strides, \
  948. constant uint& ndim, \
  949. uint tid); \
  950. template [[host_name(#NAME "_strided_cast_" #DTYPEI)]] kernel void ::c10:: \
  951. metal::ternary_strided_cast<DTYPEI, NAME##_functor, OMT>( \
  952. device void* out, \
  953. constant void* input, \
  954. constant void* other1, \
  955. constant void* other2, \
  956. constant long* sizes, \
  957. constant long* output_strides, \
  958. constant long* input_strides, \
  959. constant long* other1_strides, \
  960. constant long* other2_strides, \
  961. constant uint& ndim, \
  962. constant uint4& types, \
  963. uint tid); \
  964. template [[host_name(#NAME "_dense_" #DTYPEO "_" #DTYPEI)]] kernel void :: \
  965. c10::metal::ternary_dense<DTYPEI, NAME##_functor, OMT>( \
  966. device ::c10::metal:: \
  967. result_of<NAME##_functor, DTYPEI, DTYPEI, DTYPEI> * \
  968. out_, \
  969. constant DTYPEI * input_, \
  970. constant DTYPEI * other1_, \
  971. constant DTYPEI * other2_, \
  972. uint tid); \
  973. template [[host_name(#NAME "_dense_cast_" #DTYPEI)]] kernel void ::c10:: \
  974. metal::ternary_dense_cast<DTYPEI, NAME##_functor, OMT>( \
  975. device ::c10::metal:: \
  976. result_of<NAME##_functor, DTYPEI, DTYPEI, DTYPEI> * \
  977. out_, \
  978. constant void* input, \
  979. constant void* other1, \
  980. constant void* other2, \
  981. constant uint3& sizes, \
  982. constant uint3& types, \
  983. uint tid)
  984. // OpMath ternary Op promotes inputs to higher precision type before Functor
  985. // call
  986. #define REGISTER_OPMATH_TERNARY_OP(NAME, DTYPEI, DTYPEO) \
  987. REGISTER_TERNARY_OP_(NAME, DTYPEI, DTYPEO, ::c10::metal::opmath_t<DTYPEI>)
  988. #define REGISTER_TERNARY_OP(NAME, DTYPEI, DTYPEO) \
  989. REGISTER_TERNARY_OP_(NAME, DTYPEI, DTYPEO, DTYPEI)
  990. } // namespace metal
  991. } // namespace c10
  992. #else
  993. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  994. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)