context.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import os
  2. import threading
  3. from dataclasses import dataclass
  4. from typing import Optional
  5. from ray.util.annotations import DeveloperAPI
  6. # The context singleton on this process.
  7. _default_context: "Optional[DAGContext]" = None
  8. _context_lock = threading.Lock()
  9. DEFAULT_SUBMIT_TIMEOUT_S = int(os.environ.get("RAY_CGRAPH_submit_timeout", 10))
  10. DEFAULT_GET_TIMEOUT_S = int(os.environ.get("RAY_CGRAPH_get_timeout", 10))
  11. DEFAULT_TEARDOWN_TIMEOUT_S = int(os.environ.get("RAY_CGRAPH_teardown_timeout", 30))
  12. DEFAULT_READ_ITERATION_TIMEOUT_S = float(
  13. os.environ.get("RAY_CGRAPH_read_iteration_timeout_s", 0.1)
  14. )
  15. # Default buffer size is 1MB.
  16. DEFAULT_BUFFER_SIZE_BYTES = int(os.environ.get("RAY_CGRAPH_buffer_size_bytes", 1e6))
  17. # The default number of in-flight executions that can be submitted before consuming the
  18. # output.
  19. DEFAULT_MAX_INFLIGHT_EXECUTIONS = int(
  20. os.environ.get("RAY_CGRAPH_max_inflight_executions", 10)
  21. )
  22. # The default number of results that can be buffered at the driver.
  23. DEFAULT_MAX_BUFFERED_RESULTS = int(
  24. os.environ.get("RAY_CGRAPH_max_buffered_results", 1000)
  25. )
  26. DEFAULT_OVERLAP_GPU_COMMUNICATION = bool(
  27. os.environ.get("RAY_CGRAPH_overlap_gpu_communication", 0)
  28. )
  29. @DeveloperAPI
  30. @dataclass
  31. class DAGContext:
  32. """Global settings for Ray DAG.
  33. You can configure parameters in the DAGContext by setting the environment
  34. variables, `RAY_CGRAPH_<param>` (e.g., `RAY_CGRAPH_buffer_size_bytes`) or Python.
  35. Examples:
  36. >>> from ray.dag import DAGContext
  37. >>> DAGContext.get_current().buffer_size_bytes
  38. 1000000
  39. >>> DAGContext.get_current().buffer_size_bytes = 500
  40. >>> DAGContext.get_current().buffer_size_bytes
  41. 500
  42. Args:
  43. submit_timeout: The maximum time in seconds to wait for execute()
  44. calls.
  45. get_timeout: The maximum time in seconds to wait when retrieving
  46. a result from the DAG during `ray.get`. This should be set to a
  47. value higher than the expected time to execute the entire DAG.
  48. teardown_timeout: The maximum time in seconds to wait for the DAG to
  49. cleanly shut down.
  50. read_iteration_timeout: The timeout in seconds for each read iteration
  51. that reads one of the input channels. If the timeout is reached, the
  52. read operation will be interrupted and will try to read the next
  53. input channel. It must be less than or equal to `get_timeout`.
  54. buffer_size_bytes: The initial buffer size in bytes for messages
  55. that can be passed between tasks in the DAG. The buffers will
  56. be automatically resized if larger messages are written to the
  57. channel.
  58. max_inflight_executions: The maximum number of in-flight executions that
  59. can be submitted via `execute` or `execute_async` before consuming
  60. the output using `ray.get()`. If the caller submits more executions,
  61. `RayCgraphCapacityExceeded` is raised.
  62. overlap_gpu_communication: (experimental) Whether to overlap GPU
  63. communication with computation during DAG execution. If True, the
  64. communication and computation can be overlapped, which can improve
  65. the performance of the DAG execution.
  66. """
  67. submit_timeout: int = DEFAULT_SUBMIT_TIMEOUT_S
  68. get_timeout: int = DEFAULT_GET_TIMEOUT_S
  69. teardown_timeout: int = DEFAULT_TEARDOWN_TIMEOUT_S
  70. read_iteration_timeout: float = DEFAULT_READ_ITERATION_TIMEOUT_S
  71. buffer_size_bytes: int = DEFAULT_BUFFER_SIZE_BYTES
  72. max_inflight_executions: int = DEFAULT_MAX_INFLIGHT_EXECUTIONS
  73. max_buffered_results: int = DEFAULT_MAX_BUFFERED_RESULTS
  74. overlap_gpu_communication: bool = DEFAULT_OVERLAP_GPU_COMMUNICATION
  75. def __post_init__(self):
  76. if self.read_iteration_timeout > self.get_timeout:
  77. raise ValueError(
  78. "RAY_CGRAPH_read_iteration_timeout_s "
  79. f"({self.read_iteration_timeout}) must be less than or equal to "
  80. f"RAY_CGRAPH_get_timeout ({self.get_timeout})"
  81. )
  82. @staticmethod
  83. def get_current() -> "DAGContext":
  84. """Get or create a singleton context.
  85. If the context has not yet been created in this process, it will be
  86. initialized with default settings.
  87. """
  88. global _default_context
  89. with _context_lock:
  90. if _default_context is None:
  91. _default_context = DAGContext()
  92. return _default_context