random.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. # mypy: allow-untyped-defs
  2. import contextlib
  3. import warnings
  4. from collections.abc import Generator
  5. import torch
  6. from torch._C import default_generator
  7. def set_rng_state(new_state: torch.Tensor) -> None:
  8. r"""Sets the random number generator state.
  9. .. note:: This function only works for CPU. For CUDA, please use
  10. :func:`torch.manual_seed`, which works for both CPU and CUDA.
  11. Args:
  12. new_state (torch.ByteTensor): The desired state
  13. """
  14. default_generator.set_state(new_state)
  15. def get_rng_state() -> torch.Tensor:
  16. r"""Returns the random number generator state as a `torch.ByteTensor`.
  17. .. note:: The returned state is for the default generator on CPU only.
  18. See also: :func:`torch.random.fork_rng`.
  19. """
  20. return default_generator.get_state()
  21. def manual_seed(seed) -> torch._C.Generator:
  22. r"""Sets the seed for generating random numbers on all devices. Returns a
  23. `torch.Generator` object.
  24. Args:
  25. seed (int): The desired seed. Value must be within the inclusive range
  26. `[-0x8000_0000_0000_0000, 0xffff_ffff_ffff_ffff]`. Otherwise, a RuntimeError
  27. is raised. Negative inputs are remapped to positive values with the formula
  28. `0xffff_ffff_ffff_ffff + seed`.
  29. """
  30. return _manual_seed_impl(seed)
  31. def _manual_seed_impl(seed) -> torch._C.Generator:
  32. seed = int(seed)
  33. import torch.cuda
  34. if not torch.cuda._is_in_bad_fork():
  35. torch.cuda.manual_seed_all(seed)
  36. import torch.mps
  37. if not torch.mps._is_in_bad_fork():
  38. torch.mps.manual_seed(seed)
  39. import torch.xpu
  40. if not torch.xpu._is_in_bad_fork():
  41. torch.xpu.manual_seed_all(seed)
  42. import torch.mtia
  43. if not torch.mtia._is_in_bad_fork():
  44. torch.mtia.manual_seed_all(seed)
  45. _seed_custom_device(seed)
  46. return default_generator.manual_seed(seed)
  47. def seed() -> int:
  48. r"""Sets the seed for generating random numbers to a non-deterministic
  49. random number on all devices. Returns a 64 bit number used to seed the RNG.
  50. """
  51. seed = default_generator.seed()
  52. import torch.cuda
  53. if not torch.cuda._is_in_bad_fork():
  54. torch.cuda.manual_seed_all(seed)
  55. import torch.mps
  56. if not torch.mps._is_in_bad_fork():
  57. torch.mps.manual_seed(seed)
  58. import torch.xpu
  59. if not torch.xpu._is_in_bad_fork():
  60. torch.xpu.manual_seed_all(seed)
  61. import torch.mtia
  62. if not torch.mtia._is_in_bad_fork():
  63. torch.mtia.manual_seed_all(seed)
  64. _seed_custom_device(seed)
  65. return seed
  66. def _seed_custom_device(seed) -> None:
  67. r"""Sets the seed to generate random numbers for custom device.
  68. Args:
  69. seed (int): The desired seed.
  70. See [Note: support the custom device with privateuse1]
  71. """
  72. seed = int(seed)
  73. custom_backend_name = torch._C._get_privateuse1_backend_name()
  74. if hasattr(torch, custom_backend_name):
  75. custom_device_mod = getattr(torch, custom_backend_name)
  76. _bad_fork_name = "_is_in_bad_fork"
  77. _seed_all_name = "manual_seed_all"
  78. if hasattr(custom_device_mod, _bad_fork_name) and hasattr(
  79. custom_device_mod, _seed_all_name
  80. ):
  81. if not getattr(custom_device_mod, _bad_fork_name)():
  82. getattr(custom_device_mod, _seed_all_name)(seed)
  83. else:
  84. message = f"Set seed for `{custom_backend_name}` device does not take effect, please add API's "
  85. message += f"`{_bad_fork_name}` and `{_seed_all_name}` to `{custom_backend_name}` device module."
  86. warnings.warn(message, UserWarning, stacklevel=3)
  87. def initial_seed() -> int:
  88. r"""Returns the initial seed for generating random numbers as a
  89. Python `long`.
  90. .. note:: The returned seed is for the default generator on CPU only.
  91. """
  92. return default_generator.initial_seed()
  93. _fork_rng_warned_already = False
  94. @contextlib.contextmanager
  95. def fork_rng(
  96. devices=None,
  97. enabled=True,
  98. _caller="fork_rng",
  99. _devices_kw="devices",
  100. device_type="cuda",
  101. ) -> Generator:
  102. """
  103. Forks the RNG, so that when you return, the RNG is reset
  104. to the state that it was previously in.
  105. Args:
  106. devices (iterable of Device IDs): devices for which to fork
  107. the RNG. CPU RNG state is always forked. By default, :meth:`fork_rng` operates
  108. on all devices, but will emit a warning if your machine has a lot
  109. of devices, since this function will run very slowly in that case.
  110. If you explicitly specify devices, this warning will be suppressed
  111. enabled (bool): if ``False``, the RNG is not forked. This is a convenience
  112. argument for easily disabling the context manager without having
  113. to delete it and unindent your Python code under it.
  114. device_type (str): device type str, default is `cuda`. As for supported device,
  115. see details in :ref:`accelerator<accelerators>`
  116. """
  117. if device_type == "meta":
  118. yield
  119. return
  120. device_type = torch.device(device_type).type
  121. device_mod = getattr(torch, device_type, None)
  122. if device_mod is None:
  123. raise RuntimeError(
  124. f"torch has no module of `{device_type}`, you should register "
  125. + "a module by `torch._register_device_module`."
  126. )
  127. global _fork_rng_warned_already
  128. # Internal arguments:
  129. # _caller: the function which called fork_rng, which the user used
  130. # _devices_kw: the devices keyword of _caller
  131. if not enabled:
  132. yield
  133. return
  134. if devices is None:
  135. num_devices = device_mod.device_count()
  136. if num_devices > 1 and not _fork_rng_warned_already:
  137. message = (
  138. f"{device_type.upper()} reports that you have {num_devices} available devices, and "
  139. f"you have used {_caller} without explicitly specifying which devices are being used. "
  140. f"For safety, we initialize *every* {device_type.upper()} device by default, which can "
  141. f"be quite slow if you have a lot of {device_type.upper()}s. If you know that you are only"
  142. f" making use of a few {device_type.upper()} devices, set the environment variable "
  143. f"{device_type.upper()}_VISIBLE_DEVICES or the '{_devices_kw}' keyword argument of {_caller} "
  144. "with the set of devices you are actually using. For example, if you are using CPU only, "
  145. "set device.upper()_VISIBLE_DEVICES= or devices=[]; if you are using device 0 only, "
  146. f"set {device_type.upper()}_VISIBLE_DEVICES=0 or devices=[0]. To initialize all devices "
  147. f"and suppress this warning, set the '{_devices_kw}' keyword argument to "
  148. f"`range(torch.{device_type}.device_count())`."
  149. )
  150. warnings.warn(message, stacklevel=2)
  151. _fork_rng_warned_already = True
  152. devices = list(range(num_devices))
  153. else:
  154. # Protect against user passing us a generator; we need to traverse this
  155. # multiple times but a generator will be exhausted upon first traversal
  156. devices = list(devices)
  157. cpu_rng_state = torch.get_rng_state()
  158. device_rng_states = [device_mod.get_rng_state(device) for device in devices]
  159. try:
  160. yield
  161. finally:
  162. torch.set_rng_state(cpu_rng_state)
  163. for device, device_rng_state in zip(devices, device_rng_states):
  164. device_mod.set_rng_state(device_rng_state, device)