random.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. from ..runtime.jit import jit
  2. from . import core as tl
  3. from . import math
  4. N_ROUNDS_DEFAULT = tl.constexpr(10) # Default number of rounds for philox
  5. # -------------------
  6. # randint
  7. # -------------------
  8. @jit
  9. def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
  10. """
  11. Run `n_rounds` rounds of Philox for state (c0, c1, c2, c3) and key (k0, k1).
  12. """
  13. if c0.dtype == tl.uint32:
  14. PHILOX_KEY_A: tl.constexpr = 0x9E3779B9
  15. PHILOX_KEY_B: tl.constexpr = 0xBB67AE85
  16. PHILOX_ROUND_A: tl.constexpr = 0xD2511F53
  17. PHILOX_ROUND_B: tl.constexpr = 0xCD9E8D57
  18. else:
  19. tl.static_assert(c0.dtype == tl.uint64, "dtype not supported in philox_impl")
  20. PHILOX_KEY_A: tl.constexpr = 0x9E3779B97F4A7C15
  21. PHILOX_KEY_B: tl.constexpr = 0xBB67AE8584CAA73B
  22. PHILOX_ROUND_A: tl.constexpr = 0xD2E7470EE14C6C93
  23. PHILOX_ROUND_B: tl.constexpr = 0xCA5A826395121157
  24. for _ in tl.static_range(n_rounds):
  25. # for _ in range(n_rounds):
  26. # update random state
  27. A = PHILOX_ROUND_A
  28. B = PHILOX_ROUND_B
  29. _c0, _c2 = c0, c2
  30. c0 = math.umulhi(B, _c2) ^ c1 ^ k0
  31. c2 = math.umulhi(A, _c0) ^ c3 ^ k1
  32. c1 = tl.mul(B, _c2, sanitize_overflow=False)
  33. c3 = tl.mul(A, _c0, sanitize_overflow=False)
  34. # raise key
  35. k0 = tl.add(k0, PHILOX_KEY_A, sanitize_overflow=False)
  36. k1 = tl.add(k1, PHILOX_KEY_B, sanitize_overflow=False)
  37. return c0, c1, c2, c3
  38. @jit
  39. def philox(seed, c0, c1, c2, c3, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
  40. seed = tl.to_tensor(seed)
  41. tl.static_assert(seed.dtype.is_int())
  42. seed = seed.to(tl.uint64)
  43. c0 = tl.to_tensor(c0)
  44. c1 = tl.to_tensor(c1)
  45. c2 = tl.to_tensor(c2)
  46. c3 = tl.to_tensor(c3)
  47. if tl.constexpr(c0.dtype.primitive_bitwidth) == 32:
  48. int_dtype = tl.uint32
  49. seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32)
  50. seed_lo = (seed & 0xffffffff).to(tl.uint32)
  51. else:
  52. tl.static_assert(tl.constexpr(c0.dtype.primitive_bitwidth) == 64, "bitwidth not supported in philox")
  53. int_dtype = tl.uint64
  54. seed_hi = tl.full((1, ), 0, dtype=int_dtype)
  55. seed_lo = seed
  56. c0 = c0.to(int_dtype, bitcast=True)
  57. c1 = c1.to(int_dtype, bitcast=True)
  58. c2 = c2.to(int_dtype, bitcast=True)
  59. c3 = c3.to(int_dtype, bitcast=True)
  60. return philox_impl(c0, c1, c2, c3, seed_lo, seed_hi, n_rounds)
  61. @jit
  62. def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
  63. """
  64. Given a :code:`seed` scalar and an :code:`offset` block, returns a single
  65. block of random :code:`int32`.
  66. If you need multiple streams of random numbers,
  67. using `randint4x` is likely to be faster than calling `randint` 4 times.
  68. :param seed: The seed for generating random numbers.
  69. :param offset: The offsets to generate random numbers for.
  70. """
  71. ret, _, _, _ = randint4x(seed, offset, n_rounds)
  72. return ret
  73. @jit
  74. def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
  75. """
  76. Given a :code:`seed` scalar and an :code:`offset` block, returns four
  77. blocks of random :code:`int32`.
  78. This is the maximally efficient entry point
  79. to Triton's Philox pseudo-random number generator.
  80. :param seed: The seed for generating random numbers.
  81. :param offsets: The offsets to generate random numbers for.
  82. """
  83. # _0 = tl.zeros(offset.shape, offset.dtype)
  84. offset_lo = offset.to(tl.uint32)
  85. _0 = offset_lo * 0
  86. if tl.constexpr(offset.dtype.primitive_bitwidth) > 32:
  87. offset_hi = (offset >> 32).to(tl.uint32)
  88. else:
  89. offset_hi = _0
  90. return philox(seed, offset_lo, offset_hi, _0, _0, n_rounds)
  91. # -------------------
  92. # rand
  93. # -------------------
  94. # @jit
  95. # def uint32_to_uniform_float(x):
  96. # """
  97. # Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1).
  98. # """
  99. # two_to_the_minus_32: tl.constexpr = 2.328306e-10
  100. # return x * two_to_the_minus_32
  101. @jit
  102. def uint_to_uniform_float(x):
  103. """
  104. Numerically stable function to convert a random uint into a random float uniformly sampled in [0, 1).
  105. """
  106. # TODO: fix frontend issues and cleanup
  107. # conditions can be simplified
  108. # scale is ((2**23 - 1) / 2**23) * 2**(N_BITS - 1)
  109. if tl.constexpr(x.dtype == tl.uint32) or tl.constexpr(x.dtype == tl.int32):
  110. # maximum value such that `MAX_INT * scale < 1.0` (with float rounding)
  111. x = x.to(tl.int32, bitcast=True)
  112. scale = 4.6566127342e-10
  113. else:
  114. tl.static_assert(tl.constexpr(x.dtype == tl.uint64) or tl.constexpr(x.dtype == tl.int64))
  115. x = x.to(tl.int64, bitcast=True)
  116. scale = 1.0842020432385337e-19
  117. x = tl.where(x < 0, -x - 1, x)
  118. return x * scale
  119. @jit
  120. def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
  121. """
  122. Given a :code:`seed` scalar and an :code:`offset` block,
  123. returns a block of random :code:`float32` in :math:`U(0, 1)`.
  124. :param seed: The seed for generating random numbers.
  125. :param offsets: The offsets to generate random numbers for.
  126. """
  127. source = randint(seed, offset, n_rounds)
  128. return uint_to_uniform_float(source)
  129. @jit
  130. def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
  131. """
  132. Given a :code:`seed` scalar and an :code:`offsets` block,
  133. returns 4 blocks of random :code:`float32` in :math:`U(0, 1)`.
  134. :param seed: The seed for generating random numbers.
  135. :param offsets: The offsets to generate random numbers for.
  136. """
  137. i1, i2, i3, i4 = randint4x(seed, offsets, n_rounds)
  138. u1 = uint_to_uniform_float(i1)
  139. u2 = uint_to_uniform_float(i2)
  140. u3 = uint_to_uniform_float(i3)
  141. u4 = uint_to_uniform_float(i4)
  142. return u1, u2, u3, u4
  143. # -------------------
  144. # randn
  145. # -------------------
  146. @jit
  147. def pair_uniform_to_normal(u1, u2):
  148. """Box-Muller transform"""
  149. u1 = tl.maximum(1.0e-7, u1)
  150. th = 6.283185307179586 * u2
  151. r = math.sqrt(-2.0 * math.log(u1))
  152. return r * math.cos(th), r * math.sin(th)
  153. @jit
  154. def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
  155. """
  156. Given a :code:`seed` scalar and an :code:`offset` block,
  157. returns a block of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`.
  158. :param seed: The seed for generating random numbers.
  159. :param offsets: The offsets to generate random numbers for.
  160. """
  161. i1, i2, _, _ = randint4x(seed, offset, n_rounds)
  162. u1 = uint_to_uniform_float(i1)
  163. u2 = uint_to_uniform_float(i2)
  164. n1, _ = pair_uniform_to_normal(u1, u2)
  165. return n1
  166. @jit
  167. def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
  168. """
  169. Given a :code:`seed` scalar and an :code:`offset` block,
  170. returns 4 blocks of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`.
  171. :param seed: The seed for generating random numbers.
  172. :param offsets: The offsets to generate random numbers for.
  173. """
  174. u1, u2, u3, u4 = rand4x(seed, offset, n_rounds)
  175. n1, n2 = pair_uniform_to_normal(u1, u2)
  176. n3, n4 = pair_uniform_to_normal(u3, u4)
  177. return n1, n2, n3, n4