py_obj_scanner.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import io
  2. import pickle # noqa: F401
  3. from typing import Any, Dict, Generic, List, Tuple, Type, TypeVar, Union
  4. import ray
  5. from ray.dag.base import DAGNodeBase
  6. # Used in deserialization hooks to reference scanner instances.
  7. _instances: Dict[int, "_PyObjScanner"] = {}
  8. # Generic types for the scanner to transform from and to.
  9. SourceType = TypeVar("SourceType")
  10. TransformedType = TypeVar("TransformedType")
  11. def _get_node(instance_id: int, node_index: int) -> SourceType:
  12. """Get the node instance.
  13. Note: This function should be static and globally importable,
  14. otherwise the serialization overhead would be very significant.
  15. """
  16. return _instances[instance_id]._replace_index(node_index)
  17. class _PyObjScanner(ray.cloudpickle.CloudPickler, Generic[SourceType, TransformedType]):
  18. """Utility to find and replace the `source_type` in Python objects.
  19. `source_type` can either be a single type or a tuple of multiple types.
  20. The caller must first call `find_nodes()`, then compute a replacement table and
  21. pass it to `replace_nodes`.
  22. This uses cloudpickle under the hood, so all sub-objects that are not `source_type`
  23. must be serializable.
  24. Args:
  25. source_type: the type(s) of object to find and replace. Default to DAGNodeBase.
  26. """
  27. def __init__(self, source_type: Union[Type, Tuple] = DAGNodeBase):
  28. self.source_type = source_type
  29. # Buffer to keep intermediate serialized state.
  30. self._buf = io.BytesIO()
  31. # List of top-level SourceType found during the serialization pass.
  32. self._found = None
  33. # List of other objects found during the serialization pass.
  34. # This is used to store references to objects so they won't be
  35. # serialized by cloudpickle.
  36. self._objects = []
  37. # Replacement table to consult during deserialization.
  38. self._replace_table: Dict[SourceType, TransformedType] = None
  39. _instances[id(self)] = self
  40. super().__init__(self._buf)
  41. def reducer_override(self, obj):
  42. """Hook for reducing objects.
  43. Objects of `self.source_type` are saved to `self._found` and a global map so
  44. they can later be replaced.
  45. All other objects fall back to the default `CloudPickler` serialization.
  46. """
  47. if isinstance(obj, self.source_type):
  48. index = len(self._found)
  49. self._found.append(obj)
  50. return _get_node, (id(self), index)
  51. return super().reducer_override(obj)
  52. def find_nodes(self, obj: Any) -> List[SourceType]:
  53. """
  54. Serialize `obj` and store all instances of `source_type` found in `_found`.
  55. Args:
  56. obj: The object to scan for `source_type`.
  57. Returns:
  58. A list of all instances of `source_type` found in `obj`.
  59. """
  60. assert (
  61. self._found is None
  62. ), "find_nodes cannot be called twice on the same PyObjScanner instance."
  63. self._found = []
  64. self._objects = []
  65. self.dump(obj)
  66. return self._found
  67. def replace_nodes(self, table: Dict[SourceType, TransformedType]) -> Any:
  68. """Replace previously found DAGNodes per the given table."""
  69. assert self._found is not None, "find_nodes must be called first"
  70. self._replace_table = table
  71. self._buf.seek(0)
  72. return pickle.load(self._buf)
  73. def _replace_index(self, i: int) -> SourceType:
  74. return self._replace_table[self._found[i]]
  75. def clear(self):
  76. """Clear the scanner from the _instances"""
  77. if id(self) in _instances:
  78. del _instances[id(self)]
  79. def __del__(self):
  80. self.clear()