serialization_context.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. import warnings
  2. from typing import TYPE_CHECKING, Any, Dict, List, Set, Tuple, Union
  3. from ray.experimental.util.types import Device
  4. if TYPE_CHECKING:
  5. import numpy as np
  6. import torch
  7. _TORCH_WARNING_FILTER_ACTIVATE = True
  8. class _SerializationContext:
  9. def __init__(self):
  10. # If true, then tensors found in the data to serialize are extracted
  11. # and the caller should send them through an external transport.
  12. self._use_external_transport: bool = False
  13. # If _use_external_transport is True, then these are
  14. # the tensors that should be sent or received
  15. # out-of-band, through the external transport.
  16. self._out_of_band_tensors: List["torch.Tensor"] = []
  17. # During serialization, tensors sent out-of-band are replaced with
  18. # integer placeholders. This tracks the set of placeholders seen.
  19. self._deserialized_tensor_placeholders: Set[int] = set()
  20. # Buffer for transferring data between tasks in the same worker process.
  21. # The key is the channel ID, and the value is the data. We don't use a
  22. # lock when reading/writing the buffer because a DAG node actor will only
  23. # execute one task at a time in `do_exec_tasks`. It will not execute multiple
  24. # Ray tasks on a single actor simultaneously.
  25. self.intra_process_channel_buffers: Dict[str, Any] = {}
  26. # The number of readers for each channel. When the number of readers
  27. # reaches 0, remove the data from the buffer.
  28. self.channel_id_to_num_readers: Dict[str, int] = {}
  29. def set_target_device(self, device: Device) -> None:
  30. self._target_device = device
  31. def set_data(self, channel_id: str, value: Any, num_readers: int) -> None:
  32. assert num_readers > 0, "num_readers must be greater than 0."
  33. assert (
  34. channel_id not in self.intra_process_channel_buffers
  35. ), f"Channel {channel_id} already exists in the buffer."
  36. assert (
  37. channel_id not in self.channel_id_to_num_readers
  38. ), f"Channel {channel_id} already exists in the channel_id_to_num_readers."
  39. self.intra_process_channel_buffers[channel_id] = value
  40. self.channel_id_to_num_readers[channel_id] = num_readers
  41. def has_data(self, channel_id: str) -> bool:
  42. return channel_id in self.intra_process_channel_buffers
  43. def get_data(self, channel_id: str) -> Any:
  44. assert (
  45. channel_id in self.intra_process_channel_buffers
  46. ), f"Channel {channel_id} does not exist in the buffer."
  47. assert (
  48. channel_id in self.channel_id_to_num_readers
  49. ), f"Channel {channel_id} does not exist in the channel_id_to_num_readers."
  50. self.channel_id_to_num_readers[channel_id] -= 1
  51. if self.channel_id_to_num_readers[channel_id] == 0:
  52. # All readers have read the data, so we can remove it.
  53. self.channel_id_to_num_readers.pop(channel_id)
  54. return self.intra_process_channel_buffers.pop(channel_id)
  55. return self.intra_process_channel_buffers[channel_id]
  56. def reset_data(self, channel_id: str) -> None:
  57. self.intra_process_channel_buffers.pop(channel_id, None)
  58. self.channel_id_to_num_readers.pop(channel_id, None)
  59. def set_use_external_transport(self, use_external_transport: bool) -> None:
  60. self._use_external_transport = use_external_transport
  61. @property
  62. def use_external_transport(self) -> bool:
  63. return self._use_external_transport
  64. def reset_out_of_band_tensors(
  65. self, tensors: List["torch.Tensor"]
  66. ) -> Tuple[List["torch.Tensor"], Set[int]]:
  67. """
  68. Return and reset the out-of-band tensors and all tensor placeholders
  69. that were deserialized since the last call to reset.
  70. """
  71. prev_tensors = self._out_of_band_tensors
  72. deserialized_tensor_placeholders = self._deserialized_tensor_placeholders
  73. self._out_of_band_tensors = tensors
  74. self._deserialized_tensor_placeholders = set()
  75. return prev_tensors, deserialized_tensor_placeholders
  76. def serialize_tensor(
  77. self, tensor: "torch.Tensor"
  78. ) -> Union[int, Tuple["np.ndarray", "torch.dtype", str]]:
  79. from ray.experimental.channel import ChannelContext
  80. ctx = ChannelContext.get_current()
  81. if self._use_external_transport and (
  82. ctx._torch_device is None or ctx._torch_device == tensor.device
  83. ):
  84. # External transport is enabled and we found a tensor that matches
  85. # our device. Add the actual tensor to a buffer. The buffer of
  86. # tensors should later be popped by the caller and sent via
  87. # external transport.
  88. self._out_of_band_tensors.append(tensor)
  89. # Return a placeholder.
  90. return len(self._out_of_band_tensors) - 1
  91. return self.serialize_to_numpy_or_scalar(tensor)
  92. def serialize_to_numpy_or_scalar(
  93. self, tensor: "torch.Tensor"
  94. ) -> Tuple[Union["np.ndarray", Any], "torch.dtype", str]:
  95. """
  96. Serialize a tensor to a numpy array,
  97. or a scalar when the tensor is 0-dim.
  98. """
  99. import torch
  100. tensor_device_type = tensor.device.type
  101. # Transfer through Ray's shared memory store for now.
  102. # TODO(swang): This requires two copies, one to transfer from GPU to
  103. # CPU and another from CPU to shared memory. Ideally we should elide
  104. # the first copy and memcpy directly from GPU to the shared memory
  105. # buffer.
  106. if tensor_device_type != "cpu":
  107. tensor = tensor.to("cpu")
  108. # Numpy does not have an equivalent dtype for all torch dtypes, so
  109. # instead of casting directly to numpy:
  110. # 1) for non-scalar tensors, we first use a view with a common dtype (uint8)
  111. # and then view as numpy array.
  112. # 2) for scalar tensors, we cannot use a uint8 view when the size differs,
  113. # so we save the original item and type information.
  114. if tensor.dim() > 0:
  115. return (tensor.view(torch.uint8).numpy(), tensor.dtype, tensor_device_type)
  116. else:
  117. return (tensor.item(), tensor.dtype, tensor_device_type)
  118. def deserialize_tensor(
  119. self,
  120. val: Union[Tuple["np.ndarray", "torch.dtype", str], int],
  121. target_device: Device,
  122. ):
  123. # Found a placeholder for a tensor that was serialized via accelerator.
  124. # Replace it with the corresponding deserialized tensor.
  125. if isinstance(val, int):
  126. placeholder = val
  127. self._deserialized_tensor_placeholders.add(placeholder)
  128. assert placeholder < len(self._out_of_band_tensors), (
  129. "placeholder",
  130. placeholder,
  131. "out_of_band_tensors",
  132. self._out_of_band_tensors,
  133. )
  134. tensor = self._out_of_band_tensors[placeholder]
  135. if target_device == Device.CPU:
  136. tensor = tensor.to("cpu")
  137. return tensor
  138. np_array, dtype, tensor_device_type = val
  139. return self.deserialize_from_numpy_or_scalar(
  140. np_array, dtype, tensor_device_type, target_device
  141. )
  142. def deserialize_from_numpy_or_scalar(
  143. self,
  144. np_array: Union["np.ndarray", Any],
  145. dtype: "torch.dtype",
  146. tensor_device_type: str,
  147. target_device: Device,
  148. ):
  149. import numpy as np
  150. import torch
  151. if target_device == Device.DEFAULT:
  152. target_device_type = tensor_device_type
  153. elif target_device in [Device.GPU, Device.CUDA]:
  154. target_device_type = "cuda"
  155. else:
  156. target_device_type = target_device.value
  157. # TODO(swang): Support local P2P transfers if available.
  158. if target_device_type != "cpu":
  159. def convert_numpy_to_tensor(np_array):
  160. if not isinstance(np_array, np.ndarray):
  161. # For scalar tensors, create the 0-dim tensor.
  162. return torch.tensor(
  163. np_array, device=target_device_type, dtype=dtype
  164. )
  165. else:
  166. # For non-scalar tensors, view as the original dtype.
  167. # It does zero-copy convert np_array inside shared memory to
  168. # a tensor. Since we move data to GPU immediately, it is safe.
  169. cpu_tensor = torch.from_numpy(np_array).view(dtype)
  170. return cpu_tensor.to(device=target_device_type)
  171. global _TORCH_WARNING_FILTER_ACTIVATE
  172. # filtering warning messages would be the bottleneck for
  173. # deserializing torch tensors. Since the warning only prompts once,
  174. # we would only deal with it for the first time.
  175. if _TORCH_WARNING_FILTER_ACTIVATE:
  176. with warnings.catch_warnings():
  177. # Since np_array.is_writable is False (it is set by Ray),
  178. # this raises a warning. Suppress it.
  179. warnings.filterwarnings(
  180. "ignore",
  181. category=UserWarning,
  182. message="The given NumPy array is not writable",
  183. )
  184. gpu_tensor = convert_numpy_to_tensor(np_array)
  185. _TORCH_WARNING_FILTER_ACTIVATE = False
  186. else:
  187. gpu_tensor = convert_numpy_to_tensor(np_array)
  188. return gpu_tensor
  189. # TODO(swang): Use zero-copy from_numpy() if np_array.flags.writeable
  190. # is True. This is safe to set when deserializing np_array if the
  191. # upstream task has num_readers=1.
  192. if not isinstance(np_array, np.ndarray):
  193. # For scalar tensors, create the 0-dim tensor.
  194. return torch.tensor(np_array, device=target_device_type, dtype=dtype)
  195. else:
  196. # For non-scalar tensors, view as the original dtype.
  197. return torch.tensor(np_array, device=target_device_type).view(dtype)