grpc_utils.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. import os
  2. from concurrent import futures
  3. from typing import Any, Optional, Sequence, Tuple
  4. import grpc
  5. from grpc import aio as aiogrpc
  6. import ray
  7. from ray._private.authentication import authentication_utils
  8. from ray._private.tls_utils import load_certs_from_env
  9. def init_grpc_channel(
  10. address: str,
  11. options: Optional[Sequence[Tuple[str, Any]]] = None,
  12. asynchronous: bool = False,
  13. credentials: Optional[grpc.ChannelCredentials] = None,
  14. ):
  15. """Create a gRPC channel with authentication interceptors if token auth is enabled.
  16. This function handles:
  17. - TLS configuration via RAY_USE_TLS environment variable or custom credentials
  18. - Authentication interceptors when token auth is enabled
  19. - Keepalive settings from Ray config
  20. - Both synchronous and asynchronous channels
  21. Args:
  22. address: The gRPC server address (host:port)
  23. options: Optional gRPC channel options as sequence of (key, value) tuples
  24. asynchronous: If True, create async channel; otherwise sync
  25. credentials: Optional custom gRPC credentials for TLS. If provided, takes
  26. precedence over RAY_USE_TLS environment variable.
  27. Returns:
  28. grpc.Channel or grpc.aio.Channel: Configured gRPC channel with interceptors
  29. """
  30. grpc_module = aiogrpc if asynchronous else grpc
  31. options = options or []
  32. options_dict = dict(options)
  33. options_dict["grpc.keepalive_time_ms"] = options_dict.get(
  34. "grpc.keepalive_time_ms", ray._config.grpc_client_keepalive_time_ms()
  35. )
  36. options_dict["grpc.keepalive_timeout_ms"] = options_dict.get(
  37. "grpc.keepalive_timeout_ms", ray._config.grpc_client_keepalive_timeout_ms()
  38. )
  39. options = options_dict.items()
  40. # Build interceptors list
  41. interceptors = []
  42. if authentication_utils.is_token_auth_enabled():
  43. from ray._private.authentication.grpc_authentication_client_interceptor import (
  44. SyncAuthenticationMetadataClientInterceptor,
  45. get_async_auth_interceptors,
  46. )
  47. if asynchronous:
  48. interceptors.extend(get_async_auth_interceptors())
  49. else:
  50. interceptors.append(SyncAuthenticationMetadataClientInterceptor())
  51. # Determine channel type and credentials
  52. if credentials is not None:
  53. # Use provided custom credentials (takes precedence)
  54. channel_creator = grpc_module.secure_channel
  55. base_args = (address, credentials)
  56. elif os.environ.get("RAY_USE_TLS", "0").lower() in ("1", "true"):
  57. # Use TLS from environment variables
  58. server_cert_chain, private_key, ca_cert = load_certs_from_env()
  59. tls_credentials = grpc.ssl_channel_credentials(
  60. certificate_chain=server_cert_chain,
  61. private_key=private_key,
  62. root_certificates=ca_cert,
  63. )
  64. channel_creator = grpc_module.secure_channel
  65. base_args = (address, tls_credentials)
  66. else:
  67. # Insecure channel
  68. channel_creator = grpc_module.insecure_channel
  69. base_args = (address,)
  70. # Create channel (async channels get interceptors in constructor, sync via intercept_channel)
  71. if asynchronous:
  72. channel = channel_creator(
  73. *base_args, options=options, interceptors=interceptors
  74. )
  75. else:
  76. channel = channel_creator(*base_args, options=options)
  77. if interceptors:
  78. channel = grpc.intercept_channel(channel, *interceptors)
  79. return channel
  80. def create_grpc_server_with_interceptors(
  81. max_workers: Optional[int] = None,
  82. thread_name_prefix: str = "grpc_server",
  83. options: Optional[Sequence[Tuple[str, Any]]] = None,
  84. asynchronous: bool = False,
  85. ):
  86. """Create a gRPC server with authentication interceptors if token auth is enabled.
  87. This function handles:
  88. - Authentication interceptors when token auth is enabled
  89. - Both synchronous and asynchronous servers
  90. - Thread pool configuration for sync servers
  91. Args:
  92. max_workers: Max thread pool workers (required for sync, ignored for async)
  93. thread_name_prefix: Thread name prefix for sync thread pool
  94. options: Optional gRPC server options as sequence of (key, value) tuples
  95. asynchronous: If True, create async server; otherwise sync
  96. Returns:
  97. grpc.Server or grpc.aio.Server: Configured gRPC server with interceptors
  98. """
  99. grpc_module = aiogrpc if asynchronous else grpc
  100. # Build interceptors list
  101. interceptors = []
  102. if authentication_utils.is_token_auth_enabled():
  103. if asynchronous:
  104. from ray._private.authentication.grpc_authentication_server_interceptor import (
  105. AsyncAuthenticationServerInterceptor,
  106. )
  107. interceptors.append(AsyncAuthenticationServerInterceptor())
  108. else:
  109. from ray._private.authentication.grpc_authentication_server_interceptor import (
  110. SyncAuthenticationServerInterceptor,
  111. )
  112. interceptors.append(SyncAuthenticationServerInterceptor())
  113. # Create server
  114. if asynchronous:
  115. server = grpc_module.server(
  116. interceptors=interceptors if interceptors else None,
  117. options=options,
  118. )
  119. else:
  120. if max_workers is None:
  121. raise ValueError("max_workers is required for synchronous gRPC servers")
  122. executor = futures.ThreadPoolExecutor(
  123. max_workers=max_workers,
  124. thread_name_prefix=thread_name_prefix,
  125. )
  126. server = grpc_module.server(
  127. executor,
  128. interceptors=interceptors if interceptors else None,
  129. options=options,
  130. )
  131. return server