| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182 |
- import time
- from contextlib import contextmanager
- from typing import Any, Dict
- import ray as real_ray
- import ray.util.client.server.server as ray_client_server
- from ray._private.client_mode_hook import disable_client_hook
- from ray.job_config import JobConfig
- from ray.util.client import ray
- @contextmanager
- def ray_start_client_server(metadata=None, ray_connect_handler=None, **kwargs):
- with ray_start_client_server_pair(
- metadata=metadata, ray_connect_handler=ray_connect_handler, **kwargs
- ) as pair:
- client, server = pair
- yield client
- @contextmanager
- def ray_start_client_server_for_address(address):
- """
- Starts a Ray client server that initializes drivers at the specified address.
- """
- def connect_handler(
- job_config: JobConfig = None, **ray_init_kwargs: Dict[str, Any]
- ):
- import ray
- with disable_client_hook():
- if not ray.is_initialized():
- return ray.init(address, job_config=job_config, **ray_init_kwargs)
- with ray_start_client_server(ray_connect_handler=connect_handler) as ray:
- yield ray
- @contextmanager
- def ray_start_client_server_pair(metadata=None, ray_connect_handler=None, **kwargs):
- ray._inside_client_test = True
- with disable_client_hook():
- assert not ray.is_initialized()
- server = ray_client_server.serve(
- "127.0.0.1", 50051, ray_connect_handler=ray_connect_handler
- )
- ray.connect("127.0.0.1:50051", metadata=metadata, **kwargs)
- try:
- yield ray, server
- finally:
- ray._inside_client_test = False
- ray.disconnect()
- server.stop(0)
- del server
- start = time.monotonic()
- with disable_client_hook():
- while ray.is_initialized():
- time.sleep(1)
- if time.monotonic() - start > 30:
- raise RuntimeError("Failed to terminate Ray")
- # Allow windows to close processes before moving on
- time.sleep(3)
- @contextmanager
- def ray_start_cluster_client_server_pair(address):
- ray._inside_client_test = True
- def ray_connect_handler(job_config=None, **ray_init_kwargs):
- real_ray.init(address=address)
- server = ray_client_server.serve(
- "127.0.0.1", 50051, ray_connect_handler=ray_connect_handler
- )
- ray.connect("127.0.0.1:50051")
- try:
- yield ray, server
- finally:
- ray._inside_client_test = False
- ray.disconnect()
- server.stop(0)
|