psimd.h 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #ifndef FP16_PSIMD_H
  4. #define FP16_PSIMD_H
  5. #if defined(__cplusplus) && (__cplusplus >= 201103L)
  6. #include <cstdint>
  7. #elif !defined(__OPENCL_VERSION__)
  8. #include <stdint.h>
  9. #endif
  10. #include <psimd.h>
  11. PSIMD_INTRINSIC psimd_f32 fp16_ieee_to_fp32_psimd(psimd_u16 half) {
  12. const psimd_u32 word = (psimd_u32) psimd_interleave_lo_u16(psimd_zero_u16(), half);
  13. const psimd_u32 sign = word & psimd_splat_u32(UINT32_C(0x80000000));
  14. const psimd_u32 shr3_nonsign = (word + word) >> psimd_splat_u32(4);
  15. const psimd_u32 exp_offset = psimd_splat_u32(UINT32_C(0x70000000));
  16. #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
  17. const psimd_f32 exp_scale = psimd_splat_f32(0x1.0p-112f);
  18. #else
  19. const psimd_f32 exp_scale = psimd_splat_f32(fp32_from_bits(UINT32_C(0x7800000)));
  20. #endif
  21. const psimd_f32 norm_nonsign = psimd_mul_f32((psimd_f32) (shr3_nonsign + exp_offset), exp_scale);
  22. const psimd_u16 magic_mask = psimd_splat_u16(UINT16_C(0x3E80));
  23. const psimd_f32 magic_bias = psimd_splat_f32(0.25f);
  24. const psimd_f32 denorm_nonsign = psimd_sub_f32((psimd_f32) psimd_interleave_lo_u16(half + half, magic_mask), magic_bias);
  25. const psimd_s32 denorm_cutoff = psimd_splat_s32(INT32_C(0x00800000));
  26. const psimd_s32 denorm_mask = (psimd_s32) shr3_nonsign < denorm_cutoff;
  27. return (psimd_f32) (sign | (psimd_s32) psimd_blend_f32(denorm_mask, denorm_nonsign, norm_nonsign));
  28. }
  29. PSIMD_INTRINSIC psimd_f32x2 fp16_ieee_to_fp32x2_psimd(psimd_u16 half) {
  30. const psimd_u32 word_lo = (psimd_u32) psimd_interleave_lo_u16(psimd_zero_u16(), half);
  31. const psimd_u32 word_hi = (psimd_u32) psimd_interleave_hi_u16(psimd_zero_u16(), half);
  32. const psimd_u32 sign_mask = psimd_splat_u32(UINT32_C(0x80000000));
  33. const psimd_u32 sign_lo = word_lo & sign_mask;
  34. const psimd_u32 sign_hi = word_hi & sign_mask;
  35. const psimd_u32 shr3_nonsign_lo = (word_lo + word_lo) >> psimd_splat_u32(4);
  36. const psimd_u32 shr3_nonsign_hi = (word_hi + word_hi) >> psimd_splat_u32(4);
  37. const psimd_u32 exp_offset = psimd_splat_u32(UINT32_C(0x70000000));
  38. #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
  39. const psimd_f32 exp_scale = psimd_splat_f32(0x1.0p-112f);
  40. #else
  41. const psimd_f32 exp_scale = psimd_splat_f32(fp32_from_bits(UINT32_C(0x7800000)));
  42. #endif
  43. const psimd_f32 norm_nonsign_lo = psimd_mul_f32((psimd_f32) (shr3_nonsign_lo + exp_offset), exp_scale);
  44. const psimd_f32 norm_nonsign_hi = psimd_mul_f32((psimd_f32) (shr3_nonsign_hi + exp_offset), exp_scale);
  45. const psimd_u16 magic_mask = psimd_splat_u16(UINT16_C(0x3E80));
  46. const psimd_u16 shl1_half = half + half;
  47. const psimd_f32 magic_bias = psimd_splat_f32(0.25f);
  48. const psimd_f32 denorm_nonsign_lo = psimd_sub_f32((psimd_f32) psimd_interleave_lo_u16(shl1_half, magic_mask), magic_bias);
  49. const psimd_f32 denorm_nonsign_hi = psimd_sub_f32((psimd_f32) psimd_interleave_hi_u16(shl1_half, magic_mask), magic_bias);
  50. const psimd_s32 denorm_cutoff = psimd_splat_s32(INT32_C(0x00800000));
  51. const psimd_s32 denorm_mask_lo = (psimd_s32) shr3_nonsign_lo < denorm_cutoff;
  52. const psimd_s32 denorm_mask_hi = (psimd_s32) shr3_nonsign_hi < denorm_cutoff;
  53. psimd_f32x2 result;
  54. result.lo = (psimd_f32) (sign_lo | (psimd_s32) psimd_blend_f32(denorm_mask_lo, denorm_nonsign_lo, norm_nonsign_lo));
  55. result.hi = (psimd_f32) (sign_hi | (psimd_s32) psimd_blend_f32(denorm_mask_hi, denorm_nonsign_hi, norm_nonsign_hi));
  56. return result;
  57. }
  58. PSIMD_INTRINSIC psimd_f32 fp16_alt_to_fp32_psimd(psimd_u16 half) {
  59. const psimd_u32 word = (psimd_u32) psimd_interleave_lo_u16(psimd_zero_u16(), half);
  60. const psimd_u32 sign = word & psimd_splat_u32(INT32_C(0x80000000));
  61. const psimd_u32 shr3_nonsign = (word + word) >> psimd_splat_u32(4);
  62. #if 0
  63. const psimd_s32 exp112_offset = psimd_splat_s32(INT32_C(0x38000000));
  64. const psimd_s32 nonsign_bits = (psimd_s32) shr3_nonsign + exp112_offset;
  65. const psimd_s32 exp1_offset = psimd_splat_s32(INT32_C(0x00800000));
  66. const psimd_f32 two_nonsign = (psimd_f32) (nonsign_bits + exp1_offset);
  67. const psimd_s32 exp113_offset = exp112_offset | exp1_offset;
  68. return (psimd_f32) (sign | (psimd_s32) psimd_sub_f32(two_nonsign, (psimd_f32) psimd_max_s32(nonsign_bits, exp113_offset)));
  69. #else
  70. const psimd_u32 exp_offset = psimd_splat_u32(UINT32_C(0x38000000));
  71. const psimd_f32 nonsign = (psimd_f32) (shr3_nonsign + exp_offset);
  72. #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
  73. const psimd_f32 denorm_bias = psimd_splat_f32(0x1.0p-14f);
  74. #else
  75. const psimd_f32 denorm_bias = psimd_splat_f32(fp32_from_bits(UINT32_C(0x38800000)));
  76. #endif
  77. return (psimd_f32) (sign | (psimd_s32) psimd_sub_f32(psimd_add_f32(nonsign, nonsign), psimd_max_f32(nonsign, denorm_bias)));
  78. #endif
  79. }
  80. PSIMD_INTRINSIC psimd_f32x2 fp16_alt_to_fp32x2_psimd(psimd_u16 half) {
  81. const psimd_u32 word_lo = (psimd_u32) psimd_interleave_lo_u16(psimd_zero_u16(), half);
  82. const psimd_u32 word_hi = (psimd_u32) psimd_interleave_hi_u16(psimd_zero_u16(), half);
  83. const psimd_u32 sign_mask = psimd_splat_u32(UINT32_C(0x80000000));
  84. const psimd_u32 sign_lo = word_lo & sign_mask;
  85. const psimd_u32 sign_hi = word_hi & sign_mask;
  86. const psimd_u32 shr3_nonsign_lo = (word_lo + word_lo) >> psimd_splat_u32(4);
  87. const psimd_u32 shr3_nonsign_hi = (word_hi + word_hi) >> psimd_splat_u32(4);
  88. #if 1
  89. const psimd_s32 exp112_offset = psimd_splat_s32(INT32_C(0x38000000));
  90. const psimd_s32 nonsign_bits_lo = (psimd_s32) shr3_nonsign_lo + exp112_offset;
  91. const psimd_s32 nonsign_bits_hi = (psimd_s32) shr3_nonsign_hi + exp112_offset;
  92. const psimd_s32 exp1_offset = psimd_splat_s32(INT32_C(0x00800000));
  93. const psimd_f32 two_nonsign_lo = (psimd_f32) (nonsign_bits_lo + exp1_offset);
  94. const psimd_f32 two_nonsign_hi = (psimd_f32) (nonsign_bits_hi + exp1_offset);
  95. const psimd_s32 exp113_offset = exp1_offset | exp112_offset;
  96. psimd_f32x2 result;
  97. result.lo = (psimd_f32) (sign_lo | (psimd_s32) psimd_sub_f32(two_nonsign_lo, (psimd_f32) psimd_max_s32(nonsign_bits_lo, exp113_offset)));
  98. result.hi = (psimd_f32) (sign_hi | (psimd_s32) psimd_sub_f32(two_nonsign_hi, (psimd_f32) psimd_max_s32(nonsign_bits_hi, exp113_offset)));
  99. return result;
  100. #else
  101. const psimd_u32 exp_offset = psimd_splat_u32(UINT32_C(0x38000000));
  102. const psimd_f32 nonsign_lo = (psimd_f32) (shr3_nonsign_lo + exp_offset);
  103. const psimd_f32 nonsign_hi = (psimd_f32) (shr3_nonsign_hi + exp_offset);
  104. const psimd_f32 denorm_bias = psimd_splat_f32(0x1.0p-14f);
  105. psimd_f32x2 result;
  106. result.lo = (psimd_f32) (sign_lo | (psimd_s32) psimd_sub_f32(psimd_add_f32(nonsign_lo, nonsign_lo), psimd_max_f32(nonsign_lo, denorm_bias)));
  107. result.hi = (psimd_f32) (sign_hi | (psimd_s32) psimd_sub_f32(psimd_add_f32(nonsign_hi, nonsign_hi), psimd_max_f32(nonsign_hi, denorm_bias)));
  108. return result;
  109. #endif
  110. }
  111. #endif /* FP16_PSIMD_H */
  112. #else
  113. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  114. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)