test_utils.py 71 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183
  1. import asyncio
  2. import fnmatch
  3. import io
  4. import json
  5. import logging
  6. import os
  7. import pathlib
  8. import random
  9. import socket
  10. import subprocess
  11. import sys
  12. import tempfile
  13. import threading
  14. import time
  15. import timeit
  16. import traceback
  17. import uuid
  18. from collections import defaultdict
  19. from contextlib import contextmanager, redirect_stderr, redirect_stdout
  20. from dataclasses import dataclass, field
  21. from datetime import datetime
  22. from typing import Any, Callable, Dict, List, Optional, Set, Tuple
  23. from urllib.parse import quote, urlparse
  24. import requests
  25. import yaml
  26. import ray
  27. import ray._private.memory_monitor as memory_monitor
  28. import ray._private.services
  29. import ray._private.services as services
  30. import ray._private.utils
  31. import ray.dashboard.consts as dashboard_consts
  32. from ray._common.network_utils import build_address, parse_address
  33. from ray._common.test_utils import wait_for_condition
  34. from ray._common.utils import get_or_create_event_loop
  35. from ray._private import (
  36. ray_constants,
  37. )
  38. from ray._private.internal_api import memory_summary
  39. from ray._private.services import ProcessInfo
  40. from ray._private.tls_utils import generate_self_signed_tls_certs
  41. from ray._private.worker import RayContext
  42. from ray._raylet import Config, GcsClient, GcsClientOptions, GlobalStateAccessor
  43. from ray.core.generated import (
  44. gcs_pb2,
  45. gcs_service_pb2,
  46. node_manager_pb2,
  47. )
  48. from ray.util.queue import Empty, Queue, _QueueActor
  49. from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
  50. from ray.util.state import get_actor, list_actors
  51. import psutil # We must import psutil after ray because we bundle it with ray.
  52. logger = logging.getLogger(__name__)
  53. EXE_SUFFIX = ".exe" if sys.platform == "win32" else ""
  54. RAY_PATH = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
  55. REDIS_EXECUTABLE = os.path.join(
  56. RAY_PATH, "core/src/ray/thirdparty/redis/src/redis-server" + EXE_SUFFIX
  57. )
  58. try:
  59. from prometheus_client.core import Metric
  60. from prometheus_client.parser import Sample, text_string_to_metric_families
  61. except (ImportError, ModuleNotFoundError):
  62. Metric = None
  63. Sample = None
  64. def text_string_to_metric_families(*args, **kwargs):
  65. raise ModuleNotFoundError("`prometheus_client` not found")
  66. def make_global_state_accessor(ray_context):
  67. gcs_options = GcsClientOptions.create(
  68. ray_context.address_info["gcs_address"],
  69. None,
  70. allow_cluster_id_nil=True,
  71. fetch_cluster_id_if_nil=False,
  72. )
  73. global_state_accessor = GlobalStateAccessor(gcs_options)
  74. global_state_accessor.connect()
  75. return global_state_accessor
  76. def external_redis_test_enabled():
  77. return os.environ.get("TEST_EXTERNAL_REDIS") == "1"
  78. def redis_replicas():
  79. return int(os.environ.get("TEST_EXTERNAL_REDIS_REPLICAS", "1"))
  80. def redis_sentinel_replicas():
  81. return int(os.environ.get("TEST_EXTERNAL_REDIS_SENTINEL_REPLICAS", "2"))
  82. def get_redis_cli(port, enable_tls):
  83. try:
  84. # If there is no redis libs installed, skip the check.
  85. # This could happen In minimal test, where we don't have
  86. # redis.
  87. import redis
  88. except Exception:
  89. return True
  90. params = {}
  91. if enable_tls:
  92. from ray._raylet import Config
  93. params = {"ssl": True, "ssl_cert_reqs": "required"}
  94. if Config.REDIS_CA_CERT():
  95. params["ssl_ca_certs"] = Config.REDIS_CA_CERT()
  96. if Config.REDIS_CLIENT_CERT():
  97. params["ssl_certfile"] = Config.REDIS_CLIENT_CERT()
  98. if Config.REDIS_CLIENT_KEY():
  99. params["ssl_keyfile"] = Config.REDIS_CLIENT_KEY()
  100. return redis.Redis("localhost", str(port), **params)
  101. def start_redis_sentinel_instance(
  102. session_dir_path: str,
  103. port: int,
  104. redis_master_port: int,
  105. password: Optional[str] = None,
  106. enable_tls: bool = False,
  107. db_dir=None,
  108. free_port=0,
  109. ):
  110. config_file = os.path.join(
  111. session_dir_path, "redis-sentinel-" + uuid.uuid4().hex + ".conf"
  112. )
  113. config_lines = []
  114. # Port for this Sentinel instance
  115. if enable_tls:
  116. config_lines.append(f"port {free_port}")
  117. else:
  118. config_lines.append(f"port {port}")
  119. # Monitor the Redis master
  120. config_lines.append(f"sentinel monitor redis-test 127.0.0.1 {redis_master_port} 1")
  121. config_lines.append(
  122. "sentinel down-after-milliseconds redis-test 1000"
  123. ) # failover after 1 second
  124. config_lines.append("sentinel failover-timeout redis-test 5000") #
  125. config_lines.append("sentinel parallel-syncs redis-test 1")
  126. if password:
  127. config_lines.append(f"sentinel auth-pass redis-test {password}")
  128. if enable_tls:
  129. config_lines.append(f"tls-port {port}")
  130. if Config.REDIS_CA_CERT():
  131. config_lines.append(f"tls-ca-cert-file {Config.REDIS_CA_CERT()}")
  132. # Check and add TLS client certificate file
  133. if Config.REDIS_CLIENT_CERT():
  134. config_lines.append(f"tls-cert-file {Config.REDIS_CLIENT_CERT()}")
  135. # Check and add TLS client key file
  136. if Config.REDIS_CLIENT_KEY():
  137. config_lines.append(f"tls-key-file {Config.REDIS_CLIENT_KEY()}")
  138. config_lines.append("tls-auth-clients no")
  139. config_lines.append("sentinel tls-auth-clients redis-test no")
  140. if db_dir:
  141. config_lines.append(f"dir {db_dir}")
  142. with open(config_file, "w") as f:
  143. f.write("\n".join(config_lines))
  144. command = [REDIS_EXECUTABLE, config_file, "--sentinel"]
  145. process_info = ray._private.services.start_ray_process(
  146. command,
  147. ray_constants.PROCESS_TYPE_REDIS_SERVER,
  148. fate_share=False,
  149. )
  150. return process_info
  151. def start_redis_instance(
  152. session_dir_path: str,
  153. port: int,
  154. redis_max_clients: Optional[int] = None,
  155. num_retries: int = 20,
  156. stdout_file: Optional[str] = None,
  157. stderr_file: Optional[str] = None,
  158. password: Optional[str] = None,
  159. fate_share: Optional[bool] = None,
  160. port_denylist: Optional[List[int]] = None,
  161. listen_to_localhost_only: bool = False,
  162. enable_tls: bool = False,
  163. replica_of=None,
  164. leader_id=None,
  165. db_dir=None,
  166. free_port=0,
  167. ):
  168. """Start a single Redis server.
  169. Notes:
  170. We will initially try to start the Redis instance at the given port,
  171. and then try at most `num_retries - 1` times to start the Redis
  172. instance at successive random ports.
  173. Args:
  174. session_dir_path: Path to the session directory of
  175. this Ray cluster.
  176. port: Try to start a Redis server at this port.
  177. redis_max_clients: If this is provided, Ray will attempt to configure
  178. Redis with this maxclients number.
  179. num_retries: The number of times to attempt to start Redis at
  180. successive ports.
  181. stdout_file: A file handle opened for writing to redirect stdout to. If
  182. no redirection should happen, then this should be None.
  183. stderr_file: A file handle opened for writing to redirect stderr to. If
  184. no redirection should happen, then this should be None.
  185. password: Prevents external clients without the password
  186. from connecting to Redis if provided.
  187. port_denylist: A set of denylist ports that shouldn't
  188. be used when allocating a new port.
  189. listen_to_localhost_only: Redis server only listens to
  190. localhost (127.0.0.1) if it's true,
  191. otherwise it listens to all network interfaces.
  192. enable_tls: Enable the TLS/SSL in Redis or not
  193. Returns:
  194. A tuple of the port used by Redis and ProcessInfo for the process that
  195. was started. If a port is passed in, then the returned port value
  196. is the same.
  197. Raises:
  198. Exception: An exception is raised if Redis could not be started.
  199. """
  200. assert os.path.isfile(REDIS_EXECUTABLE)
  201. # Construct the command to start the Redis server.
  202. command = [REDIS_EXECUTABLE]
  203. if password:
  204. if " " in password:
  205. raise ValueError("Spaces not permitted in redis password.")
  206. command += ["--requirepass", password]
  207. if redis_replicas() > 1:
  208. command += ["--cluster-enabled", "yes", "--cluster-config-file", f"node-{port}"]
  209. if enable_tls:
  210. command += [
  211. "--tls-port",
  212. str(port),
  213. "--loglevel",
  214. "warning",
  215. "--port",
  216. str(free_port),
  217. ]
  218. else:
  219. command += ["--port", str(port), "--loglevel", "warning"]
  220. if listen_to_localhost_only:
  221. command += ["--bind", "127.0.0.1"]
  222. pidfile = os.path.join(session_dir_path, "redis-" + uuid.uuid4().hex + ".pid")
  223. command += ["--pidfile", pidfile]
  224. if enable_tls:
  225. if Config.REDIS_CA_CERT():
  226. command += ["--tls-ca-cert-file", Config.REDIS_CA_CERT()]
  227. if Config.REDIS_CLIENT_CERT():
  228. command += ["--tls-cert-file", Config.REDIS_CLIENT_CERT()]
  229. if Config.REDIS_CLIENT_KEY():
  230. command += ["--tls-key-file", Config.REDIS_CLIENT_KEY()]
  231. if replica_of is not None:
  232. command += ["--tls-replication", "yes"]
  233. command += ["--tls-auth-clients", "no", "--tls-cluster", "yes"]
  234. if sys.platform != "win32":
  235. command += ["--save", "", "--appendonly", "no"]
  236. if db_dir is not None:
  237. command += ["--dir", str(db_dir)]
  238. process_info = ray._private.services.start_ray_process(
  239. command,
  240. ray_constants.PROCESS_TYPE_REDIS_SERVER,
  241. stdout_file=stdout_file,
  242. stderr_file=stderr_file,
  243. fate_share=fate_share,
  244. )
  245. node_id = None
  246. if redis_replicas() > 1:
  247. # Setup redis cluster
  248. import redis
  249. while True:
  250. try:
  251. redis_cli = get_redis_cli(port, enable_tls)
  252. if replica_of is None:
  253. slots = [str(i) for i in range(16384)]
  254. redis_cli.cluster("addslots", *slots)
  255. else:
  256. logger.info(redis_cli.cluster("meet", "127.0.0.1", str(replica_of)))
  257. logger.info(redis_cli.cluster("replicate", leader_id))
  258. node_id = redis_cli.cluster("myid")
  259. break
  260. except (
  261. redis.exceptions.ConnectionError,
  262. redis.exceptions.ResponseError,
  263. ) as e:
  264. from time import sleep
  265. logger.info(
  266. f"Waiting for redis to be up. Check failed with error: {e}. "
  267. "Will retry in 0.1s"
  268. )
  269. if process_info.process.poll() is not None:
  270. raise Exception(
  271. f"Redis process exited unexpectedly: {process_info}. "
  272. f"Exit code: {process_info.process.returncode}"
  273. )
  274. sleep(0.1)
  275. logger.info(
  276. f"Redis started with node_id {node_id} and pid {process_info.process.pid}"
  277. )
  278. return node_id, process_info
  279. def _pid_alive(pid):
  280. """Check if the process with this PID is alive or not.
  281. Args:
  282. pid: The pid to check.
  283. Returns:
  284. This returns false if the process is dead. Otherwise, it returns true.
  285. """
  286. alive = True
  287. try:
  288. proc = psutil.Process(pid)
  289. if proc.status() == psutil.STATUS_ZOMBIE:
  290. alive = False
  291. except psutil.NoSuchProcess:
  292. alive = False
  293. return alive
  294. def _check_call_windows(main, argv, capture_stdout=False, capture_stderr=False):
  295. # We use this function instead of calling the "ray" command to work around
  296. # some deadlocks that occur when piping ray's output on Windows
  297. stream = io.TextIOWrapper(io.BytesIO(), encoding=sys.stdout.encoding)
  298. old_argv = sys.argv[:]
  299. try:
  300. sys.argv = argv[:]
  301. try:
  302. with redirect_stderr(stream if capture_stderr else sys.stderr):
  303. with redirect_stdout(stream if capture_stdout else sys.stdout):
  304. main()
  305. finally:
  306. stream.flush()
  307. except SystemExit as ex:
  308. if ex.code:
  309. output = stream.buffer.getvalue()
  310. raise subprocess.CalledProcessError(ex.code, argv, output)
  311. except Exception as ex:
  312. output = stream.buffer.getvalue()
  313. raise subprocess.CalledProcessError(1, argv, output, ex.args[0])
  314. finally:
  315. sys.argv = old_argv
  316. if capture_stdout:
  317. sys.stdout.buffer.write(stream.buffer.getvalue())
  318. elif capture_stderr:
  319. sys.stderr.buffer.write(stream.buffer.getvalue())
  320. return stream.buffer.getvalue()
  321. def check_call_subprocess(argv, capture_stdout=False, capture_stderr=False):
  322. # We use this function instead of calling the "ray" command to work around
  323. # some deadlocks that occur when piping ray's output on Windows
  324. from ray.scripts.scripts import main as ray_main
  325. if sys.platform == "win32":
  326. result = _check_call_windows(
  327. ray_main, argv, capture_stdout=capture_stdout, capture_stderr=capture_stderr
  328. )
  329. else:
  330. stdout_redir = None
  331. stderr_redir = None
  332. if capture_stdout:
  333. stdout_redir = subprocess.PIPE
  334. if capture_stderr and capture_stdout:
  335. stderr_redir = subprocess.STDOUT
  336. elif capture_stderr:
  337. stderr_redir = subprocess.PIPE
  338. proc = subprocess.Popen(argv, stdout=stdout_redir, stderr=stderr_redir)
  339. (stdout, stderr) = proc.communicate()
  340. if proc.returncode:
  341. raise subprocess.CalledProcessError(proc.returncode, argv, stdout, stderr)
  342. result = b"".join([s for s in [stdout, stderr] if s is not None])
  343. return result
  344. def check_call_ray(args, capture_stdout=False, capture_stderr=False):
  345. check_call_subprocess(["ray"] + args, capture_stdout, capture_stderr)
  346. def get_dashboard_agent_address(gcs_client: GcsClient, node_id: str):
  347. result = gcs_client.internal_kv_get(
  348. f"{dashboard_consts.DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX}{node_id}".encode(),
  349. namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
  350. timeout=dashboard_consts.GCS_RPC_TIMEOUT_SECONDS,
  351. )
  352. if result:
  353. # Returns [ip, http_port, grpc_port]
  354. ip, _, grpc_port = json.loads(result)
  355. return f"{ip}:{grpc_port}"
  356. return None
  357. def wait_for_dashboard_agent_available(cluster):
  358. gcs_client = GcsClient(address=cluster.address)
  359. wait_for_condition(
  360. lambda: get_dashboard_agent_address(gcs_client, cluster.head_node.node_id)
  361. is not None
  362. )
  363. def wait_for_aggregator_agent(address: str, node_id: str, timeout: float = 10) -> None:
  364. """Wait for the aggregator agent to be ready by checking socket connectivity."""
  365. gcs_client = GcsClient(address=address)
  366. # Wait for the agent to publish its address
  367. wait_for_condition(
  368. lambda: get_dashboard_agent_address(gcs_client, node_id) is not None
  369. )
  370. # Get the agent address and test socket connectivity
  371. agent_address = get_dashboard_agent_address(gcs_client, node_id)
  372. parsed = urlparse(f"grpc://{agent_address}")
  373. def _can_connect() -> bool:
  374. try:
  375. with socket.create_connection((parsed.hostname, parsed.port), timeout=1):
  376. return True
  377. except OSError:
  378. return False
  379. wait_for_condition(_can_connect, timeout=timeout)
  380. def wait_for_aggregator_agent_if_enabled(
  381. address: str, node_id: str, timeout: float = 10
  382. ) -> None:
  383. """Wait for aggregator agent only if aggregator mode is enabled.
  384. Checks RAY_enable_core_worker_ray_event_to_aggregator env var.
  385. """
  386. if os.environ.get("RAY_enable_core_worker_ray_event_to_aggregator") == "1":
  387. wait_for_aggregator_agent(address, node_id, timeout)
  388. def wait_for_pid_to_exit(pid: int, timeout: float = 20):
  389. start_time = time.time()
  390. while time.time() - start_time < timeout:
  391. if not _pid_alive(pid):
  392. return
  393. time.sleep(0.1)
  394. raise TimeoutError(f"Timed out while waiting for process {pid} to exit.")
  395. def wait_for_children_of_pid(pid, num_children=1, timeout=20):
  396. p = psutil.Process(pid)
  397. start_time = time.time()
  398. alive = []
  399. while time.time() - start_time < timeout:
  400. alive = p.children(recursive=False)
  401. num_alive = len(alive)
  402. if num_alive >= num_children:
  403. return
  404. time.sleep(0.1)
  405. raise TimeoutError(
  406. f"Timed out while waiting for process {pid} children to start "
  407. f"({num_alive}/{num_children} started: {alive})."
  408. )
  409. def wait_for_children_of_pid_to_exit(pid, timeout=20):
  410. children = psutil.Process(pid).children()
  411. if len(children) == 0:
  412. return
  413. _, alive = psutil.wait_procs(children, timeout=timeout)
  414. if len(alive) > 0:
  415. raise TimeoutError(
  416. "Timed out while waiting for process children to exit."
  417. " Children still alive: {}.".format([p.name() for p in alive])
  418. )
  419. def kill_process_by_name(name, SIGKILL=False):
  420. for p in psutil.process_iter(attrs=["name"]):
  421. if p.info["name"] == name + ray._private.services.EXE_SUFFIX:
  422. if SIGKILL:
  423. p.kill()
  424. else:
  425. p.terminate()
  426. def kill_processes(process_infos: List[ProcessInfo]):
  427. """
  428. Forcefully kills the list of given processes.
  429. Ignores processes that are already dead.
  430. Args:
  431. process_infos: The list of ProcessInfo representing the processes to kill.
  432. Raises:
  433. TimeoutError: If the process did not exit within 5 seconds.
  434. """
  435. for process_info in process_infos:
  436. try:
  437. process_info.process.kill()
  438. process_info.process.wait(timeout=5)
  439. except ProcessLookupError:
  440. # Process already dead
  441. pass
  442. except subprocess.TimeoutExpired as exception:
  443. raise TimeoutError(
  444. f"Process {process_info.process.pid} did not exit within 5 seconds "
  445. "after SIGKILL"
  446. ) from exception
  447. def run_string_as_driver(driver_script: str, env: Dict = None, encode: str = "utf-8"):
  448. """Run a driver as a separate process.
  449. Args:
  450. driver_script: A string to run as a Python script.
  451. env: The environment variables for the driver.
  452. Returns:
  453. The script's output.
  454. """
  455. proc = subprocess.Popen(
  456. [sys.executable, "-"],
  457. stdin=subprocess.PIPE,
  458. stdout=subprocess.PIPE,
  459. stderr=subprocess.STDOUT,
  460. env=env,
  461. )
  462. with proc:
  463. output = proc.communicate(driver_script.encode(encoding=encode))[0]
  464. if proc.returncode:
  465. print(ray._common.utils.decode(output, encode_type=encode))
  466. logger.error(proc.stderr)
  467. raise subprocess.CalledProcessError(
  468. proc.returncode, proc.args, output, proc.stderr
  469. )
  470. out = ray._common.utils.decode(output, encode_type=encode)
  471. return out
  472. def run_string_as_driver_stdout_stderr(
  473. driver_script: str, env: Dict = None, encode: str = "utf-8"
  474. ) -> Tuple[str, str]:
  475. """Run a driver as a separate process.
  476. Args:
  477. driver_script: A string to run as a Python script.
  478. env: The environment variables for the driver.
  479. Returns:
  480. The script's stdout and stderr.
  481. """
  482. proc = subprocess.Popen(
  483. [sys.executable, "-"],
  484. stdin=subprocess.PIPE,
  485. stdout=subprocess.PIPE,
  486. stderr=subprocess.PIPE,
  487. env=env,
  488. )
  489. with proc:
  490. outputs_bytes = proc.communicate(driver_script.encode(encoding=encode))
  491. out_str, err_str = [
  492. ray._common.utils.decode(output, encode_type=encode)
  493. for output in outputs_bytes
  494. ]
  495. if proc.returncode:
  496. print(out_str)
  497. print(err_str)
  498. raise subprocess.CalledProcessError(
  499. proc.returncode, proc.args, out_str, err_str
  500. )
  501. return out_str, err_str
  502. def run_string_as_driver_nonblocking(driver_script, env: Dict = None):
  503. """Start a driver as a separate process and return immediately.
  504. Args:
  505. driver_script: A string to run as a Python script.
  506. Returns:
  507. A handle to the driver process.
  508. """
  509. script = "; ".join(
  510. [
  511. "import sys",
  512. "script = sys.stdin.read()",
  513. "sys.stdin.close()",
  514. "del sys",
  515. 'exec("del script\\n" + script)',
  516. ]
  517. )
  518. proc = subprocess.Popen(
  519. [sys.executable, "-c", script],
  520. stdin=subprocess.PIPE,
  521. stdout=subprocess.PIPE,
  522. stderr=subprocess.PIPE,
  523. env=env,
  524. )
  525. proc.stdin.write(driver_script.encode("ascii"))
  526. proc.stdin.close()
  527. return proc
  528. def convert_actor_state(state):
  529. if not state:
  530. return None
  531. return gcs_pb2.ActorTableData.ActorState.DESCRIPTOR.values_by_number[state].name
  532. def wait_for_num_actors(num_actors, state=None, timeout=10):
  533. state = convert_actor_state(state)
  534. start_time = time.time()
  535. while time.time() - start_time < timeout:
  536. if (
  537. len(
  538. list_actors(
  539. filters=[("state", "=", state)] if state else None,
  540. limit=num_actors,
  541. )
  542. )
  543. >= num_actors
  544. ):
  545. return
  546. time.sleep(0.1)
  547. raise TimeoutError("Timed out while waiting for global state.")
  548. def kill_actor_and_wait_for_failure(actor, timeout=10, retry_interval_ms=100):
  549. actor_id = actor._actor_id.hex()
  550. current_num_restarts = get_actor(id=actor_id).num_restarts
  551. ray.kill(actor)
  552. start = time.time()
  553. while time.time() - start <= timeout:
  554. actor_state = get_actor(id=actor_id)
  555. if (
  556. actor_state.state == "DEAD"
  557. or actor_state.num_restarts > current_num_restarts
  558. ):
  559. return
  560. time.sleep(retry_interval_ms / 1000.0)
  561. raise RuntimeError("It took too much time to kill an actor: {}".format(actor_id))
  562. def wait_for_assertion(
  563. assertion_predictor: Callable,
  564. timeout: int = 10,
  565. retry_interval_ms: int = 100,
  566. raise_exceptions: bool = False,
  567. **kwargs: Any,
  568. ):
  569. """Wait until an assertion is met or time out with an exception.
  570. Args:
  571. assertion_predictor: A function that predicts the assertion.
  572. timeout: Maximum timeout in seconds.
  573. retry_interval_ms: Retry interval in milliseconds.
  574. raise_exceptions: If true, exceptions that occur while executing
  575. assertion_predictor won't be caught and instead will be raised.
  576. **kwargs: Arguments to pass to the condition_predictor.
  577. Raises:
  578. RuntimeError: If the assertion is not met before the timeout expires.
  579. """
  580. def _assertion_to_condition():
  581. try:
  582. assertion_predictor(**kwargs)
  583. return True
  584. except AssertionError:
  585. return False
  586. try:
  587. wait_for_condition(
  588. _assertion_to_condition,
  589. timeout=timeout,
  590. retry_interval_ms=retry_interval_ms,
  591. raise_exceptions=raise_exceptions,
  592. **kwargs,
  593. )
  594. except RuntimeError:
  595. assertion_predictor(**kwargs) # Should fail assert
  596. @dataclass
  597. class MetricSamplePattern:
  598. name: Optional[str] = None
  599. value: Optional[str] = None
  600. partial_label_match: Optional[Dict[str, str]] = None
  601. def matches(self, sample: Sample):
  602. if self.name is not None:
  603. if self.name != sample.name:
  604. return False
  605. if self.value is not None:
  606. if self.value != sample.value:
  607. return False
  608. if self.partial_label_match is not None:
  609. for label, value in self.partial_label_match.items():
  610. if sample.labels.get(label) != value:
  611. return False
  612. return True
  613. @dataclass
  614. class PrometheusTimeseries:
  615. """A collection of timeseries from multiple addresses. Each timeseries is a
  616. collection of samples with the same metric name and labels. Concretely:
  617. - components_dict: a dictionary of addresses to the Component labels
  618. - metric_descriptors: a dictionary of metric names to the Metric object
  619. - metric_samples: the latest value of each label
  620. """
  621. components_dict: Dict[str, Set[str]] = field(default_factory=dict)
  622. metric_descriptors: Dict[str, Metric] = field(default_factory=dict)
  623. metric_samples: Dict[frozenset, Sample] = field(default_factory=dict)
  624. def flush(self):
  625. self.components_dict.clear()
  626. self.metric_descriptors.clear()
  627. self.metric_samples.clear()
  628. def get_metric_check_condition(
  629. metrics_to_check: List[MetricSamplePattern],
  630. timeseries: PrometheusTimeseries,
  631. export_addr: Optional[str] = None,
  632. ) -> Callable[[], bool]:
  633. """A condition to check if a prometheus metrics reach a certain value.
  634. This is a blocking check that can be passed into a `wait_for_condition`
  635. style function.
  636. Args:
  637. metrics_to_check: A list of MetricSamplePattern. The fields that
  638. aren't `None` will be matched.
  639. timeseries: A PrometheusTimeseries object to store the metrics.
  640. export_addr: Optional address to export metrics to.
  641. Returns:
  642. A function that returns True if all the metrics are emitted.
  643. """
  644. node_info = ray.nodes()[0]
  645. metrics_export_port = node_info["MetricsExportPort"]
  646. addr = node_info["NodeManagerAddress"]
  647. prom_addr = export_addr or build_address(addr, metrics_export_port)
  648. def f():
  649. for metric_pattern in metrics_to_check:
  650. metric_samples = fetch_prometheus_timeseries(
  651. [prom_addr], timeseries
  652. ).metric_samples.values()
  653. for metric_sample in metric_samples:
  654. if metric_pattern.matches(metric_sample):
  655. break
  656. else:
  657. logger.info(
  658. f"Didn't find {metric_pattern} in all samples: {metric_samples}",
  659. )
  660. return False
  661. return True
  662. return f
  663. def wait_until_succeeded_without_exception(
  664. func, exceptions, *args, timeout_ms=1000, retry_interval_ms=100, raise_last_ex=False
  665. ):
  666. """A helper function that waits until a given function
  667. completes without exceptions.
  668. Args:
  669. func: A function to run.
  670. exceptions: Exceptions that are supposed to occur.
  671. args: arguments to pass for a given func
  672. timeout_ms: Maximum timeout in milliseconds.
  673. retry_interval_ms: Retry interval in milliseconds.
  674. raise_last_ex: Raise the last exception when timeout.
  675. Return:
  676. Whether exception occurs within a timeout.
  677. """
  678. if isinstance(type(exceptions), tuple):
  679. raise Exception("exceptions arguments should be given as a tuple")
  680. time_elapsed = 0
  681. start = time.time()
  682. last_ex = None
  683. while time_elapsed <= timeout_ms:
  684. try:
  685. func(*args)
  686. return True
  687. except exceptions as ex:
  688. last_ex = ex
  689. time_elapsed = (time.time() - start) * 1000
  690. time.sleep(retry_interval_ms / 1000.0)
  691. if raise_last_ex:
  692. ex_stack = (
  693. traceback.format_exception(type(last_ex), last_ex, last_ex.__traceback__)
  694. if last_ex
  695. else []
  696. )
  697. ex_stack = "".join(ex_stack)
  698. raise Exception(f"Timed out while testing, {ex_stack}")
  699. return False
  700. def recursive_fnmatch(dirpath, pattern):
  701. """Looks at a file directory subtree for a filename pattern.
  702. Similar to glob.glob(..., recursive=True) but also supports 2.7
  703. """
  704. matches = []
  705. for root, dirnames, filenames in os.walk(dirpath):
  706. for filename in fnmatch.filter(filenames, pattern):
  707. matches.append(os.path.join(root, filename))
  708. return matches
  709. def generate_system_config_map(**kwargs):
  710. ray_kwargs = {
  711. "_system_config": kwargs,
  712. }
  713. return ray_kwargs
  714. def same_elements(elems_a, elems_b):
  715. """Checks if two iterables (such as lists) contain the same elements. Elements
  716. do not have to be hashable (this allows us to compare sets of dicts for
  717. example). This comparison is not necessarily efficient.
  718. """
  719. a = list(elems_a)
  720. b = list(elems_b)
  721. for x in a:
  722. if x not in b:
  723. return False
  724. for x in b:
  725. if x not in a:
  726. return False
  727. return True
  728. @ray.remote
  729. def _put(obj):
  730. return obj
  731. def put_object(obj, use_ray_put):
  732. if use_ray_put:
  733. return ray.put(obj)
  734. else:
  735. return _put.remote(obj)
  736. def wait_until_server_available(address, timeout_ms=5000, retry_interval_ms=100):
  737. ip, port_str = parse_address(address)
  738. port = int(port_str)
  739. time_elapsed = 0
  740. start = time.time()
  741. while time_elapsed <= timeout_ms:
  742. s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  743. s.settimeout(1)
  744. try:
  745. s.connect((ip, port))
  746. except Exception:
  747. time_elapsed = (time.time() - start) * 1000
  748. time.sleep(retry_interval_ms / 1000.0)
  749. s.close()
  750. continue
  751. s.close()
  752. return True
  753. return False
  754. def get_other_nodes(cluster, exclude_head=False):
  755. """Get all nodes except the one that we're connected to."""
  756. return [
  757. node
  758. for node in cluster.list_all_nodes()
  759. if node._raylet_socket_name
  760. != ray._private.worker._global_node._raylet_socket_name
  761. and (exclude_head is False or node.head is False)
  762. ]
  763. def get_non_head_nodes(cluster):
  764. """Get all non-head nodes."""
  765. return list(filter(lambda x: x.head is False, cluster.list_all_nodes()))
  766. def init_error_pubsub():
  767. """Initialize error info pub/sub"""
  768. s = ray._raylet.GcsErrorSubscriber(
  769. address=ray._private.worker.global_worker.gcs_client.address
  770. )
  771. s.subscribe()
  772. return s
  773. def get_error_message(subscriber, num=1e6, error_type=None, timeout=20):
  774. """Gets errors from GCS subscriber.
  775. Returns maximum `num` error strings within `timeout`.
  776. Only returns errors of `error_type` if specified.
  777. """
  778. deadline = time.time() + timeout
  779. msgs = []
  780. while time.time() < deadline and len(msgs) < num:
  781. _, error_data = subscriber.poll(timeout=deadline - time.time())
  782. if not error_data:
  783. # Timed out before any data is received.
  784. break
  785. if error_type is None or error_type == error_data["type"]:
  786. msgs.append(error_data)
  787. else:
  788. time.sleep(0.01)
  789. return msgs
  790. def init_log_pubsub():
  791. """Initialize log pub/sub"""
  792. s = ray._raylet.GcsLogSubscriber(
  793. address=ray._private.worker.global_worker.gcs_client.address
  794. )
  795. s.subscribe()
  796. return s
  797. def get_log_data(
  798. subscriber,
  799. num: int = 1e6,
  800. timeout: float = 20,
  801. job_id: Optional[str] = None,
  802. matcher=None,
  803. ) -> List[dict]:
  804. deadline = time.time() + timeout
  805. msgs = []
  806. while time.time() < deadline and len(msgs) < num:
  807. logs_data = subscriber.poll(timeout=deadline - time.time())
  808. if not logs_data:
  809. # Timed out before any data is received.
  810. break
  811. if job_id and job_id != logs_data["job"]:
  812. continue
  813. if matcher and all(not matcher(line) for line in logs_data["lines"]):
  814. continue
  815. msgs.append(logs_data)
  816. return msgs
  817. def get_log_message(
  818. subscriber,
  819. num: int = 1e6,
  820. timeout: float = 20,
  821. job_id: Optional[str] = None,
  822. matcher=None,
  823. ) -> List[List[str]]:
  824. """Gets log lines through GCS subscriber.
  825. Returns maximum `num` of log messages, within `timeout`.
  826. If `job_id` or `match` is specified, only returns log lines from `job_id`
  827. or when `matcher` is true.
  828. """
  829. msgs = get_log_data(subscriber, num, timeout, job_id, matcher)
  830. return [msg["lines"] for msg in msgs]
  831. def get_log_sources(
  832. subscriber,
  833. num: int = 1e6,
  834. timeout: float = 20,
  835. job_id: Optional[str] = None,
  836. matcher=None,
  837. ):
  838. """Get the source of all log messages"""
  839. msgs = get_log_data(subscriber, num, timeout, job_id, matcher)
  840. return {msg["pid"] for msg in msgs}
  841. def get_log_batch(
  842. subscriber,
  843. num: int,
  844. timeout: float = 20,
  845. job_id: Optional[str] = None,
  846. matcher=None,
  847. ) -> List[str]:
  848. """Gets log batches through GCS subscriber.
  849. Returns maximum `num` batches of logs. Each batch is a dict that includes
  850. metadata such as `pid`, `job_id`, and `lines` of log messages.
  851. If `job_id` or `match` is specified, only returns log batches from `job_id`
  852. or when `matcher` is true.
  853. """
  854. deadline = time.time() + timeout
  855. batches = []
  856. while time.time() < deadline and len(batches) < num:
  857. logs_data = subscriber.poll(timeout=deadline - time.time())
  858. if not logs_data:
  859. # Timed out before any data is received.
  860. break
  861. if job_id and job_id != logs_data["job"]:
  862. continue
  863. if matcher and not matcher(logs_data):
  864. continue
  865. batches.append(logs_data)
  866. return batches
  867. def format_web_url(url):
  868. """Format web url."""
  869. url = url.replace("localhost", "http://127.0.0.1")
  870. if not url.startswith("http://"):
  871. return "http://" + url
  872. return url
  873. def client_test_enabled() -> bool:
  874. return ray._private.client_mode_hook.is_client_mode_enabled
  875. def object_memory_usage() -> bool:
  876. """Returns the number of bytes used in the object store."""
  877. total = ray.cluster_resources().get("object_store_memory", 0)
  878. avail = ray.available_resources().get("object_store_memory", 0)
  879. return total - avail
  880. def fetch_raw_prometheus(prom_addresses):
  881. # Local import so minimal dependency tests can run without requests
  882. import requests
  883. for address in prom_addresses:
  884. try:
  885. response = requests.get(f"http://{address}/metrics")
  886. yield address, response.text
  887. except requests.exceptions.ConnectionError:
  888. continue
  889. def fetch_prometheus(prom_addresses):
  890. components_dict = {}
  891. metric_descriptors = {}
  892. metric_samples = []
  893. for address in prom_addresses:
  894. if address not in components_dict:
  895. components_dict[address] = set()
  896. for address, response in fetch_raw_prometheus(prom_addresses):
  897. for metric in text_string_to_metric_families(response):
  898. for sample in metric.samples:
  899. metric_descriptors[sample.name] = metric
  900. metric_samples.append(sample)
  901. if "Component" in sample.labels:
  902. components_dict[address].add(sample.labels["Component"])
  903. return components_dict, metric_descriptors, metric_samples
  904. def fetch_prometheus_timeseries(
  905. prom_addreses: List[str],
  906. result: PrometheusTimeseries,
  907. ) -> PrometheusTimeseries:
  908. components_dict, metric_descriptors, metric_samples = fetch_prometheus(
  909. prom_addreses
  910. )
  911. for address, components in components_dict.items():
  912. if address not in result.components_dict:
  913. result.components_dict[address] = set()
  914. result.components_dict[address].update(components)
  915. result.metric_descriptors.update(metric_descriptors)
  916. for sample in metric_samples:
  917. # udpate sample to the latest value
  918. result.metric_samples[
  919. frozenset(list(sample.labels.items()) + [("_metric_name_", sample.name)])
  920. ] = sample
  921. return result
  922. def fetch_prometheus_metrics(prom_addresses: List[str]) -> Dict[str, List[Any]]:
  923. """Return prometheus metrics from the given addresses.
  924. Args:
  925. prom_addresses: List of metrics_agent addresses to collect metrics from.
  926. Returns:
  927. Dict mapping from metric name to list of samples for the metric.
  928. """
  929. _, _, samples = fetch_prometheus(prom_addresses)
  930. samples_by_name = defaultdict(list)
  931. for sample in samples:
  932. samples_by_name[sample.name].append(sample)
  933. return samples_by_name
  934. def fetch_prometheus_metric_timeseries(
  935. prom_addresses: List[str], result: PrometheusTimeseries
  936. ) -> Dict[str, List[Any]]:
  937. samples = fetch_prometheus_timeseries(
  938. prom_addresses, result
  939. ).metric_samples.values()
  940. samples_by_name = defaultdict(list)
  941. for sample in samples:
  942. samples_by_name[sample.name].append(sample)
  943. return samples_by_name
  944. def raw_metric_timeseries(
  945. info: RayContext, result: PrometheusTimeseries
  946. ) -> Dict[str, List[Any]]:
  947. """Return prometheus timeseries from a RayContext"""
  948. metrics_page = "localhost:{}".format(info.address_info["metrics_export_port"])
  949. print("Fetch metrics from", metrics_page)
  950. return fetch_prometheus_metric_timeseries([metrics_page], result)
  951. def get_system_metric_for_component(
  952. system_metric: str, component: str, prometheus_server_address: str
  953. ) -> List[float]:
  954. """Get the system metric for a given component from a Prometheus server address.
  955. Please note:
  956. - This function requires the availability of the Prometheus server. Therefore, it
  957. requires the server address.
  958. - It assumes the system metric has a `Component` label and `pid` label. `pid` is the
  959. process id, so it can be used to uniquely identify the process.
  960. """
  961. session_name = os.path.basename(
  962. ray._private.worker._global_node.get_session_dir_path()
  963. )
  964. query = f"sum({system_metric}{{Component='{component}',SessionName='{session_name}'}}) by (pid)"
  965. resp = requests.get(
  966. f"{prometheus_server_address}/api/v1/query?query={quote(query)}"
  967. )
  968. if resp.status_code != 200:
  969. raise Exception(f"Failed to query Prometheus: {resp.status_code}")
  970. result = resp.json()
  971. return [float(item["value"][1]) for item in result["data"]["result"]]
  972. def get_test_config_path(config_file_name):
  973. """Resolve the test config path from the config file dir"""
  974. here = os.path.realpath(__file__)
  975. path = pathlib.Path(here)
  976. grandparent = path.parent.parent
  977. return os.path.join(grandparent, "tests/test_cli_patterns", config_file_name)
  978. def load_test_config(config_file_name):
  979. """Loads a config yaml from tests/test_cli_patterns."""
  980. config_path = get_test_config_path(config_file_name)
  981. config = yaml.safe_load(open(config_path).read())
  982. return config
  983. def set_setup_func():
  984. import ray._private.runtime_env as runtime_env
  985. runtime_env.VAR = "hello world"
  986. class BatchQueue(Queue):
  987. def __init__(self, maxsize: int = 0, actor_options: Optional[Dict] = None) -> None:
  988. actor_options = actor_options or {}
  989. self.maxsize = maxsize
  990. self.actor = (
  991. ray.remote(_BatchQueueActor).options(**actor_options).remote(self.maxsize)
  992. )
  993. def get_batch(
  994. self,
  995. batch_size: int = None,
  996. total_timeout: Optional[float] = None,
  997. first_timeout: Optional[float] = None,
  998. ) -> List[Any]:
  999. """Gets batch of items from the queue and returns them in a
  1000. list in order.
  1001. Raises:
  1002. Empty: if the queue does not contain the desired number of items
  1003. """
  1004. return ray.get(
  1005. self.actor.get_batch.remote(batch_size, total_timeout, first_timeout)
  1006. )
  1007. class _BatchQueueActor(_QueueActor):
  1008. async def get_batch(self, batch_size=None, total_timeout=None, first_timeout=None):
  1009. start = timeit.default_timer()
  1010. try:
  1011. first = await asyncio.wait_for(self.queue.get(), first_timeout)
  1012. batch = [first]
  1013. if total_timeout:
  1014. end = timeit.default_timer()
  1015. total_timeout = max(total_timeout - (end - start), 0)
  1016. except asyncio.TimeoutError:
  1017. raise Empty
  1018. if batch_size is None:
  1019. if total_timeout is None:
  1020. total_timeout = 0
  1021. while True:
  1022. try:
  1023. start = timeit.default_timer()
  1024. batch.append(
  1025. await asyncio.wait_for(self.queue.get(), total_timeout)
  1026. )
  1027. if total_timeout:
  1028. end = timeit.default_timer()
  1029. total_timeout = max(total_timeout - (end - start), 0)
  1030. except asyncio.TimeoutError:
  1031. break
  1032. else:
  1033. for _ in range(batch_size - 1):
  1034. try:
  1035. start = timeit.default_timer()
  1036. batch.append(
  1037. await asyncio.wait_for(self.queue.get(), total_timeout)
  1038. )
  1039. if total_timeout:
  1040. end = timeit.default_timer()
  1041. total_timeout = max(total_timeout - (end - start), 0)
  1042. except asyncio.TimeoutError:
  1043. break
  1044. return batch
  1045. def is_placement_group_removed(pg):
  1046. table = ray.util.placement_group_table(pg)
  1047. if "state" not in table:
  1048. return False
  1049. return table["state"] == "REMOVED"
  1050. def placement_group_assert_no_leak(pgs_created):
  1051. for pg in pgs_created:
  1052. ray.util.remove_placement_group(pg)
  1053. def wait_for_pg_removed():
  1054. for pg_entry in ray.util.placement_group_table().values():
  1055. if pg_entry["state"] != "REMOVED":
  1056. return False
  1057. return True
  1058. wait_for_condition(wait_for_pg_removed)
  1059. cluster_resources = ray.cluster_resources()
  1060. cluster_resources.pop("memory")
  1061. cluster_resources.pop("object_store_memory")
  1062. def wait_for_resource_recovered():
  1063. for resource, val in ray.available_resources().items():
  1064. if resource in cluster_resources and cluster_resources[resource] != val:
  1065. return False
  1066. if "_group_" in resource:
  1067. return False
  1068. return True
  1069. wait_for_condition(wait_for_resource_recovered)
  1070. def monitor_memory_usage(
  1071. print_interval_s: int = 30,
  1072. record_interval_s: int = 5,
  1073. warning_threshold: float = 0.9,
  1074. ):
  1075. """Run the memory monitor actor that prints the memory usage.
  1076. The monitor will run on the same node as this function is called.
  1077. Params:
  1078. interval_s: The interval memory usage information is printed
  1079. warning_threshold: The threshold where the
  1080. memory usage warning is printed.
  1081. Returns:
  1082. The memory monitor actor.
  1083. """
  1084. assert ray.is_initialized(), "The API is only available when Ray is initialized."
  1085. @ray.remote(num_cpus=0)
  1086. class MemoryMonitorActor:
  1087. def __init__(
  1088. self,
  1089. print_interval_s: float = 20,
  1090. record_interval_s: float = 5,
  1091. warning_threshold: float = 0.9,
  1092. n: int = 10,
  1093. ):
  1094. """The actor that monitor the memory usage of the cluster.
  1095. Params:
  1096. print_interval_s: The interval where
  1097. memory usage is printed.
  1098. record_interval_s: The interval where
  1099. memory usage is recorded.
  1100. warning_threshold: The threshold where
  1101. memory warning is printed
  1102. n: When memory usage is printed,
  1103. top n entries are printed.
  1104. """
  1105. # -- Interval the monitor prints the memory usage information. --
  1106. self.print_interval_s = print_interval_s
  1107. # -- Interval the monitor records the memory usage information. --
  1108. self.record_interval_s = record_interval_s
  1109. # -- Whether or not the monitor is running. --
  1110. self.is_running = False
  1111. # -- The used_gb/total_gb threshold where warning message omits. --
  1112. self.warning_threshold = warning_threshold
  1113. # -- The monitor that calculates the memory usage of the node. --
  1114. self.monitor = memory_monitor.MemoryMonitor()
  1115. # -- The top n memory usage of processes are printed. --
  1116. self.n = n
  1117. # -- The peak memory usage in GB during lifetime of monitor. --
  1118. self.peak_memory_usage = 0
  1119. # -- The top n memory usage of processes
  1120. # during peak memory usage. --
  1121. self.peak_top_n_memory_usage = ""
  1122. # -- The last time memory usage was printed --
  1123. self._last_print_time = 0
  1124. # -- logger. --
  1125. logging.basicConfig(level=logging.INFO)
  1126. def ready(self):
  1127. pass
  1128. async def run(self):
  1129. """Run the monitor."""
  1130. self.is_running = True
  1131. while self.is_running:
  1132. now = time.time()
  1133. used_gb, total_gb = self.monitor.get_memory_usage()
  1134. top_n_memory_usage = memory_monitor.get_top_n_memory_usage(n=self.n)
  1135. if used_gb > self.peak_memory_usage:
  1136. self.peak_memory_usage = used_gb
  1137. self.peak_top_n_memory_usage = top_n_memory_usage
  1138. if used_gb > total_gb * self.warning_threshold:
  1139. logging.warning(
  1140. "The memory usage is high: " f"{used_gb / total_gb * 100}%"
  1141. )
  1142. if now - self._last_print_time > self.print_interval_s:
  1143. logging.info(f"Memory usage: {used_gb} / {total_gb}")
  1144. logging.info(f"Top {self.n} process memory usage:")
  1145. logging.info(top_n_memory_usage)
  1146. self._last_print_time = now
  1147. await asyncio.sleep(self.record_interval_s)
  1148. async def stop_run(self):
  1149. """Stop running the monitor.
  1150. Returns:
  1151. True if the monitor is stopped. False otherwise.
  1152. """
  1153. was_running = self.is_running
  1154. self.is_running = False
  1155. return was_running
  1156. async def get_peak_memory_info(self):
  1157. """Return the tuple of the peak memory usage and the
  1158. top n process information during the peak memory usage.
  1159. """
  1160. return self.peak_memory_usage, self.peak_top_n_memory_usage
  1161. current_node_ip = ray._private.worker.global_worker.node_ip_address
  1162. # Schedule the actor on the current node.
  1163. memory_monitor_actor = MemoryMonitorActor.options(
  1164. resources={f"node:{current_node_ip}": 0.001}
  1165. ).remote(
  1166. print_interval_s=print_interval_s,
  1167. record_interval_s=record_interval_s,
  1168. warning_threshold=warning_threshold,
  1169. )
  1170. print("Waiting for memory monitor actor to be ready...")
  1171. ray.get(memory_monitor_actor.ready.remote())
  1172. print("Memory monitor actor is ready now.")
  1173. memory_monitor_actor.run.remote()
  1174. return memory_monitor_actor
  1175. def setup_tls():
  1176. """Sets up required environment variables for tls"""
  1177. import pytest
  1178. if sys.platform == "darwin":
  1179. pytest.skip("Cryptography doesn't install in Mac build pipeline")
  1180. cert, key = generate_self_signed_tls_certs()
  1181. temp_dir = tempfile.mkdtemp("ray-test-certs")
  1182. cert_filepath = os.path.join(temp_dir, "server.crt")
  1183. key_filepath = os.path.join(temp_dir, "server.key")
  1184. with open(cert_filepath, "w") as fh:
  1185. fh.write(cert)
  1186. with open(key_filepath, "w") as fh:
  1187. fh.write(key)
  1188. os.environ["RAY_USE_TLS"] = "1"
  1189. os.environ["RAY_TLS_SERVER_CERT"] = cert_filepath
  1190. os.environ["RAY_TLS_SERVER_KEY"] = key_filepath
  1191. os.environ["RAY_TLS_CA_CERT"] = cert_filepath
  1192. return key_filepath, cert_filepath, temp_dir
  1193. def teardown_tls(key_filepath, cert_filepath, temp_dir):
  1194. os.remove(key_filepath)
  1195. os.remove(cert_filepath)
  1196. os.removedirs(temp_dir)
  1197. del os.environ["RAY_USE_TLS"]
  1198. del os.environ["RAY_TLS_SERVER_CERT"]
  1199. del os.environ["RAY_TLS_SERVER_KEY"]
  1200. del os.environ["RAY_TLS_CA_CERT"]
  1201. class ResourceKillerActor:
  1202. """Abstract base class used to implement resource killers for chaos testing.
  1203. Subclasses should implement _find_resource_to_kill, which should find a resource
  1204. to kill. This method should return the args to _kill_resource, which is another
  1205. abstract method that should kill the resource and add it to the `killed` set.
  1206. """
  1207. def __init__(
  1208. self,
  1209. head_node_id,
  1210. kill_interval_s: float = 60,
  1211. kill_delay_s: float = 0,
  1212. max_to_kill: Optional[int] = 2,
  1213. batch_size_to_kill: int = 1,
  1214. kill_filter_fn: Optional[Callable] = None,
  1215. ):
  1216. self.kill_interval_s = kill_interval_s
  1217. self.kill_delay_s = kill_delay_s
  1218. self.is_running = False
  1219. self.head_node_id = head_node_id
  1220. self.killed = set()
  1221. self.done = get_or_create_event_loop().create_future()
  1222. self.max_to_kill = max_to_kill
  1223. self.batch_size_to_kill = batch_size_to_kill
  1224. self.kill_filter_fn = kill_filter_fn
  1225. self.kill_immediately_after_found = False
  1226. # -- logger. --
  1227. logging.basicConfig(level=logging.INFO)
  1228. def ready(self):
  1229. pass
  1230. async def run(self):
  1231. self.is_running = True
  1232. time.sleep(self.kill_delay_s)
  1233. while self.is_running:
  1234. to_kills = await self._find_resources_to_kill()
  1235. if not self.is_running:
  1236. break
  1237. if self.kill_immediately_after_found:
  1238. sleep_interval = 0
  1239. else:
  1240. sleep_interval = random.random() * self.kill_interval_s
  1241. time.sleep(sleep_interval)
  1242. for to_kill in to_kills:
  1243. self._kill_resource(*to_kill)
  1244. if self.max_to_kill is not None and len(self.killed) >= self.max_to_kill:
  1245. break
  1246. await asyncio.sleep(self.kill_interval_s - sleep_interval)
  1247. self.done.set_result(True)
  1248. await self.stop_run()
  1249. async def _find_resources_to_kill(self):
  1250. raise NotImplementedError
  1251. def _kill_resource(self, *args):
  1252. raise NotImplementedError
  1253. async def stop_run(self):
  1254. was_running = self.is_running
  1255. if was_running:
  1256. self._cleanup()
  1257. self.is_running = False
  1258. return was_running
  1259. async def get_total_killed(self):
  1260. """Get the total number of killed resources"""
  1261. await self.done
  1262. return self.killed
  1263. def _cleanup(self):
  1264. """Cleanup any resources created by the killer.
  1265. Overriding this method is optional.
  1266. """
  1267. pass
  1268. class NodeKillerBase(ResourceKillerActor):
  1269. async def _find_resources_to_kill(self):
  1270. nodes_to_kill = []
  1271. while not nodes_to_kill and self.is_running:
  1272. worker_nodes = [
  1273. node
  1274. for node in ray.nodes()
  1275. if node["Alive"]
  1276. and (node["NodeID"] != self.head_node_id)
  1277. and (node["NodeID"] not in self.killed)
  1278. ]
  1279. if self.kill_filter_fn:
  1280. candidates = list(filter(self.kill_filter_fn(), worker_nodes))
  1281. else:
  1282. candidates = worker_nodes
  1283. # Ensure at least one worker node remains alive.
  1284. if len(worker_nodes) < self.batch_size_to_kill + 1:
  1285. # Give the cluster some time to start.
  1286. await asyncio.sleep(1)
  1287. continue
  1288. # Collect nodes to kill, limited by batch size.
  1289. for candidate in candidates[: self.batch_size_to_kill]:
  1290. nodes_to_kill.append(
  1291. (
  1292. candidate["NodeID"],
  1293. candidate["NodeManagerAddress"],
  1294. candidate["NodeManagerPort"],
  1295. )
  1296. )
  1297. return nodes_to_kill
  1298. @ray.remote(num_cpus=0)
  1299. class RayletKiller(NodeKillerBase):
  1300. def _kill_resource(self, node_id, node_to_kill_ip, node_to_kill_port):
  1301. if node_to_kill_port is not None:
  1302. try:
  1303. self._kill_raylet(node_to_kill_ip, node_to_kill_port, graceful=False)
  1304. except Exception:
  1305. pass
  1306. logging.info(
  1307. f"Killed node {node_id} at address: "
  1308. f"{node_to_kill_ip}, port: {node_to_kill_port}"
  1309. )
  1310. self.killed.add(node_id)
  1311. def _kill_raylet(self, ip, port, graceful=False):
  1312. import grpc
  1313. from grpc._channel import _InactiveRpcError
  1314. from ray.core.generated import node_manager_pb2_grpc
  1315. raylet_address = build_address(ip, port)
  1316. channel = grpc.insecure_channel(raylet_address)
  1317. stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel)
  1318. try:
  1319. stub.ShutdownRaylet(
  1320. node_manager_pb2.ShutdownRayletRequest(graceful=graceful)
  1321. )
  1322. except _InactiveRpcError:
  1323. assert not graceful
  1324. @ray.remote(num_cpus=0)
  1325. class EC2InstanceTerminator(NodeKillerBase):
  1326. def _kill_resource(self, node_id, node_to_kill_ip, _):
  1327. if node_to_kill_ip is not None:
  1328. try:
  1329. _terminate_ec2_instance(node_to_kill_ip)
  1330. except Exception:
  1331. pass
  1332. logging.info(f"Terminated instance, {node_id=}, address={node_to_kill_ip}")
  1333. self.killed.add(node_id)
  1334. @ray.remote(num_cpus=0)
  1335. class EC2InstanceTerminatorWithGracePeriod(NodeKillerBase):
  1336. def __init__(self, *args, grace_period_s: int = 30, **kwargs):
  1337. super().__init__(*args, **kwargs)
  1338. self._grace_period_s = grace_period_s
  1339. self._kill_threads: Set[threading.Thread] = set()
  1340. def _kill_resource(self, node_id, node_to_kill_ip, _):
  1341. assert node_id not in self.killed
  1342. # Clean up any completed threads.
  1343. for thread in self._kill_threads.copy():
  1344. if not thread.is_alive():
  1345. thread.join()
  1346. self._kill_threads.remove(thread)
  1347. def _kill_node_with_grace_period(node_id, node_to_kill_ip):
  1348. self._drain_node(node_id)
  1349. time.sleep(self._grace_period_s)
  1350. # Anyscale extends the drain deadline if you shut down the instance
  1351. # directly. To work around this, we force-stop Ray on the node. Anyscale
  1352. # should then terminate it shortly after without updating the drain
  1353. # deadline.
  1354. _execute_command_on_node("ray stop --force", node_to_kill_ip)
  1355. logger.info(f"Starting killing thread {node_id=}, {node_to_kill_ip=}")
  1356. thread = threading.Thread(
  1357. target=_kill_node_with_grace_period,
  1358. args=(node_id, node_to_kill_ip),
  1359. daemon=True,
  1360. )
  1361. thread.start()
  1362. self._kill_threads.add(thread)
  1363. self.killed.add(node_id)
  1364. def _drain_node(self, node_id: str) -> None:
  1365. # We need to lazily import this object. Otherwise, Ray can't serialize the
  1366. # class.
  1367. from ray.core.generated import autoscaler_pb2
  1368. assert ray.NodeID.from_hex(node_id) != ray.NodeID.nil()
  1369. logging.info(f"Draining node {node_id=}")
  1370. address = services.canonicalize_bootstrap_address_or_die(addr="auto")
  1371. gcs_client = ray._raylet.GcsClient(address=address)
  1372. deadline_timestamp_ms = (time.time_ns() // 1e6) + (self._grace_period_s * 1e3)
  1373. try:
  1374. is_accepted, _ = gcs_client.drain_node(
  1375. node_id,
  1376. autoscaler_pb2.DrainNodeReason.Value("DRAIN_NODE_REASON_PREEMPTION"),
  1377. "",
  1378. deadline_timestamp_ms,
  1379. )
  1380. except ray.exceptions.RayError as e:
  1381. logger.error(f"Failed to drain node {node_id=}")
  1382. raise e
  1383. assert is_accepted, "Drain node request was rejected"
  1384. def _cleanup(self):
  1385. for thread in self._kill_threads.copy():
  1386. thread.join()
  1387. self._kill_threads.remove(thread)
  1388. assert not self._kill_threads
  1389. @ray.remote(num_cpus=0)
  1390. class WorkerKillerActor(ResourceKillerActor):
  1391. def __init__(self, *args, **kwargs):
  1392. super().__init__(*args, **kwargs)
  1393. # Kill worker immediately so that the task does
  1394. # not finish successfully on its own.
  1395. self.kill_immediately_after_found = True
  1396. from ray.util.state.api import StateApiClient
  1397. from ray.util.state.common import ListApiOptions
  1398. self.client = StateApiClient()
  1399. self.task_options = ListApiOptions(
  1400. filters=[
  1401. ("state", "=", "RUNNING"),
  1402. ("name", "!=", "WorkerKillActor.run"),
  1403. ]
  1404. )
  1405. async def _find_resources_to_kill(self):
  1406. from ray.util.state.common import StateResource
  1407. process_to_kill_task_id = None
  1408. process_to_kill_pid = None
  1409. process_to_kill_node_id = None
  1410. while process_to_kill_pid is None and self.is_running:
  1411. tasks = self.client.list(
  1412. StateResource.TASKS,
  1413. options=self.task_options,
  1414. raise_on_missing_output=False,
  1415. )
  1416. if self.kill_filter_fn is not None:
  1417. tasks = list(filter(self.kill_filter_fn(), tasks))
  1418. for task in tasks:
  1419. if task.worker_id is not None and task.node_id is not None:
  1420. process_to_kill_task_id = task.task_id
  1421. process_to_kill_pid = task.worker_pid
  1422. process_to_kill_node_id = task.node_id
  1423. break
  1424. # Give the cluster some time to start.
  1425. await asyncio.sleep(0.1)
  1426. return [(process_to_kill_task_id, process_to_kill_pid, process_to_kill_node_id)]
  1427. def _kill_resource(
  1428. self, process_to_kill_task_id, process_to_kill_pid, process_to_kill_node_id
  1429. ):
  1430. if process_to_kill_pid is not None:
  1431. @ray.remote
  1432. def kill_process(pid):
  1433. import psutil
  1434. proc = psutil.Process(pid)
  1435. proc.kill()
  1436. scheduling_strategy = (
  1437. ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
  1438. node_id=process_to_kill_node_id,
  1439. soft=False,
  1440. )
  1441. )
  1442. kill_process.options(scheduling_strategy=scheduling_strategy).remote(
  1443. process_to_kill_pid
  1444. )
  1445. logging.info(
  1446. f"Killing pid {process_to_kill_pid} on node {process_to_kill_node_id}"
  1447. )
  1448. # Store both task_id and pid because retried tasks have same task_id.
  1449. self.killed.add((process_to_kill_task_id, process_to_kill_pid))
  1450. def get_and_run_resource_killer(
  1451. resource_killer_cls,
  1452. kill_interval_s,
  1453. namespace=None,
  1454. lifetime=None,
  1455. no_start=False,
  1456. max_to_kill=2,
  1457. batch_size_to_kill=1,
  1458. kill_delay_s=0,
  1459. kill_filter_fn=None,
  1460. ):
  1461. assert ray.is_initialized(), "The API is only available when Ray is initialized."
  1462. head_node_id = ray.get_runtime_context().get_node_id()
  1463. # Schedule the actor on the current node.
  1464. resource_killer = resource_killer_cls.options(
  1465. scheduling_strategy=NodeAffinitySchedulingStrategy(
  1466. node_id=head_node_id, soft=False
  1467. ),
  1468. namespace=namespace,
  1469. name="ResourceKiller",
  1470. lifetime=lifetime,
  1471. ).remote(
  1472. head_node_id,
  1473. kill_interval_s=kill_interval_s,
  1474. kill_delay_s=kill_delay_s,
  1475. max_to_kill=max_to_kill,
  1476. batch_size_to_kill=batch_size_to_kill,
  1477. kill_filter_fn=kill_filter_fn,
  1478. )
  1479. print("Waiting for ResourceKiller to be ready...")
  1480. ray.get(resource_killer.ready.remote())
  1481. print("ResourceKiller is ready now.")
  1482. if not no_start:
  1483. resource_killer.run.remote()
  1484. return resource_killer
  1485. def get_actor_node_id(actor_handle: "ray.actor.ActorHandle") -> str:
  1486. return ray.get(
  1487. actor_handle.__ray_call__.remote(
  1488. lambda self: ray.get_runtime_context().get_node_id()
  1489. )
  1490. )
  1491. @contextmanager
  1492. def chdir(d: str):
  1493. old_dir = os.getcwd()
  1494. os.chdir(d)
  1495. try:
  1496. yield
  1497. finally:
  1498. os.chdir(old_dir)
  1499. def test_get_directory_size_bytes():
  1500. with tempfile.TemporaryDirectory() as tmp_dir, chdir(tmp_dir):
  1501. assert ray._private.utils.get_directory_size_bytes(tmp_dir) == 0
  1502. with open("test_file", "wb") as f:
  1503. f.write(os.urandom(100))
  1504. assert ray._private.utils.get_directory_size_bytes(tmp_dir) == 100
  1505. with open("test_file_2", "wb") as f:
  1506. f.write(os.urandom(50))
  1507. assert ray._private.utils.get_directory_size_bytes(tmp_dir) == 150
  1508. os.mkdir("subdir")
  1509. with open("subdir/subdir_file", "wb") as f:
  1510. f.write(os.urandom(2))
  1511. assert ray._private.utils.get_directory_size_bytes(tmp_dir) == 152
  1512. def check_local_files_gced(cluster):
  1513. for node in cluster.list_all_nodes():
  1514. for subdir in ["conda", "pip", "working_dir_files", "py_modules_files"]:
  1515. all_files = os.listdir(
  1516. os.path.join(node.get_runtime_env_dir_path(), subdir)
  1517. )
  1518. # Check that there are no files remaining except for .lock files
  1519. # and generated requirements.txt files.
  1520. # Note: On Windows the top folder is not deleted as it is in use.
  1521. # TODO(architkulkarni): these files should get cleaned up too!
  1522. items = list(filter(lambda f: not f.endswith((".lock", ".txt")), all_files))
  1523. if len(items) > 0:
  1524. print(f"runtime_env files not GC'd from subdir '{subdir}': {items}")
  1525. return False
  1526. return True
  1527. def generate_runtime_env_dict(field, spec_format, tmp_path, pip_list=None):
  1528. if pip_list is None:
  1529. pip_list = ["pip-install-test==0.5"]
  1530. if field == "conda":
  1531. conda_dict = {"dependencies": ["pip", {"pip": pip_list}]}
  1532. if spec_format == "file":
  1533. conda_file = tmp_path / f"environment-{hash(str(pip_list))}.yml"
  1534. conda_file.write_text(yaml.dump(conda_dict))
  1535. conda = str(conda_file)
  1536. elif spec_format == "python_object":
  1537. conda = conda_dict
  1538. runtime_env = {"conda": conda}
  1539. elif field == "pip":
  1540. if spec_format == "file":
  1541. pip_file = tmp_path / f"requirements-{hash(str(pip_list))}.txt"
  1542. pip_file.write_text("\n".join(pip_list))
  1543. pip = str(pip_file)
  1544. elif spec_format == "python_object":
  1545. pip = pip_list
  1546. runtime_env = {"pip": pip}
  1547. return runtime_env
  1548. def check_spilled_mb(address, spilled=None, restored=None, fallback=None):
  1549. def ok():
  1550. s = memory_summary(address=address["address"], stats_only=True)
  1551. print(s)
  1552. if restored:
  1553. if "Restored {} MiB".format(restored) not in s:
  1554. return False
  1555. else:
  1556. if "Restored" in s:
  1557. return False
  1558. if spilled:
  1559. if not isinstance(spilled, list):
  1560. spilled_lst = [spilled]
  1561. else:
  1562. spilled_lst = spilled
  1563. found = False
  1564. for n in spilled_lst:
  1565. if "Spilled {} MiB".format(n) in s:
  1566. found = True
  1567. if not found:
  1568. return False
  1569. else:
  1570. if "Spilled" in s:
  1571. return False
  1572. if fallback:
  1573. if "Plasma filesystem mmap usage: {} MiB".format(fallback) not in s:
  1574. return False
  1575. else:
  1576. if "Plasma filesystem mmap usage:" in s:
  1577. return False
  1578. return True
  1579. wait_for_condition(ok, timeout=3, retry_interval_ms=1000)
  1580. def no_resource_leaks_excluding_node_resources():
  1581. cluster_resources = ray.cluster_resources()
  1582. available_resources = ray.available_resources()
  1583. for r in ray.cluster_resources():
  1584. if "node" in r:
  1585. del cluster_resources[r]
  1586. del available_resources[r]
  1587. return cluster_resources == available_resources
  1588. def job_hook(**kwargs):
  1589. """Function called by reflection by test_cli_integration."""
  1590. cmd = " ".join(kwargs["entrypoint"])
  1591. print(f"hook intercepted: {cmd}")
  1592. sys.exit(0)
  1593. def wandb_setup_api_key_hook():
  1594. """
  1595. Example external hook to set up W&B API key in
  1596. WandbIntegrationTest.testWandbLoggerConfig
  1597. """
  1598. return "abcd"
  1599. # Get node stats from node manager.
  1600. def get_node_stats(raylet, num_retry=5, timeout=2):
  1601. import grpc
  1602. from ray._private.grpc_utils import init_grpc_channel
  1603. from ray.core.generated import node_manager_pb2_grpc
  1604. raylet_address = build_address(
  1605. raylet["NodeManagerAddress"], raylet["NodeManagerPort"]
  1606. )
  1607. channel = init_grpc_channel(raylet_address)
  1608. stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel)
  1609. for _ in range(num_retry):
  1610. try:
  1611. reply = stub.GetNodeStats(
  1612. node_manager_pb2.GetNodeStatsRequest(), timeout=timeout
  1613. )
  1614. break
  1615. except grpc.RpcError:
  1616. continue
  1617. assert reply is not None
  1618. return reply
  1619. # Gets resource usage assuming gcs is local.
  1620. def get_resource_usage(gcs_address, timeout=10):
  1621. from ray._private.grpc_utils import init_grpc_channel
  1622. from ray.core.generated import gcs_service_pb2_grpc
  1623. if not gcs_address:
  1624. gcs_address = ray.worker._global_node.gcs_address
  1625. gcs_channel = init_grpc_channel(
  1626. gcs_address, ray_constants.GLOBAL_GRPC_OPTIONS, asynchronous=False
  1627. )
  1628. gcs_node_resources_stub = gcs_service_pb2_grpc.NodeResourceInfoGcsServiceStub(
  1629. gcs_channel
  1630. )
  1631. request = gcs_service_pb2.GetAllResourceUsageRequest()
  1632. response = gcs_node_resources_stub.GetAllResourceUsage(request, timeout=timeout)
  1633. resources_batch_data = response.resource_usage_data
  1634. return resources_batch_data
  1635. # Gets the load metrics report assuming gcs is local.
  1636. def get_load_metrics_report(webui_url):
  1637. webui_url = format_web_url(webui_url)
  1638. response = requests.get(f"{webui_url}/api/cluster_status")
  1639. response.raise_for_status()
  1640. return response.json()["data"]["clusterStatus"]["loadMetricsReport"]
  1641. # Send a RPC to the raylet to have it self-destruct its process.
  1642. def kill_raylet(raylet, graceful=False):
  1643. import grpc
  1644. from grpc._channel import _InactiveRpcError
  1645. from ray.core.generated import node_manager_pb2_grpc
  1646. raylet_address = build_address(
  1647. raylet["NodeManagerAddress"], raylet["NodeManagerPort"]
  1648. )
  1649. channel = grpc.insecure_channel(raylet_address)
  1650. stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel)
  1651. try:
  1652. stub.ShutdownRaylet(node_manager_pb2.ShutdownRayletRequest(graceful=graceful))
  1653. except _InactiveRpcError:
  1654. assert not graceful
  1655. def get_gcs_memory_used():
  1656. import psutil
  1657. m = {
  1658. proc.info["name"]: proc.info["memory_info"].rss
  1659. for proc in psutil.process_iter(["status", "name", "memory_info"])
  1660. if (
  1661. proc.info["status"] not in (psutil.STATUS_ZOMBIE, psutil.STATUS_DEAD)
  1662. and proc.info["name"] in ("gcs_server", "redis-server")
  1663. )
  1664. }
  1665. assert "gcs_server" in m
  1666. return sum(m.values())
  1667. def safe_write_to_results_json(
  1668. result: dict,
  1669. default_file_name: str = "/tmp/release_test_output.json",
  1670. env_var: Optional[str] = "TEST_OUTPUT_JSON",
  1671. ):
  1672. """
  1673. Safe (atomic) write to file to guard against malforming the json
  1674. if the job gets interrupted in the middle of writing.
  1675. """
  1676. test_output_json = os.environ.get(env_var, default_file_name)
  1677. test_output_json_tmp = f"{test_output_json}.tmp.{str(uuid.uuid4())}"
  1678. with open(test_output_json_tmp, "wt") as f:
  1679. json.dump(result, f)
  1680. f.flush()
  1681. os.replace(test_output_json_tmp, test_output_json)
  1682. logger.info(f"Wrote results to {test_output_json}")
  1683. logger.info(json.dumps(result))
  1684. def get_current_unused_port():
  1685. """
  1686. Returns a port number that is not currently in use.
  1687. This is useful for testing when we need to bind to a port but don't
  1688. care which one.
  1689. Returns:
  1690. A port number that is not currently in use. (Note that this port
  1691. might become used by the time you try to bind to it.)
  1692. """
  1693. sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  1694. # Bind the socket to a local address with a random port number
  1695. sock.bind(("localhost", 0))
  1696. port = sock.getsockname()[1]
  1697. sock.close()
  1698. return port
  1699. # Global counter to test different return values
  1700. # for external_ray_cluster_activity_hook1.
  1701. ray_cluster_activity_hook_counter = 0
  1702. ray_cluster_activity_hook_5_counter = 0
  1703. def external_ray_cluster_activity_hook1():
  1704. """
  1705. Example external hook for test_component_activities_hook.
  1706. Returns valid response and increments counter in `reason`
  1707. field on each call.
  1708. """
  1709. global ray_cluster_activity_hook_counter
  1710. ray_cluster_activity_hook_counter += 1
  1711. from pydantic import BaseModel, Extra
  1712. class TestRayActivityResponse(BaseModel, extra=Extra.allow):
  1713. """
  1714. Redefinition of dashboard.modules.api.api_head.RayActivityResponse
  1715. used in test_component_activities_hook to mimic typical
  1716. usage of redefining or extending response type.
  1717. """
  1718. is_active: str
  1719. reason: Optional[str] = None
  1720. timestamp: float
  1721. return {
  1722. "test_component1": TestRayActivityResponse(
  1723. is_active="ACTIVE",
  1724. reason=f"Counter: {ray_cluster_activity_hook_counter}",
  1725. timestamp=datetime.now().timestamp(),
  1726. )
  1727. }
  1728. def external_ray_cluster_activity_hook2():
  1729. """
  1730. Example external hook for test_component_activities_hook.
  1731. Returns invalid output because the value of `test_component2`
  1732. should be of type RayActivityResponse.
  1733. """
  1734. return {"test_component2": "bad_output"}
  1735. def external_ray_cluster_activity_hook3():
  1736. """
  1737. Example external hook for test_component_activities_hook.
  1738. Returns invalid output because return type is not
  1739. Dict[str, RayActivityResponse]
  1740. """
  1741. return "bad_output"
  1742. def external_ray_cluster_activity_hook4():
  1743. """
  1744. Example external hook for test_component_activities_hook.
  1745. Errors during execution.
  1746. """
  1747. raise Exception("Error in external cluster activity hook")
  1748. def external_ray_cluster_activity_hook5():
  1749. """
  1750. Example external hook for test_component_activities_hook.
  1751. Returns valid response and increments counter in `reason`
  1752. field on each call.
  1753. """
  1754. global ray_cluster_activity_hook_5_counter
  1755. ray_cluster_activity_hook_5_counter += 1
  1756. return {
  1757. "test_component5": {
  1758. "is_active": "ACTIVE",
  1759. "reason": f"Counter: {ray_cluster_activity_hook_5_counter}",
  1760. "timestamp": datetime.now().timestamp(),
  1761. }
  1762. }
  1763. # TODO(rickyx): We could remove this once we unify the autoscaler v1 and v2
  1764. # code path for ray status
  1765. def reset_autoscaler_v2_enabled_cache():
  1766. import ray.autoscaler.v2.utils as u
  1767. u.cached_is_autoscaler_v2 = None
  1768. def _terminate_ec2_instance(node_ip: str) -> None:
  1769. logging.info(f"Terminating instance {node_ip}")
  1770. # This command uses IMDSv2 to get the host instance id and region.
  1771. # After that it terminates itself using aws cli.
  1772. command = (
  1773. 'instanceId=$(curl -H "X-aws-ec2-metadata-token: $TOKEN" http://169.254.169.254/latest/meta-data/instance-id/);' # noqa: E501
  1774. 'region=$(curl -H "X-aws-ec2-metadata-token: $TOKEN" http://169.254.169.254/latest/meta-data/placement/region);' # noqa: E501
  1775. "aws ec2 terminate-instances --region $region --instance-ids $instanceId" # noqa: E501
  1776. )
  1777. _execute_command_on_node(command, node_ip)
  1778. def _execute_command_on_node(command: str, node_ip: str):
  1779. logging.debug(f"Executing command on node {node_ip}: {command}")
  1780. multi_line_command = (
  1781. 'TOKEN=$(curl -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 21600");' # noqa: E501
  1782. f"{command}"
  1783. )
  1784. # This is a feature on Anyscale platform that enables
  1785. # easy ssh access to worker nodes.
  1786. ssh_command = f"ssh -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -p 2222 ray@{node_ip} '{multi_line_command}'" # noqa: E501
  1787. try:
  1788. subprocess.run(
  1789. ssh_command, shell=True, capture_output=True, text=True, check=True
  1790. )
  1791. except subprocess.CalledProcessError as e:
  1792. print("Exit code:", e.returncode)
  1793. print("Stderr:", e.stderr)
  1794. RPC_FAILURE_MAP = {
  1795. "request": {
  1796. "req_failure_prob": 100,
  1797. "resp_failure_prob": 0,
  1798. "in_flight_failure_prob": 0,
  1799. },
  1800. "response": {
  1801. "req_failure_prob": 0,
  1802. "resp_failure_prob": 100,
  1803. "in_flight_failure_prob": 0,
  1804. },
  1805. "in_flight": {
  1806. "req_failure_prob": 0,
  1807. "resp_failure_prob": 0,
  1808. "in_flight_failure_prob": 100,
  1809. },
  1810. }
  1811. RPC_FAILURE_TYPES = list(RPC_FAILURE_MAP.keys())