math.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. from . import core
  2. from functools import wraps
  3. from typing import List
  4. T = core.TypeVar('T')
  5. def _check_dtype(dtypes: List[str]) -> T:
  6. """
  7. We're following libdevice's convention to check accepted data types for math functions.
  8. It is not a good practice to support all data types as accelerators/GPUs don't support
  9. many float16 and bfloat16 math operations.
  10. We should let the users know that they are using and invoke explicit cast to convert
  11. the data type to the supported one.
  12. """
  13. def wrapper(fn):
  14. @wraps(fn)
  15. def check(*args, **kwargs):
  16. # concatenate args and kwargs
  17. all_args = list(args) + list(kwargs.values())
  18. for arg in [a for a in all_args if isinstance(a, core.tensor)]:
  19. if arg.type.scalar.name not in dtypes:
  20. raise ValueError(f"Expected dtype {dtypes} but got {arg.type.scalar.name}")
  21. return fn(*args, **kwargs)
  22. return check
  23. return wrapper
  24. def _add_math_1arg_docstr(name: str) -> core.Callable[[T], T]:
  25. def _decorator(func: T) -> T:
  26. docstr = """
  27. Computes the element-wise {name} of :code:`x`.
  28. :param x: the input values
  29. :type x: Block
  30. """
  31. func.__doc__ = docstr.format(name=name)
  32. return func
  33. return _decorator
  34. def _add_math_2arg_docstr(name: str) -> core.Callable[[T], T]:
  35. def _decorator(func: T) -> T:
  36. docstr = """
  37. Computes the element-wise {name} of :code:`x` and :code:`y`.
  38. :param x: the input values
  39. :type x: Block
  40. :param y: the input values
  41. :type y: Block
  42. """
  43. func.__doc__ = docstr.format(name=name)
  44. return func
  45. return _decorator
  46. def _add_math_3arg_docstr(name: str) -> core.Callable[[T], T]:
  47. def _decorator(func: T) -> T:
  48. docstr = """
  49. Computes the element-wise {name} of :code:`x`, :code:`y`, and :code:`z`.
  50. :param x: the input values
  51. :type x: Block
  52. :param y: the input values
  53. :type y: Block
  54. :param z: the input values
  55. :type z: Block
  56. """
  57. func.__doc__ = docstr.format(name=name)
  58. return func
  59. return _decorator
  60. @core.builtin
  61. @_check_dtype(dtypes=["int32", "int64", "uint32", "uint64"])
  62. @_add_math_2arg_docstr("most significant N bits of the 2N-bit product")
  63. def umulhi(x, y, _semantic=None):
  64. x = _semantic.to_tensor(x)
  65. y = _semantic.to_tensor(y)
  66. x, y = core.binary_op_type_legalization(x, y, _semantic)
  67. return core.tensor(_semantic.builder.create_umulhi(x.handle, y.handle), x.type)
  68. @core.builtin
  69. @_check_dtype(dtypes=["fp32", "fp64"])
  70. @_add_math_1arg_docstr("exponential")
  71. @core._tensor_member_fn
  72. def exp(x, _semantic=None):
  73. x = _semantic.to_tensor(x)
  74. return core.tensor(_semantic.builder.create_exp(x.handle), x.type)
  75. @core.builtin
  76. @_check_dtype(dtypes=["fp32", "fp64"])
  77. @_add_math_1arg_docstr("exponential (base 2)")
  78. @core._tensor_member_fn
  79. def exp2(x, _semantic=None):
  80. x = _semantic.to_tensor(x)
  81. return core.tensor(_semantic.builder.create_exp2(x.handle), x.type)
  82. @core.builtin
  83. @_check_dtype(dtypes=["fp32", "fp64"])
  84. @_add_math_1arg_docstr("natural logarithm")
  85. @core._tensor_member_fn
  86. def log(x, _semantic=None):
  87. x = _semantic.to_tensor(x)
  88. return core.tensor(_semantic.builder.create_log(x.handle), x.type)
  89. @core.builtin
  90. @_check_dtype(dtypes=["fp32", "fp64"])
  91. @_add_math_1arg_docstr("logarithm (base 2)")
  92. @core._tensor_member_fn
  93. def log2(x, _semantic=None):
  94. x = _semantic.to_tensor(x)
  95. return core.tensor(_semantic.builder.create_log2(x.handle), x.type)
  96. @core.builtin
  97. @_check_dtype(dtypes=["fp32", "fp64"])
  98. @_add_math_1arg_docstr("cosine")
  99. @core._tensor_member_fn
  100. def cos(x, _semantic=None):
  101. x = _semantic.to_tensor(x)
  102. return core.tensor(_semantic.builder.create_cos(x.handle), x.type)
  103. @core.builtin
  104. @_check_dtype(dtypes=["fp32", "fp64"])
  105. @_add_math_1arg_docstr("sine")
  106. @core._tensor_member_fn
  107. def sin(x, _semantic=None):
  108. x = _semantic.to_tensor(x)
  109. return core.tensor(_semantic.builder.create_sin(x.handle), x.type)
  110. @core.builtin
  111. @_check_dtype(dtypes=["fp32", "fp64"])
  112. @_add_math_1arg_docstr("fast square root")
  113. @core._tensor_member_fn
  114. def sqrt(x, _semantic=None):
  115. x = _semantic.to_tensor(x)
  116. return core.tensor(_semantic.builder.create_sqrt(x.handle), x.type)
  117. @core.builtin
  118. @_check_dtype(dtypes=["fp32"])
  119. @_add_math_1arg_docstr("precise square root (rounding to nearest wrt the IEEE standard)")
  120. @core._tensor_member_fn
  121. def sqrt_rn(x, _semantic=None):
  122. x = _semantic.to_tensor(x)
  123. return core.tensor(_semantic.builder.create_precise_sqrt(x.handle), x.type)
  124. @core.builtin
  125. @_check_dtype(dtypes=["fp32", "fp64"])
  126. @_add_math_1arg_docstr("inverse square root")
  127. @core._tensor_member_fn
  128. def rsqrt(x, _semantic=None):
  129. x = _semantic.to_tensor(x)
  130. return core.tensor(_semantic.builder.create_rsqrt(x.handle), x.type)
  131. @core._tensor_member_fn
  132. @core.builtin
  133. @_add_math_1arg_docstr("absolute value")
  134. def abs(x, _semantic=None):
  135. x = _semantic.to_tensor(x)
  136. dtype = x.dtype
  137. if dtype.is_fp8e4b15():
  138. mask = core.full(x.shape, 0x7F, core.int8, _semantic=_semantic)
  139. return core.tensor(_semantic.builder.create_and(x.handle, mask.handle), x.type)
  140. elif dtype.is_floating():
  141. return core.tensor(_semantic.builder.create_fabs(x.handle), x.type)
  142. elif dtype.is_int_signed():
  143. return core.tensor(_semantic.builder.create_iabs(x.handle), x.type)
  144. elif dtype.is_int_unsigned():
  145. return x # no-op
  146. else:
  147. assert False, f"Unexpected dtype {dtype}"
  148. @core.builtin
  149. @_add_math_2arg_docstr("fast division")
  150. def fdiv(x, y, ieee_rounding=False, _semantic=None):
  151. ieee_rounding = core._unwrap_if_constexpr(ieee_rounding)
  152. x = _semantic.to_tensor(x)
  153. y = _semantic.to_tensor(y)
  154. return _semantic.fdiv(x, y, ieee_rounding)
  155. @core.builtin
  156. @_check_dtype(dtypes=["fp32"])
  157. @_add_math_2arg_docstr("precise division (rounding to nearest wrt the IEEE standard)")
  158. def div_rn(x, y, _semantic=None):
  159. x = _semantic.to_tensor(x)
  160. y = _semantic.to_tensor(y)
  161. x, y = core.binary_op_type_legalization(x, y, _semantic)
  162. return core.tensor(_semantic.builder.create_precise_divf(x.handle, y.handle), x.type)
  163. @core.builtin
  164. @_check_dtype(dtypes=["fp32", "fp64"])
  165. @_add_math_1arg_docstr("error function")
  166. @core._tensor_member_fn
  167. def erf(x, _semantic=None):
  168. x = _semantic.to_tensor(x)
  169. return core.tensor(_semantic.builder.create_erf(x.handle), x.type)
  170. @core.builtin
  171. @_check_dtype(dtypes=["fp32", "fp64"])
  172. @_add_math_1arg_docstr("floor")
  173. @core._tensor_member_fn
  174. def floor(x, _semantic=None):
  175. x = _semantic.to_tensor(x)
  176. return core.tensor(_semantic.builder.create_floor(x.handle), x.type)
  177. @core.builtin
  178. @_check_dtype(dtypes=["fp32", "fp64"])
  179. @_add_math_1arg_docstr("ceil")
  180. @core._tensor_member_fn
  181. def ceil(x, _semantic=None):
  182. x = _semantic.to_tensor(x)
  183. return core.tensor(_semantic.builder.create_ceil(x.handle), x.type)
  184. @core.builtin
  185. @_add_math_3arg_docstr("fused multiply-add")
  186. def fma(x, y, z, _semantic=None):
  187. x = _semantic.to_tensor(x)
  188. y = _semantic.to_tensor(y)
  189. z = _semantic.to_tensor(z)
  190. x, y = core.binary_op_type_legalization(x, y, _semantic)
  191. z, x = core.binary_op_type_legalization(z, x, _semantic)
  192. z, y = core.binary_op_type_legalization(z, y, _semantic)
  193. return core.tensor(_semantic.builder.create_fma(x.handle, y.handle, z.handle), x.type)