utils.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528
  1. import collections
  2. import logging
  3. import os
  4. import random
  5. import shutil
  6. import subprocess
  7. import sys
  8. import threading
  9. import time
  10. from ray._common.network_utils import is_ipv6
  11. _logger = logging.getLogger("ray.util.spark.utils")
  12. def is_in_databricks_runtime():
  13. return "DATABRICKS_RUNTIME_VERSION" in os.environ
  14. def gen_cmd_exec_failure_msg(cmd, return_code, tail_output_deque):
  15. cmd_str = " ".join(cmd)
  16. tail_output = "".join(tail_output_deque)
  17. return (
  18. f"Command {cmd_str} failed with return code {return_code}, tail output are "
  19. f"included below.\n{tail_output}\n"
  20. )
  21. def get_configured_spark_executor_memory_bytes(spark):
  22. value_str = spark.conf.get("spark.executor.memory", "1g").lower()
  23. value_num = int(value_str[:-1])
  24. value_unit = value_str[-1]
  25. unit_map = {
  26. "k": 1024,
  27. "m": 1024 * 1024,
  28. "g": 1024 * 1024 * 1024,
  29. "t": 1024 * 1024 * 1024 * 1024,
  30. }
  31. return value_num * unit_map[value_unit]
  32. def exec_cmd(
  33. cmd,
  34. *,
  35. extra_env=None,
  36. synchronous=True,
  37. **kwargs,
  38. ):
  39. """
  40. A convenience wrapper of `subprocess.Popen` for running a command from a Python
  41. script.
  42. If `synchronous` is True, wait until the process terminated and if subprocess
  43. return code is not 0, raise error containing last 100 lines output.
  44. If `synchronous` is False, return an `Popen` instance and a deque instance holding
  45. tail outputs.
  46. The subprocess stdout / stderr output will be streamly redirected to current
  47. process stdout.
  48. """
  49. illegal_kwargs = set(kwargs.keys()).intersection({"text", "stdout", "stderr"})
  50. if illegal_kwargs:
  51. raise ValueError(f"`kwargs` cannot contain {list(illegal_kwargs)}")
  52. env = kwargs.pop("env", None)
  53. if extra_env is not None and env is not None:
  54. raise ValueError("`extra_env` and `env` cannot be used at the same time")
  55. env = env if extra_env is None else {**os.environ, **extra_env}
  56. process = subprocess.Popen(
  57. cmd,
  58. env=env,
  59. text=True,
  60. stdout=subprocess.PIPE,
  61. stderr=subprocess.STDOUT,
  62. **kwargs,
  63. )
  64. tail_output_deque = collections.deque(maxlen=100)
  65. def redirect_log_thread_fn():
  66. for line in process.stdout:
  67. # collect tail logs by `tail_output_deque`
  68. tail_output_deque.append(line)
  69. # redirect to stdout.
  70. sys.stdout.write(line)
  71. threading.Thread(target=redirect_log_thread_fn, args=()).start()
  72. if not synchronous:
  73. return process, tail_output_deque
  74. return_code = process.wait()
  75. if return_code != 0:
  76. raise RuntimeError(
  77. gen_cmd_exec_failure_msg(cmd, return_code, tail_output_deque)
  78. )
  79. def is_port_in_use(host, port):
  80. import socket
  81. from contextlib import closing
  82. with closing(
  83. socket.socket(
  84. socket.AF_INET6 if is_ipv6(host) else socket.AF_INET, socket.SOCK_STREAM
  85. )
  86. ) as sock:
  87. return sock.connect_ex((host, port)) == 0
  88. def _wait_service_up(host, port, timeout):
  89. beg_time = time.time()
  90. while time.time() - beg_time < timeout:
  91. if is_port_in_use(host, port):
  92. return True
  93. time.sleep(1)
  94. return False
  95. def get_random_unused_port(
  96. host, min_port=1024, max_port=65535, max_retries=100, exclude_list=None
  97. ):
  98. """
  99. Get random unused port.
  100. """
  101. # Use true random generator
  102. rng = random.SystemRandom()
  103. exclude_list = exclude_list or []
  104. for _ in range(max_retries):
  105. port = rng.randint(min_port, max_port)
  106. if port in exclude_list:
  107. continue
  108. if not is_port_in_use(host, port):
  109. return port
  110. raise RuntimeError(
  111. f"Get available port between range {min_port} and {max_port} failed."
  112. )
  113. def get_spark_session():
  114. from pyspark.sql import SparkSession
  115. spark_session = SparkSession.getActiveSession()
  116. if spark_session is None:
  117. raise RuntimeError(
  118. "Spark session haven't been initiated yet. Please use "
  119. "`SparkSession.builder` to create a spark session and connect to a spark "
  120. "cluster."
  121. )
  122. return spark_session
  123. def get_spark_application_driver_host(spark):
  124. return spark.conf.get("spark.driver.host")
  125. def get_max_num_concurrent_tasks(spark_context, resource_profile):
  126. """Gets the current max number of concurrent tasks."""
  127. # pylint: disable=protected-access=
  128. ssc = spark_context._jsc.sc()
  129. if resource_profile is not None:
  130. def dummpy_mapper(_):
  131. pass
  132. # Runs a dummy spark job to register the `res_profile`
  133. spark_context.parallelize([1], 1).withResources(resource_profile).map(
  134. dummpy_mapper
  135. ).collect()
  136. return ssc.maxNumConcurrentTasks(resource_profile._java_resource_profile)
  137. else:
  138. return ssc.maxNumConcurrentTasks(
  139. ssc.resourceProfileManager().defaultResourceProfile()
  140. )
  141. def _get_spark_worker_total_physical_memory():
  142. import psutil
  143. if RAY_ON_SPARK_WORKER_PHYSICAL_MEMORY_BYTES in os.environ:
  144. return int(os.environ[RAY_ON_SPARK_WORKER_PHYSICAL_MEMORY_BYTES])
  145. return psutil.virtual_memory().total
  146. def _get_spark_worker_total_shared_memory():
  147. import shutil
  148. if RAY_ON_SPARK_WORKER_SHARED_MEMORY_BYTES in os.environ:
  149. return int(os.environ[RAY_ON_SPARK_WORKER_SHARED_MEMORY_BYTES])
  150. return shutil.disk_usage("/dev/shm").total
  151. # The maximum proportion for Ray worker node object store memory size
  152. _RAY_ON_SPARK_MAX_OBJECT_STORE_MEMORY_PROPORTION = 0.8
  153. # The buffer offset for calculating Ray node memory.
  154. _RAY_ON_SPARK_NODE_MEMORY_BUFFER_OFFSET = 0.8
  155. def calc_mem_ray_head_node(configured_heap_memory_bytes, configured_object_store_bytes):
  156. import shutil
  157. import psutil
  158. if RAY_ON_SPARK_DRIVER_PHYSICAL_MEMORY_BYTES in os.environ:
  159. available_physical_mem = int(
  160. os.environ[RAY_ON_SPARK_DRIVER_PHYSICAL_MEMORY_BYTES]
  161. )
  162. else:
  163. available_physical_mem = psutil.virtual_memory().total
  164. available_physical_mem = (
  165. available_physical_mem * _RAY_ON_SPARK_NODE_MEMORY_BUFFER_OFFSET
  166. )
  167. if RAY_ON_SPARK_DRIVER_SHARED_MEMORY_BYTES in os.environ:
  168. available_shared_mem = int(os.environ[RAY_ON_SPARK_DRIVER_SHARED_MEMORY_BYTES])
  169. else:
  170. available_shared_mem = shutil.disk_usage("/dev/shm").total
  171. available_shared_mem = (
  172. available_shared_mem * _RAY_ON_SPARK_NODE_MEMORY_BUFFER_OFFSET
  173. )
  174. heap_mem_bytes, object_store_bytes, warning_msg = _calc_mem_per_ray_node(
  175. available_physical_mem,
  176. available_shared_mem,
  177. configured_heap_memory_bytes,
  178. configured_object_store_bytes,
  179. )
  180. if warning_msg is not None:
  181. _logger.warning(warning_msg)
  182. return heap_mem_bytes, object_store_bytes
  183. def _calc_mem_per_ray_worker_node(
  184. num_task_slots,
  185. physical_mem_bytes,
  186. shared_mem_bytes,
  187. configured_heap_memory_bytes,
  188. configured_object_store_bytes,
  189. ):
  190. available_physical_mem_per_node = int(
  191. physical_mem_bytes / num_task_slots * _RAY_ON_SPARK_NODE_MEMORY_BUFFER_OFFSET
  192. )
  193. available_shared_mem_per_node = int(
  194. shared_mem_bytes / num_task_slots * _RAY_ON_SPARK_NODE_MEMORY_BUFFER_OFFSET
  195. )
  196. return _calc_mem_per_ray_node(
  197. available_physical_mem_per_node,
  198. available_shared_mem_per_node,
  199. configured_heap_memory_bytes,
  200. configured_object_store_bytes,
  201. )
  202. def _calc_mem_per_ray_node(
  203. available_physical_mem_per_node,
  204. available_shared_mem_per_node,
  205. configured_heap_memory_bytes,
  206. configured_object_store_bytes,
  207. ):
  208. from ray._private.ray_constants import (
  209. DEFAULT_OBJECT_STORE_MEMORY_PROPORTION,
  210. OBJECT_STORE_MINIMUM_MEMORY_BYTES,
  211. )
  212. warning_msg = None
  213. object_store_bytes = configured_object_store_bytes or (
  214. available_physical_mem_per_node * DEFAULT_OBJECT_STORE_MEMORY_PROPORTION
  215. )
  216. # If allow Ray using slow storage oas object store,
  217. # we don't need to cap object store size by /dev/shm capacity
  218. if not os.environ.get("RAY_OBJECT_STORE_ALLOW_SLOW_STORAGE"):
  219. if object_store_bytes > available_shared_mem_per_node:
  220. object_store_bytes = available_shared_mem_per_node
  221. object_store_bytes_upper_bound = (
  222. available_physical_mem_per_node
  223. * _RAY_ON_SPARK_MAX_OBJECT_STORE_MEMORY_PROPORTION
  224. )
  225. if object_store_bytes > object_store_bytes_upper_bound:
  226. object_store_bytes = object_store_bytes_upper_bound
  227. warning_msg = (
  228. "Your configured `object_store_memory_per_node` value "
  229. "is too high and it is capped by 80% of per-Ray node "
  230. "allocated memory."
  231. )
  232. if object_store_bytes < OBJECT_STORE_MINIMUM_MEMORY_BYTES:
  233. if object_store_bytes == available_shared_mem_per_node:
  234. warning_msg = (
  235. "Your operating system is configured with too small /dev/shm "
  236. "size, so `object_store_memory_worker_node` value is configured "
  237. f"to minimal size ({OBJECT_STORE_MINIMUM_MEMORY_BYTES} bytes),"
  238. f"Please increase system /dev/shm size."
  239. )
  240. else:
  241. warning_msg = (
  242. "You configured too small Ray node object store memory size, "
  243. "so `object_store_memory_worker_node` value is configured "
  244. f"to minimal size ({OBJECT_STORE_MINIMUM_MEMORY_BYTES} bytes),"
  245. "Please increase 'object_store_memory_worker_node' argument value."
  246. )
  247. object_store_bytes = OBJECT_STORE_MINIMUM_MEMORY_BYTES
  248. object_store_bytes = int(object_store_bytes)
  249. if configured_heap_memory_bytes is None:
  250. heap_mem_bytes = int(available_physical_mem_per_node - object_store_bytes)
  251. else:
  252. heap_mem_bytes = int(configured_heap_memory_bytes)
  253. return heap_mem_bytes, object_store_bytes, warning_msg
  254. # User can manually set these environment variables
  255. # if ray on spark code accessing corresponding information failed.
  256. # Note these environment variables must be set in spark executor side,
  257. # you should set them via setting spark config of
  258. # `spark.executorEnv.[EnvironmentVariableName]`
  259. RAY_ON_SPARK_WORKER_CPU_CORES = "RAY_ON_SPARK_WORKER_CPU_CORES"
  260. RAY_ON_SPARK_WORKER_GPU_NUM = "RAY_ON_SPARK_WORKER_GPU_NUM"
  261. RAY_ON_SPARK_WORKER_PHYSICAL_MEMORY_BYTES = "RAY_ON_SPARK_WORKER_PHYSICAL_MEMORY_BYTES"
  262. RAY_ON_SPARK_WORKER_SHARED_MEMORY_BYTES = "RAY_ON_SPARK_WORKER_SHARED_MEMORY_BYTES"
  263. # User can manually set these environment variables on spark driver node
  264. # if ray on spark code accessing corresponding information failed.
  265. RAY_ON_SPARK_DRIVER_PHYSICAL_MEMORY_BYTES = "RAY_ON_SPARK_DRIVER_PHYSICAL_MEMORY_BYTES"
  266. RAY_ON_SPARK_DRIVER_SHARED_MEMORY_BYTES = "RAY_ON_SPARK_DRIVER_SHARED_MEMORY_BYTES"
  267. def _get_cpu_cores():
  268. import multiprocessing
  269. if RAY_ON_SPARK_WORKER_CPU_CORES in os.environ:
  270. # In some cases, spark standalone cluster might configure virtual cpu cores
  271. # for spark worker that different with number of physical cpu cores,
  272. # but we cannot easily get the virtual cpu cores configured for spark
  273. # worker, as a workaround, we provide an environmental variable config
  274. # `RAY_ON_SPARK_WORKER_CPU_CORES` for user.
  275. return int(os.environ[RAY_ON_SPARK_WORKER_CPU_CORES])
  276. return multiprocessing.cpu_count()
  277. def _get_num_physical_gpus():
  278. if RAY_ON_SPARK_WORKER_GPU_NUM in os.environ:
  279. # In some cases, spark standalone cluster might configure part of physical
  280. # GPUs for spark worker,
  281. # but we cannot easily get related configuration,
  282. # as a workaround, we provide an environmental variable config
  283. # `RAY_ON_SPARK_WORKER_CPU_CORES` for user.
  284. return int(os.environ[RAY_ON_SPARK_WORKER_GPU_NUM])
  285. if shutil.which("nvidia-smi") is None:
  286. # GPU driver is not installed.
  287. return 0
  288. try:
  289. completed_proc = subprocess.run(
  290. "nvidia-smi --query-gpu=name --format=csv,noheader",
  291. shell=True,
  292. check=True,
  293. text=True,
  294. capture_output=True,
  295. )
  296. return len(completed_proc.stdout.strip().split("\n"))
  297. except Exception as e:
  298. _logger.info(
  299. "'nvidia-smi --query-gpu=name --format=csv,noheader' command execution "
  300. f"failed, error: {repr(e)}"
  301. )
  302. return 0
  303. def _get_local_ray_node_slots(
  304. num_cpus,
  305. num_gpus,
  306. num_cpus_per_node,
  307. num_gpus_per_node,
  308. ):
  309. if num_cpus_per_node > num_cpus:
  310. raise ValueError(
  311. "cpu number per Ray worker node should be <= spark worker node CPU cores, "
  312. f"you set cpu number per Ray worker node to {num_cpus_per_node} but "
  313. f"spark worker node CPU core number is {num_cpus}."
  314. )
  315. num_ray_node_slots = num_cpus // num_cpus_per_node
  316. if num_gpus_per_node > 0:
  317. if num_gpus_per_node > num_gpus:
  318. raise ValueError(
  319. "gpu number per Ray worker node should be <= spark worker node "
  320. "GPU number, you set GPU devices number per Ray worker node to "
  321. f"{num_gpus_per_node} but spark worker node GPU devices number "
  322. f"is {num_gpus}."
  323. )
  324. if num_ray_node_slots > num_gpus // num_gpus_per_node:
  325. num_ray_node_slots = num_gpus // num_gpus_per_node
  326. return num_ray_node_slots
  327. def _get_avail_mem_per_ray_worker_node(
  328. num_cpus_per_node,
  329. num_gpus_per_node,
  330. heap_memory_per_node,
  331. object_store_memory_per_node,
  332. ):
  333. """
  334. Returns tuple of (
  335. ray_worker_node_heap_mem_bytes,
  336. ray_worker_node_object_store_bytes,
  337. error_message, # always None
  338. warning_message,
  339. )
  340. """
  341. num_cpus = _get_cpu_cores()
  342. if num_gpus_per_node > 0:
  343. num_gpus = _get_num_physical_gpus()
  344. else:
  345. num_gpus = 0
  346. num_ray_node_slots = _get_local_ray_node_slots(
  347. num_cpus, num_gpus, num_cpus_per_node, num_gpus_per_node
  348. )
  349. physical_mem_bytes = _get_spark_worker_total_physical_memory()
  350. shared_mem_bytes = _get_spark_worker_total_shared_memory()
  351. (
  352. ray_worker_node_heap_mem_bytes,
  353. ray_worker_node_object_store_bytes,
  354. warning_msg,
  355. ) = _calc_mem_per_ray_worker_node(
  356. num_ray_node_slots,
  357. physical_mem_bytes,
  358. shared_mem_bytes,
  359. heap_memory_per_node,
  360. object_store_memory_per_node,
  361. )
  362. return (
  363. ray_worker_node_heap_mem_bytes,
  364. ray_worker_node_object_store_bytes,
  365. None,
  366. warning_msg,
  367. )
  368. def get_avail_mem_per_ray_worker_node(
  369. spark,
  370. heap_memory_per_node,
  371. object_store_memory_per_node,
  372. num_cpus_per_node,
  373. num_gpus_per_node,
  374. ):
  375. """
  376. Return the available heap memory and object store memory for each ray worker,
  377. and error / warning message if it has.
  378. Return value is a tuple of
  379. (ray_worker_node_heap_mem_bytes, ray_worker_node_object_store_bytes,
  380. error_message, warning_message)
  381. NB: We have one ray node per spark task.
  382. """
  383. def mapper(_):
  384. try:
  385. return _get_avail_mem_per_ray_worker_node(
  386. num_cpus_per_node,
  387. num_gpus_per_node,
  388. heap_memory_per_node,
  389. object_store_memory_per_node,
  390. )
  391. except Exception as e:
  392. import traceback
  393. trace_msg = "\n".join(traceback.format_tb(e.__traceback__))
  394. return -1, -1, repr(e) + trace_msg, None
  395. # Running memory inference routine on spark executor side since the spark worker
  396. # nodes may have a different machine configuration compared to the spark driver
  397. # node.
  398. (
  399. inferred_ray_worker_node_heap_mem_bytes,
  400. inferred_ray_worker_node_object_store_bytes,
  401. err,
  402. warning_msg,
  403. ) = (
  404. spark.sparkContext.parallelize([1], 1).map(mapper).collect()[0]
  405. )
  406. if err is not None:
  407. raise RuntimeError(
  408. f"Inferring ray worker node available memory failed, error: {err}. "
  409. "You can bypass this error by setting following spark configs: "
  410. "spark.executorEnv.RAY_ON_SPARK_WORKER_CPU_CORES, "
  411. "spark.executorEnv.RAY_ON_SPARK_WORKER_GPU_NUM, "
  412. "spark.executorEnv.RAY_ON_SPARK_WORKER_PHYSICAL_MEMORY_BYTES, "
  413. "spark.executorEnv.RAY_ON_SPARK_WORKER_SHARED_MEMORY_BYTES."
  414. )
  415. if warning_msg is not None:
  416. _logger.warning(warning_msg)
  417. return (
  418. inferred_ray_worker_node_heap_mem_bytes,
  419. inferred_ray_worker_node_object_store_bytes,
  420. )
  421. def get_spark_task_assigned_physical_gpus(gpu_addr_list):
  422. if "CUDA_VISIBLE_DEVICES" in os.environ:
  423. visible_cuda_dev_list = [
  424. int(dev.strip()) for dev in os.environ["CUDA_VISIBLE_DEVICES"].split(",")
  425. ]
  426. return [visible_cuda_dev_list[addr] for addr in gpu_addr_list]
  427. else:
  428. return gpu_addr_list