random.h 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. // Philox Counter based RNG implementation for Metal
  3. // Borrowed from aten/src/ATen/core/PhiloxRNGEngine.h
  4. // Which in turn borrowed from
  5. // http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
  6. #pragma once
  7. #include <metal_stdlib>
  8. namespace c10 {
  9. namespace metal {
  10. namespace detail {
  11. constexpr float uint32_to_uniform_float(uint32_t value) {
  12. // maximum value such that `MAX_INT * scale < 1.0` (with float rounding)
  13. constexpr float scale = 4.6566127342e-10;
  14. return static_cast<float>(value & 0x7FFFFFFF) * scale;
  15. }
  16. inline uint2 splitlong(ulong v) {
  17. return uint2(v >> 32, v & 0xffffffff);
  18. }
  19. } // namespace detail
  20. namespace philox4 {
  21. uint2 mulhilo(uint a, uint b) {
  22. auto rc = static_cast<ulong>(a) * b;
  23. return detail::splitlong(rc);
  24. }
  25. uint4 single_round(uint4 ctr, uint2 key) {
  26. constexpr uint kPhiloxSA = 0xD2511F53;
  27. constexpr uint kPhiloxSB = 0xCD9E8D57;
  28. auto rc0 = mulhilo(kPhiloxSA, ctr.x);
  29. auto rc1 = mulhilo(kPhiloxSB, ctr.z);
  30. return uint4(rc1.y ^ ctr.y ^ key.x, rc1.x, rc0.y ^ ctr.w ^ key.y, rc0.x);
  31. }
  32. uint4 multiple_rounds(uint4 ctr, uint2 key, uint rounds) {
  33. constexpr uint2 kPhilox10 = {0x9E3779B9, 0xBB67AE85};
  34. for (uint round = 0; round < rounds - 1; ++round) {
  35. ctr = single_round(ctr, key);
  36. key += kPhilox10;
  37. }
  38. return ctr;
  39. }
  40. uint4 rand(long seed, long index) {
  41. uint4 ctr = 0;
  42. ctr.zw = detail::splitlong(index);
  43. return multiple_rounds(ctr, detail::splitlong(seed), 10);
  44. }
  45. } // namespace philox4
  46. float randn(long seed, long index) {
  47. auto value = philox4::rand(seed, index);
  48. float u1 = 1.0 - detail::uint32_to_uniform_float(value.x);
  49. float u2 = 1.0 - detail::uint32_to_uniform_float(value.y);
  50. return ::metal::sqrt(-2.0 * ::metal::log(u1)) *
  51. ::metal::cos(2.0 * M_PI_F * u2);
  52. }
  53. float rand(long seed, long index) {
  54. auto value = philox4::rand(seed, index);
  55. return detail::uint32_to_uniform_float(value.x);
  56. }
  57. long randint64(long seed, long index, long low, long high) {
  58. auto range = high - low;
  59. auto value = philox4::rand(seed, index);
  60. // TODO: Implement better algorithm for large ranges
  61. return low +
  62. static_cast<long>(detail::uint32_to_uniform_float(value.x) * range);
  63. }
  64. } // namespace metal
  65. } // namespace c10
  66. #else
  67. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  68. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)