random.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. # mypy: allow-untyped-defs
  2. from collections.abc import Iterable
  3. import torch
  4. from torch import Tensor
  5. from . import _lazy_call, _lazy_init, current_device, device_count, is_initialized
  6. def get_rng_state(device: int | str | torch.device = "xpu") -> Tensor:
  7. r"""Return the random number generator state of the specified GPU as a ByteTensor.
  8. Args:
  9. device (torch.device or int, optional): The device to return the RNG state of.
  10. Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device).
  11. .. warning::
  12. This function eagerly initializes XPU.
  13. """
  14. _lazy_init()
  15. if isinstance(device, str):
  16. device = torch.device(device)
  17. elif isinstance(device, int):
  18. device = torch.device("xpu", device)
  19. idx = device.index
  20. if idx is None:
  21. idx = current_device()
  22. default_generator = torch.xpu.default_generators[idx]
  23. return default_generator.get_state()
  24. def get_rng_state_all() -> list[Tensor]:
  25. r"""Return a list of ByteTensor representing the random number states of all devices."""
  26. results = [get_rng_state(i) for i in range(device_count())]
  27. return results
  28. def set_rng_state(new_state: Tensor, device: int | str | torch.device = "xpu") -> None:
  29. r"""Set the random number generator state of the specified GPU.
  30. Args:
  31. new_state (torch.ByteTensor): The desired state
  32. device (torch.device or int, optional): The device to set the RNG state.
  33. Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device).
  34. """
  35. if not is_initialized():
  36. with torch._C._DisableFuncTorch():
  37. new_state = new_state.clone(memory_format=torch.contiguous_format)
  38. if isinstance(device, str):
  39. device = torch.device(device)
  40. elif isinstance(device, int):
  41. device = torch.device("xpu", device)
  42. def cb() -> None:
  43. idx = device.index
  44. if idx is None:
  45. idx = current_device()
  46. default_generator = torch.xpu.default_generators[idx]
  47. default_generator.set_state(new_state)
  48. _lazy_call(cb)
  49. def set_rng_state_all(new_states: Iterable[Tensor]) -> None:
  50. r"""Set the random number generator state of all devices.
  51. Args:
  52. new_states (Iterable of torch.ByteTensor): The desired state for each device.
  53. """
  54. for i, state in enumerate(new_states):
  55. set_rng_state(state, i)
  56. def manual_seed(seed: int) -> None:
  57. r"""Set the seed for generating random numbers for the current GPU.
  58. It's safe to call this function if XPU is not available; in that case, it is silently ignored.
  59. Args:
  60. seed (int): The desired seed.
  61. .. warning::
  62. If you are working with a multi-GPU model, this function is insufficient
  63. to get determinism. To seed all GPUs, use :func:`manual_seed_all`.
  64. """
  65. seed = int(seed)
  66. def cb() -> None:
  67. idx = current_device()
  68. default_generator = torch.xpu.default_generators[idx]
  69. default_generator.manual_seed(seed)
  70. _lazy_call(cb, seed=True)
  71. def manual_seed_all(seed: int) -> None:
  72. r"""Set the seed for generating random numbers on all GPUs.
  73. It's safe to call this function if XPU is not available; in that case, it is silently ignored.
  74. Args:
  75. seed (int): The desired seed.
  76. """
  77. seed = int(seed)
  78. def cb() -> None:
  79. for i in range(device_count()):
  80. default_generator = torch.xpu.default_generators[i]
  81. default_generator.manual_seed(seed)
  82. _lazy_call(cb, seed_all=True)
  83. def seed() -> None:
  84. r"""Set the seed for generating random numbers to a random number for the current GPU.
  85. It's safe to call this function if XPU is not available; in that case, it is silently ignored.
  86. .. warning::
  87. If you are working with a multi-GPU model, this function will only initialize
  88. the seed on one GPU. To initialize all GPUs, use :func:`seed_all`.
  89. """
  90. def cb() -> None:
  91. idx = current_device()
  92. default_generator = torch.xpu.default_generators[idx]
  93. default_generator.seed()
  94. _lazy_call(cb)
  95. def seed_all() -> None:
  96. r"""Set the seed for generating random numbers to a random number on all GPUs.
  97. It's safe to call this function if XPU is not available; in that case, it is silently ignored.
  98. """
  99. def cb() -> None:
  100. random_seed = 0
  101. seeded = False
  102. for i in range(device_count()):
  103. default_generator = torch.xpu.default_generators[i]
  104. if not seeded:
  105. default_generator.seed()
  106. random_seed = default_generator.initial_seed()
  107. seeded = True
  108. else:
  109. default_generator.manual_seed(random_seed)
  110. _lazy_call(cb)
  111. def initial_seed() -> int:
  112. r"""Return the current random seed of the current GPU.
  113. .. warning::
  114. This function eagerly initializes XPU.
  115. """
  116. _lazy_init()
  117. idx = current_device()
  118. default_generator = torch.xpu.default_generators[idx]
  119. return default_generator.initial_seed()
  120. __all__ = [
  121. "get_rng_state",
  122. "get_rng_state_all",
  123. "set_rng_state",
  124. "set_rng_state_all",
  125. "manual_seed",
  126. "manual_seed_all",
  127. "seed",
  128. "seed_all",
  129. "initial_seed",
  130. ]