| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249 |
- from . import core
- from functools import wraps
- from typing import List
- T = core.TypeVar('T')
- def _check_dtype(dtypes: List[str]) -> T:
- """
- We're following libdevice's convention to check accepted data types for math functions.
- It is not a good practice to support all data types as accelerators/GPUs don't support
- many float16 and bfloat16 math operations.
- We should let the users know that they are using and invoke explicit cast to convert
- the data type to the supported one.
- """
- def wrapper(fn):
- @wraps(fn)
- def check(*args, **kwargs):
- # concatenate args and kwargs
- all_args = list(args) + list(kwargs.values())
- for arg in [a for a in all_args if isinstance(a, core.tensor)]:
- if arg.type.scalar.name not in dtypes:
- raise ValueError(f"Expected dtype {dtypes} but got {arg.type.scalar.name}")
- return fn(*args, **kwargs)
- return check
- return wrapper
- def _add_math_1arg_docstr(name: str) -> core.Callable[[T], T]:
- def _decorator(func: T) -> T:
- docstr = """
- Computes the element-wise {name} of :code:`x`.
- :param x: the input values
- :type x: Block
- """
- func.__doc__ = docstr.format(name=name)
- return func
- return _decorator
- def _add_math_2arg_docstr(name: str) -> core.Callable[[T], T]:
- def _decorator(func: T) -> T:
- docstr = """
- Computes the element-wise {name} of :code:`x` and :code:`y`.
- :param x: the input values
- :type x: Block
- :param y: the input values
- :type y: Block
- """
- func.__doc__ = docstr.format(name=name)
- return func
- return _decorator
- def _add_math_3arg_docstr(name: str) -> core.Callable[[T], T]:
- def _decorator(func: T) -> T:
- docstr = """
- Computes the element-wise {name} of :code:`x`, :code:`y`, and :code:`z`.
- :param x: the input values
- :type x: Block
- :param y: the input values
- :type y: Block
- :param z: the input values
- :type z: Block
- """
- func.__doc__ = docstr.format(name=name)
- return func
- return _decorator
- @core.builtin
- @_check_dtype(dtypes=["int32", "int64", "uint32", "uint64"])
- @_add_math_2arg_docstr("most significant N bits of the 2N-bit product")
- def umulhi(x, y, _semantic=None):
- x = _semantic.to_tensor(x)
- y = _semantic.to_tensor(y)
- x, y = core.binary_op_type_legalization(x, y, _semantic)
- return core.tensor(_semantic.builder.create_umulhi(x.handle, y.handle), x.type)
- @core.builtin
- @_check_dtype(dtypes=["fp32", "fp64"])
- @_add_math_1arg_docstr("exponential")
- @core._tensor_member_fn
- def exp(x, _semantic=None):
- x = _semantic.to_tensor(x)
- return core.tensor(_semantic.builder.create_exp(x.handle), x.type)
- @core.builtin
- @_check_dtype(dtypes=["fp32", "fp64"])
- @_add_math_1arg_docstr("exponential (base 2)")
- @core._tensor_member_fn
- def exp2(x, _semantic=None):
- x = _semantic.to_tensor(x)
- return core.tensor(_semantic.builder.create_exp2(x.handle), x.type)
- @core.builtin
- @_check_dtype(dtypes=["fp32", "fp64"])
- @_add_math_1arg_docstr("natural logarithm")
- @core._tensor_member_fn
- def log(x, _semantic=None):
- x = _semantic.to_tensor(x)
- return core.tensor(_semantic.builder.create_log(x.handle), x.type)
- @core.builtin
- @_check_dtype(dtypes=["fp32", "fp64"])
- @_add_math_1arg_docstr("logarithm (base 2)")
- @core._tensor_member_fn
- def log2(x, _semantic=None):
- x = _semantic.to_tensor(x)
- return core.tensor(_semantic.builder.create_log2(x.handle), x.type)
- @core.builtin
- @_check_dtype(dtypes=["fp32", "fp64"])
- @_add_math_1arg_docstr("cosine")
- @core._tensor_member_fn
- def cos(x, _semantic=None):
- x = _semantic.to_tensor(x)
- return core.tensor(_semantic.builder.create_cos(x.handle), x.type)
- @core.builtin
- @_check_dtype(dtypes=["fp32", "fp64"])
- @_add_math_1arg_docstr("sine")
- @core._tensor_member_fn
- def sin(x, _semantic=None):
- x = _semantic.to_tensor(x)
- return core.tensor(_semantic.builder.create_sin(x.handle), x.type)
- @core.builtin
- @_check_dtype(dtypes=["fp32", "fp64"])
- @_add_math_1arg_docstr("fast square root")
- @core._tensor_member_fn
- def sqrt(x, _semantic=None):
- x = _semantic.to_tensor(x)
- return core.tensor(_semantic.builder.create_sqrt(x.handle), x.type)
- @core.builtin
- @_check_dtype(dtypes=["fp32"])
- @_add_math_1arg_docstr("precise square root (rounding to nearest wrt the IEEE standard)")
- @core._tensor_member_fn
- def sqrt_rn(x, _semantic=None):
- x = _semantic.to_tensor(x)
- return core.tensor(_semantic.builder.create_precise_sqrt(x.handle), x.type)
- @core.builtin
- @_check_dtype(dtypes=["fp32", "fp64"])
- @_add_math_1arg_docstr("inverse square root")
- @core._tensor_member_fn
- def rsqrt(x, _semantic=None):
- x = _semantic.to_tensor(x)
- return core.tensor(_semantic.builder.create_rsqrt(x.handle), x.type)
- @core._tensor_member_fn
- @core.builtin
- @_add_math_1arg_docstr("absolute value")
- def abs(x, _semantic=None):
- x = _semantic.to_tensor(x)
- dtype = x.dtype
- if dtype.is_fp8e4b15():
- mask = core.full(x.shape, 0x7F, core.int8, _semantic=_semantic)
- return core.tensor(_semantic.builder.create_and(x.handle, mask.handle), x.type)
- elif dtype.is_floating():
- return core.tensor(_semantic.builder.create_fabs(x.handle), x.type)
- elif dtype.is_int_signed():
- return core.tensor(_semantic.builder.create_iabs(x.handle), x.type)
- elif dtype.is_int_unsigned():
- return x # no-op
- else:
- assert False, f"Unexpected dtype {dtype}"
- @core.builtin
- @_add_math_2arg_docstr("fast division")
- def fdiv(x, y, ieee_rounding=False, _semantic=None):
- ieee_rounding = core._unwrap_if_constexpr(ieee_rounding)
- x = _semantic.to_tensor(x)
- y = _semantic.to_tensor(y)
- return _semantic.fdiv(x, y, ieee_rounding)
- @core.builtin
- @_check_dtype(dtypes=["fp32"])
- @_add_math_2arg_docstr("precise division (rounding to nearest wrt the IEEE standard)")
- def div_rn(x, y, _semantic=None):
- x = _semantic.to_tensor(x)
- y = _semantic.to_tensor(y)
- x, y = core.binary_op_type_legalization(x, y, _semantic)
- return core.tensor(_semantic.builder.create_precise_divf(x.handle, y.handle), x.type)
- @core.builtin
- @_check_dtype(dtypes=["fp32", "fp64"])
- @_add_math_1arg_docstr("error function")
- @core._tensor_member_fn
- def erf(x, _semantic=None):
- x = _semantic.to_tensor(x)
- return core.tensor(_semantic.builder.create_erf(x.handle), x.type)
- @core.builtin
- @_check_dtype(dtypes=["fp32", "fp64"])
- @_add_math_1arg_docstr("floor")
- @core._tensor_member_fn
- def floor(x, _semantic=None):
- x = _semantic.to_tensor(x)
- return core.tensor(_semantic.builder.create_floor(x.handle), x.type)
- @core.builtin
- @_check_dtype(dtypes=["fp32", "fp64"])
- @_add_math_1arg_docstr("ceil")
- @core._tensor_member_fn
- def ceil(x, _semantic=None):
- x = _semantic.to_tensor(x)
- return core.tensor(_semantic.builder.create_ceil(x.handle), x.type)
- @core.builtin
- @_add_math_3arg_docstr("fused multiply-add")
- def fma(x, y, z, _semantic=None):
- x = _semantic.to_tensor(x)
- y = _semantic.to_tensor(y)
- z = _semantic.to_tensor(z)
- x, y = core.binary_op_type_legalization(x, y, _semantic)
- z, x = core.binary_op_type_legalization(z, x, _semantic)
- z, y = core.binary_op_type_legalization(z, y, _semantic)
- return core.tensor(_semantic.builder.create_fma(x.handle, y.handle, z.handle), x.type)
|