deterministic.py 1.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import random
  2. from typing import Optional
  3. import numpy as np
  4. from ray.rllib.utils.annotations import DeveloperAPI
  5. from ray.rllib.utils.framework import try_import_tf
  6. from ray.rllib.utils.torch_utils import set_torch_seed
  7. @DeveloperAPI
  8. def update_global_seed_if_necessary(
  9. framework: Optional[str] = None, seed: Optional[int] = None
  10. ) -> None:
  11. """Seed global modules such as random, numpy, torch, or tf.
  12. This is useful for debugging and testing.
  13. Args:
  14. framework: The framework specifier (may be None).
  15. seed: An optional int seed. If None, will not do
  16. anything.
  17. """
  18. if seed is None:
  19. return
  20. # Python random module.
  21. random.seed(seed)
  22. # Numpy.
  23. np.random.seed(seed)
  24. # Torch.
  25. if framework == "torch":
  26. set_torch_seed(seed=seed)
  27. elif framework == "tf2":
  28. tf1, tf, tfv = try_import_tf()
  29. # Tf2.x.
  30. if tfv == 2:
  31. tf.random.set_seed(seed)
  32. # Tf1.x.
  33. else:
  34. tf1.set_random_seed(seed)