libdevice.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491
  1. from triton.language import core
  2. @core.extern
  3. def abs(arg0, _semantic=None):
  4. return core.extern_elementwise(
  5. "", "", [arg0], {
  6. (core.dtype("int32"), ): ("__triton_hip_iabs", core.dtype("int32")),
  7. (core.dtype("int64"), ): ("__triton_hip_iabs", core.dtype("int64")),
  8. (core.dtype("fp32"), ): ("__triton_hip_fabs", core.dtype("fp32")),
  9. (core.dtype("fp64"), ): ("__triton_hip_fabs", core.dtype("fp64")),
  10. }, is_pure=True, _semantic=_semantic)
  11. @core.extern
  12. def floor(arg0, _semantic=None):
  13. return core.extern_elementwise(
  14. "", "", [arg0], {
  15. (core.dtype("fp32"), ): ("__ocml_floor_f32", core.dtype("fp32")),
  16. (core.dtype("fp64"), ): ("__ocml_floor_f64", core.dtype("fp64")),
  17. }, is_pure=True, _semantic=_semantic)
  18. @core.extern
  19. def rsqrt(arg0, _semantic=None):
  20. return core.extern_elementwise(
  21. "", "", [arg0], {
  22. (core.dtype("fp32"), ): ("__ocml_rsqrt_f32", core.dtype("fp32")),
  23. (core.dtype("fp64"), ): ("__ocml_rsqrt_f64", core.dtype("fp64")),
  24. }, is_pure=True, _semantic=_semantic)
  25. @core.extern
  26. def ceil(arg0, _semantic=None):
  27. return core.extern_elementwise(
  28. "", "", [arg0], {
  29. (core.dtype("fp32"), ): ("__ocml_ceil_f32", core.dtype("fp32")),
  30. (core.dtype("fp64"), ): ("__ocml_ceil_f64", core.dtype("fp64")),
  31. }, is_pure=True, _semantic=_semantic)
  32. @core.extern
  33. def trunc(arg0, _semantic=None):
  34. return core.extern_elementwise(
  35. "", "", [arg0], {
  36. (core.dtype("fp32"), ): ("__ocml_trunc_f32", core.dtype("fp32")),
  37. (core.dtype("fp64"), ): ("__ocml_trunc_f64", core.dtype("fp64")),
  38. }, is_pure=True, _semantic=_semantic)
  39. @core.extern
  40. def exp2(arg0, _semantic=None):
  41. return core.extern_elementwise(
  42. "", "", [arg0], {
  43. (core.dtype("fp32"), ): ("__ocml_exp2_f32", core.dtype("fp32")),
  44. (core.dtype("fp64"), ): ("__ocml_exp2_f64", core.dtype("fp64")),
  45. }, is_pure=True, _semantic=_semantic)
  46. @core.extern
  47. def exp(arg0, _semantic=None):
  48. return core.extern_elementwise(
  49. "", "", [arg0], {
  50. (core.dtype("fp32"), ): ("__ocml_exp_f32", core.dtype("fp32")),
  51. (core.dtype("fp64"), ): ("__ocml_exp_f64", core.dtype("fp64")),
  52. }, is_pure=True, _semantic=_semantic)
  53. @core.extern
  54. def fast_expf(arg0, _semantic=None):
  55. return core.extern_elementwise("", "", [arg0], {
  56. (core.dtype("fp32"), ): ("__triton_hip_fast_expf", core.dtype("fp32")),
  57. }, is_pure=True, _semantic=_semantic)
  58. @core.extern
  59. def fast_tanhf(arg0, _semantic=None):
  60. return core.extern_elementwise("", "", [arg0], {
  61. (core.dtype("fp32"), ): ("__triton_hip_fast_tanhf", core.dtype("fp32")),
  62. }, is_pure=True, _semantic=_semantic)
  63. @core.extern
  64. def fast_dividef(arg0, arg1, _semantic=None):
  65. return core.extern_elementwise("", "", [arg0, arg1], {
  66. (core.dtype("fp32"), core.dtype("fp32")): ("__triton_hip_fast_fdividef", core.dtype("fp32")),
  67. }, is_pure=True, _semantic=_semantic)
  68. @core.extern
  69. def sqrt(arg0, _semantic=None):
  70. return core.extern_elementwise(
  71. "", "", [arg0], {
  72. (core.dtype("fp32"), ): ("__ocml_sqrt_f32", core.dtype("fp32")),
  73. (core.dtype("fp64"), ): ("__ocml_sqrt_f64", core.dtype("fp64")),
  74. }, is_pure=True, _semantic=_semantic)
  75. @core.extern
  76. def llrint(arg0, _semantic=None):
  77. return core.extern_elementwise(
  78. "", "", [arg0], {
  79. (core.dtype("fp32"), ): ("__triton_hip_llrint", core.dtype("int64")),
  80. (core.dtype("fp64"), ): ("__triton_hip_llrint", core.dtype("int64")),
  81. }, is_pure=True, _semantic=_semantic)
  82. @core.extern
  83. def nearbyint(arg0, _semantic=None):
  84. return core.extern_elementwise(
  85. "", "", [
  86. arg0,
  87. ], {
  88. (core.dtype("fp32"), ): ("__ocml_nearbyint_f32", core.dtype("fp32")),
  89. (core.dtype("fp64"), ): ("__ocml_nearbyint_f64", core.dtype("fp64")),
  90. }, is_pure=True, _semantic=_semantic)
  91. @core.extern
  92. def isnan(arg0, _semantic=None):
  93. return core.extern_elementwise(
  94. "", "", [
  95. arg0,
  96. ], {
  97. (core.dtype("fp32"), ): ("__ocml_isnan_f32", core.dtype("int32")),
  98. (core.dtype("fp64"), ): ("__ocml_isnan_f64", core.dtype("int32")),
  99. }, is_pure=True, _semantic=_semantic).to(core.int1, _semantic=_semantic)
  100. @core.extern
  101. def signbit(arg0, _semantic=None):
  102. return core.extern_elementwise(
  103. "", "", [
  104. arg0,
  105. ], {
  106. (core.dtype("fp32"), ): ("__ocml_signbit_f32", core.dtype("int32")),
  107. (core.dtype("fp64"), ): ("__ocml_signbit_f64", core.dtype("int32")),
  108. }, is_pure=True, _semantic=_semantic)
  109. @core.extern
  110. def copysign(arg0, arg1, _semantic=None):
  111. return core.extern_elementwise(
  112. "", "", [arg0, arg1], {
  113. (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_copysign_f32", core.dtype("fp32")),
  114. (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_copysign_f64", core.dtype("fp64")),
  115. }, is_pure=True, _semantic=_semantic)
  116. @core.extern
  117. def isinf(arg0, _semantic=None):
  118. return core.extern_elementwise(
  119. "", "", [arg0], {
  120. (core.dtype("fp32"), ): ("__ocml_isinf_f32", core.dtype("int32")),
  121. (core.dtype("fp64"), ): ("__ocml_isinf_f64", core.dtype("int32")),
  122. }, is_pure=True, _semantic=_semantic).to(core.int1, _semantic=_semantic)
  123. @core.extern
  124. def nextafter(arg0, arg1, _semantic=None):
  125. return core.extern_elementwise(
  126. "", "", [arg0, arg1], {
  127. (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_nextafter_f32", core.dtype("fp32")),
  128. (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_nextafter_f64", core.dtype("fp64")),
  129. }, is_pure=True, _semantic=_semantic)
  130. @core.extern
  131. def sin(arg0, _semantic=None):
  132. return core.extern_elementwise(
  133. "", "", [arg0], {
  134. (core.dtype("fp32"), ): ("__ocml_sin_f32", core.dtype("fp32")),
  135. (core.dtype("fp64"), ): ("__ocml_sin_f64", core.dtype("fp64")),
  136. }, is_pure=True, _semantic=_semantic)
  137. @core.extern
  138. def cos(arg0, _semantic=None):
  139. return core.extern_elementwise(
  140. "", "", [arg0], {
  141. (core.dtype("fp32"), ): ("__ocml_cos_f32", core.dtype("fp32")),
  142. (core.dtype("fp64"), ): ("__ocml_cos_f64", core.dtype("fp64")),
  143. }, is_pure=True, _semantic=_semantic)
  144. @core.extern
  145. def tan(arg0, _semantic=None):
  146. return core.extern_elementwise(
  147. "", "", [arg0], {
  148. (core.dtype("fp32"), ): ("__ocml_tan_f32", core.dtype("fp32")),
  149. (core.dtype("fp64"), ): ("__ocml_tan_f64", core.dtype("fp64")),
  150. }, is_pure=True, _semantic=_semantic)
  151. @core.extern
  152. def log2(arg0, _semantic=None):
  153. return core.extern_elementwise(
  154. "", "", [arg0], {
  155. (core.dtype("fp32"), ): ("__ocml_log2_f32", core.dtype("fp32")),
  156. (core.dtype("fp64"), ): ("__ocml_log2_f64", core.dtype("fp64")),
  157. }, is_pure=True, _semantic=_semantic)
  158. @core.extern
  159. def cosh(arg0, _semantic=None):
  160. return core.extern_elementwise(
  161. "", "", [arg0], {
  162. (core.dtype("fp32"), ): ("__ocml_cosh_f32", core.dtype("fp32")),
  163. (core.dtype("fp64"), ): ("__ocml_cosh_f64", core.dtype("fp64")),
  164. }, is_pure=True, _semantic=_semantic)
  165. @core.extern
  166. def sinh(arg0, _semantic=None):
  167. return core.extern_elementwise(
  168. "", "", [arg0], {
  169. (core.dtype("fp32"), ): ("__ocml_sinh_f32", core.dtype("fp32")),
  170. (core.dtype("fp64"), ): ("__ocml_sinh_f64", core.dtype("fp64")),
  171. }, is_pure=True, _semantic=_semantic)
  172. @core.extern
  173. def tanh(arg0, _semantic=None):
  174. return core.extern_elementwise(
  175. "", "", [arg0], {
  176. (core.dtype("fp32"), ): ("__ocml_tanh_f32", core.dtype("fp32")),
  177. (core.dtype("fp64"), ): ("__ocml_tanh_f64", core.dtype("fp64")),
  178. }, is_pure=True, _semantic=_semantic)
  179. @core.extern
  180. def atan2(arg0, arg1, _semantic=None):
  181. return core.extern_elementwise(
  182. "", "", [arg0, arg1], {
  183. (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_atan2_f32", core.dtype("fp32")),
  184. (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_atan2_f64", core.dtype("fp64")),
  185. }, is_pure=True, _semantic=_semantic)
  186. @core.extern
  187. def atan(arg0, _semantic=None):
  188. return core.extern_elementwise(
  189. "", "", [arg0], {
  190. (core.dtype("fp32"), ): ("__ocml_atan_f32", core.dtype("fp32")),
  191. (core.dtype("fp64"), ): ("__ocml_atan_f64", core.dtype("fp64")),
  192. }, is_pure=True, _semantic=_semantic)
  193. @core.extern
  194. def asin(arg0, _semantic=None):
  195. return core.extern_elementwise(
  196. "", "", [arg0], {
  197. (core.dtype("fp32"), ): ("__ocml_asin_f32", core.dtype("fp32")),
  198. (core.dtype("fp64"), ): ("__ocml_asin_f64", core.dtype("fp64")),
  199. }, is_pure=True, _semantic=_semantic)
  200. @core.extern
  201. def acos(arg0, _semantic=None):
  202. return core.extern_elementwise(
  203. "", "", [arg0], {
  204. (core.dtype("fp32"), ): ("__ocml_acos_f32", core.dtype("fp32")),
  205. (core.dtype("fp64"), ): ("__ocml_acos_f64", core.dtype("fp64")),
  206. }, is_pure=True, _semantic=_semantic)
  207. @core.extern
  208. def log(arg0, _semantic=None):
  209. return core.extern_elementwise(
  210. "", "", [arg0], {
  211. (core.dtype("fp32"), ): ("__ocml_log_f32", core.dtype("fp32")),
  212. (core.dtype("fp64"), ): ("__ocml_log_f64", core.dtype("fp64")),
  213. }, is_pure=True, _semantic=_semantic)
  214. @core.extern
  215. def log10(arg0, _semantic=None):
  216. return core.extern_elementwise(
  217. "", "", [arg0], {
  218. (core.dtype("fp32"), ): ("__ocml_log10_f32", core.dtype("fp32")),
  219. (core.dtype("fp64"), ): ("__ocml_log10_f64", core.dtype("fp64")),
  220. }, is_pure=True, _semantic=_semantic)
  221. @core.extern
  222. def log1p(arg0, _semantic=None):
  223. return core.extern_elementwise(
  224. "", "", [arg0], {
  225. (core.dtype("fp32"), ): ("__ocml_log1p_f32", core.dtype("fp32")),
  226. (core.dtype("fp64"), ): ("__ocml_log1p_f64", core.dtype("fp64")),
  227. }, is_pure=True, _semantic=_semantic)
  228. @core.extern
  229. def acosh(arg0, _semantic=None):
  230. return core.extern_elementwise(
  231. "", "", [arg0], {
  232. (core.dtype("fp32"), ): ("__ocml_acosh_f32", core.dtype("fp32")),
  233. (core.dtype("fp64"), ): ("__ocml_acosh_f64", core.dtype("fp64")),
  234. }, is_pure=True, _semantic=_semantic)
  235. @core.extern
  236. def asinh(arg0, _semantic=None):
  237. return core.extern_elementwise(
  238. "", "", [arg0], {
  239. (core.dtype("fp32"), ): ("__ocml_asinh_f32", core.dtype("fp32")),
  240. (core.dtype("fp64"), ): ("__ocml_asinh_f64", core.dtype("fp64")),
  241. }, is_pure=True, _semantic=_semantic)
  242. @core.extern
  243. def atanh(arg0, _semantic=None):
  244. return core.extern_elementwise(
  245. "", "", [arg0], {
  246. (core.dtype("fp32"), ): ("__ocml_atanh_f32", core.dtype("fp32")),
  247. (core.dtype("fp64"), ): ("__ocml_atanh_f64", core.dtype("fp64")),
  248. }, is_pure=True, _semantic=_semantic)
  249. @core.extern
  250. def expm1(arg0, _semantic=None):
  251. return core.extern_elementwise(
  252. "", "", [arg0], {
  253. (core.dtype("fp32"), ): ("__ocml_expm1_f32", core.dtype("fp32")),
  254. (core.dtype("fp64"), ): ("__ocml_expm1_f64", core.dtype("fp64")),
  255. }, is_pure=True, _semantic=_semantic)
  256. @core.extern
  257. def hypot(arg0, arg1, _semantic=None):
  258. return core.extern_elementwise(
  259. "", "", [arg0, arg1], {
  260. (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_hypot_f32", core.dtype("fp32")),
  261. (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_hypot_f64", core.dtype("fp64")),
  262. }, is_pure=True, _semantic=_semantic)
  263. @core.extern
  264. def j0(arg0, _semantic=None):
  265. return core.extern_elementwise(
  266. "", "", [arg0], {
  267. (core.dtype("fp32"), ): ("__ocml_j0_f32", core.dtype("fp32")),
  268. (core.dtype("fp64"), ): ("__ocml_j0_f64", core.dtype("fp64")),
  269. }, is_pure=True, _semantic=_semantic)
  270. @core.extern
  271. def j1(arg0, _semantic=None):
  272. return core.extern_elementwise(
  273. "", "", [arg0], {
  274. (core.dtype("fp32"), ): ("__ocml_j1_f32", core.dtype("fp32")),
  275. (core.dtype("fp64"), ): ("__ocml_j1_f64", core.dtype("fp64")),
  276. }, is_pure=True, _semantic=_semantic)
  277. @core.extern
  278. def y0(arg0, _semantic=None):
  279. return core.extern_elementwise(
  280. "", "", [arg0], {
  281. (core.dtype("fp32"), ): ("__ocml_y0_f32", core.dtype("fp32")),
  282. (core.dtype("fp64"), ): ("__ocml_y0_f64", core.dtype("fp64")),
  283. }, is_pure=True, _semantic=_semantic)
  284. @core.extern
  285. def y1(arg0, _semantic=None):
  286. return core.extern_elementwise(
  287. "", "", [arg0], {
  288. (core.dtype("fp32"), ): ("__ocml_y1_f32", core.dtype("fp32")),
  289. (core.dtype("fp64"), ): ("__ocml_y1_f64", core.dtype("fp64")),
  290. }, is_pure=True, _semantic=_semantic)
  291. @core.extern
  292. def cyl_bessel_i0(arg0, _semantic=None):
  293. return core.extern_elementwise(
  294. "", "", [arg0], {
  295. (core.dtype("fp32"), ): ("__ocml_i0_f32", core.dtype("fp32")),
  296. (core.dtype("fp64"), ): ("__ocml_i0_f64", core.dtype("fp64")),
  297. }, is_pure=True, _semantic=_semantic)
  298. @core.extern
  299. def cyl_bessel_i1(arg0, _semantic=None):
  300. return core.extern_elementwise(
  301. "", "", [arg0], {
  302. (core.dtype("fp32"), ): ("__ocml_i1_f32", core.dtype("fp32")),
  303. (core.dtype("fp64"), ): ("__ocml_i1_f64", core.dtype("fp64")),
  304. }, is_pure=True, _semantic=_semantic)
  305. @core.extern
  306. def erf(arg0, _semantic=None):
  307. return core.extern_elementwise(
  308. "", "", [arg0], {
  309. (core.dtype("fp32"), ): ("__ocml_erf_f32", core.dtype("fp32")),
  310. (core.dtype("fp64"), ): ("__ocml_erf_f64", core.dtype("fp64")),
  311. }, is_pure=True, _semantic=_semantic)
  312. @core.extern
  313. def erfinv(arg0, _semantic=None):
  314. return core.extern_elementwise(
  315. "", "", [arg0], {
  316. (core.dtype("fp32"), ): ("__ocml_erfinv_f32", core.dtype("fp32")),
  317. (core.dtype("fp64"), ): ("__ocml_erfinv_f64", core.dtype("fp64")),
  318. }, is_pure=True, _semantic=_semantic)
  319. @core.extern
  320. def erfc(arg0, _semantic=None):
  321. return core.extern_elementwise(
  322. "", "", [arg0], {
  323. (core.dtype("fp32"), ): ("__ocml_erfc_f32", core.dtype("fp32")),
  324. (core.dtype("fp64"), ): ("__ocml_erfc_f64", core.dtype("fp64")),
  325. }, is_pure=True, _semantic=_semantic)
  326. @core.extern
  327. def erfcx(arg0, _semantic=None):
  328. return core.extern_elementwise(
  329. "", "", [arg0], {
  330. (core.dtype("fp32"), ): ("__ocml_erfcx_f32", core.dtype("fp32")),
  331. (core.dtype("fp64"), ): ("__ocml_erfcx_f64", core.dtype("fp64")),
  332. }, is_pure=True, _semantic=_semantic)
  333. @core.extern
  334. def lgamma(arg0, _semantic=None):
  335. return core.extern_elementwise(
  336. "", "", [arg0], {
  337. (core.dtype("fp32"), ): ("__ocml_lgamma_f32", core.dtype("fp32")),
  338. (core.dtype("fp64"), ): ("__ocml_lgamma_f64", core.dtype("fp64")),
  339. }, is_pure=True, _semantic=_semantic)
  340. @core.extern
  341. def ldexp(arg0, arg1, _semantic=None):
  342. return core.extern_elementwise(
  343. "", "", [arg0, arg1], {
  344. (core.dtype("fp32"), core.dtype("int32")): ("__ocml_ldexp_f32", core.dtype("fp32")),
  345. (core.dtype("fp64"), core.dtype("int32")): ("__ocml_ldexp_f64", core.dtype("fp64")),
  346. }, is_pure=True, _semantic=_semantic)
  347. @core.extern
  348. def fmod(arg0, arg1, _semantic=None):
  349. return core.extern_elementwise(
  350. "", "", [arg0, arg1], {
  351. (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_fmod_f32", core.dtype("fp32")),
  352. (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_fmod_f64", core.dtype("fp64")),
  353. }, is_pure=True, _semantic=_semantic)
  354. @core.extern
  355. def fma(arg0, arg1, arg2, _semantic=None):
  356. return core.extern_elementwise(
  357. "", "", [arg0, arg1, arg2], {
  358. (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__ocml_fma_f32", core.dtype("fp32")),
  359. (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__ocml_fma_f64", core.dtype("fp64")),
  360. }, is_pure=True, _semantic=_semantic)
  361. @core.extern
  362. def pow(arg0, arg1, _semantic=None):
  363. return core.extern_elementwise(
  364. "", "", [arg0, arg1], {
  365. (core.dtype("fp32"), core.dtype("int32")): ("__ocml_pown_f32", core.dtype("fp32")),
  366. (core.dtype("fp64"), core.dtype("int32")): ("__ocml_pown_f64", core.dtype("fp64")),
  367. (core.dtype("fp32"), core.dtype("fp32")): ("__ocml_pow_f32", core.dtype("fp32")),
  368. (core.dtype("fp64"), core.dtype("fp64")): ("__ocml_pow_f64", core.dtype("fp64")),
  369. }, is_pure=True, _semantic=_semantic)
  370. @core.extern
  371. def ilogb(arg0, _semantic=None):
  372. return core.extern_elementwise(
  373. "", "", [arg0], {
  374. (core.dtype("fp32"), ): ("__ocml_ilogb_f32", core.dtype("int32")),
  375. (core.dtype("fp64"), ): ("__ocml_ilogb_f64", core.dtype("int32")),
  376. }, is_pure=True, _semantic=_semantic)
  377. @core.extern
  378. def round(arg0, _semantic=None):
  379. return core.extern_elementwise(
  380. "", "", [arg0], {
  381. (core.dtype("fp32"), ): ("__ocml_round_f32", core.dtype("fp32")),
  382. (core.dtype("fp64"), ): ("__ocml_round_f64", core.dtype("fp64")),
  383. }, is_pure=True, _semantic=_semantic)