client_mode_hook.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import os
  2. import threading
  3. from contextlib import contextmanager
  4. from functools import wraps
  5. from ray._private.auto_init_hook import auto_init_ray
  6. # Attr set on func defs to mark they have been converted to client mode.
  7. RAY_CLIENT_MODE_ATTR = "__ray_client_mode_key__"
  8. # Global setting of whether client mode is enabled. This default to OFF,
  9. # but is enabled upon ray.client(...).connect() or in tests.
  10. is_client_mode_enabled = os.environ.get("RAY_CLIENT_MODE", "0") == "1"
  11. # When RAY_CLIENT_MODE == 1, we treat it as default enabled client mode
  12. # This is useful for testing
  13. is_client_mode_enabled_by_default = is_client_mode_enabled
  14. os.environ.update({"RAY_CLIENT_MODE": "0"})
  15. is_init_called = False
  16. # Local setting of whether to ignore client hook conversion. This defaults
  17. # to TRUE and is disabled when the underlying 'real' Ray function is needed.
  18. _client_hook_status_on_thread = threading.local()
  19. _client_hook_status_on_thread.status = True
  20. def _get_client_hook_status_on_thread():
  21. """Get's the value of `_client_hook_status_on_thread`.
  22. Since `_client_hook_status_on_thread` is a thread-local variable, we may
  23. need to add and set the 'status' attribute.
  24. """
  25. global _client_hook_status_on_thread
  26. if not hasattr(_client_hook_status_on_thread, "status"):
  27. _client_hook_status_on_thread.status = True
  28. return _client_hook_status_on_thread.status
  29. def _set_client_hook_status(val: bool):
  30. global _client_hook_status_on_thread
  31. _client_hook_status_on_thread.status = val
  32. def _disable_client_hook():
  33. global _client_hook_status_on_thread
  34. out = _get_client_hook_status_on_thread()
  35. _client_hook_status_on_thread.status = False
  36. return out
  37. def _explicitly_enable_client_mode():
  38. """Force client mode to be enabled.
  39. NOTE: This should not be used in tests, use `enable_client_mode`.
  40. """
  41. global is_client_mode_enabled
  42. is_client_mode_enabled = True
  43. def _explicitly_disable_client_mode():
  44. global is_client_mode_enabled
  45. is_client_mode_enabled = False
  46. @contextmanager
  47. def disable_client_hook():
  48. val = _disable_client_hook()
  49. try:
  50. yield None
  51. finally:
  52. _set_client_hook_status(val)
  53. @contextmanager
  54. def enable_client_mode():
  55. _explicitly_enable_client_mode()
  56. try:
  57. yield None
  58. finally:
  59. _explicitly_disable_client_mode()
  60. def client_mode_hook(func: callable):
  61. """Decorator for whether to use the 'regular' ray version of a function,
  62. or the Ray Client version of that function.
  63. Args:
  64. func: This function. This is set when this function is used
  65. as a decorator.
  66. """
  67. from ray.util.client import ray
  68. @wraps(func)
  69. def wrapper(*args, **kwargs):
  70. # NOTE(hchen): DO NOT use "import" inside this function.
  71. # Because when it's called within a `__del__` method, this error
  72. # will be raised (see #35114):
  73. # ImportError: sys.meta_path is None, Python is likely shutting down.
  74. if client_mode_should_convert():
  75. # Legacy code
  76. # we only convert init function if RAY_CLIENT_MODE=1
  77. if func.__name__ != "init" or is_client_mode_enabled_by_default:
  78. return getattr(ray, func.__name__)(*args, **kwargs)
  79. return func(*args, **kwargs)
  80. return wrapper
  81. def client_mode_should_convert():
  82. """Determines if functions should be converted to client mode."""
  83. # `is_client_mode_enabled_by_default` is used for testing with
  84. # `RAY_CLIENT_MODE=1`. This flag means all tests run with client mode.
  85. return (
  86. is_client_mode_enabled or is_client_mode_enabled_by_default
  87. ) and _get_client_hook_status_on_thread()
  88. def client_mode_wrap(func):
  89. """Wraps a function called during client mode for execution as a remote
  90. task.
  91. Can be used to implement public features of ray client which do not
  92. belong in the main ray API (`ray.*`), yet require server-side execution.
  93. An example is the creation of placement groups:
  94. `ray.util.placement_group.placement_group()`. When called on the client
  95. side, this function is wrapped in a task to facilitate interaction with
  96. the GCS.
  97. """
  98. @wraps(func)
  99. def wrapper(*args, **kwargs):
  100. from ray.util.client import ray
  101. auto_init_ray()
  102. # Directly pass this through since `client_mode_wrap` is for
  103. # Placement Group APIs
  104. if client_mode_should_convert():
  105. f = ray.remote(num_cpus=0)(func)
  106. ref = f.remote(*args, **kwargs)
  107. return ray.get(ref)
  108. return func(*args, **kwargs)
  109. return wrapper
  110. def client_mode_convert_function(func_cls, in_args, in_kwargs, **kwargs):
  111. """Runs a preregistered ray RemoteFunction through the ray client.
  112. The common case for this is to transparently convert that RemoteFunction
  113. to a ClientRemoteFunction. This happens in circumstances where the
  114. RemoteFunction is declared early, in a library and only then is Ray used in
  115. client mode -- necessitating a conversion.
  116. """
  117. from ray.util.client import ray
  118. key = getattr(func_cls, RAY_CLIENT_MODE_ATTR, None)
  119. # Second part of "or" is needed in case func_cls is reused between Ray
  120. # client sessions in one Python interpreter session.
  121. if (key is None) or (not ray._converted_key_exists(key)):
  122. key = ray._convert_function(func_cls)
  123. setattr(func_cls, RAY_CLIENT_MODE_ATTR, key)
  124. client_func = ray._get_converted(key)
  125. return client_func._remote(in_args, in_kwargs, **kwargs)
  126. def client_mode_convert_actor(actor_cls, in_args, in_kwargs, **kwargs):
  127. """Runs a preregistered actor class on the ray client
  128. The common case for this decorator is for instantiating an ActorClass
  129. transparently as a ClientActorClass. This happens in circumstances where
  130. the ActorClass is declared early, in a library and only then is Ray used in
  131. client mode -- necessitating a conversion.
  132. """
  133. from ray.util.client import ray
  134. key = getattr(actor_cls, RAY_CLIENT_MODE_ATTR, None)
  135. # Second part of "or" is needed in case actor_cls is reused between Ray
  136. # client sessions in one Python interpreter session.
  137. if (key is None) or (not ray._converted_key_exists(key)):
  138. key = ray._convert_actor(actor_cls)
  139. setattr(actor_cls, RAY_CLIENT_MODE_ATTR, key)
  140. client_actor = ray._get_converted(key)
  141. return client_actor._remote(in_args, in_kwargs, **kwargs)