random.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. # mypy: allow-untyped-defs
  2. from collections.abc import Iterable
  3. import torch
  4. from torch import Tensor
  5. from . import _lazy_call, _lazy_init, current_device, device_count, is_initialized
  6. __all__ = [
  7. "get_rng_state",
  8. "get_rng_state_all",
  9. "set_rng_state",
  10. "set_rng_state_all",
  11. "manual_seed",
  12. "manual_seed_all",
  13. "seed",
  14. "seed_all",
  15. "initial_seed",
  16. ]
  17. def get_rng_state(device: int | str | torch.device = "cuda") -> Tensor:
  18. r"""Return the random number generator state of the specified GPU as a ByteTensor.
  19. Args:
  20. device (torch.device or int, optional): The device to return the RNG state of.
  21. Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).
  22. .. warning::
  23. This function eagerly initializes CUDA.
  24. """
  25. _lazy_init()
  26. if isinstance(device, str):
  27. device = torch.device(device)
  28. elif isinstance(device, int):
  29. device = torch.device("cuda", device)
  30. idx = device.index
  31. if idx is None:
  32. idx = current_device()
  33. default_generator = torch.cuda.default_generators[idx]
  34. return default_generator.get_state()
  35. def get_rng_state_all() -> list[Tensor]:
  36. r"""Return a list of ByteTensor representing the random number states of all devices."""
  37. results = [get_rng_state(i) for i in range(device_count())]
  38. return results
  39. def set_rng_state(new_state: Tensor, device: int | str | torch.device = "cuda") -> None:
  40. r"""Set the random number generator state of the specified GPU.
  41. Args:
  42. new_state (torch.ByteTensor): The desired state
  43. device (torch.device or int, optional): The device to set the RNG state.
  44. Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).
  45. """
  46. if not is_initialized():
  47. with torch._C._DisableFuncTorch():
  48. # Clone the state because the callback will be triggered
  49. # later when CUDA is lazy initialized.
  50. new_state = new_state.clone(memory_format=torch.contiguous_format)
  51. if isinstance(device, str):
  52. device = torch.device(device)
  53. elif isinstance(device, int):
  54. device = torch.device("cuda", device)
  55. def cb():
  56. idx = device.index
  57. if idx is None:
  58. idx = current_device()
  59. default_generator = torch.cuda.default_generators[idx]
  60. default_generator.set_state(new_state)
  61. _lazy_call(cb)
  62. def set_rng_state_all(new_states: Iterable[Tensor]) -> None:
  63. r"""Set the random number generator state of all devices.
  64. Args:
  65. new_states (Iterable of torch.ByteTensor): The desired state for each device.
  66. """
  67. for i, state in enumerate(new_states):
  68. set_rng_state(state, i)
  69. def manual_seed(seed: int) -> None:
  70. r"""Set the seed for generating random numbers for the current GPU.
  71. It's safe to call this function if CUDA is not available; in that
  72. case, it is silently ignored.
  73. Args:
  74. seed (int): The desired seed.
  75. .. warning::
  76. If you are working with a multi-GPU model, this function is insufficient
  77. to get determinism. To seed all GPUs, use :func:`manual_seed_all`.
  78. """
  79. seed = int(seed)
  80. def cb():
  81. idx = current_device()
  82. default_generator = torch.cuda.default_generators[idx]
  83. default_generator.manual_seed(seed)
  84. _lazy_call(cb, seed=True)
  85. def manual_seed_all(seed: int) -> None:
  86. r"""Set the seed for generating random numbers on all GPUs.
  87. It's safe to call this function if CUDA is not available; in that
  88. case, it is silently ignored.
  89. Args:
  90. seed (int): The desired seed.
  91. """
  92. seed = int(seed)
  93. def cb():
  94. for i in range(device_count()):
  95. default_generator = torch.cuda.default_generators[i]
  96. default_generator.manual_seed(seed)
  97. _lazy_call(cb, seed_all=True)
  98. def seed() -> None:
  99. r"""Set the seed for generating random numbers to a random number for the current GPU.
  100. It's safe to call this function if CUDA is not available; in that
  101. case, it is silently ignored.
  102. .. warning::
  103. If you are working with a multi-GPU model, this function will only initialize
  104. the seed on one GPU. To initialize all GPUs, use :func:`seed_all`.
  105. """
  106. def cb():
  107. idx = current_device()
  108. default_generator = torch.cuda.default_generators[idx]
  109. default_generator.seed()
  110. _lazy_call(cb)
  111. def seed_all() -> None:
  112. r"""Set the seed for generating random numbers to a random number on all GPUs.
  113. It's safe to call this function if CUDA is not available; in that
  114. case, it is silently ignored.
  115. """
  116. def cb():
  117. random_seed = 0
  118. seeded = False
  119. for i in range(device_count()):
  120. default_generator = torch.cuda.default_generators[i]
  121. if not seeded:
  122. default_generator.seed()
  123. random_seed = default_generator.initial_seed()
  124. seeded = True
  125. else:
  126. default_generator.manual_seed(random_seed)
  127. _lazy_call(cb)
  128. def initial_seed() -> int:
  129. r"""Return the current random seed of the current GPU.
  130. .. warning::
  131. This function eagerly initializes CUDA.
  132. """
  133. _lazy_init()
  134. idx = current_device()
  135. default_generator = torch.cuda.default_generators[idx]
  136. return default_generator.initial_seed()