dt_namespace.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. from __future__ import annotations
  2. from dataclasses import dataclass
  3. from typing import TYPE_CHECKING, Callable, Literal
  4. import pyarrow
  5. import pyarrow.compute as pc
  6. from ray.data.datatype import DataType
  7. from ray.data.expressions import pyarrow_udf
  8. if TYPE_CHECKING:
  9. from ray.data.expressions import Expr, UDFExpr
  10. TemporalUnit = Literal[
  11. "year",
  12. "quarter",
  13. "month",
  14. "week",
  15. "day",
  16. "hour",
  17. "minute",
  18. "second",
  19. "millisecond",
  20. "microsecond",
  21. "nanosecond",
  22. ]
  23. @dataclass
  24. class _DatetimeNamespace:
  25. """Datetime namespace for operations on datetime-typed expression columns."""
  26. _expr: "Expr"
  27. def _unary_temporal_int(
  28. self, func: Callable[[pyarrow.Array], pyarrow.Array]
  29. ) -> "UDFExpr":
  30. """Helper for year/month/… that return int32."""
  31. @pyarrow_udf(return_dtype=DataType.int32())
  32. def _udf(arr: pyarrow.Array) -> pyarrow.Array:
  33. return func(arr)
  34. return _udf(self._expr)
  35. # extractors
  36. def year(self) -> "UDFExpr":
  37. """Extract year component."""
  38. return self._unary_temporal_int(pc.year)
  39. def month(self) -> "UDFExpr":
  40. """Extract month component."""
  41. return self._unary_temporal_int(pc.month)
  42. def day(self) -> "UDFExpr":
  43. """Extract day component."""
  44. return self._unary_temporal_int(pc.day)
  45. def hour(self) -> "UDFExpr":
  46. """Extract hour component."""
  47. return self._unary_temporal_int(pc.hour)
  48. def minute(self) -> "UDFExpr":
  49. """Extract minute component."""
  50. return self._unary_temporal_int(pc.minute)
  51. def second(self) -> "UDFExpr":
  52. """Extract second component."""
  53. return self._unary_temporal_int(pc.second)
  54. # formatting
  55. def strftime(self, fmt: str) -> "UDFExpr":
  56. """Format timestamps with a strftime pattern."""
  57. @pyarrow_udf(return_dtype=DataType.string())
  58. def _format(arr: pyarrow.Array) -> pyarrow.Array:
  59. return pc.strftime(arr, format=fmt)
  60. return _format(self._expr)
  61. # rounding
  62. def ceil(self, unit: TemporalUnit) -> "UDFExpr":
  63. """Ceil timestamps to the next multiple of the given unit."""
  64. return_dtype = self._expr.data_type
  65. @pyarrow_udf(return_dtype=return_dtype)
  66. def _ceil(arr: pyarrow.Array) -> pyarrow.Array:
  67. return pc.ceil_temporal(arr, multiple=1, unit=unit)
  68. return _ceil(self._expr)
  69. def floor(self, unit: TemporalUnit) -> "UDFExpr":
  70. """Floor timestamps to the previous multiple of the given unit."""
  71. return_dtype = self._expr.data_type
  72. @pyarrow_udf(return_dtype=return_dtype)
  73. def _floor(arr: pyarrow.Array) -> pyarrow.Array:
  74. return pc.floor_temporal(arr, multiple=1, unit=unit)
  75. return _floor(self._expr)
  76. def round(self, unit: TemporalUnit) -> "UDFExpr":
  77. """Round timestamps to the nearest multiple of the given unit."""
  78. return_dtype = self._expr.data_type
  79. @pyarrow_udf(return_dtype=return_dtype)
  80. def _round(arr: pyarrow.Array) -> pyarrow.Array:
  81. return pc.round_temporal(arr, multiple=1, unit=unit)
  82. return _round(self._expr)