memory_utils.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523
  1. import base64
  2. import logging
  3. from collections import defaultdict
  4. from enum import Enum
  5. from typing import List
  6. import ray
  7. from ray._private.internal_api import node_stats
  8. from ray._raylet import ActorID, JobID, TaskID
  9. from ray.dashboard.utils import node_stats_to_dict
  10. logger = logging.getLogger(__name__)
  11. # These values are used to calculate if objectRefs are actor handles.
  12. TASKID_BYTES_SIZE = TaskID.size()
  13. ACTORID_BYTES_SIZE = ActorID.size()
  14. JOBID_BYTES_SIZE = JobID.size()
  15. def decode_object_ref_if_needed(object_ref: str) -> bytes:
  16. """Decode objectRef bytes string.
  17. gRPC reply contains an objectRef that is encodded by Base64.
  18. This function is used to decode the objectRef.
  19. Note that there are times that objectRef is already decoded as
  20. a hex string. In this case, just convert it to a binary number.
  21. """
  22. if object_ref.endswith("="):
  23. # If the object ref ends with =, that means it is base64 encoded.
  24. # Object refs will always have = as a padding
  25. # when it is base64 encoded because objectRef is always 20B.
  26. return base64.standard_b64decode(object_ref)
  27. else:
  28. return ray._common.utils.hex_to_binary(object_ref)
  29. class SortingType(Enum):
  30. PID = 1
  31. OBJECT_SIZE = 3
  32. REFERENCE_TYPE = 4
  33. class GroupByType(Enum):
  34. NODE_ADDRESS = "node"
  35. STACK_TRACE = "stack_trace"
  36. class ReferenceType(Enum):
  37. # We don't use enum because enum is not json serializable.
  38. ACTOR_HANDLE = "ACTOR_HANDLE"
  39. PINNED_IN_MEMORY = "PINNED_IN_MEMORY"
  40. LOCAL_REFERENCE = "LOCAL_REFERENCE"
  41. USED_BY_PENDING_TASK = "USED_BY_PENDING_TASK"
  42. CAPTURED_IN_OBJECT = "CAPTURED_IN_OBJECT"
  43. UNKNOWN_STATUS = "UNKNOWN_STATUS"
  44. def get_sorting_type(sort_by: str):
  45. """Translate string input into SortingType instance"""
  46. sort_by = sort_by.upper()
  47. if sort_by == "PID":
  48. return SortingType.PID
  49. elif sort_by == "OBJECT_SIZE":
  50. return SortingType.OBJECT_SIZE
  51. elif sort_by == "REFERENCE_TYPE":
  52. return SortingType.REFERENCE_TYPE
  53. else:
  54. raise Exception(
  55. "The sort-by input provided is not one of\
  56. PID, OBJECT_SIZE, or REFERENCE_TYPE."
  57. )
  58. def get_group_by_type(group_by: str):
  59. """Translate string input into GroupByType instance"""
  60. group_by = group_by.upper()
  61. if group_by == "NODE_ADDRESS":
  62. return GroupByType.NODE_ADDRESS
  63. elif group_by == "STACK_TRACE":
  64. return GroupByType.STACK_TRACE
  65. else:
  66. raise Exception(
  67. "The group-by input provided is not one of\
  68. NODE_ADDRESS or STACK_TRACE."
  69. )
  70. class MemoryTableEntry:
  71. def __init__(
  72. self, *, object_ref: dict, node_address: str, is_driver: bool, pid: int
  73. ):
  74. # worker info
  75. self.is_driver = is_driver
  76. self.pid = pid
  77. self.node_address = node_address
  78. # object info
  79. self.task_status = object_ref.get("taskStatus", "?")
  80. if self.task_status == "NIL":
  81. self.task_status = "-"
  82. self.attempt_number = int(object_ref.get("attemptNumber", 0)) + 1
  83. self.object_size = int(object_ref.get("objectSize", -1))
  84. self.call_site = object_ref.get("callSite", "<Unknown>")
  85. if len(self.call_site) == 0:
  86. self.call_site = "disabled"
  87. self.object_ref = ray.ObjectRef(
  88. decode_object_ref_if_needed(object_ref["objectId"])
  89. )
  90. # reference info
  91. self.local_ref_count = int(object_ref.get("localRefCount", 0))
  92. self.pinned_in_memory = bool(object_ref.get("pinnedInMemory", False))
  93. self.submitted_task_ref_count = int(object_ref.get("submittedTaskRefCount", 0))
  94. self.contained_in_owned = [
  95. ray.ObjectRef(decode_object_ref_if_needed(object_ref))
  96. for object_ref in object_ref.get("containedInOwned", [])
  97. ]
  98. self.reference_type = self._get_reference_type()
  99. def is_valid(self) -> bool:
  100. # If the entry doesn't have a reference type or some invalid state,
  101. # (e.g., no object ref presented), it is considered invalid.
  102. if (
  103. not self.pinned_in_memory
  104. and self.local_ref_count == 0
  105. and self.submitted_task_ref_count == 0
  106. and len(self.contained_in_owned) == 0
  107. ):
  108. return False
  109. elif self.object_ref.is_nil():
  110. return False
  111. else:
  112. return True
  113. def group_key(self, group_by_type: GroupByType) -> str:
  114. if group_by_type == GroupByType.NODE_ADDRESS:
  115. return self.node_address
  116. elif group_by_type == GroupByType.STACK_TRACE:
  117. return self.call_site
  118. else:
  119. raise ValueError(f"group by type {group_by_type} is invalid.")
  120. def _get_reference_type(self) -> str:
  121. if self._is_object_ref_actor_handle():
  122. return ReferenceType.ACTOR_HANDLE.value
  123. if self.pinned_in_memory:
  124. return ReferenceType.PINNED_IN_MEMORY.value
  125. elif self.submitted_task_ref_count > 0:
  126. return ReferenceType.USED_BY_PENDING_TASK.value
  127. elif self.local_ref_count > 0:
  128. return ReferenceType.LOCAL_REFERENCE.value
  129. elif len(self.contained_in_owned) > 0:
  130. return ReferenceType.CAPTURED_IN_OBJECT.value
  131. else:
  132. return ReferenceType.UNKNOWN_STATUS.value
  133. def _is_object_ref_actor_handle(self) -> bool:
  134. object_ref_hex = self.object_ref.hex()
  135. # We need to multiply 2 because we need bits size instead of bytes size.
  136. taskid_random_bits_size = (TASKID_BYTES_SIZE - ACTORID_BYTES_SIZE) * 2
  137. actorid_random_bits_size = (ACTORID_BYTES_SIZE - JOBID_BYTES_SIZE) * 2
  138. # random (8B) | ActorID(6B) | flag (2B) | index (6B)
  139. # ActorID(6B) == ActorRandomByte(4B) + JobID(2B)
  140. # If random bytes are all 'f', but ActorRandomBytes
  141. # are not all 'f', that means it is an actor creation
  142. # task, which is an actor handle.
  143. random_bits = object_ref_hex[:taskid_random_bits_size]
  144. actor_random_bits = object_ref_hex[
  145. taskid_random_bits_size : taskid_random_bits_size + actorid_random_bits_size
  146. ]
  147. if random_bits == "f" * 16 and not actor_random_bits == "f" * 24:
  148. return True
  149. else:
  150. return False
  151. def as_dict(self):
  152. return {
  153. "object_ref": self.object_ref.hex(),
  154. "pid": self.pid,
  155. "node_ip_address": self.node_address,
  156. "object_size": self.object_size,
  157. "reference_type": self.reference_type,
  158. "call_site": self.call_site,
  159. "task_status": self.task_status,
  160. "attempt_number": self.attempt_number,
  161. "local_ref_count": self.local_ref_count,
  162. "pinned_in_memory": self.pinned_in_memory,
  163. "submitted_task_ref_count": self.submitted_task_ref_count,
  164. "contained_in_owned": [
  165. object_ref.hex() for object_ref in self.contained_in_owned
  166. ],
  167. "type": "Driver" if self.is_driver else "Worker",
  168. }
  169. def __str__(self):
  170. return self.__repr__()
  171. def __repr__(self):
  172. return str(self.as_dict())
  173. class MemoryTable:
  174. def __init__(
  175. self,
  176. entries: List[MemoryTableEntry],
  177. group_by_type: GroupByType = GroupByType.NODE_ADDRESS,
  178. sort_by_type: SortingType = SortingType.PID,
  179. ):
  180. self.table = entries
  181. # Group is a list of memory tables grouped by a group key.
  182. self.group = {}
  183. self.summary = defaultdict(int)
  184. # NOTE YOU MUST SORT TABLE BEFORE GROUPING.
  185. # self._group_by(..)._sort_by(..) != self._sort_by(..)._group_by(..)
  186. if group_by_type and sort_by_type:
  187. self.setup(group_by_type, sort_by_type)
  188. elif group_by_type:
  189. self._group_by(group_by_type)
  190. elif sort_by_type:
  191. self._sort_by(sort_by_type)
  192. def setup(self, group_by_type: GroupByType, sort_by_type: SortingType):
  193. """Setup memory table.
  194. This will sort entries first and group them after.
  195. Sort order will be still kept.
  196. """
  197. self._sort_by(sort_by_type)._group_by(group_by_type)
  198. for group_memory_table in self.group.values():
  199. group_memory_table.summarize()
  200. self.summarize()
  201. return self
  202. def insert_entry(self, entry: MemoryTableEntry):
  203. self.table.append(entry)
  204. def summarize(self):
  205. # Reset summary.
  206. total_object_size = 0
  207. total_local_ref_count = 0
  208. total_pinned_in_memory = 0
  209. total_used_by_pending_task = 0
  210. total_captured_in_objects = 0
  211. total_actor_handles = 0
  212. for entry in self.table:
  213. if entry.object_size > 0:
  214. total_object_size += entry.object_size
  215. if entry.reference_type == ReferenceType.LOCAL_REFERENCE.value:
  216. total_local_ref_count += 1
  217. elif entry.reference_type == ReferenceType.PINNED_IN_MEMORY.value:
  218. total_pinned_in_memory += 1
  219. elif entry.reference_type == ReferenceType.USED_BY_PENDING_TASK.value:
  220. total_used_by_pending_task += 1
  221. elif entry.reference_type == ReferenceType.CAPTURED_IN_OBJECT.value:
  222. total_captured_in_objects += 1
  223. elif entry.reference_type == ReferenceType.ACTOR_HANDLE.value:
  224. total_actor_handles += 1
  225. self.summary = {
  226. "total_object_size": total_object_size,
  227. "total_local_ref_count": total_local_ref_count,
  228. "total_pinned_in_memory": total_pinned_in_memory,
  229. "total_used_by_pending_task": total_used_by_pending_task,
  230. "total_captured_in_objects": total_captured_in_objects,
  231. "total_actor_handles": total_actor_handles,
  232. }
  233. return self
  234. def _sort_by(self, sorting_type: SortingType):
  235. if sorting_type == SortingType.PID:
  236. self.table.sort(key=lambda entry: entry.pid)
  237. elif sorting_type == SortingType.OBJECT_SIZE:
  238. self.table.sort(key=lambda entry: entry.object_size)
  239. elif sorting_type == SortingType.REFERENCE_TYPE:
  240. self.table.sort(key=lambda entry: entry.reference_type)
  241. else:
  242. raise ValueError(f"Give sorting type: {sorting_type} is invalid.")
  243. return self
  244. def _group_by(self, group_by_type: GroupByType):
  245. """Group entries and summarize the result.
  246. NOTE: Each group is another MemoryTable.
  247. """
  248. # Reset group
  249. self.group = {}
  250. # Build entries per group.
  251. group = defaultdict(list)
  252. for entry in self.table:
  253. group[entry.group_key(group_by_type)].append(entry)
  254. # Build a group table.
  255. for group_key, entries in group.items():
  256. self.group[group_key] = MemoryTable(
  257. entries, group_by_type=None, sort_by_type=None
  258. )
  259. for group_key, group_memory_table in self.group.items():
  260. group_memory_table.summarize()
  261. return self
  262. def as_dict(self):
  263. return {
  264. "summary": self.summary,
  265. "group": {
  266. group_key: {
  267. "entries": group_memory_table.get_entries(),
  268. "summary": group_memory_table.summary,
  269. }
  270. for group_key, group_memory_table in self.group.items()
  271. },
  272. }
  273. def get_entries(self) -> List[dict]:
  274. return [entry.as_dict() for entry in self.table]
  275. def __repr__(self):
  276. return str(self.as_dict())
  277. def __str__(self):
  278. return self.__repr__()
  279. def construct_memory_table(
  280. workers_stats: List,
  281. group_by: GroupByType = GroupByType.NODE_ADDRESS,
  282. sort_by=SortingType.OBJECT_SIZE,
  283. ) -> MemoryTable:
  284. memory_table_entries = []
  285. for core_worker_stats in workers_stats:
  286. pid = core_worker_stats["pid"]
  287. is_driver = core_worker_stats.get("workerType") == "DRIVER"
  288. node_address = core_worker_stats["ipAddress"]
  289. object_refs = core_worker_stats.get("objectRefs", [])
  290. for object_ref in object_refs:
  291. memory_table_entry = MemoryTableEntry(
  292. object_ref=object_ref,
  293. node_address=node_address,
  294. is_driver=is_driver,
  295. pid=pid,
  296. )
  297. if memory_table_entry.is_valid():
  298. memory_table_entries.append(memory_table_entry)
  299. memory_table = MemoryTable(
  300. memory_table_entries, group_by_type=group_by, sort_by_type=sort_by
  301. )
  302. return memory_table
  303. def track_reference_size(group):
  304. """Returns dictionary mapping reference type
  305. to memory usage for a given memory table group."""
  306. d = defaultdict(int)
  307. table_name = {
  308. "LOCAL_REFERENCE": "total_local_ref_count",
  309. "PINNED_IN_MEMORY": "total_pinned_in_memory",
  310. "USED_BY_PENDING_TASK": "total_used_by_pending_task",
  311. "CAPTURED_IN_OBJECT": "total_captured_in_objects",
  312. "ACTOR_HANDLE": "total_actor_handles",
  313. }
  314. for entry in group["entries"]:
  315. size = entry["object_size"]
  316. if size == -1:
  317. # size not recorded
  318. size = 0
  319. d[table_name[entry["reference_type"]]] += size
  320. return d
  321. def memory_summary(
  322. state,
  323. group_by="NODE_ADDRESS",
  324. sort_by="OBJECT_SIZE",
  325. line_wrap=True,
  326. unit="B",
  327. num_entries=None,
  328. ) -> str:
  329. # Get terminal size
  330. import shutil
  331. size = shutil.get_terminal_size((80, 20)).columns
  332. line_wrap_threshold = 137
  333. # Unit conversions
  334. units = {"B": 10**0, "KB": 10**3, "MB": 10**6, "GB": 10**9}
  335. # Fetch core memory worker stats, store as a dictionary
  336. core_worker_stats = []
  337. for raylet in state.node_table():
  338. if not raylet["Alive"]:
  339. continue
  340. try:
  341. stats = node_stats_to_dict(
  342. node_stats(raylet["NodeManagerAddress"], raylet["NodeManagerPort"])
  343. )
  344. except RuntimeError:
  345. continue
  346. core_worker_stats.extend(stats["coreWorkersStats"])
  347. assert type(stats) is dict and "coreWorkersStats" in stats
  348. # Build memory table with "group_by" and "sort_by" parameters
  349. group_by, sort_by = get_group_by_type(group_by), get_sorting_type(sort_by)
  350. memory_table = construct_memory_table(
  351. core_worker_stats, group_by, sort_by
  352. ).as_dict()
  353. assert "summary" in memory_table and "group" in memory_table
  354. # Build memory summary
  355. mem = ""
  356. group_by, sort_by = group_by.name.lower().replace(
  357. "_", " "
  358. ), sort_by.name.lower().replace("_", " ")
  359. summary_labels = [
  360. "Mem Used by Objects",
  361. "Local References",
  362. "Pinned",
  363. "Used by task",
  364. "Captured in Objects",
  365. "Actor Handles",
  366. ]
  367. summary_string = "{:<19} {:<16} {:<12} {:<13} {:<19} {:<13}\n"
  368. object_ref_labels = [
  369. "IP Address",
  370. "PID",
  371. "Type",
  372. "Call Site",
  373. "Status",
  374. "Attempt",
  375. "Size",
  376. "Reference Type",
  377. "Object Ref",
  378. ]
  379. object_ref_string = "{:<13} | {:<8} | {:<7} | {:<9} \
  380. | {:<9} | {:<8} | {:<8} | {:<14} | {:<10}\n"
  381. if size > line_wrap_threshold and line_wrap:
  382. object_ref_string = "{:<15} {:<5} {:<6} {:<22} {:<14} {:<8} {:<6} \
  383. {:<18} {:<56}\n"
  384. mem += f"Grouping by {group_by}...\
  385. Sorting by {sort_by}...\
  386. Display {num_entries if num_entries is not None else 'all'} \
  387. entries per group...\n\n\n"
  388. for key, group in memory_table["group"].items():
  389. # Group summary
  390. summary = group["summary"]
  391. ref_size = track_reference_size(group)
  392. for k, v in summary.items():
  393. if k == "total_object_size":
  394. summary[k] = str(v / units[unit]) + f" {unit}"
  395. else:
  396. summary[k] = str(v) + f", ({ref_size[k] / units[unit]} {unit})"
  397. mem += f"--- Summary for {group_by}: {key} ---\n"
  398. mem += summary_string.format(*summary_labels)
  399. mem += summary_string.format(*summary.values()) + "\n"
  400. # Memory table per group
  401. mem += f"--- Object references for {group_by}: {key} ---\n"
  402. mem += object_ref_string.format(*object_ref_labels)
  403. n = 1 # Counter for num entries per group
  404. for entry in group["entries"]:
  405. if num_entries is not None and n > num_entries:
  406. break
  407. entry["object_size"] = (
  408. str(entry["object_size"] / units[unit]) + f" {unit}"
  409. if entry["object_size"] > -1
  410. else "?"
  411. )
  412. num_lines = 1
  413. if size > line_wrap_threshold and line_wrap:
  414. call_site_length = 22
  415. if len(entry["call_site"]) == 0:
  416. entry["call_site"] = ["disabled"]
  417. else:
  418. entry["call_site"] = [
  419. entry["call_site"][i : i + call_site_length]
  420. for i in range(0, len(entry["call_site"]), call_site_length)
  421. ]
  422. task_status_length = 12
  423. entry["task_status"] = [
  424. entry["task_status"][i : i + task_status_length]
  425. for i in range(0, len(entry["task_status"]), task_status_length)
  426. ]
  427. num_lines = max(len(entry["call_site"]), len(entry["task_status"]))
  428. else:
  429. mem += "\n"
  430. object_ref_values = [
  431. entry["node_ip_address"],
  432. entry["pid"],
  433. entry["type"],
  434. entry["call_site"],
  435. entry["task_status"],
  436. entry["attempt_number"],
  437. entry["object_size"],
  438. entry["reference_type"],
  439. entry["object_ref"],
  440. ]
  441. for i in range(len(object_ref_values)):
  442. if not isinstance(object_ref_values[i], list):
  443. object_ref_values[i] = [object_ref_values[i]]
  444. object_ref_values[i].extend(
  445. ["" for x in range(num_lines - len(object_ref_values[i]))]
  446. )
  447. for i in range(num_lines):
  448. row = [elem[i] for elem in object_ref_values]
  449. mem += object_ref_string.format(*row)
  450. mem += "\n"
  451. n += 1
  452. mem += (
  453. "To record callsite information for each ObjectRef created, set "
  454. "env variable RAY_record_ref_creation_sites=1\n\n"
  455. )
  456. return mem