class_cache.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import os
  2. import ray
  3. from ray.air.constants import COPY_DIRECTORY_CHECKPOINTS_INSTEAD_OF_MOVING_ENV
  4. from ray.train.constants import (
  5. ENABLE_V2_MIGRATION_WARNINGS_ENV_VAR,
  6. RAY_CHDIR_TO_TRIAL_DIR,
  7. )
  8. from ray.train.v2._internal.constants import (
  9. ENV_VARS_TO_PROPAGATE as TRAIN_ENV_VARS_TO_PROPAGATE,
  10. )
  11. DEFAULT_ENV_VARS = {
  12. # https://github.com/ray-project/ray/issues/28197
  13. "PL_DISABLE_FORK": "1"
  14. }
  15. ENV_VARS_TO_PROPAGATE = (
  16. {
  17. COPY_DIRECTORY_CHECKPOINTS_INSTEAD_OF_MOVING_ENV,
  18. RAY_CHDIR_TO_TRIAL_DIR,
  19. ENABLE_V2_MIGRATION_WARNINGS_ENV_VAR,
  20. "AWS_ACCESS_KEY_ID",
  21. "AWS_SECRET_ACCESS_KEY",
  22. "AWS_SECURITY_TOKEN",
  23. "AWS_SESSION_TOKEN",
  24. }
  25. # Propagate the Ray Train environment variables from the driver process
  26. # to the trainable process so that Tune + Train v2 can be used together.
  27. | TRAIN_ENV_VARS_TO_PROPAGATE
  28. )
  29. class _ActorClassCache:
  30. """Caches actor classes.
  31. ray.remote is a registration call. It sends the serialized object to the
  32. key value store (redis), and will be fetched at an arbitrary worker
  33. later. Registration does not use any Ray scheduling resources.
  34. Later, class.remote() actually creates the remote actor. The
  35. actor will be instantiated on some arbitrary machine,
  36. according to the underlying Ray scheduler.
  37. Without this cache, you would register the same serialized object
  38. over and over again. Naturally, since redis doesn’t spill to disk,
  39. this can easily nuke the redis instance (and basically blow up Ray).
  40. This cache instead allows us to register once and only once.
  41. Note that we assume there can be multiple trainables in the
  42. system at once.
  43. """
  44. def __init__(self):
  45. self._cache = {}
  46. def get(self, trainable_cls):
  47. """Gets the wrapped trainable_cls, otherwise calls ray.remote."""
  48. env_vars = DEFAULT_ENV_VARS.copy()
  49. for env_var_to_propagate in ENV_VARS_TO_PROPAGATE:
  50. if env_var_to_propagate in os.environ:
  51. env_vars[env_var_to_propagate] = os.environ[env_var_to_propagate]
  52. runtime_env = {"env_vars": env_vars}
  53. if trainable_cls not in self._cache:
  54. remote_cls = ray.remote(runtime_env=runtime_env)(trainable_cls)
  55. self._cache[trainable_cls] = remote_cls
  56. return self._cache[trainable_cls]