cpu_neon_bf16.cpp 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. #if (defined __GNUC__ && (defined __arm__ || defined __aarch64__)) || (defined _MSC_VER && (defined _M_ARM64 || defined _M_ARM64EC))
  2. #include <stdio.h>
  3. #include "arm_neon.h"
  4. /*#if defined __clang__
  5. #pragma clang attribute push (__attribute__((target("bf16"))), apply_to=function)
  6. #elif defined GCC
  7. #pragma GCC push_options
  8. #pragma GCC target("armv8.2-a", "bf16")
  9. #endif*/
  10. bfloat16x8_t vld1q_as_bf16(const float* src)
  11. {
  12. float32x4_t s0 = vld1q_f32(src), s1 = vld1q_f32(src + 4);
  13. return vcombine_bf16(vcvt_bf16_f32(s0), vcvt_bf16_f32(s1));
  14. }
  15. void vprintreg(const char* name, const float32x4_t& r)
  16. {
  17. float data[4];
  18. vst1q_f32(data, r);
  19. printf("%s: (%.2f, %.2f, %.2f, %.2f)\n",
  20. name, data[0], data[1], data[2], data[3]);
  21. }
  22. void test()
  23. {
  24. const float src1[] = { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f };
  25. const float src2[] = { 1.f, 3.f, 6.f, 10.f, 15.f, 21.f, 28.f, 36.f };
  26. bfloat16x8_t s1 = vld1q_as_bf16(src1), s2 = vld1q_as_bf16(src2);
  27. float32x4_t d = vbfdotq_f32(vdupq_n_f32(0.f), s1, s2);
  28. vprintreg("(s1[0]*s2[0] + s1[1]*s2[1], ... s1[6]*s2[6] + s1[7]*s2[7])", d);
  29. }
  30. /*#if defined __clang__
  31. #pragma clang attribute pop
  32. #elif defined GCC
  33. #pragma GCC pop_options
  34. #endif*/
  35. #else
  36. #error "BF16 is not supported"
  37. #endif
  38. int main()
  39. {
  40. test();
  41. return 0;
  42. }