tls_utils.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import datetime
  2. import os
  3. import socket
  4. from ray._common.network_utils import (
  5. get_localhost_ip,
  6. node_ip_address_from_perspective,
  7. )
  8. def generate_self_signed_tls_certs():
  9. """Create self-signed key/cert pair for testing.
  10. This method requires the library ``cryptography`` be installed.
  11. """
  12. try:
  13. from cryptography import x509
  14. from cryptography.hazmat.backends import default_backend
  15. from cryptography.hazmat.primitives import hashes, serialization
  16. from cryptography.hazmat.primitives.asymmetric import rsa
  17. from cryptography.x509.oid import NameOID
  18. except ImportError:
  19. raise ImportError(
  20. "Using `Security.temporary` requires `cryptography`, please "
  21. "install it using either pip or conda"
  22. )
  23. key = rsa.generate_private_key(
  24. public_exponent=65537, key_size=2048, backend=default_backend()
  25. )
  26. key_contents = key.private_bytes(
  27. encoding=serialization.Encoding.PEM,
  28. format=serialization.PrivateFormat.PKCS8,
  29. encryption_algorithm=serialization.NoEncryption(),
  30. ).decode()
  31. ray_interal = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "ray-internal")])
  32. altnames = x509.SubjectAlternativeName(
  33. [
  34. x509.DNSName(
  35. socket.gethostbyname(socket.gethostname())
  36. ), # Probably 127.0.0.1 or ::1
  37. x509.DNSName(get_localhost_ip()),
  38. x509.DNSName(node_ip_address_from_perspective()),
  39. x509.DNSName("localhost"),
  40. ]
  41. )
  42. now = datetime.datetime.utcnow()
  43. cert = (
  44. x509.CertificateBuilder()
  45. .subject_name(ray_interal)
  46. .issuer_name(ray_interal)
  47. .add_extension(altnames, critical=False)
  48. .public_key(key.public_key())
  49. .serial_number(x509.random_serial_number())
  50. .not_valid_before(now)
  51. .not_valid_after(now + datetime.timedelta(days=365))
  52. .sign(key, hashes.SHA256(), default_backend())
  53. )
  54. cert_contents = cert.public_bytes(serialization.Encoding.PEM).decode()
  55. return cert_contents, key_contents
  56. def add_port_to_grpc_server(server, address):
  57. import grpc
  58. if os.environ.get("RAY_USE_TLS", "0").lower() in ("1", "true"):
  59. server_cert_chain, private_key, ca_cert = load_certs_from_env()
  60. credentials = grpc.ssl_server_credentials(
  61. [(private_key, server_cert_chain)],
  62. root_certificates=ca_cert,
  63. require_client_auth=ca_cert is not None,
  64. )
  65. return server.add_secure_port(address, credentials)
  66. else:
  67. return server.add_insecure_port(address)
  68. def load_certs_from_env():
  69. tls_env_vars = ["RAY_TLS_SERVER_CERT", "RAY_TLS_SERVER_KEY", "RAY_TLS_CA_CERT"]
  70. if any(v not in os.environ for v in tls_env_vars):
  71. raise RuntimeError(
  72. "If the environment variable RAY_USE_TLS is set to true "
  73. "then RAY_TLS_SERVER_CERT, RAY_TLS_SERVER_KEY and "
  74. "RAY_TLS_CA_CERT must also be set."
  75. )
  76. with open(os.environ["RAY_TLS_SERVER_CERT"], "rb") as f:
  77. server_cert_chain = f.read()
  78. with open(os.environ["RAY_TLS_SERVER_KEY"], "rb") as f:
  79. private_key = f.read()
  80. with open(os.environ["RAY_TLS_CA_CERT"], "rb") as f:
  81. ca_cert = f.read()
  82. return server_cert_chain, private_key, ca_cert