test_utils.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. """Test utilities for Ray.
  2. This module contains test utility classes that are distributed with the Ray package
  3. and can be used by external libraries and tests. These utilities must remain in
  4. _common/ (not in tests/) to be accessible in the Ray package distribution.
  5. """
  6. import asyncio
  7. import inspect
  8. import os
  9. import threading
  10. import time
  11. import traceback
  12. import uuid
  13. from collections.abc import Awaitable
  14. from contextlib import contextmanager
  15. from enum import Enum
  16. from typing import Any, Callable, Dict, Iterator, List, Optional, Set
  17. import ray
  18. import ray._common.usage.usage_lib as ray_usage_lib
  19. import ray._private.utils
  20. from ray._common.network_utils import build_address
  21. @ray.remote(num_cpus=0)
  22. class SignalActor:
  23. """A Ray actor for coordinating test execution through signals.
  24. Useful for testing async coordination, waiting for specific states,
  25. and synchronizing multiple actors or tasks in tests.
  26. """
  27. def __init__(self):
  28. self.ready_event = asyncio.Event()
  29. self.num_waiters = 0
  30. def send(self, clear: bool = False):
  31. self.ready_event.set()
  32. if clear:
  33. self.ready_event.clear()
  34. async def wait(self, should_wait: bool = True):
  35. if should_wait:
  36. self.num_waiters += 1
  37. await self.ready_event.wait()
  38. self.num_waiters -= 1
  39. async def cur_num_waiters(self) -> int:
  40. return self.num_waiters
  41. @ray.remote(num_cpus=0)
  42. class Semaphore:
  43. """A Ray actor implementing a semaphore for test coordination.
  44. Useful for testing resource limiting, concurrency control,
  45. and coordination between multiple actors or tasks.
  46. """
  47. def __init__(self, value: int = 1):
  48. self._sema = asyncio.Semaphore(value=value)
  49. async def acquire(self):
  50. await self._sema.acquire()
  51. async def release(self):
  52. self._sema.release()
  53. async def locked(self) -> bool:
  54. return self._sema.locked()
  55. __all__ = ["SignalActor", "Semaphore"]
  56. def wait_for_condition(
  57. condition_predictor: Callable[..., bool],
  58. timeout: float = 10,
  59. retry_interval_ms: float = 100,
  60. raise_exceptions: bool = False,
  61. **kwargs: Any,
  62. ):
  63. """Wait until a condition is met or time out with an exception.
  64. Args:
  65. condition_predictor: A function that predicts the condition.
  66. timeout: Maximum timeout in seconds.
  67. retry_interval_ms: Retry interval in milliseconds.
  68. raise_exceptions: If true, exceptions that occur while executing
  69. condition_predictor won't be caught and instead will be raised.
  70. **kwargs: Arguments to pass to the condition_predictor.
  71. Returns:
  72. None: Returns when the condition is met.
  73. Raises:
  74. RuntimeError: If the condition is not met before the timeout expires.
  75. """
  76. start = time.time()
  77. last_ex = None
  78. while time.time() - start <= timeout:
  79. try:
  80. if condition_predictor(**kwargs):
  81. return
  82. except Exception:
  83. if raise_exceptions:
  84. raise
  85. last_ex = ray._private.utils.format_error_message(traceback.format_exc())
  86. time.sleep(retry_interval_ms / 1000.0)
  87. message = "The condition wasn't met before the timeout expired."
  88. if last_ex is not None:
  89. message += f" Last exception: {last_ex}"
  90. raise RuntimeError(message)
  91. async def async_wait_for_condition(
  92. condition_predictor: Callable[..., Awaitable[bool]],
  93. timeout: float = 10,
  94. retry_interval_ms: float = 100,
  95. **kwargs: Any,
  96. ):
  97. """Wait until a condition is met or time out with an exception.
  98. Args:
  99. condition_predictor: A function that predicts the condition.
  100. timeout: Maximum timeout in seconds.
  101. retry_interval_ms: Retry interval in milliseconds.
  102. **kwargs: Arguments to pass to the condition_predictor.
  103. Returns:
  104. None: Returns when the condition is met.
  105. Raises:
  106. RuntimeError: If the condition is not met before the timeout expires.
  107. """
  108. start = time.time()
  109. last_ex = None
  110. while time.time() - start <= timeout:
  111. try:
  112. if inspect.iscoroutinefunction(condition_predictor):
  113. if await condition_predictor(**kwargs):
  114. return
  115. else:
  116. if condition_predictor(**kwargs):
  117. return
  118. except Exception as ex:
  119. last_ex = ex
  120. await asyncio.sleep(retry_interval_ms / 1000.0)
  121. message = "The condition wasn't met before the timeout expired."
  122. if last_ex is not None:
  123. message += f" Last exception: {last_ex}"
  124. raise RuntimeError(message)
  125. @contextmanager
  126. def simulate_s3_bucket(
  127. port: int = 5002,
  128. region: str = "us-west-2",
  129. ) -> Iterator[str]:
  130. """Context manager that simulates an S3 bucket and yields the URI.
  131. Args:
  132. port: The port of the localhost endpoint where S3 is being served.
  133. region: The S3 region.
  134. Yields:
  135. str: URI for the simulated S3 bucket.
  136. """
  137. from moto.server import ThreadedMotoServer
  138. old_env = os.environ
  139. os.environ["AWS_ACCESS_KEY_ID"] = "testing"
  140. os.environ["AWS_SECRET_ACCESS_KEY"] = "testing"
  141. os.environ["AWS_SECURITY_TOKEN"] = "testing"
  142. os.environ["AWS_SESSION_TOKEN"] = "testing"
  143. s3_server = f"http://{build_address('localhost', port)}"
  144. server = ThreadedMotoServer(port=port)
  145. server.start()
  146. url = f"s3://{uuid.uuid4().hex}?region={region}&endpoint_override={s3_server}"
  147. yield url
  148. server.stop()
  149. os.environ = old_env
  150. class TelemetryCallsite(Enum):
  151. DRIVER = "driver"
  152. ACTOR = "actor"
  153. TASK = "task"
  154. def _get_library_usages() -> Set[str]:
  155. return set(
  156. ray_usage_lib.get_library_usages_to_report(
  157. ray.experimental.internal_kv.internal_kv_get_gcs_client()
  158. )
  159. )
  160. def _get_extra_usage_tags() -> Dict[str, str]:
  161. return ray_usage_lib.get_extra_usage_tags_to_report(
  162. ray.experimental.internal_kv.internal_kv_get_gcs_client()
  163. )
  164. def check_library_usage_telemetry(
  165. use_lib_fn: Callable[[], None],
  166. *,
  167. callsite: TelemetryCallsite,
  168. expected_library_usages: List[Set[str]],
  169. expected_extra_usage_tags: Optional[Dict[str, str]] = None,
  170. ):
  171. """Helper for writing tests to validate library usage telemetry.
  172. `use_lib_fn` is a callable that will be called from the provided callsite.
  173. After calling it, the telemetry data to export will be validated against
  174. expected_library_usages and expected_extra_usage_tags.
  175. """
  176. assert len(_get_library_usages()) == 0, _get_library_usages()
  177. if callsite == TelemetryCallsite.DRIVER:
  178. use_lib_fn()
  179. elif callsite == TelemetryCallsite.ACTOR:
  180. @ray.remote
  181. class A:
  182. def __init__(self):
  183. use_lib_fn()
  184. a = A.remote()
  185. ray.get(a.__ray_ready__.remote())
  186. elif callsite == TelemetryCallsite.TASK:
  187. @ray.remote
  188. def f():
  189. use_lib_fn()
  190. ray.get(f.remote())
  191. else:
  192. assert False, f"Unrecognized callsite: {callsite}"
  193. library_usages = _get_library_usages()
  194. extra_usage_tags = _get_extra_usage_tags()
  195. assert library_usages in expected_library_usages, library_usages
  196. if expected_extra_usage_tags:
  197. assert all(
  198. [extra_usage_tags[k] == v for k, v in expected_extra_usage_tags.items()]
  199. ), extra_usage_tags
  200. class FakeTimer:
  201. def __init__(self, start_time: Optional[float] = None):
  202. self._lock = threading.Lock()
  203. self.reset(start_time=start_time)
  204. def reset(self, start_time: Optional[float] = None):
  205. with self._lock:
  206. if start_time is None:
  207. start_time = time.time()
  208. self._curr = start_time
  209. def time(self) -> float:
  210. return self._curr
  211. def advance(self, by: float):
  212. with self._lock:
  213. self._curr += by
  214. def realistic_sleep(self, amt: float):
  215. with self._lock:
  216. self._curr += amt + 0.001
  217. def is_named_tuple(cls):
  218. """Return True if cls is a namedtuple and False otherwise."""
  219. b = cls.__bases__
  220. if len(b) != 1 or b[0] is not tuple:
  221. return False
  222. f = getattr(cls, "_fields", None)
  223. if not isinstance(f, tuple):
  224. return False
  225. return all(type(n) is str for n in f)
  226. def assert_tensors_equivalent(obj1, obj2):
  227. """
  228. Recursively compare objects with special handling for torch.Tensor.
  229. Tensors are considered equivalent if:
  230. - Same dtype and shape
  231. - Same device type (e.g., both 'cpu' or both 'cuda'), index ignored
  232. - Values are equal (or close for floats)
  233. """
  234. import torch
  235. if isinstance(obj1, torch.Tensor) and isinstance(obj2, torch.Tensor):
  236. # 1. dtype
  237. assert obj1.dtype == obj2.dtype, f"dtype mismatch: {obj1.dtype} vs {obj2.dtype}"
  238. # 2. shape
  239. assert obj1.shape == obj2.shape, f"shape mismatch: {obj1.shape} vs {obj2.shape}"
  240. # 3. device type must match (cpu/cpu or cuda/cuda), ignore index
  241. assert (
  242. obj1.device.type == obj2.device.type
  243. ), f"Device type mismatch: {obj1.device} vs {obj2.device}"
  244. # 4. Compare values safely on CPU
  245. t1_cpu = obj1.cpu()
  246. t2_cpu = obj2.cpu()
  247. if obj1.dtype.is_floating_point or obj1.dtype.is_complex:
  248. assert torch.allclose(
  249. t1_cpu, t2_cpu, atol=1e-6, rtol=1e-5
  250. ), "Floating-point tensors not close"
  251. else:
  252. assert torch.equal(t1_cpu, t2_cpu), "Integer/bool tensors not equal"
  253. return
  254. # Type must match
  255. if type(obj1) is not type(obj2):
  256. raise AssertionError(f"Type mismatch: {type(obj1)} vs {type(obj2)}")
  257. # Handle namedtuples
  258. if is_named_tuple(type(obj1)):
  259. assert len(obj1) == len(obj2)
  260. for a, b in zip(obj1, obj2):
  261. assert_tensors_equivalent(a, b)
  262. elif isinstance(obj1, dict):
  263. assert obj1.keys() == obj2.keys()
  264. for k in obj1:
  265. assert_tensors_equivalent(obj1[k], obj2[k])
  266. elif isinstance(obj1, (list, tuple)):
  267. assert len(obj1) == len(obj2)
  268. for a, b in zip(obj1, obj2):
  269. assert_tensors_equivalent(a, b)
  270. elif hasattr(obj1, "__dict__") and hasattr(obj2, "__dict__"):
  271. # Compare user-defined objects by their public attributes
  272. keys1 = {
  273. k
  274. for k in obj1.__dict__.keys()
  275. if not k.startswith("_ray_") and k != "_pytype_"
  276. }
  277. keys2 = {
  278. k
  279. for k in obj2.__dict__.keys()
  280. if not k.startswith("_ray_") and k != "_pytype_"
  281. }
  282. assert keys1 == keys2, f"Object attribute keys differ: {keys1} vs {keys2}"
  283. for k in keys1:
  284. assert_tensors_equivalent(obj1.__dict__[k], obj2.__dict__[k])
  285. else:
  286. # Fallback for primitives: int, float, str, bool, etc.
  287. assert obj1 == obj2, f"Non-tensor values differ: {obj1} vs {obj2}"