optimizations.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. import warnings
  2. import dask
  3. from dask import core
  4. from dask.dataframe.core import _concat
  5. from dask.highlevelgraph import HighLevelGraph
  6. from .scheduler import MultipleReturnFunc, multiple_return_get
  7. try:
  8. from dask.dataframe.optimize import optimize
  9. from dask.dataframe.shuffle import SimpleShuffleLayer, shuffle_group
  10. except ImportError:
  11. # SimpleShuffleLayer doesn't exist in this version of Dask.
  12. # This is the case for dask>=2025.1.0.
  13. SimpleShuffleLayer = None
  14. try:
  15. import dask_expr # noqa: F401
  16. SimpleShuffleLayer = None
  17. except ImportError:
  18. pass
  19. if SimpleShuffleLayer is not None:
  20. class MultipleReturnSimpleShuffleLayer(SimpleShuffleLayer):
  21. @classmethod
  22. def clone(cls, layer: SimpleShuffleLayer):
  23. # TODO(Clark): Probably don't need this since SimpleShuffleLayer
  24. # implements __copy__() and the shallow clone should be enough?
  25. return cls(
  26. name=layer.name,
  27. column=layer.column,
  28. npartitions=layer.npartitions,
  29. npartitions_input=layer.npartitions_input,
  30. ignore_index=layer.ignore_index,
  31. name_input=layer.name_input,
  32. meta_input=layer.meta_input,
  33. parts_out=layer.parts_out,
  34. annotations=layer.annotations,
  35. )
  36. def __repr__(self):
  37. return (
  38. f"MultipleReturnSimpleShuffleLayer<name='{self.name}', "
  39. f"npartitions={self.npartitions}>"
  40. )
  41. def __reduce__(self):
  42. attrs = [
  43. "name",
  44. "column",
  45. "npartitions",
  46. "npartitions_input",
  47. "ignore_index",
  48. "name_input",
  49. "meta_input",
  50. "parts_out",
  51. "annotations",
  52. ]
  53. return (
  54. MultipleReturnSimpleShuffleLayer,
  55. tuple(getattr(self, attr) for attr in attrs),
  56. )
  57. def _cull(self, parts_out):
  58. return MultipleReturnSimpleShuffleLayer(
  59. self.name,
  60. self.column,
  61. self.npartitions,
  62. self.npartitions_input,
  63. self.ignore_index,
  64. self.name_input,
  65. self.meta_input,
  66. parts_out=parts_out,
  67. )
  68. def _construct_graph(self):
  69. """Construct graph for a simple shuffle operation."""
  70. shuffle_group_name = "group-" + self.name
  71. shuffle_split_name = "split-" + self.name
  72. dsk = {}
  73. n_parts_out = len(self.parts_out)
  74. for part_out in self.parts_out:
  75. # TODO(Clark): Find better pattern than in-scheduler concat.
  76. _concat_list = [
  77. (shuffle_split_name, part_out, part_in)
  78. for part_in in range(self.npartitions_input)
  79. ]
  80. dsk[(self.name, part_out)] = (_concat, _concat_list, self.ignore_index)
  81. for _, _part_out, _part_in in _concat_list:
  82. dsk[(shuffle_split_name, _part_out, _part_in)] = (
  83. multiple_return_get,
  84. (shuffle_group_name, _part_in),
  85. _part_out,
  86. )
  87. if (shuffle_group_name, _part_in) not in dsk:
  88. dsk[(shuffle_group_name, _part_in)] = (
  89. MultipleReturnFunc(
  90. shuffle_group,
  91. n_parts_out,
  92. ),
  93. (self.name_input, _part_in),
  94. self.column,
  95. 0,
  96. self.npartitions,
  97. self.npartitions,
  98. self.ignore_index,
  99. self.npartitions,
  100. )
  101. return dsk
  102. def rewrite_simple_shuffle_layer(dsk, keys):
  103. if not isinstance(dsk, HighLevelGraph):
  104. dsk = HighLevelGraph.from_collections(id(dsk), dsk, dependencies=())
  105. else:
  106. dsk = dsk.copy()
  107. layers = dsk.layers.copy()
  108. for key, layer in layers.items():
  109. if type(layer) is SimpleShuffleLayer:
  110. dsk.layers[key] = MultipleReturnSimpleShuffleLayer.clone(layer)
  111. return dsk
  112. def dataframe_optimize(dsk, keys, **kwargs):
  113. if not isinstance(keys, (list, set)):
  114. keys = [keys]
  115. keys = list(core.flatten(keys))
  116. if not isinstance(dsk, HighLevelGraph):
  117. dsk = HighLevelGraph.from_collections(id(dsk), dsk, dependencies=())
  118. dsk = rewrite_simple_shuffle_layer(dsk, keys=keys)
  119. return optimize(dsk, keys, **kwargs)
  120. else:
  121. def dataframe_optimize(dsk, keys, **kwargs):
  122. warnings.warn(
  123. "Custom dataframe shuffle optimization only works on "
  124. "dask>=2024.11.0,<2025.1.0, you are on version "
  125. f"{dask.__version__}."
  126. "Doing no additional optimization aside from the default one."
  127. )
  128. return None