runtime_context.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. from types import SimpleNamespace
  2. from typing import TYPE_CHECKING
  3. if TYPE_CHECKING:
  4. from ray import JobID, NodeID
  5. from ray.runtime_context import RuntimeContext
  6. class _ClientWorkerPropertyAPI:
  7. """Emulates the properties of the ray._private.worker object for the client"""
  8. def __init__(self, worker):
  9. assert worker is not None
  10. self.worker = worker
  11. def build_runtime_context(self) -> "RuntimeContext":
  12. """Creates a RuntimeContext backed by the properites of this API"""
  13. # Defer the import of RuntimeContext until needed to avoid cycles
  14. from ray.runtime_context import RuntimeContext
  15. return RuntimeContext(self)
  16. def _fetch_runtime_context(self):
  17. import ray.core.generated.ray_client_pb2 as ray_client_pb2
  18. return self.worker.get_cluster_info(
  19. ray_client_pb2.ClusterInfoType.RUNTIME_CONTEXT
  20. )
  21. @property
  22. def mode(self):
  23. from ray._private.worker import SCRIPT_MODE
  24. return SCRIPT_MODE
  25. @property
  26. def current_job_id(self) -> "JobID":
  27. from ray import JobID
  28. return JobID(self._fetch_runtime_context().job_id)
  29. @property
  30. def current_node_id(self) -> "NodeID":
  31. from ray import NodeID
  32. return NodeID(self._fetch_runtime_context().node_id)
  33. @property
  34. def namespace(self) -> str:
  35. return self._fetch_runtime_context().namespace
  36. @property
  37. def should_capture_child_tasks_in_placement_group(self) -> bool:
  38. return self._fetch_runtime_context().capture_client_tasks
  39. @property
  40. def runtime_env(self) -> str:
  41. return self._fetch_runtime_context().runtime_env
  42. def check_connected(self) -> bool:
  43. return self.worker.ping_server()
  44. @property
  45. def gcs_client(self) -> str:
  46. return SimpleNamespace(address=self._fetch_runtime_context().gcs_address)
  47. @property
  48. def node(self):
  49. """Emulates the worker.node property for client mode"""
  50. return SimpleNamespace(session_name=self._fetch_runtime_context().session_name)