| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350 |
- """isort:skip_file"""
- # Import order is significant here.
- from . import math
- from . import extra
- from .standard import (
- argmax,
- argmin,
- bitonic_merge,
- cdiv,
- cumprod,
- cumsum,
- flip,
- interleave,
- max,
- min,
- ravel,
- reduce_or,
- sigmoid,
- softmax,
- sort,
- sum,
- swizzle2d,
- topk,
- xor_sum,
- zeros,
- zeros_like,
- )
- from .core import (
- PropagateNan,
- TRITON_MAX_TENSOR_NUMEL,
- load_tensor_descriptor,
- store_tensor_descriptor,
- make_tensor_descriptor,
- tensor_descriptor,
- tensor_descriptor_type,
- add,
- advance,
- arange,
- associative_scan,
- assume,
- atomic_add,
- atomic_and,
- atomic_cas,
- atomic_max,
- atomic_min,
- atomic_or,
- atomic_xchg,
- atomic_xor,
- bfloat16,
- block_type,
- broadcast,
- broadcast_to,
- cat,
- cast,
- clamp,
- condition,
- const,
- constexpr,
- constexpr_type,
- debug_barrier,
- device_assert,
- device_print,
- dot,
- dot_scaled,
- dtype,
- expand_dims,
- float16,
- float32,
- float64,
- float8e4b15,
- float8e4nv,
- float8e4b8,
- float8e5,
- float8e5b16,
- full,
- gather,
- histogram,
- inline_asm_elementwise,
- int1,
- int16,
- int32,
- int64,
- int8,
- join,
- load,
- make_block_ptr,
- map_elementwise,
- max_constancy,
- max_contiguous,
- maximum,
- minimum,
- mul,
- multiple_of,
- num_programs,
- permute,
- pi32_t,
- pointer_type,
- program_id,
- range,
- reduce,
- reshape,
- slice,
- split,
- static_assert,
- static_print,
- static_range,
- store,
- sub,
- tensor,
- trans,
- tuple,
- tuple_type,
- uint16,
- uint32,
- uint64,
- uint8,
- view,
- void,
- where,
- )
- from .math import (umulhi, exp, exp2, fma, log, log2, cos, rsqrt, sin, sqrt, sqrt_rn, abs, fdiv, div_rn, erf, floor,
- ceil)
- from .random import (
- pair_uniform_to_normal,
- philox,
- philox_impl,
- rand,
- rand4x,
- randint,
- randint4x,
- randn,
- randn4x,
- uint_to_uniform_float,
- )
- from . import target_info
- __all__ = [
- "PropagateNan",
- "TRITON_MAX_TENSOR_NUMEL",
- "load_tensor_descriptor",
- "store_tensor_descriptor",
- "make_tensor_descriptor",
- "tensor_descriptor",
- "abs",
- "add",
- "advance",
- "arange",
- "argmax",
- "argmin",
- "associative_scan",
- "assume",
- "atomic_add",
- "atomic_and",
- "atomic_cas",
- "atomic_max",
- "atomic_min",
- "atomic_or",
- "atomic_xchg",
- "atomic_xor",
- "bfloat16",
- "bitonic_merge",
- "block_type",
- "broadcast",
- "broadcast_to",
- "cat",
- "cast",
- "cdiv",
- "ceil",
- "clamp",
- "condition",
- "const",
- "constexpr",
- "constexpr_type",
- "cos",
- "cumprod",
- "cumsum",
- "debug_barrier",
- "device_assert",
- "device_print",
- "div_rn",
- "dot",
- "dot_scaled",
- "dtype",
- "erf",
- "exp",
- "exp2",
- "expand_dims",
- "extra",
- "fdiv",
- "flip",
- "float16",
- "float32",
- "float64",
- "float8e4b15",
- "float8e4nv",
- "float8e4b8",
- "float8e5",
- "float8e5b16",
- "floor",
- "fma",
- "full",
- "gather",
- "histogram",
- "inline_asm_elementwise",
- "interleave",
- "int1",
- "int16",
- "int32",
- "int64",
- "int8",
- "join",
- "load",
- "log",
- "log2",
- "make_block_ptr",
- "map_elementwise",
- "math",
- "max",
- "max_constancy",
- "max_contiguous",
- "maximum",
- "min",
- "minimum",
- "mul",
- "multiple_of",
- "num_programs",
- "pair_uniform_to_normal",
- "permute",
- "philox",
- "philox_impl",
- "pi32_t",
- "pointer_type",
- "program_id",
- "rand",
- "rand4x",
- "randint",
- "randint4x",
- "randn",
- "randn4x",
- "range",
- "ravel",
- "reduce",
- "reduce_or",
- "reshape",
- "rsqrt",
- "slice",
- "sigmoid",
- "sin",
- "softmax",
- "sort",
- "split",
- "sqrt",
- "sqrt_rn",
- "static_assert",
- "static_print",
- "static_range",
- "store",
- "sub",
- "sum",
- "swizzle2d",
- "target_info",
- "tensor",
- "topk",
- "trans",
- "tuple",
- "uint16",
- "uint32",
- "uint64",
- "uint8",
- "uint_to_uniform_float",
- "umulhi",
- "view",
- "void",
- "where",
- "xor_sum",
- "zeros",
- "zeros_like",
- ]
- def str_to_ty(name, c):
- from builtins import tuple
- if isinstance(name, tuple):
- fields = type(name).__dict__.get("_fields", None)
- return tuple_type([str_to_ty(x, c) for x in name], fields)
- if name[0] == "*":
- name = name[1:]
- const = False
- if name[0] == "k":
- name = name[1:]
- const = True
- ty = str_to_ty(name, c)
- return pointer_type(element_ty=ty, const=const)
- if name.startswith("tensordesc"):
- inner = name.split("<")[1].rstrip(">")
- dtype, rest = inner.split("[", maxsplit=1)
- block_shape, rest = rest.split("]", maxsplit=1)
- block_shape = [int(s.strip()) for s in block_shape.rstrip("]").split(",")]
- layout = rest.lstrip(",")
- is_gluon = len(layout)
- dtype = str_to_ty(dtype, None)
- ndim = len(block_shape)
- shape_type = tuple_type([int32] * ndim)
- # FIXME: Last dim stride should be constexpr(1)
- stride_type = tuple_type(([int64] * ndim))
- block = block_type(dtype, block_shape)
- if is_gluon:
- from triton.experimental.gluon.language._layouts import NVMMASharedLayout, PaddedSharedLayout, SwizzledSharedLayout
- from triton.experimental.gluon.language.nvidia.hopper.tma import tensor_descriptor_type as nvidia_tensor_descriptor_type
- from triton.experimental.gluon.language.amd.gfx1250.tdm import tensor_descriptor_type as amd_tensor_descriptor_type
- layout = eval(
- layout,
- dict(NVMMASharedLayout=NVMMASharedLayout, PaddedSharedLayout=PaddedSharedLayout,
- SwizzledSharedLayout=SwizzledSharedLayout))
- if isinstance(layout, NVMMASharedLayout):
- return nvidia_tensor_descriptor_type(block, shape_type, stride_type, layout)
- else:
- return amd_tensor_descriptor_type(block, shape_type, stride_type, layout)
- return tensor_descriptor_type(block, shape_type, stride_type)
- if name.startswith("constexpr"):
- return constexpr_type(c)
- tys = {
- "fp8e4nv": float8e4nv,
- "fp8e4b8": float8e4b8,
- "fp8e5": float8e5,
- "fp8e5b16": float8e5b16,
- "fp8e4b15": float8e4b15,
- "fp16": float16,
- "bf16": bfloat16,
- "fp32": float32,
- "fp64": float64,
- "i1": int1,
- "i8": int8,
- "i16": int16,
- "i32": int32,
- "i64": int64,
- "u1": int1,
- "u8": uint8,
- "u16": uint16,
- "u32": uint32,
- "u64": uint64,
- "B": int1,
- }
- return tys[name]
|