| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183 |
- import asyncio
- import fnmatch
- import io
- import json
- import logging
- import os
- import pathlib
- import random
- import socket
- import subprocess
- import sys
- import tempfile
- import threading
- import time
- import timeit
- import traceback
- import uuid
- from collections import defaultdict
- from contextlib import contextmanager, redirect_stderr, redirect_stdout
- from dataclasses import dataclass, field
- from datetime import datetime
- from typing import Any, Callable, Dict, List, Optional, Set, Tuple
- from urllib.parse import quote, urlparse
- import requests
- import yaml
- import ray
- import ray._private.memory_monitor as memory_monitor
- import ray._private.services
- import ray._private.services as services
- import ray._private.utils
- import ray.dashboard.consts as dashboard_consts
- from ray._common.network_utils import build_address, parse_address
- from ray._common.test_utils import wait_for_condition
- from ray._common.utils import get_or_create_event_loop
- from ray._private import (
- ray_constants,
- )
- from ray._private.internal_api import memory_summary
- from ray._private.services import ProcessInfo
- from ray._private.tls_utils import generate_self_signed_tls_certs
- from ray._private.worker import RayContext
- from ray._raylet import Config, GcsClient, GcsClientOptions, GlobalStateAccessor
- from ray.core.generated import (
- gcs_pb2,
- gcs_service_pb2,
- node_manager_pb2,
- )
- from ray.util.queue import Empty, Queue, _QueueActor
- from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
- from ray.util.state import get_actor, list_actors
- import psutil # We must import psutil after ray because we bundle it with ray.
- logger = logging.getLogger(__name__)
- EXE_SUFFIX = ".exe" if sys.platform == "win32" else ""
- RAY_PATH = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
- REDIS_EXECUTABLE = os.path.join(
- RAY_PATH, "core/src/ray/thirdparty/redis/src/redis-server" + EXE_SUFFIX
- )
- try:
- from prometheus_client.core import Metric
- from prometheus_client.parser import Sample, text_string_to_metric_families
- except (ImportError, ModuleNotFoundError):
- Metric = None
- Sample = None
- def text_string_to_metric_families(*args, **kwargs):
- raise ModuleNotFoundError("`prometheus_client` not found")
- def make_global_state_accessor(ray_context):
- gcs_options = GcsClientOptions.create(
- ray_context.address_info["gcs_address"],
- None,
- allow_cluster_id_nil=True,
- fetch_cluster_id_if_nil=False,
- )
- global_state_accessor = GlobalStateAccessor(gcs_options)
- global_state_accessor.connect()
- return global_state_accessor
- def external_redis_test_enabled():
- return os.environ.get("TEST_EXTERNAL_REDIS") == "1"
- def redis_replicas():
- return int(os.environ.get("TEST_EXTERNAL_REDIS_REPLICAS", "1"))
- def redis_sentinel_replicas():
- return int(os.environ.get("TEST_EXTERNAL_REDIS_SENTINEL_REPLICAS", "2"))
- def get_redis_cli(port, enable_tls):
- try:
- # If there is no redis libs installed, skip the check.
- # This could happen In minimal test, where we don't have
- # redis.
- import redis
- except Exception:
- return True
- params = {}
- if enable_tls:
- from ray._raylet import Config
- params = {"ssl": True, "ssl_cert_reqs": "required"}
- if Config.REDIS_CA_CERT():
- params["ssl_ca_certs"] = Config.REDIS_CA_CERT()
- if Config.REDIS_CLIENT_CERT():
- params["ssl_certfile"] = Config.REDIS_CLIENT_CERT()
- if Config.REDIS_CLIENT_KEY():
- params["ssl_keyfile"] = Config.REDIS_CLIENT_KEY()
- return redis.Redis("localhost", str(port), **params)
- def start_redis_sentinel_instance(
- session_dir_path: str,
- port: int,
- redis_master_port: int,
- password: Optional[str] = None,
- enable_tls: bool = False,
- db_dir=None,
- free_port=0,
- ):
- config_file = os.path.join(
- session_dir_path, "redis-sentinel-" + uuid.uuid4().hex + ".conf"
- )
- config_lines = []
- # Port for this Sentinel instance
- if enable_tls:
- config_lines.append(f"port {free_port}")
- else:
- config_lines.append(f"port {port}")
- # Monitor the Redis master
- config_lines.append(f"sentinel monitor redis-test 127.0.0.1 {redis_master_port} 1")
- config_lines.append(
- "sentinel down-after-milliseconds redis-test 1000"
- ) # failover after 1 second
- config_lines.append("sentinel failover-timeout redis-test 5000") #
- config_lines.append("sentinel parallel-syncs redis-test 1")
- if password:
- config_lines.append(f"sentinel auth-pass redis-test {password}")
- if enable_tls:
- config_lines.append(f"tls-port {port}")
- if Config.REDIS_CA_CERT():
- config_lines.append(f"tls-ca-cert-file {Config.REDIS_CA_CERT()}")
- # Check and add TLS client certificate file
- if Config.REDIS_CLIENT_CERT():
- config_lines.append(f"tls-cert-file {Config.REDIS_CLIENT_CERT()}")
- # Check and add TLS client key file
- if Config.REDIS_CLIENT_KEY():
- config_lines.append(f"tls-key-file {Config.REDIS_CLIENT_KEY()}")
- config_lines.append("tls-auth-clients no")
- config_lines.append("sentinel tls-auth-clients redis-test no")
- if db_dir:
- config_lines.append(f"dir {db_dir}")
- with open(config_file, "w") as f:
- f.write("\n".join(config_lines))
- command = [REDIS_EXECUTABLE, config_file, "--sentinel"]
- process_info = ray._private.services.start_ray_process(
- command,
- ray_constants.PROCESS_TYPE_REDIS_SERVER,
- fate_share=False,
- )
- return process_info
- def start_redis_instance(
- session_dir_path: str,
- port: int,
- redis_max_clients: Optional[int] = None,
- num_retries: int = 20,
- stdout_file: Optional[str] = None,
- stderr_file: Optional[str] = None,
- password: Optional[str] = None,
- fate_share: Optional[bool] = None,
- port_denylist: Optional[List[int]] = None,
- listen_to_localhost_only: bool = False,
- enable_tls: bool = False,
- replica_of=None,
- leader_id=None,
- db_dir=None,
- free_port=0,
- ):
- """Start a single Redis server.
- Notes:
- We will initially try to start the Redis instance at the given port,
- and then try at most `num_retries - 1` times to start the Redis
- instance at successive random ports.
- Args:
- session_dir_path: Path to the session directory of
- this Ray cluster.
- port: Try to start a Redis server at this port.
- redis_max_clients: If this is provided, Ray will attempt to configure
- Redis with this maxclients number.
- num_retries: The number of times to attempt to start Redis at
- successive ports.
- stdout_file: A file handle opened for writing to redirect stdout to. If
- no redirection should happen, then this should be None.
- stderr_file: A file handle opened for writing to redirect stderr to. If
- no redirection should happen, then this should be None.
- password: Prevents external clients without the password
- from connecting to Redis if provided.
- port_denylist: A set of denylist ports that shouldn't
- be used when allocating a new port.
- listen_to_localhost_only: Redis server only listens to
- localhost (127.0.0.1) if it's true,
- otherwise it listens to all network interfaces.
- enable_tls: Enable the TLS/SSL in Redis or not
- Returns:
- A tuple of the port used by Redis and ProcessInfo for the process that
- was started. If a port is passed in, then the returned port value
- is the same.
- Raises:
- Exception: An exception is raised if Redis could not be started.
- """
- assert os.path.isfile(REDIS_EXECUTABLE)
- # Construct the command to start the Redis server.
- command = [REDIS_EXECUTABLE]
- if password:
- if " " in password:
- raise ValueError("Spaces not permitted in redis password.")
- command += ["--requirepass", password]
- if redis_replicas() > 1:
- command += ["--cluster-enabled", "yes", "--cluster-config-file", f"node-{port}"]
- if enable_tls:
- command += [
- "--tls-port",
- str(port),
- "--loglevel",
- "warning",
- "--port",
- str(free_port),
- ]
- else:
- command += ["--port", str(port), "--loglevel", "warning"]
- if listen_to_localhost_only:
- command += ["--bind", "127.0.0.1"]
- pidfile = os.path.join(session_dir_path, "redis-" + uuid.uuid4().hex + ".pid")
- command += ["--pidfile", pidfile]
- if enable_tls:
- if Config.REDIS_CA_CERT():
- command += ["--tls-ca-cert-file", Config.REDIS_CA_CERT()]
- if Config.REDIS_CLIENT_CERT():
- command += ["--tls-cert-file", Config.REDIS_CLIENT_CERT()]
- if Config.REDIS_CLIENT_KEY():
- command += ["--tls-key-file", Config.REDIS_CLIENT_KEY()]
- if replica_of is not None:
- command += ["--tls-replication", "yes"]
- command += ["--tls-auth-clients", "no", "--tls-cluster", "yes"]
- if sys.platform != "win32":
- command += ["--save", "", "--appendonly", "no"]
- if db_dir is not None:
- command += ["--dir", str(db_dir)]
- process_info = ray._private.services.start_ray_process(
- command,
- ray_constants.PROCESS_TYPE_REDIS_SERVER,
- stdout_file=stdout_file,
- stderr_file=stderr_file,
- fate_share=fate_share,
- )
- node_id = None
- if redis_replicas() > 1:
- # Setup redis cluster
- import redis
- while True:
- try:
- redis_cli = get_redis_cli(port, enable_tls)
- if replica_of is None:
- slots = [str(i) for i in range(16384)]
- redis_cli.cluster("addslots", *slots)
- else:
- logger.info(redis_cli.cluster("meet", "127.0.0.1", str(replica_of)))
- logger.info(redis_cli.cluster("replicate", leader_id))
- node_id = redis_cli.cluster("myid")
- break
- except (
- redis.exceptions.ConnectionError,
- redis.exceptions.ResponseError,
- ) as e:
- from time import sleep
- logger.info(
- f"Waiting for redis to be up. Check failed with error: {e}. "
- "Will retry in 0.1s"
- )
- if process_info.process.poll() is not None:
- raise Exception(
- f"Redis process exited unexpectedly: {process_info}. "
- f"Exit code: {process_info.process.returncode}"
- )
- sleep(0.1)
- logger.info(
- f"Redis started with node_id {node_id} and pid {process_info.process.pid}"
- )
- return node_id, process_info
- def _pid_alive(pid):
- """Check if the process with this PID is alive or not.
- Args:
- pid: The pid to check.
- Returns:
- This returns false if the process is dead. Otherwise, it returns true.
- """
- alive = True
- try:
- proc = psutil.Process(pid)
- if proc.status() == psutil.STATUS_ZOMBIE:
- alive = False
- except psutil.NoSuchProcess:
- alive = False
- return alive
- def _check_call_windows(main, argv, capture_stdout=False, capture_stderr=False):
- # We use this function instead of calling the "ray" command to work around
- # some deadlocks that occur when piping ray's output on Windows
- stream = io.TextIOWrapper(io.BytesIO(), encoding=sys.stdout.encoding)
- old_argv = sys.argv[:]
- try:
- sys.argv = argv[:]
- try:
- with redirect_stderr(stream if capture_stderr else sys.stderr):
- with redirect_stdout(stream if capture_stdout else sys.stdout):
- main()
- finally:
- stream.flush()
- except SystemExit as ex:
- if ex.code:
- output = stream.buffer.getvalue()
- raise subprocess.CalledProcessError(ex.code, argv, output)
- except Exception as ex:
- output = stream.buffer.getvalue()
- raise subprocess.CalledProcessError(1, argv, output, ex.args[0])
- finally:
- sys.argv = old_argv
- if capture_stdout:
- sys.stdout.buffer.write(stream.buffer.getvalue())
- elif capture_stderr:
- sys.stderr.buffer.write(stream.buffer.getvalue())
- return stream.buffer.getvalue()
- def check_call_subprocess(argv, capture_stdout=False, capture_stderr=False):
- # We use this function instead of calling the "ray" command to work around
- # some deadlocks that occur when piping ray's output on Windows
- from ray.scripts.scripts import main as ray_main
- if sys.platform == "win32":
- result = _check_call_windows(
- ray_main, argv, capture_stdout=capture_stdout, capture_stderr=capture_stderr
- )
- else:
- stdout_redir = None
- stderr_redir = None
- if capture_stdout:
- stdout_redir = subprocess.PIPE
- if capture_stderr and capture_stdout:
- stderr_redir = subprocess.STDOUT
- elif capture_stderr:
- stderr_redir = subprocess.PIPE
- proc = subprocess.Popen(argv, stdout=stdout_redir, stderr=stderr_redir)
- (stdout, stderr) = proc.communicate()
- if proc.returncode:
- raise subprocess.CalledProcessError(proc.returncode, argv, stdout, stderr)
- result = b"".join([s for s in [stdout, stderr] if s is not None])
- return result
- def check_call_ray(args, capture_stdout=False, capture_stderr=False):
- check_call_subprocess(["ray"] + args, capture_stdout, capture_stderr)
- def get_dashboard_agent_address(gcs_client: GcsClient, node_id: str):
- result = gcs_client.internal_kv_get(
- f"{dashboard_consts.DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX}{node_id}".encode(),
- namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
- timeout=dashboard_consts.GCS_RPC_TIMEOUT_SECONDS,
- )
- if result:
- # Returns [ip, http_port, grpc_port]
- ip, _, grpc_port = json.loads(result)
- return f"{ip}:{grpc_port}"
- return None
- def wait_for_dashboard_agent_available(cluster):
- gcs_client = GcsClient(address=cluster.address)
- wait_for_condition(
- lambda: get_dashboard_agent_address(gcs_client, cluster.head_node.node_id)
- is not None
- )
- def wait_for_aggregator_agent(address: str, node_id: str, timeout: float = 10) -> None:
- """Wait for the aggregator agent to be ready by checking socket connectivity."""
- gcs_client = GcsClient(address=address)
- # Wait for the agent to publish its address
- wait_for_condition(
- lambda: get_dashboard_agent_address(gcs_client, node_id) is not None
- )
- # Get the agent address and test socket connectivity
- agent_address = get_dashboard_agent_address(gcs_client, node_id)
- parsed = urlparse(f"grpc://{agent_address}")
- def _can_connect() -> bool:
- try:
- with socket.create_connection((parsed.hostname, parsed.port), timeout=1):
- return True
- except OSError:
- return False
- wait_for_condition(_can_connect, timeout=timeout)
- def wait_for_aggregator_agent_if_enabled(
- address: str, node_id: str, timeout: float = 10
- ) -> None:
- """Wait for aggregator agent only if aggregator mode is enabled.
- Checks RAY_enable_core_worker_ray_event_to_aggregator env var.
- """
- if os.environ.get("RAY_enable_core_worker_ray_event_to_aggregator") == "1":
- wait_for_aggregator_agent(address, node_id, timeout)
- def wait_for_pid_to_exit(pid: int, timeout: float = 20):
- start_time = time.time()
- while time.time() - start_time < timeout:
- if not _pid_alive(pid):
- return
- time.sleep(0.1)
- raise TimeoutError(f"Timed out while waiting for process {pid} to exit.")
- def wait_for_children_of_pid(pid, num_children=1, timeout=20):
- p = psutil.Process(pid)
- start_time = time.time()
- alive = []
- while time.time() - start_time < timeout:
- alive = p.children(recursive=False)
- num_alive = len(alive)
- if num_alive >= num_children:
- return
- time.sleep(0.1)
- raise TimeoutError(
- f"Timed out while waiting for process {pid} children to start "
- f"({num_alive}/{num_children} started: {alive})."
- )
- def wait_for_children_of_pid_to_exit(pid, timeout=20):
- children = psutil.Process(pid).children()
- if len(children) == 0:
- return
- _, alive = psutil.wait_procs(children, timeout=timeout)
- if len(alive) > 0:
- raise TimeoutError(
- "Timed out while waiting for process children to exit."
- " Children still alive: {}.".format([p.name() for p in alive])
- )
- def kill_process_by_name(name, SIGKILL=False):
- for p in psutil.process_iter(attrs=["name"]):
- if p.info["name"] == name + ray._private.services.EXE_SUFFIX:
- if SIGKILL:
- p.kill()
- else:
- p.terminate()
- def kill_processes(process_infos: List[ProcessInfo]):
- """
- Forcefully kills the list of given processes.
- Ignores processes that are already dead.
- Args:
- process_infos: The list of ProcessInfo representing the processes to kill.
- Raises:
- TimeoutError: If the process did not exit within 5 seconds.
- """
- for process_info in process_infos:
- try:
- process_info.process.kill()
- process_info.process.wait(timeout=5)
- except ProcessLookupError:
- # Process already dead
- pass
- except subprocess.TimeoutExpired as exception:
- raise TimeoutError(
- f"Process {process_info.process.pid} did not exit within 5 seconds "
- "after SIGKILL"
- ) from exception
- def run_string_as_driver(driver_script: str, env: Dict = None, encode: str = "utf-8"):
- """Run a driver as a separate process.
- Args:
- driver_script: A string to run as a Python script.
- env: The environment variables for the driver.
- Returns:
- The script's output.
- """
- proc = subprocess.Popen(
- [sys.executable, "-"],
- stdin=subprocess.PIPE,
- stdout=subprocess.PIPE,
- stderr=subprocess.STDOUT,
- env=env,
- )
- with proc:
- output = proc.communicate(driver_script.encode(encoding=encode))[0]
- if proc.returncode:
- print(ray._common.utils.decode(output, encode_type=encode))
- logger.error(proc.stderr)
- raise subprocess.CalledProcessError(
- proc.returncode, proc.args, output, proc.stderr
- )
- out = ray._common.utils.decode(output, encode_type=encode)
- return out
- def run_string_as_driver_stdout_stderr(
- driver_script: str, env: Dict = None, encode: str = "utf-8"
- ) -> Tuple[str, str]:
- """Run a driver as a separate process.
- Args:
- driver_script: A string to run as a Python script.
- env: The environment variables for the driver.
- Returns:
- The script's stdout and stderr.
- """
- proc = subprocess.Popen(
- [sys.executable, "-"],
- stdin=subprocess.PIPE,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- env=env,
- )
- with proc:
- outputs_bytes = proc.communicate(driver_script.encode(encoding=encode))
- out_str, err_str = [
- ray._common.utils.decode(output, encode_type=encode)
- for output in outputs_bytes
- ]
- if proc.returncode:
- print(out_str)
- print(err_str)
- raise subprocess.CalledProcessError(
- proc.returncode, proc.args, out_str, err_str
- )
- return out_str, err_str
- def run_string_as_driver_nonblocking(driver_script, env: Dict = None):
- """Start a driver as a separate process and return immediately.
- Args:
- driver_script: A string to run as a Python script.
- Returns:
- A handle to the driver process.
- """
- script = "; ".join(
- [
- "import sys",
- "script = sys.stdin.read()",
- "sys.stdin.close()",
- "del sys",
- 'exec("del script\\n" + script)',
- ]
- )
- proc = subprocess.Popen(
- [sys.executable, "-c", script],
- stdin=subprocess.PIPE,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- env=env,
- )
- proc.stdin.write(driver_script.encode("ascii"))
- proc.stdin.close()
- return proc
- def convert_actor_state(state):
- if not state:
- return None
- return gcs_pb2.ActorTableData.ActorState.DESCRIPTOR.values_by_number[state].name
- def wait_for_num_actors(num_actors, state=None, timeout=10):
- state = convert_actor_state(state)
- start_time = time.time()
- while time.time() - start_time < timeout:
- if (
- len(
- list_actors(
- filters=[("state", "=", state)] if state else None,
- limit=num_actors,
- )
- )
- >= num_actors
- ):
- return
- time.sleep(0.1)
- raise TimeoutError("Timed out while waiting for global state.")
- def kill_actor_and_wait_for_failure(actor, timeout=10, retry_interval_ms=100):
- actor_id = actor._actor_id.hex()
- current_num_restarts = get_actor(id=actor_id).num_restarts
- ray.kill(actor)
- start = time.time()
- while time.time() - start <= timeout:
- actor_state = get_actor(id=actor_id)
- if (
- actor_state.state == "DEAD"
- or actor_state.num_restarts > current_num_restarts
- ):
- return
- time.sleep(retry_interval_ms / 1000.0)
- raise RuntimeError("It took too much time to kill an actor: {}".format(actor_id))
- def wait_for_assertion(
- assertion_predictor: Callable,
- timeout: int = 10,
- retry_interval_ms: int = 100,
- raise_exceptions: bool = False,
- **kwargs: Any,
- ):
- """Wait until an assertion is met or time out with an exception.
- Args:
- assertion_predictor: A function that predicts the assertion.
- timeout: Maximum timeout in seconds.
- retry_interval_ms: Retry interval in milliseconds.
- raise_exceptions: If true, exceptions that occur while executing
- assertion_predictor won't be caught and instead will be raised.
- **kwargs: Arguments to pass to the condition_predictor.
- Raises:
- RuntimeError: If the assertion is not met before the timeout expires.
- """
- def _assertion_to_condition():
- try:
- assertion_predictor(**kwargs)
- return True
- except AssertionError:
- return False
- try:
- wait_for_condition(
- _assertion_to_condition,
- timeout=timeout,
- retry_interval_ms=retry_interval_ms,
- raise_exceptions=raise_exceptions,
- **kwargs,
- )
- except RuntimeError:
- assertion_predictor(**kwargs) # Should fail assert
- @dataclass
- class MetricSamplePattern:
- name: Optional[str] = None
- value: Optional[str] = None
- partial_label_match: Optional[Dict[str, str]] = None
- def matches(self, sample: Sample):
- if self.name is not None:
- if self.name != sample.name:
- return False
- if self.value is not None:
- if self.value != sample.value:
- return False
- if self.partial_label_match is not None:
- for label, value in self.partial_label_match.items():
- if sample.labels.get(label) != value:
- return False
- return True
- @dataclass
- class PrometheusTimeseries:
- """A collection of timeseries from multiple addresses. Each timeseries is a
- collection of samples with the same metric name and labels. Concretely:
- - components_dict: a dictionary of addresses to the Component labels
- - metric_descriptors: a dictionary of metric names to the Metric object
- - metric_samples: the latest value of each label
- """
- components_dict: Dict[str, Set[str]] = field(default_factory=dict)
- metric_descriptors: Dict[str, Metric] = field(default_factory=dict)
- metric_samples: Dict[frozenset, Sample] = field(default_factory=dict)
- def flush(self):
- self.components_dict.clear()
- self.metric_descriptors.clear()
- self.metric_samples.clear()
- def get_metric_check_condition(
- metrics_to_check: List[MetricSamplePattern],
- timeseries: PrometheusTimeseries,
- export_addr: Optional[str] = None,
- ) -> Callable[[], bool]:
- """A condition to check if a prometheus metrics reach a certain value.
- This is a blocking check that can be passed into a `wait_for_condition`
- style function.
- Args:
- metrics_to_check: A list of MetricSamplePattern. The fields that
- aren't `None` will be matched.
- timeseries: A PrometheusTimeseries object to store the metrics.
- export_addr: Optional address to export metrics to.
- Returns:
- A function that returns True if all the metrics are emitted.
- """
- node_info = ray.nodes()[0]
- metrics_export_port = node_info["MetricsExportPort"]
- addr = node_info["NodeManagerAddress"]
- prom_addr = export_addr or build_address(addr, metrics_export_port)
- def f():
- for metric_pattern in metrics_to_check:
- metric_samples = fetch_prometheus_timeseries(
- [prom_addr], timeseries
- ).metric_samples.values()
- for metric_sample in metric_samples:
- if metric_pattern.matches(metric_sample):
- break
- else:
- logger.info(
- f"Didn't find {metric_pattern} in all samples: {metric_samples}",
- )
- return False
- return True
- return f
- def wait_until_succeeded_without_exception(
- func, exceptions, *args, timeout_ms=1000, retry_interval_ms=100, raise_last_ex=False
- ):
- """A helper function that waits until a given function
- completes without exceptions.
- Args:
- func: A function to run.
- exceptions: Exceptions that are supposed to occur.
- args: arguments to pass for a given func
- timeout_ms: Maximum timeout in milliseconds.
- retry_interval_ms: Retry interval in milliseconds.
- raise_last_ex: Raise the last exception when timeout.
- Return:
- Whether exception occurs within a timeout.
- """
- if isinstance(type(exceptions), tuple):
- raise Exception("exceptions arguments should be given as a tuple")
- time_elapsed = 0
- start = time.time()
- last_ex = None
- while time_elapsed <= timeout_ms:
- try:
- func(*args)
- return True
- except exceptions as ex:
- last_ex = ex
- time_elapsed = (time.time() - start) * 1000
- time.sleep(retry_interval_ms / 1000.0)
- if raise_last_ex:
- ex_stack = (
- traceback.format_exception(type(last_ex), last_ex, last_ex.__traceback__)
- if last_ex
- else []
- )
- ex_stack = "".join(ex_stack)
- raise Exception(f"Timed out while testing, {ex_stack}")
- return False
- def recursive_fnmatch(dirpath, pattern):
- """Looks at a file directory subtree for a filename pattern.
- Similar to glob.glob(..., recursive=True) but also supports 2.7
- """
- matches = []
- for root, dirnames, filenames in os.walk(dirpath):
- for filename in fnmatch.filter(filenames, pattern):
- matches.append(os.path.join(root, filename))
- return matches
- def generate_system_config_map(**kwargs):
- ray_kwargs = {
- "_system_config": kwargs,
- }
- return ray_kwargs
- def same_elements(elems_a, elems_b):
- """Checks if two iterables (such as lists) contain the same elements. Elements
- do not have to be hashable (this allows us to compare sets of dicts for
- example). This comparison is not necessarily efficient.
- """
- a = list(elems_a)
- b = list(elems_b)
- for x in a:
- if x not in b:
- return False
- for x in b:
- if x not in a:
- return False
- return True
- @ray.remote
- def _put(obj):
- return obj
- def put_object(obj, use_ray_put):
- if use_ray_put:
- return ray.put(obj)
- else:
- return _put.remote(obj)
- def wait_until_server_available(address, timeout_ms=5000, retry_interval_ms=100):
- ip, port_str = parse_address(address)
- port = int(port_str)
- time_elapsed = 0
- start = time.time()
- while time_elapsed <= timeout_ms:
- s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- s.settimeout(1)
- try:
- s.connect((ip, port))
- except Exception:
- time_elapsed = (time.time() - start) * 1000
- time.sleep(retry_interval_ms / 1000.0)
- s.close()
- continue
- s.close()
- return True
- return False
- def get_other_nodes(cluster, exclude_head=False):
- """Get all nodes except the one that we're connected to."""
- return [
- node
- for node in cluster.list_all_nodes()
- if node._raylet_socket_name
- != ray._private.worker._global_node._raylet_socket_name
- and (exclude_head is False or node.head is False)
- ]
- def get_non_head_nodes(cluster):
- """Get all non-head nodes."""
- return list(filter(lambda x: x.head is False, cluster.list_all_nodes()))
- def init_error_pubsub():
- """Initialize error info pub/sub"""
- s = ray._raylet.GcsErrorSubscriber(
- address=ray._private.worker.global_worker.gcs_client.address
- )
- s.subscribe()
- return s
- def get_error_message(subscriber, num=1e6, error_type=None, timeout=20):
- """Gets errors from GCS subscriber.
- Returns maximum `num` error strings within `timeout`.
- Only returns errors of `error_type` if specified.
- """
- deadline = time.time() + timeout
- msgs = []
- while time.time() < deadline and len(msgs) < num:
- _, error_data = subscriber.poll(timeout=deadline - time.time())
- if not error_data:
- # Timed out before any data is received.
- break
- if error_type is None or error_type == error_data["type"]:
- msgs.append(error_data)
- else:
- time.sleep(0.01)
- return msgs
- def init_log_pubsub():
- """Initialize log pub/sub"""
- s = ray._raylet.GcsLogSubscriber(
- address=ray._private.worker.global_worker.gcs_client.address
- )
- s.subscribe()
- return s
- def get_log_data(
- subscriber,
- num: int = 1e6,
- timeout: float = 20,
- job_id: Optional[str] = None,
- matcher=None,
- ) -> List[dict]:
- deadline = time.time() + timeout
- msgs = []
- while time.time() < deadline and len(msgs) < num:
- logs_data = subscriber.poll(timeout=deadline - time.time())
- if not logs_data:
- # Timed out before any data is received.
- break
- if job_id and job_id != logs_data["job"]:
- continue
- if matcher and all(not matcher(line) for line in logs_data["lines"]):
- continue
- msgs.append(logs_data)
- return msgs
- def get_log_message(
- subscriber,
- num: int = 1e6,
- timeout: float = 20,
- job_id: Optional[str] = None,
- matcher=None,
- ) -> List[List[str]]:
- """Gets log lines through GCS subscriber.
- Returns maximum `num` of log messages, within `timeout`.
- If `job_id` or `match` is specified, only returns log lines from `job_id`
- or when `matcher` is true.
- """
- msgs = get_log_data(subscriber, num, timeout, job_id, matcher)
- return [msg["lines"] for msg in msgs]
- def get_log_sources(
- subscriber,
- num: int = 1e6,
- timeout: float = 20,
- job_id: Optional[str] = None,
- matcher=None,
- ):
- """Get the source of all log messages"""
- msgs = get_log_data(subscriber, num, timeout, job_id, matcher)
- return {msg["pid"] for msg in msgs}
- def get_log_batch(
- subscriber,
- num: int,
- timeout: float = 20,
- job_id: Optional[str] = None,
- matcher=None,
- ) -> List[str]:
- """Gets log batches through GCS subscriber.
- Returns maximum `num` batches of logs. Each batch is a dict that includes
- metadata such as `pid`, `job_id`, and `lines` of log messages.
- If `job_id` or `match` is specified, only returns log batches from `job_id`
- or when `matcher` is true.
- """
- deadline = time.time() + timeout
- batches = []
- while time.time() < deadline and len(batches) < num:
- logs_data = subscriber.poll(timeout=deadline - time.time())
- if not logs_data:
- # Timed out before any data is received.
- break
- if job_id and job_id != logs_data["job"]:
- continue
- if matcher and not matcher(logs_data):
- continue
- batches.append(logs_data)
- return batches
- def format_web_url(url):
- """Format web url."""
- url = url.replace("localhost", "http://127.0.0.1")
- if not url.startswith("http://"):
- return "http://" + url
- return url
- def client_test_enabled() -> bool:
- return ray._private.client_mode_hook.is_client_mode_enabled
- def object_memory_usage() -> bool:
- """Returns the number of bytes used in the object store."""
- total = ray.cluster_resources().get("object_store_memory", 0)
- avail = ray.available_resources().get("object_store_memory", 0)
- return total - avail
- def fetch_raw_prometheus(prom_addresses):
- # Local import so minimal dependency tests can run without requests
- import requests
- for address in prom_addresses:
- try:
- response = requests.get(f"http://{address}/metrics")
- yield address, response.text
- except requests.exceptions.ConnectionError:
- continue
- def fetch_prometheus(prom_addresses):
- components_dict = {}
- metric_descriptors = {}
- metric_samples = []
- for address in prom_addresses:
- if address not in components_dict:
- components_dict[address] = set()
- for address, response in fetch_raw_prometheus(prom_addresses):
- for metric in text_string_to_metric_families(response):
- for sample in metric.samples:
- metric_descriptors[sample.name] = metric
- metric_samples.append(sample)
- if "Component" in sample.labels:
- components_dict[address].add(sample.labels["Component"])
- return components_dict, metric_descriptors, metric_samples
- def fetch_prometheus_timeseries(
- prom_addreses: List[str],
- result: PrometheusTimeseries,
- ) -> PrometheusTimeseries:
- components_dict, metric_descriptors, metric_samples = fetch_prometheus(
- prom_addreses
- )
- for address, components in components_dict.items():
- if address not in result.components_dict:
- result.components_dict[address] = set()
- result.components_dict[address].update(components)
- result.metric_descriptors.update(metric_descriptors)
- for sample in metric_samples:
- # udpate sample to the latest value
- result.metric_samples[
- frozenset(list(sample.labels.items()) + [("_metric_name_", sample.name)])
- ] = sample
- return result
- def fetch_prometheus_metrics(prom_addresses: List[str]) -> Dict[str, List[Any]]:
- """Return prometheus metrics from the given addresses.
- Args:
- prom_addresses: List of metrics_agent addresses to collect metrics from.
- Returns:
- Dict mapping from metric name to list of samples for the metric.
- """
- _, _, samples = fetch_prometheus(prom_addresses)
- samples_by_name = defaultdict(list)
- for sample in samples:
- samples_by_name[sample.name].append(sample)
- return samples_by_name
- def fetch_prometheus_metric_timeseries(
- prom_addresses: List[str], result: PrometheusTimeseries
- ) -> Dict[str, List[Any]]:
- samples = fetch_prometheus_timeseries(
- prom_addresses, result
- ).metric_samples.values()
- samples_by_name = defaultdict(list)
- for sample in samples:
- samples_by_name[sample.name].append(sample)
- return samples_by_name
- def raw_metric_timeseries(
- info: RayContext, result: PrometheusTimeseries
- ) -> Dict[str, List[Any]]:
- """Return prometheus timeseries from a RayContext"""
- metrics_page = "localhost:{}".format(info.address_info["metrics_export_port"])
- print("Fetch metrics from", metrics_page)
- return fetch_prometheus_metric_timeseries([metrics_page], result)
- def get_system_metric_for_component(
- system_metric: str, component: str, prometheus_server_address: str
- ) -> List[float]:
- """Get the system metric for a given component from a Prometheus server address.
- Please note:
- - This function requires the availability of the Prometheus server. Therefore, it
- requires the server address.
- - It assumes the system metric has a `Component` label and `pid` label. `pid` is the
- process id, so it can be used to uniquely identify the process.
- """
- session_name = os.path.basename(
- ray._private.worker._global_node.get_session_dir_path()
- )
- query = f"sum({system_metric}{{Component='{component}',SessionName='{session_name}'}}) by (pid)"
- resp = requests.get(
- f"{prometheus_server_address}/api/v1/query?query={quote(query)}"
- )
- if resp.status_code != 200:
- raise Exception(f"Failed to query Prometheus: {resp.status_code}")
- result = resp.json()
- return [float(item["value"][1]) for item in result["data"]["result"]]
- def get_test_config_path(config_file_name):
- """Resolve the test config path from the config file dir"""
- here = os.path.realpath(__file__)
- path = pathlib.Path(here)
- grandparent = path.parent.parent
- return os.path.join(grandparent, "tests/test_cli_patterns", config_file_name)
- def load_test_config(config_file_name):
- """Loads a config yaml from tests/test_cli_patterns."""
- config_path = get_test_config_path(config_file_name)
- config = yaml.safe_load(open(config_path).read())
- return config
- def set_setup_func():
- import ray._private.runtime_env as runtime_env
- runtime_env.VAR = "hello world"
- class BatchQueue(Queue):
- def __init__(self, maxsize: int = 0, actor_options: Optional[Dict] = None) -> None:
- actor_options = actor_options or {}
- self.maxsize = maxsize
- self.actor = (
- ray.remote(_BatchQueueActor).options(**actor_options).remote(self.maxsize)
- )
- def get_batch(
- self,
- batch_size: int = None,
- total_timeout: Optional[float] = None,
- first_timeout: Optional[float] = None,
- ) -> List[Any]:
- """Gets batch of items from the queue and returns them in a
- list in order.
- Raises:
- Empty: if the queue does not contain the desired number of items
- """
- return ray.get(
- self.actor.get_batch.remote(batch_size, total_timeout, first_timeout)
- )
- class _BatchQueueActor(_QueueActor):
- async def get_batch(self, batch_size=None, total_timeout=None, first_timeout=None):
- start = timeit.default_timer()
- try:
- first = await asyncio.wait_for(self.queue.get(), first_timeout)
- batch = [first]
- if total_timeout:
- end = timeit.default_timer()
- total_timeout = max(total_timeout - (end - start), 0)
- except asyncio.TimeoutError:
- raise Empty
- if batch_size is None:
- if total_timeout is None:
- total_timeout = 0
- while True:
- try:
- start = timeit.default_timer()
- batch.append(
- await asyncio.wait_for(self.queue.get(), total_timeout)
- )
- if total_timeout:
- end = timeit.default_timer()
- total_timeout = max(total_timeout - (end - start), 0)
- except asyncio.TimeoutError:
- break
- else:
- for _ in range(batch_size - 1):
- try:
- start = timeit.default_timer()
- batch.append(
- await asyncio.wait_for(self.queue.get(), total_timeout)
- )
- if total_timeout:
- end = timeit.default_timer()
- total_timeout = max(total_timeout - (end - start), 0)
- except asyncio.TimeoutError:
- break
- return batch
- def is_placement_group_removed(pg):
- table = ray.util.placement_group_table(pg)
- if "state" not in table:
- return False
- return table["state"] == "REMOVED"
- def placement_group_assert_no_leak(pgs_created):
- for pg in pgs_created:
- ray.util.remove_placement_group(pg)
- def wait_for_pg_removed():
- for pg_entry in ray.util.placement_group_table().values():
- if pg_entry["state"] != "REMOVED":
- return False
- return True
- wait_for_condition(wait_for_pg_removed)
- cluster_resources = ray.cluster_resources()
- cluster_resources.pop("memory")
- cluster_resources.pop("object_store_memory")
- def wait_for_resource_recovered():
- for resource, val in ray.available_resources().items():
- if resource in cluster_resources and cluster_resources[resource] != val:
- return False
- if "_group_" in resource:
- return False
- return True
- wait_for_condition(wait_for_resource_recovered)
- def monitor_memory_usage(
- print_interval_s: int = 30,
- record_interval_s: int = 5,
- warning_threshold: float = 0.9,
- ):
- """Run the memory monitor actor that prints the memory usage.
- The monitor will run on the same node as this function is called.
- Params:
- interval_s: The interval memory usage information is printed
- warning_threshold: The threshold where the
- memory usage warning is printed.
- Returns:
- The memory monitor actor.
- """
- assert ray.is_initialized(), "The API is only available when Ray is initialized."
- @ray.remote(num_cpus=0)
- class MemoryMonitorActor:
- def __init__(
- self,
- print_interval_s: float = 20,
- record_interval_s: float = 5,
- warning_threshold: float = 0.9,
- n: int = 10,
- ):
- """The actor that monitor the memory usage of the cluster.
- Params:
- print_interval_s: The interval where
- memory usage is printed.
- record_interval_s: The interval where
- memory usage is recorded.
- warning_threshold: The threshold where
- memory warning is printed
- n: When memory usage is printed,
- top n entries are printed.
- """
- # -- Interval the monitor prints the memory usage information. --
- self.print_interval_s = print_interval_s
- # -- Interval the monitor records the memory usage information. --
- self.record_interval_s = record_interval_s
- # -- Whether or not the monitor is running. --
- self.is_running = False
- # -- The used_gb/total_gb threshold where warning message omits. --
- self.warning_threshold = warning_threshold
- # -- The monitor that calculates the memory usage of the node. --
- self.monitor = memory_monitor.MemoryMonitor()
- # -- The top n memory usage of processes are printed. --
- self.n = n
- # -- The peak memory usage in GB during lifetime of monitor. --
- self.peak_memory_usage = 0
- # -- The top n memory usage of processes
- # during peak memory usage. --
- self.peak_top_n_memory_usage = ""
- # -- The last time memory usage was printed --
- self._last_print_time = 0
- # -- logger. --
- logging.basicConfig(level=logging.INFO)
- def ready(self):
- pass
- async def run(self):
- """Run the monitor."""
- self.is_running = True
- while self.is_running:
- now = time.time()
- used_gb, total_gb = self.monitor.get_memory_usage()
- top_n_memory_usage = memory_monitor.get_top_n_memory_usage(n=self.n)
- if used_gb > self.peak_memory_usage:
- self.peak_memory_usage = used_gb
- self.peak_top_n_memory_usage = top_n_memory_usage
- if used_gb > total_gb * self.warning_threshold:
- logging.warning(
- "The memory usage is high: " f"{used_gb / total_gb * 100}%"
- )
- if now - self._last_print_time > self.print_interval_s:
- logging.info(f"Memory usage: {used_gb} / {total_gb}")
- logging.info(f"Top {self.n} process memory usage:")
- logging.info(top_n_memory_usage)
- self._last_print_time = now
- await asyncio.sleep(self.record_interval_s)
- async def stop_run(self):
- """Stop running the monitor.
- Returns:
- True if the monitor is stopped. False otherwise.
- """
- was_running = self.is_running
- self.is_running = False
- return was_running
- async def get_peak_memory_info(self):
- """Return the tuple of the peak memory usage and the
- top n process information during the peak memory usage.
- """
- return self.peak_memory_usage, self.peak_top_n_memory_usage
- current_node_ip = ray._private.worker.global_worker.node_ip_address
- # Schedule the actor on the current node.
- memory_monitor_actor = MemoryMonitorActor.options(
- resources={f"node:{current_node_ip}": 0.001}
- ).remote(
- print_interval_s=print_interval_s,
- record_interval_s=record_interval_s,
- warning_threshold=warning_threshold,
- )
- print("Waiting for memory monitor actor to be ready...")
- ray.get(memory_monitor_actor.ready.remote())
- print("Memory monitor actor is ready now.")
- memory_monitor_actor.run.remote()
- return memory_monitor_actor
- def setup_tls():
- """Sets up required environment variables for tls"""
- import pytest
- if sys.platform == "darwin":
- pytest.skip("Cryptography doesn't install in Mac build pipeline")
- cert, key = generate_self_signed_tls_certs()
- temp_dir = tempfile.mkdtemp("ray-test-certs")
- cert_filepath = os.path.join(temp_dir, "server.crt")
- key_filepath = os.path.join(temp_dir, "server.key")
- with open(cert_filepath, "w") as fh:
- fh.write(cert)
- with open(key_filepath, "w") as fh:
- fh.write(key)
- os.environ["RAY_USE_TLS"] = "1"
- os.environ["RAY_TLS_SERVER_CERT"] = cert_filepath
- os.environ["RAY_TLS_SERVER_KEY"] = key_filepath
- os.environ["RAY_TLS_CA_CERT"] = cert_filepath
- return key_filepath, cert_filepath, temp_dir
- def teardown_tls(key_filepath, cert_filepath, temp_dir):
- os.remove(key_filepath)
- os.remove(cert_filepath)
- os.removedirs(temp_dir)
- del os.environ["RAY_USE_TLS"]
- del os.environ["RAY_TLS_SERVER_CERT"]
- del os.environ["RAY_TLS_SERVER_KEY"]
- del os.environ["RAY_TLS_CA_CERT"]
- class ResourceKillerActor:
- """Abstract base class used to implement resource killers for chaos testing.
- Subclasses should implement _find_resource_to_kill, which should find a resource
- to kill. This method should return the args to _kill_resource, which is another
- abstract method that should kill the resource and add it to the `killed` set.
- """
- def __init__(
- self,
- head_node_id,
- kill_interval_s: float = 60,
- kill_delay_s: float = 0,
- max_to_kill: Optional[int] = 2,
- batch_size_to_kill: int = 1,
- kill_filter_fn: Optional[Callable] = None,
- ):
- self.kill_interval_s = kill_interval_s
- self.kill_delay_s = kill_delay_s
- self.is_running = False
- self.head_node_id = head_node_id
- self.killed = set()
- self.done = get_or_create_event_loop().create_future()
- self.max_to_kill = max_to_kill
- self.batch_size_to_kill = batch_size_to_kill
- self.kill_filter_fn = kill_filter_fn
- self.kill_immediately_after_found = False
- # -- logger. --
- logging.basicConfig(level=logging.INFO)
- def ready(self):
- pass
- async def run(self):
- self.is_running = True
- time.sleep(self.kill_delay_s)
- while self.is_running:
- to_kills = await self._find_resources_to_kill()
- if not self.is_running:
- break
- if self.kill_immediately_after_found:
- sleep_interval = 0
- else:
- sleep_interval = random.random() * self.kill_interval_s
- time.sleep(sleep_interval)
- for to_kill in to_kills:
- self._kill_resource(*to_kill)
- if self.max_to_kill is not None and len(self.killed) >= self.max_to_kill:
- break
- await asyncio.sleep(self.kill_interval_s - sleep_interval)
- self.done.set_result(True)
- await self.stop_run()
- async def _find_resources_to_kill(self):
- raise NotImplementedError
- def _kill_resource(self, *args):
- raise NotImplementedError
- async def stop_run(self):
- was_running = self.is_running
- if was_running:
- self._cleanup()
- self.is_running = False
- return was_running
- async def get_total_killed(self):
- """Get the total number of killed resources"""
- await self.done
- return self.killed
- def _cleanup(self):
- """Cleanup any resources created by the killer.
- Overriding this method is optional.
- """
- pass
- class NodeKillerBase(ResourceKillerActor):
- async def _find_resources_to_kill(self):
- nodes_to_kill = []
- while not nodes_to_kill and self.is_running:
- worker_nodes = [
- node
- for node in ray.nodes()
- if node["Alive"]
- and (node["NodeID"] != self.head_node_id)
- and (node["NodeID"] not in self.killed)
- ]
- if self.kill_filter_fn:
- candidates = list(filter(self.kill_filter_fn(), worker_nodes))
- else:
- candidates = worker_nodes
- # Ensure at least one worker node remains alive.
- if len(worker_nodes) < self.batch_size_to_kill + 1:
- # Give the cluster some time to start.
- await asyncio.sleep(1)
- continue
- # Collect nodes to kill, limited by batch size.
- for candidate in candidates[: self.batch_size_to_kill]:
- nodes_to_kill.append(
- (
- candidate["NodeID"],
- candidate["NodeManagerAddress"],
- candidate["NodeManagerPort"],
- )
- )
- return nodes_to_kill
- @ray.remote(num_cpus=0)
- class RayletKiller(NodeKillerBase):
- def _kill_resource(self, node_id, node_to_kill_ip, node_to_kill_port):
- if node_to_kill_port is not None:
- try:
- self._kill_raylet(node_to_kill_ip, node_to_kill_port, graceful=False)
- except Exception:
- pass
- logging.info(
- f"Killed node {node_id} at address: "
- f"{node_to_kill_ip}, port: {node_to_kill_port}"
- )
- self.killed.add(node_id)
- def _kill_raylet(self, ip, port, graceful=False):
- import grpc
- from grpc._channel import _InactiveRpcError
- from ray.core.generated import node_manager_pb2_grpc
- raylet_address = build_address(ip, port)
- channel = grpc.insecure_channel(raylet_address)
- stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel)
- try:
- stub.ShutdownRaylet(
- node_manager_pb2.ShutdownRayletRequest(graceful=graceful)
- )
- except _InactiveRpcError:
- assert not graceful
- @ray.remote(num_cpus=0)
- class EC2InstanceTerminator(NodeKillerBase):
- def _kill_resource(self, node_id, node_to_kill_ip, _):
- if node_to_kill_ip is not None:
- try:
- _terminate_ec2_instance(node_to_kill_ip)
- except Exception:
- pass
- logging.info(f"Terminated instance, {node_id=}, address={node_to_kill_ip}")
- self.killed.add(node_id)
- @ray.remote(num_cpus=0)
- class EC2InstanceTerminatorWithGracePeriod(NodeKillerBase):
- def __init__(self, *args, grace_period_s: int = 30, **kwargs):
- super().__init__(*args, **kwargs)
- self._grace_period_s = grace_period_s
- self._kill_threads: Set[threading.Thread] = set()
- def _kill_resource(self, node_id, node_to_kill_ip, _):
- assert node_id not in self.killed
- # Clean up any completed threads.
- for thread in self._kill_threads.copy():
- if not thread.is_alive():
- thread.join()
- self._kill_threads.remove(thread)
- def _kill_node_with_grace_period(node_id, node_to_kill_ip):
- self._drain_node(node_id)
- time.sleep(self._grace_period_s)
- # Anyscale extends the drain deadline if you shut down the instance
- # directly. To work around this, we force-stop Ray on the node. Anyscale
- # should then terminate it shortly after without updating the drain
- # deadline.
- _execute_command_on_node("ray stop --force", node_to_kill_ip)
- logger.info(f"Starting killing thread {node_id=}, {node_to_kill_ip=}")
- thread = threading.Thread(
- target=_kill_node_with_grace_period,
- args=(node_id, node_to_kill_ip),
- daemon=True,
- )
- thread.start()
- self._kill_threads.add(thread)
- self.killed.add(node_id)
- def _drain_node(self, node_id: str) -> None:
- # We need to lazily import this object. Otherwise, Ray can't serialize the
- # class.
- from ray.core.generated import autoscaler_pb2
- assert ray.NodeID.from_hex(node_id) != ray.NodeID.nil()
- logging.info(f"Draining node {node_id=}")
- address = services.canonicalize_bootstrap_address_or_die(addr="auto")
- gcs_client = ray._raylet.GcsClient(address=address)
- deadline_timestamp_ms = (time.time_ns() // 1e6) + (self._grace_period_s * 1e3)
- try:
- is_accepted, _ = gcs_client.drain_node(
- node_id,
- autoscaler_pb2.DrainNodeReason.Value("DRAIN_NODE_REASON_PREEMPTION"),
- "",
- deadline_timestamp_ms,
- )
- except ray.exceptions.RayError as e:
- logger.error(f"Failed to drain node {node_id=}")
- raise e
- assert is_accepted, "Drain node request was rejected"
- def _cleanup(self):
- for thread in self._kill_threads.copy():
- thread.join()
- self._kill_threads.remove(thread)
- assert not self._kill_threads
- @ray.remote(num_cpus=0)
- class WorkerKillerActor(ResourceKillerActor):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- # Kill worker immediately so that the task does
- # not finish successfully on its own.
- self.kill_immediately_after_found = True
- from ray.util.state.api import StateApiClient
- from ray.util.state.common import ListApiOptions
- self.client = StateApiClient()
- self.task_options = ListApiOptions(
- filters=[
- ("state", "=", "RUNNING"),
- ("name", "!=", "WorkerKillActor.run"),
- ]
- )
- async def _find_resources_to_kill(self):
- from ray.util.state.common import StateResource
- process_to_kill_task_id = None
- process_to_kill_pid = None
- process_to_kill_node_id = None
- while process_to_kill_pid is None and self.is_running:
- tasks = self.client.list(
- StateResource.TASKS,
- options=self.task_options,
- raise_on_missing_output=False,
- )
- if self.kill_filter_fn is not None:
- tasks = list(filter(self.kill_filter_fn(), tasks))
- for task in tasks:
- if task.worker_id is not None and task.node_id is not None:
- process_to_kill_task_id = task.task_id
- process_to_kill_pid = task.worker_pid
- process_to_kill_node_id = task.node_id
- break
- # Give the cluster some time to start.
- await asyncio.sleep(0.1)
- return [(process_to_kill_task_id, process_to_kill_pid, process_to_kill_node_id)]
- def _kill_resource(
- self, process_to_kill_task_id, process_to_kill_pid, process_to_kill_node_id
- ):
- if process_to_kill_pid is not None:
- @ray.remote
- def kill_process(pid):
- import psutil
- proc = psutil.Process(pid)
- proc.kill()
- scheduling_strategy = (
- ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
- node_id=process_to_kill_node_id,
- soft=False,
- )
- )
- kill_process.options(scheduling_strategy=scheduling_strategy).remote(
- process_to_kill_pid
- )
- logging.info(
- f"Killing pid {process_to_kill_pid} on node {process_to_kill_node_id}"
- )
- # Store both task_id and pid because retried tasks have same task_id.
- self.killed.add((process_to_kill_task_id, process_to_kill_pid))
- def get_and_run_resource_killer(
- resource_killer_cls,
- kill_interval_s,
- namespace=None,
- lifetime=None,
- no_start=False,
- max_to_kill=2,
- batch_size_to_kill=1,
- kill_delay_s=0,
- kill_filter_fn=None,
- ):
- assert ray.is_initialized(), "The API is only available when Ray is initialized."
- head_node_id = ray.get_runtime_context().get_node_id()
- # Schedule the actor on the current node.
- resource_killer = resource_killer_cls.options(
- scheduling_strategy=NodeAffinitySchedulingStrategy(
- node_id=head_node_id, soft=False
- ),
- namespace=namespace,
- name="ResourceKiller",
- lifetime=lifetime,
- ).remote(
- head_node_id,
- kill_interval_s=kill_interval_s,
- kill_delay_s=kill_delay_s,
- max_to_kill=max_to_kill,
- batch_size_to_kill=batch_size_to_kill,
- kill_filter_fn=kill_filter_fn,
- )
- print("Waiting for ResourceKiller to be ready...")
- ray.get(resource_killer.ready.remote())
- print("ResourceKiller is ready now.")
- if not no_start:
- resource_killer.run.remote()
- return resource_killer
- def get_actor_node_id(actor_handle: "ray.actor.ActorHandle") -> str:
- return ray.get(
- actor_handle.__ray_call__.remote(
- lambda self: ray.get_runtime_context().get_node_id()
- )
- )
- @contextmanager
- def chdir(d: str):
- old_dir = os.getcwd()
- os.chdir(d)
- try:
- yield
- finally:
- os.chdir(old_dir)
- def test_get_directory_size_bytes():
- with tempfile.TemporaryDirectory() as tmp_dir, chdir(tmp_dir):
- assert ray._private.utils.get_directory_size_bytes(tmp_dir) == 0
- with open("test_file", "wb") as f:
- f.write(os.urandom(100))
- assert ray._private.utils.get_directory_size_bytes(tmp_dir) == 100
- with open("test_file_2", "wb") as f:
- f.write(os.urandom(50))
- assert ray._private.utils.get_directory_size_bytes(tmp_dir) == 150
- os.mkdir("subdir")
- with open("subdir/subdir_file", "wb") as f:
- f.write(os.urandom(2))
- assert ray._private.utils.get_directory_size_bytes(tmp_dir) == 152
- def check_local_files_gced(cluster):
- for node in cluster.list_all_nodes():
- for subdir in ["conda", "pip", "working_dir_files", "py_modules_files"]:
- all_files = os.listdir(
- os.path.join(node.get_runtime_env_dir_path(), subdir)
- )
- # Check that there are no files remaining except for .lock files
- # and generated requirements.txt files.
- # Note: On Windows the top folder is not deleted as it is in use.
- # TODO(architkulkarni): these files should get cleaned up too!
- items = list(filter(lambda f: not f.endswith((".lock", ".txt")), all_files))
- if len(items) > 0:
- print(f"runtime_env files not GC'd from subdir '{subdir}': {items}")
- return False
- return True
- def generate_runtime_env_dict(field, spec_format, tmp_path, pip_list=None):
- if pip_list is None:
- pip_list = ["pip-install-test==0.5"]
- if field == "conda":
- conda_dict = {"dependencies": ["pip", {"pip": pip_list}]}
- if spec_format == "file":
- conda_file = tmp_path / f"environment-{hash(str(pip_list))}.yml"
- conda_file.write_text(yaml.dump(conda_dict))
- conda = str(conda_file)
- elif spec_format == "python_object":
- conda = conda_dict
- runtime_env = {"conda": conda}
- elif field == "pip":
- if spec_format == "file":
- pip_file = tmp_path / f"requirements-{hash(str(pip_list))}.txt"
- pip_file.write_text("\n".join(pip_list))
- pip = str(pip_file)
- elif spec_format == "python_object":
- pip = pip_list
- runtime_env = {"pip": pip}
- return runtime_env
- def check_spilled_mb(address, spilled=None, restored=None, fallback=None):
- def ok():
- s = memory_summary(address=address["address"], stats_only=True)
- print(s)
- if restored:
- if "Restored {} MiB".format(restored) not in s:
- return False
- else:
- if "Restored" in s:
- return False
- if spilled:
- if not isinstance(spilled, list):
- spilled_lst = [spilled]
- else:
- spilled_lst = spilled
- found = False
- for n in spilled_lst:
- if "Spilled {} MiB".format(n) in s:
- found = True
- if not found:
- return False
- else:
- if "Spilled" in s:
- return False
- if fallback:
- if "Plasma filesystem mmap usage: {} MiB".format(fallback) not in s:
- return False
- else:
- if "Plasma filesystem mmap usage:" in s:
- return False
- return True
- wait_for_condition(ok, timeout=3, retry_interval_ms=1000)
- def no_resource_leaks_excluding_node_resources():
- cluster_resources = ray.cluster_resources()
- available_resources = ray.available_resources()
- for r in ray.cluster_resources():
- if "node" in r:
- del cluster_resources[r]
- del available_resources[r]
- return cluster_resources == available_resources
- def job_hook(**kwargs):
- """Function called by reflection by test_cli_integration."""
- cmd = " ".join(kwargs["entrypoint"])
- print(f"hook intercepted: {cmd}")
- sys.exit(0)
- def wandb_setup_api_key_hook():
- """
- Example external hook to set up W&B API key in
- WandbIntegrationTest.testWandbLoggerConfig
- """
- return "abcd"
- # Get node stats from node manager.
- def get_node_stats(raylet, num_retry=5, timeout=2):
- import grpc
- from ray._private.grpc_utils import init_grpc_channel
- from ray.core.generated import node_manager_pb2_grpc
- raylet_address = build_address(
- raylet["NodeManagerAddress"], raylet["NodeManagerPort"]
- )
- channel = init_grpc_channel(raylet_address)
- stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel)
- for _ in range(num_retry):
- try:
- reply = stub.GetNodeStats(
- node_manager_pb2.GetNodeStatsRequest(), timeout=timeout
- )
- break
- except grpc.RpcError:
- continue
- assert reply is not None
- return reply
- # Gets resource usage assuming gcs is local.
- def get_resource_usage(gcs_address, timeout=10):
- from ray._private.grpc_utils import init_grpc_channel
- from ray.core.generated import gcs_service_pb2_grpc
- if not gcs_address:
- gcs_address = ray.worker._global_node.gcs_address
- gcs_channel = init_grpc_channel(
- gcs_address, ray_constants.GLOBAL_GRPC_OPTIONS, asynchronous=False
- )
- gcs_node_resources_stub = gcs_service_pb2_grpc.NodeResourceInfoGcsServiceStub(
- gcs_channel
- )
- request = gcs_service_pb2.GetAllResourceUsageRequest()
- response = gcs_node_resources_stub.GetAllResourceUsage(request, timeout=timeout)
- resources_batch_data = response.resource_usage_data
- return resources_batch_data
- # Gets the load metrics report assuming gcs is local.
- def get_load_metrics_report(webui_url):
- webui_url = format_web_url(webui_url)
- response = requests.get(f"{webui_url}/api/cluster_status")
- response.raise_for_status()
- return response.json()["data"]["clusterStatus"]["loadMetricsReport"]
- # Send a RPC to the raylet to have it self-destruct its process.
- def kill_raylet(raylet, graceful=False):
- import grpc
- from grpc._channel import _InactiveRpcError
- from ray.core.generated import node_manager_pb2_grpc
- raylet_address = build_address(
- raylet["NodeManagerAddress"], raylet["NodeManagerPort"]
- )
- channel = grpc.insecure_channel(raylet_address)
- stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel)
- try:
- stub.ShutdownRaylet(node_manager_pb2.ShutdownRayletRequest(graceful=graceful))
- except _InactiveRpcError:
- assert not graceful
- def get_gcs_memory_used():
- import psutil
- m = {
- proc.info["name"]: proc.info["memory_info"].rss
- for proc in psutil.process_iter(["status", "name", "memory_info"])
- if (
- proc.info["status"] not in (psutil.STATUS_ZOMBIE, psutil.STATUS_DEAD)
- and proc.info["name"] in ("gcs_server", "redis-server")
- )
- }
- assert "gcs_server" in m
- return sum(m.values())
- def safe_write_to_results_json(
- result: dict,
- default_file_name: str = "/tmp/release_test_output.json",
- env_var: Optional[str] = "TEST_OUTPUT_JSON",
- ):
- """
- Safe (atomic) write to file to guard against malforming the json
- if the job gets interrupted in the middle of writing.
- """
- test_output_json = os.environ.get(env_var, default_file_name)
- test_output_json_tmp = f"{test_output_json}.tmp.{str(uuid.uuid4())}"
- with open(test_output_json_tmp, "wt") as f:
- json.dump(result, f)
- f.flush()
- os.replace(test_output_json_tmp, test_output_json)
- logger.info(f"Wrote results to {test_output_json}")
- logger.info(json.dumps(result))
- def get_current_unused_port():
- """
- Returns a port number that is not currently in use.
- This is useful for testing when we need to bind to a port but don't
- care which one.
- Returns:
- A port number that is not currently in use. (Note that this port
- might become used by the time you try to bind to it.)
- """
- sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- # Bind the socket to a local address with a random port number
- sock.bind(("localhost", 0))
- port = sock.getsockname()[1]
- sock.close()
- return port
- # Global counter to test different return values
- # for external_ray_cluster_activity_hook1.
- ray_cluster_activity_hook_counter = 0
- ray_cluster_activity_hook_5_counter = 0
- def external_ray_cluster_activity_hook1():
- """
- Example external hook for test_component_activities_hook.
- Returns valid response and increments counter in `reason`
- field on each call.
- """
- global ray_cluster_activity_hook_counter
- ray_cluster_activity_hook_counter += 1
- from pydantic import BaseModel, Extra
- class TestRayActivityResponse(BaseModel, extra=Extra.allow):
- """
- Redefinition of dashboard.modules.api.api_head.RayActivityResponse
- used in test_component_activities_hook to mimic typical
- usage of redefining or extending response type.
- """
- is_active: str
- reason: Optional[str] = None
- timestamp: float
- return {
- "test_component1": TestRayActivityResponse(
- is_active="ACTIVE",
- reason=f"Counter: {ray_cluster_activity_hook_counter}",
- timestamp=datetime.now().timestamp(),
- )
- }
- def external_ray_cluster_activity_hook2():
- """
- Example external hook for test_component_activities_hook.
- Returns invalid output because the value of `test_component2`
- should be of type RayActivityResponse.
- """
- return {"test_component2": "bad_output"}
- def external_ray_cluster_activity_hook3():
- """
- Example external hook for test_component_activities_hook.
- Returns invalid output because return type is not
- Dict[str, RayActivityResponse]
- """
- return "bad_output"
- def external_ray_cluster_activity_hook4():
- """
- Example external hook for test_component_activities_hook.
- Errors during execution.
- """
- raise Exception("Error in external cluster activity hook")
- def external_ray_cluster_activity_hook5():
- """
- Example external hook for test_component_activities_hook.
- Returns valid response and increments counter in `reason`
- field on each call.
- """
- global ray_cluster_activity_hook_5_counter
- ray_cluster_activity_hook_5_counter += 1
- return {
- "test_component5": {
- "is_active": "ACTIVE",
- "reason": f"Counter: {ray_cluster_activity_hook_5_counter}",
- "timestamp": datetime.now().timestamp(),
- }
- }
- # TODO(rickyx): We could remove this once we unify the autoscaler v1 and v2
- # code path for ray status
- def reset_autoscaler_v2_enabled_cache():
- import ray.autoscaler.v2.utils as u
- u.cached_is_autoscaler_v2 = None
- def _terminate_ec2_instance(node_ip: str) -> None:
- logging.info(f"Terminating instance {node_ip}")
- # This command uses IMDSv2 to get the host instance id and region.
- # After that it terminates itself using aws cli.
- command = (
- 'instanceId=$(curl -H "X-aws-ec2-metadata-token: $TOKEN" http://169.254.169.254/latest/meta-data/instance-id/);' # noqa: E501
- 'region=$(curl -H "X-aws-ec2-metadata-token: $TOKEN" http://169.254.169.254/latest/meta-data/placement/region);' # noqa: E501
- "aws ec2 terminate-instances --region $region --instance-ids $instanceId" # noqa: E501
- )
- _execute_command_on_node(command, node_ip)
- def _execute_command_on_node(command: str, node_ip: str):
- logging.debug(f"Executing command on node {node_ip}: {command}")
- multi_line_command = (
- 'TOKEN=$(curl -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 21600");' # noqa: E501
- f"{command}"
- )
- # This is a feature on Anyscale platform that enables
- # easy ssh access to worker nodes.
- ssh_command = f"ssh -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -p 2222 ray@{node_ip} '{multi_line_command}'" # noqa: E501
- try:
- subprocess.run(
- ssh_command, shell=True, capture_output=True, text=True, check=True
- )
- except subprocess.CalledProcessError as e:
- print("Exit code:", e.returncode)
- print("Stderr:", e.stderr)
- RPC_FAILURE_MAP = {
- "request": {
- "req_failure_prob": 100,
- "resp_failure_prob": 0,
- "in_flight_failure_prob": 0,
- },
- "response": {
- "req_failure_prob": 0,
- "resp_failure_prob": 100,
- "in_flight_failure_prob": 0,
- },
- "in_flight": {
- "req_failure_prob": 0,
- "resp_failure_prob": 0,
- "in_flight_failure_prob": 100,
- },
- }
- RPC_FAILURE_TYPES = list(RPC_FAILURE_MAP.keys())
|