ray_client_helpers.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import time
  2. from contextlib import contextmanager
  3. from typing import Any, Dict
  4. import ray as real_ray
  5. import ray.util.client.server.server as ray_client_server
  6. from ray._private.client_mode_hook import disable_client_hook
  7. from ray.job_config import JobConfig
  8. from ray.util.client import ray
  9. @contextmanager
  10. def ray_start_client_server(metadata=None, ray_connect_handler=None, **kwargs):
  11. with ray_start_client_server_pair(
  12. metadata=metadata, ray_connect_handler=ray_connect_handler, **kwargs
  13. ) as pair:
  14. client, server = pair
  15. yield client
  16. @contextmanager
  17. def ray_start_client_server_for_address(address):
  18. """
  19. Starts a Ray client server that initializes drivers at the specified address.
  20. """
  21. def connect_handler(
  22. job_config: JobConfig = None, **ray_init_kwargs: Dict[str, Any]
  23. ):
  24. import ray
  25. with disable_client_hook():
  26. if not ray.is_initialized():
  27. return ray.init(address, job_config=job_config, **ray_init_kwargs)
  28. with ray_start_client_server(ray_connect_handler=connect_handler) as ray:
  29. yield ray
  30. @contextmanager
  31. def ray_start_client_server_pair(metadata=None, ray_connect_handler=None, **kwargs):
  32. ray._inside_client_test = True
  33. with disable_client_hook():
  34. assert not ray.is_initialized()
  35. server = ray_client_server.serve(
  36. "127.0.0.1", 50051, ray_connect_handler=ray_connect_handler
  37. )
  38. ray.connect("127.0.0.1:50051", metadata=metadata, **kwargs)
  39. try:
  40. yield ray, server
  41. finally:
  42. ray._inside_client_test = False
  43. ray.disconnect()
  44. server.stop(0)
  45. del server
  46. start = time.monotonic()
  47. with disable_client_hook():
  48. while ray.is_initialized():
  49. time.sleep(1)
  50. if time.monotonic() - start > 30:
  51. raise RuntimeError("Failed to terminate Ray")
  52. # Allow windows to close processes before moving on
  53. time.sleep(3)
  54. @contextmanager
  55. def ray_start_cluster_client_server_pair(address):
  56. ray._inside_client_test = True
  57. def ray_connect_handler(job_config=None, **ray_init_kwargs):
  58. real_ray.init(address=address)
  59. server = ray_client_server.serve(
  60. "127.0.0.1", 50051, ray_connect_handler=ray_connect_handler
  61. )
  62. ray.connect("127.0.0.1:50051")
  63. try:
  64. yield ray, server
  65. finally:
  66. ray._inside_client_test = False
  67. ray.disconnect()
  68. server.stop(0)