utils.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. from triton.language import core
  2. @core.extern
  3. def globaltimer(_semantic=None):
  4. return core.inline_asm_elementwise("mov.u64 $0, %globaltimer;", "=l", [], dtype=core.int64, is_pure=False, pack=1,
  5. _semantic=_semantic)
  6. @core.extern
  7. def smid(_semantic=None):
  8. return core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [], dtype=core.int32, is_pure=True, pack=1,
  9. _semantic=_semantic)
  10. @core.builtin
  11. def num_threads(_semantic=None):
  12. return core.constexpr(_semantic.builder.options.num_warps * 32)
  13. @core.builtin
  14. def num_warps(_semantic=None):
  15. return core.constexpr(_semantic.builder.options.num_warps)
  16. # ----- FP8E4M3B15 ------
  17. # This data-type is a variant of the standard FP8E4M3 format.
  18. # It was designed for fast software conversion to FP16 on
  19. # nvidia GPUs that do not support it natively.
  20. # This is the same format as FP8E4M3Nv, but:
  21. # - the exponent bias is 15 instead of 7
  22. # - 0xff and 0x7f are mapped to +-1.750 instead of +-nan
  23. @core.builtin
  24. def convert_fp8e4b15_to_float16(arg, _semantic=None):
  25. return core.inline_asm_elementwise(
  26. "{ \n"
  27. ".reg .b32 a<2>, b<2>; \n"
  28. "prmt.b32 a0, 0, $2, 0x5746; \n"
  29. "and.b32 b0, a0, 0x7f007f00; \n"
  30. "and.b32 b1, a0, 0x00ff00ff; \n"
  31. "and.b32 a1, a0, 0x00800080; \n"
  32. "shr.b32 b0, b0, 1; \n"
  33. "add.u32 b1, b1, a1; \n"
  34. "lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n"
  35. "shl.b32 $1, b1, 7; \n"
  36. "} \n", "=r,=r,r", [arg], dtype=core.float16, is_pure=True, pack=4,
  37. _semantic=_semantic)
  38. @core.builtin
  39. def convert_float16_to_fp8e4b15(arg, has_minx2, _semantic=None):
  40. asm = """{
  41. .reg .pred p<4>;
  42. .reg .b32 a<2>, b<2>;
  43. .reg .b16 c<4>;
  44. .reg .b16 max_val_f16;
  45. .reg .b32 max_val_f16x2;
  46. mov.b16 max_val_f16, 0x3F00;
  47. mov.b32 max_val_f16x2, 0x3F003F00;
  48. and.b32 a0, $1, 0x7fff7fff;
  49. and.b32 a1, $2, 0x7fff7fff;"""
  50. if has_minx2:
  51. asm += """min.f16x2 a0, a0, max_val_f16x2;
  52. min.f16x2 a1, a1, max_val_f16x2;"""
  53. else:
  54. asm += """setp.lt.f16x2 p0|p1, a0, max_val_f16x2;
  55. setp.lt.f16x2 p2|p3, a1, max_val_f16x2;
  56. mov.b32 {c0, c1}, a0;
  57. mov.b32 {c2, c3}, a1;
  58. selp.b16 c0, c0, max_val_f16, p0;
  59. selp.b16 c1, c1, max_val_f16, p1;
  60. selp.b16 c2, c2, max_val_f16, p2;
  61. selp.b16 c3, c3, max_val_f16, p3;
  62. mov.b32 a0, {c0, c1};
  63. mov.b32 a1, {c2, c3};"""
  64. asm += """mad.lo.u32 a0, a0, 2, 0x00800080;
  65. mad.lo.u32 a1, a1, 2, 0x00800080;
  66. lop3.b32 b0, $1, 0x80008000, a0, 0xea;
  67. lop3.b32 b1, $2, 0x80008000, a1, 0xea;
  68. prmt.b32 $0, b0, b1, 0x7531;
  69. }"""
  70. return core.inline_asm_elementwise(asm, "=r,r,r", [arg], dtype=core.float8e4b15, is_pure=True, pack=4,
  71. _semantic=_semantic)
  72. @core.builtin
  73. def convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2, _semantic=None):
  74. if arg.type.scalar.is_fp8e4b15():
  75. upcast_val = convert_fp8e4b15_to_float16(arg, _semantic=_semantic)
  76. if dst_ty.scalar.is_fp32():
  77. upcast_val = upcast_val.to(core.float32, _semantic=_semantic)
  78. return upcast_val
  79. assert arg.type.scalar.is_fp16() or arg.type.scalar.is_fp32()
  80. downcast_val = arg
  81. if arg.type.scalar.is_fp32():
  82. downcast_val = downcast_val.to(core.float16, fp_downcast_rounding="rtz", _semantic=_semantic)
  83. downcast_val = convert_float16_to_fp8e4b15(downcast_val, has_minx2=has_minx2, _semantic=_semantic)
  84. return downcast_val
  85. @core.builtin
  86. def convert_custom_float8_sm80(arg, dst_ty, fp_downcast_rounding=None, _semantic=None):
  87. return convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2=True, _semantic=_semantic)
  88. @core.builtin
  89. def convert_custom_float8_sm70(arg, dst_ty, fp_downcast_rounding=None, _semantic=None):
  90. return convert_custom_float8(arg, dst_ty, fp_downcast_rounding, has_minx2=False, _semantic=_semantic)