arrow_serialization.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795
  1. # arrow_serialization.py must resides outside of ray.data, otherwise
  2. # it causes circular dependency issues for AsyncActors due to
  3. # ray.data's lazy import.
  4. # see https://github.com/ray-project/ray/issues/30498 for more context.
  5. import logging
  6. import os
  7. import sys
  8. from dataclasses import dataclass
  9. from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
  10. from ray._private.utils import is_in_test
  11. if TYPE_CHECKING:
  12. import pyarrow
  13. RAY_DISABLE_CUSTOM_ARROW_JSON_OPTIONS_SERIALIZATION = (
  14. "RAY_DISABLE_CUSTOM_ARROW_JSON_OPTIONS_SERIALIZATION"
  15. )
  16. RAY_DISABLE_CUSTOM_ARROW_DATA_SERIALIZATION = (
  17. "RAY_DISABLE_CUSTOM_ARROW_DATA_SERIALIZATION"
  18. )
  19. logger = logging.getLogger(__name__)
  20. # Whether we have already warned the user about bloated fallback serialization.
  21. _serialization_fallback_set = set()
  22. def _register_custom_datasets_serializers(serialization_context):
  23. try:
  24. import pyarrow as pa # noqa: F401
  25. except ModuleNotFoundError:
  26. # No pyarrow installed so not using Arrow, so no need for custom serializers.
  27. return
  28. # Register all custom serializers required by Datasets.
  29. _register_arrow_data_serializer(serialization_context)
  30. _register_arrow_json_readoptions_serializer(serialization_context)
  31. _register_arrow_json_parseoptions_serializer(serialization_context)
  32. # Register custom Arrow JSON ReadOptions serializer to workaround it not being picklable
  33. # in Arrow < 8.0.0.
  34. def _register_arrow_json_readoptions_serializer(serialization_context):
  35. if (
  36. os.environ.get(
  37. RAY_DISABLE_CUSTOM_ARROW_JSON_OPTIONS_SERIALIZATION,
  38. "0",
  39. )
  40. == "1"
  41. ):
  42. return
  43. import pyarrow.json as pajson
  44. serialization_context._register_cloudpickle_serializer(
  45. pajson.ReadOptions,
  46. custom_serializer=lambda opts: (opts.use_threads, opts.block_size),
  47. custom_deserializer=lambda args: pajson.ReadOptions(*args),
  48. )
  49. def _register_arrow_json_parseoptions_serializer(serialization_context):
  50. if (
  51. os.environ.get(
  52. RAY_DISABLE_CUSTOM_ARROW_JSON_OPTIONS_SERIALIZATION,
  53. "0",
  54. )
  55. == "1"
  56. ):
  57. return
  58. import pyarrow.json as pajson
  59. serialization_context._register_cloudpickle_serializer(
  60. pajson.ParseOptions,
  61. custom_serializer=lambda opts: (
  62. opts.explicit_schema,
  63. opts.newlines_in_values,
  64. opts.unexpected_field_behavior,
  65. ),
  66. custom_deserializer=lambda args: pajson.ParseOptions(*args),
  67. )
  68. # Register custom Arrow data serializer to work around zero-copy slice pickling bug.
  69. # See https://issues.apache.org/jira/browse/ARROW-10739.
  70. def _register_arrow_data_serializer(serialization_context):
  71. """Custom reducer for Arrow data that works around a zero-copy slicing pickling
  72. bug by using the Arrow IPC format for the underlying serialization.
  73. Background:
  74. Arrow has both array-level slicing and buffer-level slicing; both are zero-copy,
  75. but the former has a serialization bug where the entire buffer is serialized
  76. instead of just the slice, while the latter's serialization works as expected
  77. and only serializes the slice of the buffer. I.e., array-level slicing doesn't
  78. propagate the slice down to the buffer when serializing the array.
  79. We work around this by registering a custom cloudpickle reducers for Arrow
  80. Tables that delegates serialization to the Arrow IPC format; thankfully, Arrow's
  81. IPC serialization has fixed this buffer truncation bug.
  82. See https://issues.apache.org/jira/browse/ARROW-10739.
  83. """
  84. if os.environ.get(RAY_DISABLE_CUSTOM_ARROW_DATA_SERIALIZATION, "0") == "1":
  85. return
  86. import pyarrow as pa
  87. serialization_context._register_cloudpickle_reducer(pa.Table, _arrow_table_reduce)
  88. serialization_context._register_cloudpickle_reducer(pa.Schema, _arrow_schema_reduce)
  89. def _arrow_schema_reduce(
  90. schema: "pyarrow.Schema",
  91. ) -> Tuple[Callable[["bytes"], "pyarrow.Schema"], Tuple[bytes]]:
  92. """Custom reducer for Arrow Schema that uses IPC serialization for performance.
  93. Arrow's native IPC serialization for schemas is significantly faster than
  94. cloudpickle (10-20x for serialization, 2-3x for deserialization), making
  95. this optimization particularly valuable for workloads with large schemas.
  96. """
  97. # Use Arrow's native IPC serialization which is much faster than cloudpickle
  98. return _restore_schema_from_ipc, (schema.serialize().to_pybytes(),)
  99. def _restore_schema_from_ipc(buf: bytes) -> "pyarrow.Schema":
  100. """Restore an Arrow Schema serialized to Arrow IPC format."""
  101. import pyarrow as pa
  102. return pa.ipc.read_schema(pa.BufferReader(buf))
  103. def _arrow_table_reduce(t: "pyarrow.Table"):
  104. """Custom reducer for Arrow Tables that works around a zero-copy slice pickling bug.
  105. Background:
  106. Arrow has both array-level slicing and buffer-level slicing; both are zero-copy,
  107. but the former has a serialization bug where the entire buffer is serialized
  108. instead of just the slice, while the latter's serialization works as expected
  109. and only serializes the slice of the buffer. I.e., array-level slicing doesn't
  110. propagate the slice down to the buffer when serializing the array.
  111. All that these copy methods do is, at serialization time, take the array-level
  112. slicing and translate them to buffer-level slicing, so only the buffer slice is
  113. sent over the wire instead of the entire buffer.
  114. See https://issues.apache.org/jira/browse/ARROW-10739.
  115. """
  116. global _serialization_fallback_set
  117. # Reduce the ChunkedArray columns.
  118. reduced_columns = []
  119. for column_name in t.column_names:
  120. column = t[column_name]
  121. try:
  122. # Delegate to ChunkedArray reducer.
  123. reduced_column = _arrow_chunked_array_reduce(column)
  124. except Exception as e:
  125. if not _is_dense_union(column.type) and is_in_test():
  126. # If running in a test and the column is not a dense union array
  127. # (which we expect to need a fallback), we want to raise the error,
  128. # not fall back.
  129. raise e from None
  130. if type(column.type) not in _serialization_fallback_set:
  131. logger.warning(
  132. "Failed to complete optimized serialization of Arrow Table, "
  133. f"serialization of column '{column_name}' of type {column.type} "
  134. "failed, so we're falling back to Arrow IPC serialization for the "
  135. "table. Note that this may result in slower serialization and more "
  136. "worker memory utilization. Serialization error:",
  137. exc_info=True,
  138. )
  139. _serialization_fallback_set.add(type(column.type))
  140. # Fall back to Arrow IPC-based workaround for the entire table.
  141. return _arrow_table_ipc_reduce(t)
  142. else:
  143. # Column reducer succeeded, add reduced column to list.
  144. reduced_columns.append(reduced_column)
  145. return _reconstruct_table, (reduced_columns, t.schema)
  146. def _reconstruct_table(
  147. reduced_columns: List[Tuple[List["pyarrow.Array"], "pyarrow.DataType"]],
  148. schema: "pyarrow.Schema",
  149. ) -> "pyarrow.Table":
  150. """Restore a serialized Arrow Table, reconstructing each reduced column."""
  151. import pyarrow as pa
  152. # Reconstruct each reduced column.
  153. columns = []
  154. for chunks_payload, type_ in reduced_columns:
  155. columns.append(_reconstruct_chunked_array(chunks_payload, type_))
  156. return pa.Table.from_arrays(columns, schema=schema)
  157. def _arrow_chunked_array_reduce(
  158. ca: "pyarrow.ChunkedArray",
  159. ) -> Tuple[List["PicklableArrayPayload"], "pyarrow.DataType"]:
  160. """Custom reducer for Arrow ChunkedArrays that works around a zero-copy slice
  161. pickling bug. This reducer does not return a reconstruction function, since it's
  162. expected to be reconstructed by the Arrow Table reconstructor.
  163. """
  164. # Convert chunks to serialization payloads.
  165. chunk_payloads = []
  166. for chunk in ca.chunks:
  167. chunk_payload = PicklableArrayPayload.from_array(chunk)
  168. chunk_payloads.append(chunk_payload)
  169. return chunk_payloads, ca.type
  170. def _reconstruct_chunked_array(
  171. chunks: List["PicklableArrayPayload"], type_: "pyarrow.DataType"
  172. ) -> "pyarrow.ChunkedArray":
  173. """Restore a serialized Arrow ChunkedArray from chunks and type."""
  174. import pyarrow as pa
  175. # Reconstruct chunks from serialization payloads.
  176. chunks = [chunk.to_array() for chunk in chunks]
  177. return pa.chunked_array(chunks, type_)
  178. @dataclass
  179. class PicklableArrayPayload:
  180. """Picklable array payload, holding data buffers and array metadata.
  181. This is a helper container for pickling and reconstructing nested Arrow Arrays while
  182. ensuring that the buffers that underly zero-copy slice views are properly truncated.
  183. """
  184. # Array type.
  185. type: "pyarrow.DataType"
  186. # Length of array.
  187. length: int
  188. # Underlying data buffers.
  189. buffers: List["pyarrow.Buffer"]
  190. # Cached null count.
  191. null_count: int
  192. # Slice offset into base array.
  193. offset: int
  194. # Serialized array payloads for nested (child) arrays.
  195. children: List["PicklableArrayPayload"]
  196. @classmethod
  197. def from_array(self, a: "pyarrow.Array") -> "PicklableArrayPayload":
  198. """Create a picklable array payload from an Arrow Array.
  199. This will recursively accumulate data buffer and metadata payloads that are
  200. ready for pickling; namely, the data buffers underlying zero-copy slice views
  201. will be properly truncated.
  202. """
  203. return _array_to_array_payload(a)
  204. def to_array(self) -> "pyarrow.Array":
  205. """Reconstruct an Arrow Array from this picklable payload."""
  206. return _array_payload_to_array(self)
  207. def _array_payload_to_array(payload: "PicklableArrayPayload") -> "pyarrow.Array":
  208. """Reconstruct an Arrow Array from a possibly nested PicklableArrayPayload."""
  209. import pyarrow as pa
  210. children = [child_payload.to_array() for child_payload in payload.children]
  211. if pa.types.is_dictionary(payload.type):
  212. # Dedicated path for reconstructing a DictionaryArray, since
  213. # Array.from_buffers() doesn't work for DictionaryArrays.
  214. assert len(children) == 2, len(children)
  215. indices, dictionary = children
  216. return pa.DictionaryArray.from_arrays(indices, dictionary)
  217. elif pa.types.is_map(payload.type) and len(children) > 1:
  218. # In pyarrow<7.0.0, the underlying map child array is not exposed, so we work
  219. # with the key and item arrays.
  220. assert len(children) == 3, len(children)
  221. offsets, keys, items = children
  222. return pa.MapArray.from_arrays(offsets, keys, items)
  223. elif isinstance(payload.type, pa.BaseExtensionType):
  224. assert len(children) == 1, len(children)
  225. storage = children[0]
  226. return payload.type.wrap_array(storage)
  227. else:
  228. # Common case: use Array.from_buffers() to construct an array of a certain type.
  229. return pa.Array.from_buffers(
  230. type=payload.type,
  231. length=payload.length,
  232. buffers=payload.buffers,
  233. null_count=payload.null_count,
  234. offset=payload.offset,
  235. children=children,
  236. )
  237. def _array_to_array_payload(a: "pyarrow.Array") -> "PicklableArrayPayload":
  238. """Serialize an Arrow Array to an PicklableArrayPayload for later pickling.
  239. This function's primary purpose is to dispatch to the handler for the input array
  240. type.
  241. """
  242. import pyarrow as pa
  243. if _is_dense_union(a.type):
  244. # Dense unions are not supported.
  245. # TODO(Clark): Support dense unions.
  246. raise NotImplementedError(
  247. "Custom slice view serialization of dense union arrays is not yet "
  248. "supported."
  249. )
  250. # Dispatch to handler for array type.
  251. if pa.types.is_null(a.type):
  252. return _null_array_to_array_payload(a)
  253. elif _is_primitive(a.type):
  254. return _primitive_array_to_array_payload(a)
  255. elif _is_binary(a.type):
  256. return _binary_array_to_array_payload(a)
  257. elif pa.types.is_list(a.type) or pa.types.is_large_list(a.type):
  258. return _list_array_to_array_payload(a)
  259. elif pa.types.is_fixed_size_list(a.type):
  260. return _fixed_size_list_array_to_array_payload(a)
  261. elif pa.types.is_struct(a.type):
  262. return _struct_array_to_array_payload(a)
  263. elif pa.types.is_union(a.type):
  264. return _union_array_to_array_payload(a)
  265. elif pa.types.is_dictionary(a.type):
  266. return _dictionary_array_to_array_payload(a)
  267. elif pa.types.is_map(a.type):
  268. return _map_array_to_array_payload(a)
  269. elif isinstance(a.type, pa.BaseExtensionType):
  270. return _extension_array_to_array_payload(a)
  271. else:
  272. raise ValueError("Unhandled Arrow array type:", a.type)
  273. def _is_primitive(type_: "pyarrow.DataType") -> bool:
  274. """Whether the provided Array type is primitive (boolean, numeric, temporal or
  275. fixed-size binary)."""
  276. import pyarrow as pa
  277. return (
  278. pa.types.is_integer(type_)
  279. or pa.types.is_floating(type_)
  280. or pa.types.is_decimal(type_)
  281. or pa.types.is_boolean(type_)
  282. or pa.types.is_temporal(type_)
  283. or pa.types.is_fixed_size_binary(type_)
  284. )
  285. def _is_binary(type_: "pyarrow.DataType") -> bool:
  286. """Whether the provided Array type is a variable-sized binary type."""
  287. import pyarrow as pa
  288. return (
  289. pa.types.is_string(type_)
  290. or pa.types.is_large_string(type_)
  291. or pa.types.is_binary(type_)
  292. or pa.types.is_large_binary(type_)
  293. )
  294. def _null_array_to_array_payload(a: "pyarrow.NullArray") -> "PicklableArrayPayload":
  295. """Serialize null array to PicklableArrayPayload."""
  296. # Buffer scheme: [None]
  297. return PicklableArrayPayload(
  298. type=a.type,
  299. length=len(a),
  300. buffers=[None], # Single null buffer is expected.
  301. null_count=a.null_count,
  302. offset=0,
  303. children=[],
  304. )
  305. def _primitive_array_to_array_payload(a: "pyarrow.Array") -> "PicklableArrayPayload":
  306. """Serialize primitive (numeric, temporal, boolean) arrays to
  307. PicklableArrayPayload.
  308. """
  309. assert _is_primitive(a.type), a.type
  310. # Buffer scheme: [bitmap, data]
  311. buffers = a.buffers()
  312. assert len(buffers) == 2, len(buffers)
  313. # Copy bitmap buffer, if needed.
  314. bitmap_buf = buffers[0]
  315. if a.null_count > 0:
  316. bitmap_buf = _copy_bitpacked_buffer_if_needed(bitmap_buf, a.offset, len(a))
  317. else:
  318. bitmap_buf = None
  319. # Copy data buffer, if needed.
  320. data_buf = buffers[1]
  321. if data_buf is not None:
  322. data_buf = _copy_buffer_if_needed(buffers[1], a.type, a.offset, len(a))
  323. return PicklableArrayPayload(
  324. type=a.type,
  325. length=len(a),
  326. buffers=[bitmap_buf, data_buf],
  327. null_count=a.null_count,
  328. offset=0,
  329. children=[],
  330. )
  331. def _binary_array_to_array_payload(a: "pyarrow.Array") -> "PicklableArrayPayload":
  332. """Serialize binary (variable-sized binary, string) arrays to
  333. PicklableArrayPayload.
  334. """
  335. assert _is_binary(a.type), a.type
  336. # Buffer scheme: [bitmap, value_offsets, data]
  337. buffers = a.buffers()
  338. assert len(buffers) == 3, len(buffers)
  339. # Copy bitmap buffer, if needed.
  340. if a.null_count > 0:
  341. bitmap_buf = _copy_bitpacked_buffer_if_needed(buffers[0], a.offset, len(a))
  342. else:
  343. bitmap_buf = None
  344. # Copy offset buffer, if needed.
  345. offset_buf = buffers[1]
  346. offset_buf, data_offset, data_length = _copy_offsets_buffer_if_needed(
  347. offset_buf, a.type, a.offset, len(a)
  348. )
  349. data_buf = buffers[2]
  350. data_buf = _copy_buffer_if_needed(data_buf, None, data_offset, data_length)
  351. return PicklableArrayPayload(
  352. type=a.type,
  353. length=len(a),
  354. buffers=[bitmap_buf, offset_buf, data_buf],
  355. null_count=a.null_count,
  356. offset=0,
  357. children=[],
  358. )
  359. def _list_array_to_array_payload(a: "pyarrow.Array") -> "PicklableArrayPayload":
  360. """Serialize list (regular and large) arrays to PicklableArrayPayload."""
  361. # Dedicated path for ListArrays. These arrays have a nested set of bitmap and
  362. # offset buffers, eventually bottoming out on a data buffer.
  363. # Buffer scheme:
  364. # [bitmap, offsets, bitmap, offsets, ..., bitmap, data]
  365. buffers = a.buffers()
  366. assert len(buffers) > 1, len(buffers)
  367. # Copy bitmap buffer, if needed.
  368. if a.null_count > 0:
  369. bitmap_buf = _copy_bitpacked_buffer_if_needed(buffers[0], a.offset, len(a))
  370. else:
  371. bitmap_buf = None
  372. # Copy offset buffer, if needed.
  373. offset_buf = buffers[1]
  374. offset_buf, child_offset, child_length = _copy_offsets_buffer_if_needed(
  375. offset_buf, a.type, a.offset, len(a)
  376. )
  377. # Propagate slice to child.
  378. child = a.values.slice(child_offset, child_length)
  379. return PicklableArrayPayload(
  380. type=a.type,
  381. length=len(a),
  382. buffers=[bitmap_buf, offset_buf],
  383. null_count=a.null_count,
  384. offset=0,
  385. children=[_array_to_array_payload(child)],
  386. )
  387. def _fixed_size_list_array_to_array_payload(
  388. a: "pyarrow.FixedSizeListArray",
  389. ) -> "PicklableArrayPayload":
  390. """Serialize fixed size list arrays to PicklableArrayPayload."""
  391. # Dedicated path for fixed-size lists.
  392. # Buffer scheme:
  393. # [bitmap, values_bitmap, values_data, values_subbuffers...]
  394. buffers = a.buffers()
  395. assert len(buffers) >= 1, len(buffers)
  396. # Copy bitmap buffer, if needed.
  397. if a.null_count > 0:
  398. bitmap_buf = _copy_bitpacked_buffer_if_needed(buffers[0], a.offset, len(a))
  399. else:
  400. bitmap_buf = None
  401. # Propagate slice to child.
  402. child_offset = a.type.list_size * a.offset
  403. child_length = a.type.list_size * len(a)
  404. child = a.values.slice(child_offset, child_length)
  405. return PicklableArrayPayload(
  406. type=a.type,
  407. length=len(a),
  408. buffers=[bitmap_buf],
  409. null_count=a.null_count,
  410. offset=0,
  411. children=[_array_to_array_payload(child)],
  412. )
  413. def _struct_array_to_array_payload(a: "pyarrow.StructArray") -> "PicklableArrayPayload":
  414. """Serialize struct arrays to PicklableArrayPayload."""
  415. # Dedicated path for StructArrays.
  416. # StructArrays have a top-level bitmap buffer and one or more children arrays.
  417. # Buffer scheme: [bitmap, None, child_bitmap, child_data, ...]
  418. buffers = a.buffers()
  419. assert len(buffers) >= 1, len(buffers)
  420. # Copy bitmap buffer, if needed.
  421. if a.null_count > 0:
  422. bitmap_buf = _copy_bitpacked_buffer_if_needed(buffers[0], a.offset, len(a))
  423. else:
  424. bitmap_buf = None
  425. # Get field children payload.
  426. # Offsets and truncations are already propagated to the field arrays, so we can
  427. # serialize them as-is.
  428. children = [_array_to_array_payload(a.field(i)) for i in range(a.type.num_fields)]
  429. return PicklableArrayPayload(
  430. type=a.type,
  431. length=len(a),
  432. buffers=[bitmap_buf],
  433. null_count=a.null_count,
  434. offset=0,
  435. children=children,
  436. )
  437. def _union_array_to_array_payload(a: "pyarrow.UnionArray") -> "PicklableArrayPayload":
  438. """Serialize union arrays to PicklableArrayPayload."""
  439. import pyarrow as pa
  440. # Dedicated path for UnionArrays.
  441. # UnionArrays have a top-level bitmap buffer and type code buffer, and one or
  442. # more children arrays.
  443. # Buffer scheme: [None, typecodes, child_bitmap, child_data, ...]
  444. assert not _is_dense_union(a.type)
  445. buffers = a.buffers()
  446. assert len(buffers) > 1, len(buffers)
  447. bitmap_buf = buffers[0]
  448. assert bitmap_buf is None, bitmap_buf
  449. # Copy type code buffer, if needed.
  450. type_code_buf = buffers[1]
  451. type_code_buf = _copy_buffer_if_needed(type_code_buf, pa.int8(), a.offset, len(a))
  452. # Get field children payload.
  453. # Offsets and truncations are already propagated to the field arrays, so we can
  454. # serialize them as-is.
  455. children = [_array_to_array_payload(a.field(i)) for i in range(a.type.num_fields)]
  456. return PicklableArrayPayload(
  457. type=a.type,
  458. length=len(a),
  459. buffers=[bitmap_buf, type_code_buf],
  460. null_count=a.null_count,
  461. offset=0,
  462. children=children,
  463. )
  464. def _dictionary_array_to_array_payload(
  465. a: "pyarrow.DictionaryArray",
  466. ) -> "PicklableArrayPayload":
  467. """Serialize dictionary arrays to PicklableArrayPayload."""
  468. # Dedicated path for DictionaryArrays.
  469. # Buffer scheme: [indices_bitmap, indices_data] (dictionary stored separately)
  470. indices_payload = _array_to_array_payload(a.indices)
  471. dictionary_payload = _array_to_array_payload(a.dictionary)
  472. return PicklableArrayPayload(
  473. type=a.type,
  474. length=len(a),
  475. buffers=[],
  476. null_count=a.null_count,
  477. offset=0,
  478. children=[indices_payload, dictionary_payload],
  479. )
  480. def _map_array_to_array_payload(a: "pyarrow.MapArray") -> "PicklableArrayPayload":
  481. """Serialize map arrays to PicklableArrayPayload."""
  482. import pyarrow as pa
  483. # Dedicated path for MapArrays.
  484. # Buffer scheme: [bitmap, offsets, child_struct_array_buffers, ...]
  485. buffers = a.buffers()
  486. assert len(buffers) > 0, len(buffers)
  487. # Copy bitmap buffer, if needed.
  488. if a.null_count > 0:
  489. bitmap_buf = _copy_bitpacked_buffer_if_needed(buffers[0], a.offset, len(a))
  490. else:
  491. bitmap_buf = None
  492. new_buffers = [bitmap_buf]
  493. # Copy offsets buffer, if needed.
  494. offset_buf = buffers[1]
  495. offset_buf, data_offset, data_length = _copy_offsets_buffer_if_needed(
  496. offset_buf, a.type, a.offset, len(a)
  497. )
  498. if isinstance(a, pa.lib.ListArray):
  499. # Map arrays directly expose the one child struct array in pyarrow>=7.0.0, which
  500. # is easier to work with than the raw buffers.
  501. new_buffers.append(offset_buf)
  502. children = [_array_to_array_payload(a.values.slice(data_offset, data_length))]
  503. else:
  504. # In pyarrow<7.0.0, the child struct array is not exposed, so we work with the
  505. # key and item arrays.
  506. buffers = a.buffers()
  507. assert len(buffers) > 2, len(buffers)
  508. # Reconstruct offsets array.
  509. offsets = pa.Array.from_buffers(
  510. pa.int32(), len(a) + 1, [bitmap_buf, offset_buf]
  511. )
  512. # Propagate slice to keys.
  513. keys = a.keys.slice(data_offset, data_length)
  514. # Propagate slice to items.
  515. items = a.items.slice(data_offset, data_length)
  516. children = [
  517. _array_to_array_payload(offsets),
  518. _array_to_array_payload(keys),
  519. _array_to_array_payload(items),
  520. ]
  521. return PicklableArrayPayload(
  522. type=a.type,
  523. length=len(a),
  524. buffers=new_buffers,
  525. null_count=a.null_count,
  526. offset=0,
  527. children=children,
  528. )
  529. def _extension_array_to_array_payload(
  530. a: "pyarrow.ExtensionArray",
  531. ) -> "PicklableArrayPayload":
  532. storage_payload = _array_to_array_payload(a.storage)
  533. return PicklableArrayPayload(
  534. type=a.type,
  535. length=len(a),
  536. buffers=[],
  537. null_count=a.null_count,
  538. offset=0,
  539. children=[storage_payload],
  540. )
  541. def _copy_buffer_if_needed(
  542. buf: "pyarrow.Buffer",
  543. type_: Optional["pyarrow.DataType"],
  544. offset: int,
  545. length: int,
  546. ) -> "pyarrow.Buffer":
  547. """Copy buffer, if needed."""
  548. import pyarrow as pa
  549. if type_ is not None and pa.types.is_boolean(type_):
  550. # Arrow boolean array buffers are bit-packed, with 8 entries per byte,
  551. # and are accessed via bit offsets.
  552. buf = _copy_bitpacked_buffer_if_needed(buf, offset, length)
  553. else:
  554. type_bytewidth = type_.bit_width // 8 if type_ is not None else 1
  555. buf = _copy_normal_buffer_if_needed(buf, type_bytewidth, offset, length)
  556. return buf
  557. def _copy_normal_buffer_if_needed(
  558. buf: "pyarrow.Buffer",
  559. byte_width: int,
  560. offset: int,
  561. length: int,
  562. ) -> "pyarrow.Buffer":
  563. """Copy buffer, if needed."""
  564. byte_offset = offset * byte_width
  565. byte_length = length * byte_width
  566. if offset > 0 or byte_length < buf.size:
  567. # Array is a zero-copy slice, so we need to copy to a new buffer before
  568. # serializing; this slice of the underlying buffer (not the array) will ensure
  569. # that the buffer is properly copied at pickle-time.
  570. buf = buf.slice(byte_offset, byte_length)
  571. return buf
  572. def _copy_bitpacked_buffer_if_needed(
  573. buf: "pyarrow.Buffer",
  574. offset: int,
  575. length: int,
  576. ) -> "pyarrow.Buffer":
  577. """Copy bit-packed binary buffer, if needed."""
  578. bit_offset = offset % 8
  579. byte_offset = offset // 8
  580. byte_length = _bytes_for_bits(bit_offset + length) // 8
  581. if offset > 0 or byte_length < buf.size:
  582. buf = buf.slice(byte_offset, byte_length)
  583. if bit_offset != 0:
  584. # Need to manually shift the buffer to eliminate the bit offset.
  585. buf = _align_bit_offset(buf, bit_offset, byte_length)
  586. return buf
  587. def _copy_offsets_buffer_if_needed(
  588. buf: "pyarrow.Buffer",
  589. arr_type: "pyarrow.DataType",
  590. offset: int,
  591. length: int,
  592. ) -> Tuple["pyarrow.Buffer", int, int]:
  593. """Copy the provided offsets buffer, returning the copied buffer and the
  594. offset + length of the underlying data.
  595. """
  596. import pyarrow as pa
  597. import pyarrow.compute as pac
  598. if (
  599. pa.types.is_large_list(arr_type)
  600. or pa.types.is_large_string(arr_type)
  601. or pa.types.is_large_binary(arr_type)
  602. or pa.types.is_large_unicode(arr_type)
  603. ):
  604. offset_type = pa.int64()
  605. else:
  606. offset_type = pa.int32()
  607. # Copy offset buffer, if needed.
  608. buf = _copy_buffer_if_needed(buf, offset_type, offset, length + 1)
  609. # Reconstruct the offset array so we can determine the offset and length
  610. # of the child array.
  611. offsets = pa.Array.from_buffers(offset_type, length + 1, [None, buf])
  612. child_offset = offsets[0].as_py()
  613. child_length = offsets[-1].as_py() - child_offset
  614. # Create new offsets aligned to 0 for the copied data buffer slice.
  615. offsets = pac.subtract(offsets, child_offset)
  616. if pa.types.is_int32(offset_type):
  617. # We need to cast the resulting Int64Array back down to an Int32Array.
  618. offsets = offsets.cast(offset_type, safe=False)
  619. buf = offsets.buffers()[1]
  620. return buf, child_offset, child_length
  621. def _bytes_for_bits(n: int) -> int:
  622. """Round up n to the nearest multiple of 8.
  623. This is used to get the byte-padded number of bits for n bits.
  624. """
  625. return (n + 7) & (-8)
  626. def _align_bit_offset(
  627. buf: "pyarrow.Buffer",
  628. bit_offset: int,
  629. byte_length: int,
  630. ) -> "pyarrow.Buffer":
  631. """Align the bit offset into the buffer with the front of the buffer by shifting
  632. the buffer and eliminating the offset.
  633. """
  634. import pyarrow as pa
  635. bytes_ = buf.to_pybytes()
  636. bytes_as_int = int.from_bytes(bytes_, sys.byteorder)
  637. bytes_as_int >>= bit_offset
  638. bytes_ = bytes_as_int.to_bytes(byte_length, sys.byteorder)
  639. return pa.py_buffer(bytes_)
  640. def _arrow_table_ipc_reduce(table: "pyarrow.Table"):
  641. """Custom reducer for Arrow Table that works around a zero-copy slicing pickling
  642. bug by using the Arrow IPC format for the underlying serialization.
  643. This is currently used as a fallback for unsupported types (or unknown bugs) for
  644. the manual buffer truncation workaround, e.g. for dense unions.
  645. """
  646. from pyarrow.ipc import RecordBatchStreamWriter
  647. from pyarrow.lib import BufferOutputStream
  648. output_stream = BufferOutputStream()
  649. with RecordBatchStreamWriter(output_stream, schema=table.schema) as wr:
  650. wr.write_table(table)
  651. # NOTE: output_stream.getvalue() materializes the serialized table to a single
  652. # contiguous bytestring, resulting in a few copy. This adds 1-2 extra copies on the
  653. # serialization side, and 1 extra copy on the deserialization side.
  654. return _restore_table_from_ipc, (output_stream.getvalue(),)
  655. def _restore_table_from_ipc(buf: bytes) -> "pyarrow.Table":
  656. """Restore an Arrow Table serialized to Arrow IPC format."""
  657. from pyarrow.ipc import RecordBatchStreamReader
  658. with RecordBatchStreamReader(buf) as reader:
  659. return reader.read_all()
  660. def _is_dense_union(type_: "pyarrow.DataType") -> bool:
  661. """Whether the provided Arrow type is a dense union."""
  662. import pyarrow as pa
  663. return pa.types.is_union(type_) and type_.mode == "dense"