check_serialize.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. """A utility for debugging serialization issues."""
  2. import inspect
  3. from contextlib import contextmanager
  4. from typing import Any, Optional, Set, Tuple
  5. import colorama
  6. # Import ray first to use the bundled colorama
  7. import ray # noqa: F401
  8. import ray.cloudpickle as cp
  9. from ray.util.annotations import DeveloperAPI
  10. @contextmanager
  11. def _indent(printer):
  12. printer.level += 1
  13. yield
  14. printer.level -= 1
  15. class _Printer:
  16. def __init__(self, print_file):
  17. self.level = 0
  18. self.print_file = print_file
  19. def indent(self):
  20. return _indent(self)
  21. def print(self, msg):
  22. indent = " " * self.level
  23. print(indent + msg, file=self.print_file)
  24. @DeveloperAPI
  25. class FailureTuple:
  26. """Represents the serialization 'frame'.
  27. Attributes:
  28. obj: The object that fails serialization.
  29. name: The variable name of the object.
  30. parent: The object that references the `obj`.
  31. """
  32. def __init__(self, obj: Any, name: str, parent: Any):
  33. self.obj = obj
  34. self.name = name
  35. self.parent = parent
  36. def __repr__(self):
  37. return f"FailTuple({self.name} [obj={self.obj}, parent={self.parent}])"
  38. def _inspect_func_serialization(base_obj, depth, parent, failure_set, printer):
  39. """Adds the first-found non-serializable element to the failure_set."""
  40. assert inspect.isfunction(base_obj)
  41. closure = inspect.getclosurevars(base_obj)
  42. found = False
  43. if closure.globals:
  44. printer.print(
  45. f"Detected {len(closure.globals)} global variables. "
  46. "Checking serializability..."
  47. )
  48. with printer.indent():
  49. for name, obj in closure.globals.items():
  50. serializable, _ = _inspect_serializability(
  51. obj,
  52. name=name,
  53. depth=depth - 1,
  54. parent=parent,
  55. failure_set=failure_set,
  56. printer=printer,
  57. )
  58. found = found or not serializable
  59. if found:
  60. break
  61. if closure.nonlocals:
  62. printer.print(
  63. f"Detected {len(closure.nonlocals)} nonlocal variables. "
  64. "Checking serializability..."
  65. )
  66. with printer.indent():
  67. for name, obj in closure.nonlocals.items():
  68. serializable, _ = _inspect_serializability(
  69. obj,
  70. name=name,
  71. depth=depth - 1,
  72. parent=parent,
  73. failure_set=failure_set,
  74. printer=printer,
  75. )
  76. found = found or not serializable
  77. if found:
  78. break
  79. if not found:
  80. printer.print(
  81. f"WARNING: Did not find non-serializable object in {base_obj}. "
  82. "This may be an oversight."
  83. )
  84. return found
  85. def _inspect_generic_serialization(base_obj, depth, parent, failure_set, printer):
  86. """Adds the first-found non-serializable element to the failure_set."""
  87. assert not inspect.isfunction(base_obj)
  88. functions = inspect.getmembers(base_obj, predicate=inspect.isfunction)
  89. found = False
  90. with printer.indent():
  91. for name, obj in functions:
  92. serializable, _ = _inspect_serializability(
  93. obj,
  94. name=name,
  95. depth=depth - 1,
  96. parent=parent,
  97. failure_set=failure_set,
  98. printer=printer,
  99. )
  100. found = found or not serializable
  101. if found:
  102. break
  103. with printer.indent():
  104. members = inspect.getmembers(base_obj)
  105. for name, obj in members:
  106. if name.startswith("__") and name.endswith("__") or inspect.isbuiltin(obj):
  107. continue
  108. serializable, _ = _inspect_serializability(
  109. obj,
  110. name=name,
  111. depth=depth - 1,
  112. parent=parent,
  113. failure_set=failure_set,
  114. printer=printer,
  115. )
  116. found = found or not serializable
  117. if found:
  118. break
  119. if not found:
  120. printer.print(
  121. f"WARNING: Did not find non-serializable object in {base_obj}. "
  122. "This may be an oversight."
  123. )
  124. return found
  125. @DeveloperAPI
  126. def inspect_serializability(
  127. base_obj: Any,
  128. name: Optional[str] = None,
  129. depth: int = 3,
  130. print_file: Optional[Any] = None,
  131. ) -> Tuple[bool, Set[FailureTuple]]:
  132. """Identifies what objects are preventing serialization.
  133. Args:
  134. base_obj: Object to be serialized.
  135. name: Optional name of string.
  136. depth: Depth of the scope stack to walk through. Defaults to 3.
  137. print_file: file argument that will be passed to print().
  138. Returns:
  139. bool: True if serializable.
  140. set[FailureTuple]: Set of unserializable objects.
  141. .. versionadded:: 1.1.0
  142. """
  143. printer = _Printer(print_file)
  144. return _inspect_serializability(base_obj, name, depth, None, None, printer)
  145. def _inspect_serializability(
  146. base_obj, name, depth, parent, failure_set, printer
  147. ) -> Tuple[bool, Set[FailureTuple]]:
  148. colorama.init()
  149. top_level = False
  150. declaration = ""
  151. found = False
  152. if failure_set is None:
  153. top_level = True
  154. failure_set = set()
  155. declaration = f"Checking Serializability of {base_obj}"
  156. printer.print("=" * min(len(declaration), 80))
  157. printer.print(declaration)
  158. printer.print("=" * min(len(declaration), 80))
  159. if name is None:
  160. name = str(base_obj)
  161. else:
  162. printer.print(f"Serializing '{name}' {base_obj}...")
  163. try:
  164. cp.dumps(base_obj)
  165. return True, failure_set
  166. except Exception as e:
  167. printer.print(
  168. f"{colorama.Fore.RED}!!! FAIL{colorama.Fore.RESET} " f"serialization: {e}"
  169. )
  170. found = True
  171. try:
  172. if depth == 0:
  173. failure_set.add(FailureTuple(base_obj, name, parent))
  174. # Some objects may not be hashable, so we skip adding this to the set.
  175. except Exception:
  176. pass
  177. if depth <= 0:
  178. return False, failure_set
  179. # TODO: we only differentiate between 'function' and 'object'
  180. # but we should do a better job of diving into something
  181. # more specific like a Type, Object, etc.
  182. if inspect.isfunction(base_obj):
  183. _inspect_func_serialization(
  184. base_obj,
  185. depth=depth,
  186. parent=base_obj,
  187. failure_set=failure_set,
  188. printer=printer,
  189. )
  190. else:
  191. _inspect_generic_serialization(
  192. base_obj,
  193. depth=depth,
  194. parent=base_obj,
  195. failure_set=failure_set,
  196. printer=printer,
  197. )
  198. if not failure_set:
  199. failure_set.add(FailureTuple(base_obj, name, parent))
  200. if top_level:
  201. printer.print("=" * min(len(declaration), 80))
  202. if not failure_set:
  203. printer.print(
  204. "Nothing failed the inspect_serialization test, though "
  205. "serialization did not succeed."
  206. )
  207. else:
  208. fail_vars = (
  209. f"\n\n\t{colorama.Style.BRIGHT}"
  210. + "\n".join(str(k) for k in failure_set)
  211. + f"{colorama.Style.RESET_ALL}\n\n"
  212. )
  213. printer.print(
  214. f"Variable: {fail_vars}was found to be non-serializable. "
  215. "There may be multiple other undetected variables that were "
  216. "non-serializable. "
  217. )
  218. printer.print(
  219. "Consider either removing the "
  220. "instantiation/imports of these variables or moving the "
  221. "instantiation into the scope of the function/class. "
  222. )
  223. printer.print("=" * min(len(declaration), 80))
  224. printer.print(
  225. "Check https://docs.ray.io/en/master/ray-core/objects/serialization.html#troubleshooting for more information." # noqa
  226. )
  227. printer.print(
  228. "If you have any suggestions on how to improve "
  229. "this error message, please reach out to the "
  230. "Ray developers on github.com/ray-project/ray/issues/"
  231. )
  232. printer.print("=" * min(len(declaration), 80))
  233. return not found, failure_set