softplus.py 328 B

123456789101112131415
  1. import triton
  2. import triton.language as tl
  3. from packaging import version
  4. TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0")
  5. if TRITON3:
  6. @triton.jit
  7. def softplus(dt):
  8. return tl.math.log(tl.math.exp(dt) + 1)
  9. else:
  10. @triton.jit
  11. def softplus(dt):
  12. return tl.math.log1p(tl.exp(dt))