utils.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770
  1. import asyncio
  2. import collections
  3. import copy
  4. import errno
  5. import importlib
  6. import inspect
  7. import logging
  8. import random
  9. import re
  10. import time
  11. import uuid
  12. from decimal import ROUND_HALF_UP, Decimal
  13. from enum import Enum
  14. from functools import wraps
  15. from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
  16. import requests
  17. import ray
  18. import ray.util.serialization_addons
  19. from ray._common.constants import HEAD_NODE_RESOURCE_NAME
  20. from ray._common.utils import get_random_alphanumeric_string, import_attr
  21. from ray._private.worker import LOCAL_MODE, SCRIPT_MODE
  22. from ray._raylet import MessagePackSerializer
  23. from ray.actor import ActorHandle
  24. from ray.serve._private.common import RequestMetadata, ServeComponentType
  25. from ray.serve._private.constants import HTTP_PROXY_TIMEOUT, SERVE_LOGGER_NAME
  26. from ray.types import ObjectRef
  27. from ray.util.serialization import StandaloneSerializationContext
  28. try:
  29. import pandas as pd
  30. except ImportError:
  31. pd = None
  32. try:
  33. import numpy as np
  34. except ImportError:
  35. np = None
  36. FILE_NAME_REGEX = r"[^\x20-\x7E]|[<>:\"/\\|?*]"
  37. MESSAGE_PACK_OFFSET = 9
  38. def asyncio_grpc_exception_handler(loop, context):
  39. """Exception handler to filter out false positive BlockingIOErrors from gRPC."""
  40. exc = context.get("exception")
  41. msg = context.get("message")
  42. if (
  43. exc
  44. and isinstance(exc, BlockingIOError)
  45. and exc.errno == errno.EAGAIN
  46. and "PollerCompletionQueue._handle_events" in msg
  47. ):
  48. return
  49. loop.default_exception_handler(context)
  50. def validate_ssl_config(
  51. ssl_certfile: Optional[str], ssl_keyfile: Optional[str]
  52. ) -> None:
  53. """Validate SSL configuration for HTTPS support.
  54. Args:
  55. ssl_certfile: Path to SSL certificate file
  56. ssl_keyfile: Path to SSL private key file
  57. Raises:
  58. ValueError: If only one of ssl_certfile or ssl_keyfile is provided
  59. """
  60. if (ssl_certfile and not ssl_keyfile) or (ssl_keyfile and not ssl_certfile):
  61. raise ValueError(
  62. "Both ssl_keyfile and ssl_certfile must be provided together "
  63. "to enable HTTPS."
  64. )
  65. GENERATOR_COMPOSITION_NOT_SUPPORTED_ERROR = RuntimeError(
  66. "Streaming deployment handle results cannot be passed to "
  67. "downstream handle calls. If you have a use case requiring "
  68. "this feature, please file a feature request on GitHub."
  69. )
  70. # Use a global singleton enum to emulate default options. We cannot use None
  71. # for those option because None is a valid new value.
  72. class DEFAULT(Enum):
  73. VALUE = 1
  74. class DeploymentOptionUpdateType(str, Enum):
  75. # Nothing needs to be done other than setting the target state.
  76. LightWeight = "LightWeight"
  77. # Each DeploymentReplica instance (tracked in DeploymentState) uses certain options
  78. # from the deployment config. These values need to be updated in DeploymentReplica.
  79. NeedsReconfigure = "NeedsReconfigure"
  80. # Options that are sent to the replica actor. If changed, reconfigure() on the actor
  81. # needs to be called to update these values.
  82. NeedsActorReconfigure = "NeedsActorReconfigure"
  83. # If changed, restart all replicas.
  84. HeavyWeight = "HeavyWeight"
  85. # Type alias: objects that can be DEFAULT.VALUE have type Default[T]
  86. T = TypeVar("T")
  87. Default = Union[DEFAULT, T]
  88. logger = logging.getLogger(SERVE_LOGGER_NAME)
  89. # Format for component files
  90. FILE_FMT = "{component_name}_{component_id}{suffix}"
  91. class _ServeCustomEncoders:
  92. """Group of custom encoders for common types that's not handled by FastAPI."""
  93. @staticmethod
  94. def encode_np_array(obj):
  95. assert isinstance(obj, np.ndarray)
  96. if obj.dtype.kind == "f": # floats
  97. obj = obj.astype(float)
  98. if obj.dtype.kind in {"i", "u"}: # signed and unsigned integers.
  99. obj = obj.astype(int)
  100. return obj.tolist()
  101. @staticmethod
  102. def encode_np_scaler(obj):
  103. assert isinstance(obj, np.generic)
  104. return obj.item()
  105. @staticmethod
  106. def encode_exception(obj):
  107. assert isinstance(obj, Exception)
  108. return str(obj)
  109. @staticmethod
  110. def encode_pandas_dataframe(obj):
  111. assert isinstance(obj, pd.DataFrame)
  112. return obj.to_dict(orient="records")
  113. serve_encoders = {Exception: _ServeCustomEncoders.encode_exception}
  114. if np is not None:
  115. serve_encoders[np.ndarray] = _ServeCustomEncoders.encode_np_array
  116. serve_encoders[np.generic] = _ServeCustomEncoders.encode_np_scaler
  117. if pd is not None:
  118. serve_encoders[pd.DataFrame] = _ServeCustomEncoders.encode_pandas_dataframe
  119. @ray.remote(num_cpus=0)
  120. def block_until_http_ready(
  121. http_endpoint,
  122. backoff_time_s=1,
  123. check_ready=None,
  124. timeout=HTTP_PROXY_TIMEOUT,
  125. ):
  126. http_is_ready = False
  127. start_time = time.time()
  128. while not http_is_ready:
  129. try:
  130. resp = requests.get(http_endpoint)
  131. assert resp.status_code == 200
  132. if check_ready is None:
  133. http_is_ready = True
  134. else:
  135. http_is_ready = check_ready(resp)
  136. except Exception:
  137. pass
  138. if 0 < timeout < time.time() - start_time:
  139. raise TimeoutError("HTTP proxy not ready after {} seconds.".format(timeout))
  140. time.sleep(backoff_time_s)
  141. def get_random_string(length: int = 8):
  142. return get_random_alphanumeric_string(length)
  143. def format_actor_name(actor_name, *modifiers):
  144. name = actor_name
  145. for modifier in modifiers:
  146. name += "-{}".format(modifier)
  147. return name
  148. CLASS_WRAPPER_METADATA_ATTRS = (
  149. "__name__",
  150. "__qualname__",
  151. "__module__",
  152. "__doc__",
  153. "__annotations__",
  154. )
  155. def copy_class_metadata(wrapper_cls, target_cls) -> None:
  156. """Copy common class-level metadata onto a wrapper class."""
  157. for attr in CLASS_WRAPPER_METADATA_ATTRS:
  158. if attr == "__annotations__":
  159. target_annotations = getattr(target_cls, "__annotations__", None)
  160. if target_annotations:
  161. merged_annotations = dict(
  162. wrapper_cls.__dict__.get("__annotations__", {})
  163. )
  164. for key, value in target_annotations.items():
  165. merged_annotations.setdefault(key, value)
  166. wrapper_cls.__annotations__ = merged_annotations
  167. continue
  168. if hasattr(target_cls, attr):
  169. setattr(wrapper_cls, attr, getattr(target_cls, attr))
  170. wrapper_cls.__wrapped__ = target_cls
  171. def ensure_serialization_context():
  172. """Ensure the serialization addons on registered, even when Ray has not
  173. been started."""
  174. ctx = StandaloneSerializationContext()
  175. ray.util.serialization_addons.apply(ctx)
  176. def msgpack_serialize(obj):
  177. ctx = ray._private.worker.global_worker.get_serialization_context()
  178. buffer = ctx.serialize(obj)
  179. serialized = buffer.to_bytes()
  180. return serialized
  181. def msgpack_deserialize(data):
  182. # todo: Ray does not provide a msgpack deserialization api.
  183. try:
  184. obj = MessagePackSerializer.loads(data[MESSAGE_PACK_OFFSET:], None)
  185. except Exception:
  186. raise
  187. return obj
  188. def merge_dict(dict1, dict2):
  189. if dict1 is None and dict2 is None:
  190. return None
  191. if dict1 is None:
  192. dict1 = dict()
  193. if dict2 is None:
  194. dict2 = dict()
  195. result = dict()
  196. for key in dict1.keys() | dict2.keys():
  197. result[key] = sum([e.get(key, 0) for e in (dict1, dict2)])
  198. return result
  199. def parse_import_path(import_path: str):
  200. """
  201. Takes in an import_path of form:
  202. [subdirectory 1].[subdir 2]...[subdir n].[file name].[attribute name]
  203. Parses this path and returns the module name (everything before the last
  204. dot) and attribute name (everything after the last dot), such that the
  205. attribute can be imported using "from module_name import attr_name".
  206. """
  207. nodes = import_path.split(".")
  208. if len(nodes) < 2:
  209. raise ValueError(
  210. f"Got {import_path} as import path. The import path "
  211. f"should at least specify the file name and "
  212. f"attribute name connected by a dot."
  213. )
  214. return ".".join(nodes[:-1]), nodes[-1]
  215. def override_runtime_envs_except_env_vars(parent_env: Dict, child_env: Dict) -> Dict:
  216. """Creates a runtime_env dict by merging a parent and child environment.
  217. This method is not destructive. It leaves the parent and child envs
  218. the same.
  219. The merge is a shallow update where the child environment inherits the
  220. parent environment's settings. If the child environment specifies any
  221. env settings, those settings take precdence over the parent.
  222. - Note: env_vars are a special case. The child's env_vars are combined
  223. with the parent.
  224. Args:
  225. parent_env: The environment to inherit settings from.
  226. child_env: The environment with override settings.
  227. Returns: A new dictionary containing the merged runtime_env settings.
  228. Raises:
  229. TypeError: If a dictionary is not passed in for parent_env or child_env.
  230. """
  231. if not isinstance(parent_env, Dict):
  232. raise TypeError(
  233. f'Got unexpected type "{type(parent_env)}" for parent_env. '
  234. "parent_env must be a dictionary."
  235. )
  236. if not isinstance(child_env, Dict):
  237. raise TypeError(
  238. f'Got unexpected type "{type(child_env)}" for child_env. '
  239. "child_env must be a dictionary."
  240. )
  241. defaults = copy.deepcopy(parent_env)
  242. overrides = copy.deepcopy(child_env)
  243. default_env_vars = defaults.get("env_vars", {})
  244. override_env_vars = overrides.get("env_vars", {})
  245. defaults.update(overrides)
  246. default_env_vars.update(override_env_vars)
  247. defaults["env_vars"] = default_env_vars
  248. return defaults
  249. class JavaActorHandleProxy:
  250. """Wraps actor handle and translate snake_case to camelCase."""
  251. def __init__(self, handle: ActorHandle):
  252. self.handle = handle
  253. self._available_attrs = set(dir(self.handle))
  254. def __getattr__(self, key: str):
  255. if key in self._available_attrs:
  256. camel_case_key = key
  257. else:
  258. components = key.split("_")
  259. camel_case_key = components[0] + "".join(x.title() for x in components[1:])
  260. return getattr(self.handle, camel_case_key)
  261. def require_packages(packages: List[str]):
  262. """Decorator making sure function run in specified environments
  263. Examples:
  264. >>> from ray.serve._private.utils import require_packages
  265. >>> @require_packages(["numpy", "package_a"]) # doctest: +SKIP
  266. ... def func(): # doctest: +SKIP
  267. ... import numpy as np # doctest: +SKIP
  268. ... ... # doctest: +SKIP
  269. >>> func() # doctest: +SKIP
  270. ImportError: func requires ["numpy", "package_a"] but
  271. ["package_a"] are not available, please pip install them.
  272. """
  273. def decorator(func):
  274. def check_import_once():
  275. if not hasattr(func, "_require_packages_checked"):
  276. missing_packages = []
  277. for package in packages:
  278. try:
  279. importlib.import_module(package)
  280. except ModuleNotFoundError:
  281. missing_packages.append(package)
  282. if len(missing_packages) > 0:
  283. raise ImportError(
  284. f"{func} requires packages {packages} to run but "
  285. f"{missing_packages} are missing. Please "
  286. "`pip install` them or add them to "
  287. "`runtime_env`."
  288. )
  289. func._require_packages_checked = True
  290. if inspect.iscoroutinefunction(func):
  291. @wraps(func)
  292. async def wrapped(*args, **kwargs):
  293. check_import_once()
  294. return await func(*args, **kwargs)
  295. elif inspect.isroutine(func):
  296. @wraps(func)
  297. def wrapped(*args, **kwargs):
  298. check_import_once()
  299. return func(*args, **kwargs)
  300. else:
  301. raise ValueError("Decorator expect callable functions.")
  302. return wrapped
  303. return decorator
  304. def in_interactive_shell():
  305. # Taken from:
  306. # https://stackoverflow.com/questions/15411967/how-can-i-check-if-code-is-executed-in-the-ipython-notebook
  307. import __main__ as main
  308. return not hasattr(main, "__file__")
  309. def snake_to_camel_case(snake_str: str) -> str:
  310. """Convert a snake case string to camel case."""
  311. words = snake_str.strip("_").split("_")
  312. return words[0] + "".join(word[:1].upper() + word[1:] for word in words[1:])
  313. def check_obj_ref_ready_nowait(obj_ref: ObjectRef) -> bool:
  314. """Check if ray object reference is ready without waiting for it."""
  315. finished, _ = ray.wait([obj_ref], timeout=0)
  316. return len(finished) == 1
  317. def extract_self_if_method_call(args: List[Any], func: Callable) -> Optional[object]:
  318. """Check if this is a method rather than a function.
  319. Does this by checking to see if `func` is the attribute of the first
  320. (`self`) argument under `func.__name__`. Unfortunately, this is the most
  321. robust solution to this I was able to find. It would also be preferable
  322. to do this check when the decorator runs, rather than when the method is.
  323. Returns the `self` object if it's a method call, else None.
  324. Arguments:
  325. args: arguments to the function/method call.
  326. func: the unbound function that was called.
  327. """
  328. if len(args) > 0:
  329. method = getattr(args[0], func.__name__, False)
  330. if method:
  331. wrapped = getattr(method, "__wrapped__", False)
  332. if wrapped and wrapped == func:
  333. return args[0]
  334. return None
  335. def call_function_from_import_path(import_path: str) -> Any:
  336. """Call the function given import path.
  337. Args:
  338. import_path: The import path of the function to call.
  339. Raises:
  340. ValueError: If the import path is invalid.
  341. TypeError: If the import path is not callable.
  342. RuntimeError: if the function raise exeception during execution.
  343. Returns:
  344. The result of the function call.
  345. """
  346. try:
  347. callback_func = import_attr(import_path)
  348. except Exception as e:
  349. raise ValueError(f"The import path {import_path} cannot be imported: {e}")
  350. if not callable(callback_func):
  351. raise TypeError(f"The import path {import_path} is not callable.")
  352. try:
  353. return callback_func()
  354. except Exception as e:
  355. raise RuntimeError(f"The function {import_path} raised an exception: {e}")
  356. def get_head_node_id() -> str:
  357. """Get the head node id.
  358. Iterate through all nodes in the ray cluster and return the node id of the first
  359. alive node with head node resource.
  360. """
  361. head_node_id = None
  362. for node in ray.nodes():
  363. if HEAD_NODE_RESOURCE_NAME in node["Resources"] and node["Alive"]:
  364. head_node_id = node["NodeID"]
  365. break
  366. assert head_node_id is not None, "Cannot find alive head node."
  367. return head_node_id
  368. def calculate_remaining_timeout(
  369. *,
  370. timeout_s: Optional[float],
  371. start_time_s: float,
  372. curr_time_s: float,
  373. ) -> Optional[float]:
  374. """Get the timeout remaining given an overall timeout, start time, and curr time.
  375. If the timeout passed in was `None` or negative, will always return that timeout
  376. directly.
  377. If the timeout is >= 0, the returned remaining timeout always be >= 0.
  378. """
  379. if timeout_s is None or timeout_s < 0:
  380. return timeout_s
  381. time_since_start_s = curr_time_s - start_time_s
  382. return max(0, timeout_s - time_since_start_s)
  383. def get_all_live_placement_group_names() -> List[str]:
  384. """Fetch and parse the Ray placement group table for live placement group names.
  385. Placement groups are filtered based on their `scheduling_state`; any placement
  386. group not in the "REMOVED" state is considered live.
  387. """
  388. placement_group_table = ray.util.placement_group_table()
  389. live_pg_names = []
  390. for entry in placement_group_table.values():
  391. pg_name = entry.get("name", "")
  392. if (
  393. pg_name
  394. and entry.get("stats", {}).get("scheduling_state", "UNKNOWN") != "REMOVED"
  395. ):
  396. live_pg_names.append(pg_name)
  397. return live_pg_names
  398. def get_current_actor_id() -> str:
  399. """Gets the ID of the calling actor.
  400. If this is called in a driver, returns "DRIVER."
  401. If otherwise called outside of an actor, returns an empty string.
  402. This function hangs when GCS is down due to the `ray.get_runtime_context()`
  403. call.
  404. """
  405. worker_mode = ray.get_runtime_context().worker.mode
  406. if worker_mode in {SCRIPT_MODE, LOCAL_MODE}:
  407. return "DRIVER"
  408. else:
  409. try:
  410. actor_id = ray.get_runtime_context().get_actor_id()
  411. if actor_id is None:
  412. return ""
  413. else:
  414. return actor_id
  415. except Exception:
  416. return ""
  417. def is_running_in_asyncio_loop() -> bool:
  418. try:
  419. asyncio.get_running_loop()
  420. return True
  421. except RuntimeError:
  422. return False
  423. def get_capacity_adjusted_num_replicas(
  424. num_replicas: int, target_capacity: Optional[float]
  425. ) -> int:
  426. """Return the `num_replicas` adjusted by the `target_capacity`.
  427. The output will only ever be 0 if `target_capacity` is 0 or `num_replicas` is
  428. 0 (to support autoscaling deployments using scale-to-zero).
  429. Rather than using the default `round` behavior in Python, which rounds half to
  430. even, uses the `decimal` module to round half up (standard rounding behavior).
  431. """
  432. if target_capacity is None or target_capacity == 100:
  433. return num_replicas
  434. if target_capacity == 0 or num_replicas == 0:
  435. return 0
  436. adjusted_num_replicas = Decimal(num_replicas * target_capacity) / Decimal(100.0)
  437. rounded_adjusted_num_replicas = adjusted_num_replicas.to_integral_value(
  438. rounding=ROUND_HALF_UP
  439. )
  440. return max(1, int(rounded_adjusted_num_replicas))
  441. def generate_request_id() -> str:
  442. # NOTE(edoakes): we use random.getrandbits because it reduces CPU overhead
  443. # significantly. This is less cryptographically secure but should be ok for
  444. # request ID generation.
  445. # See https://bugs.python.org/issue45556 for discussion.
  446. return str(uuid.UUID(int=random.getrandbits(128), version=4))
  447. def inside_ray_client_context() -> bool:
  448. return ray.util.client.ray.is_connected()
  449. def get_component_file_name(
  450. component_name: str,
  451. component_id: str,
  452. component_type: Optional[ServeComponentType],
  453. suffix: str = "",
  454. ) -> str:
  455. """Get the component's file name. Replaces special characters with underscores."""
  456. component_name = re.sub(FILE_NAME_REGEX, "_", component_name)
  457. # For DEPLOYMENT component type, we want to log the deployment name
  458. # instead of adding the component type to the component name.
  459. component_log_file_name = component_name
  460. if component_type is not None:
  461. component_log_file_name = f"{component_type.value}_{component_name}"
  462. if component_type != ServeComponentType.REPLICA:
  463. component_name = f"{component_type}_{component_name}"
  464. file_name = FILE_FMT.format(
  465. component_name=component_log_file_name,
  466. component_id=component_id,
  467. suffix=suffix,
  468. )
  469. return file_name
  470. def validate_route_prefix(route_prefix: Union[DEFAULT, None, str]):
  471. if route_prefix is DEFAULT.VALUE or route_prefix is None:
  472. return
  473. if not route_prefix.startswith("/"):
  474. raise ValueError(
  475. f"Invalid route_prefix '{route_prefix}', "
  476. "must start with a forward slash ('/')."
  477. )
  478. if route_prefix != "/" and route_prefix.endswith("/"):
  479. raise ValueError(
  480. f"Invalid route_prefix '{route_prefix}', "
  481. "may not end with a trailing '/'."
  482. )
  483. if "{" in route_prefix or "}" in route_prefix:
  484. raise ValueError(
  485. f"Invalid route_prefix '{route_prefix}', may not contain wildcards."
  486. )
  487. async def await_deployment_response(deployment_response):
  488. return await deployment_response
  489. async def resolve_deployment_response(obj: Any, request_metadata: RequestMetadata):
  490. """Resolve `DeploymentResponse` objects to underlying object references.
  491. This enables composition without explicitly calling `_to_object_ref`.
  492. """
  493. from ray.serve.handle import DeploymentResponse, DeploymentResponseGenerator
  494. if isinstance(obj, DeploymentResponseGenerator):
  495. raise GENERATOR_COMPOSITION_NOT_SUPPORTED_ERROR
  496. elif isinstance(obj, DeploymentResponse):
  497. if request_metadata._by_reference and obj.by_reference:
  498. # If sending requests by reference, launch async task to
  499. # convert DeploymentResponse to an object ref
  500. return asyncio.create_task(obj._to_object_ref())
  501. else:
  502. # Otherwise, resolve DeploymentResponse directly to result
  503. return asyncio.create_task(await_deployment_response(obj))
  504. elif not request_metadata._by_reference and isinstance(obj, ray.ObjectRef):
  505. # If the router is sending requests by value (i.e. using gRPC),
  506. # resolve all Ray objects to mirror Ray behavior
  507. return asyncio.wrap_future(obj.future())
  508. def wait_for_interrupt() -> None:
  509. try:
  510. while True:
  511. # Block, letting Ray print logs to the terminal.
  512. time.sleep(10)
  513. except KeyboardInterrupt:
  514. logger.warning("Got KeyboardInterrupt, exiting...")
  515. # We need to re-raise KeyboardInterrupt, so serve components can be shutdown
  516. # from the main script.
  517. raise
  518. def is_grpc_enabled(grpc_config) -> bool:
  519. return grpc_config.port > 0 and len(grpc_config.grpc_servicer_functions) > 0
  520. class Semaphore:
  521. """Based on asyncio.Semaphore.
  522. This is a semaphore that can be used to limit the number of concurrent requests.
  523. Its maximum value is dynamic and is determined by the `get_value_fn` function.
  524. """
  525. def __init__(self, get_value_fn: Callable[[], int]):
  526. self._waiters = None
  527. self._value = 0
  528. self._get_value_fn = get_value_fn
  529. def __repr__(self):
  530. res = super().__repr__()
  531. extra = "locked" if self.locked() else f"unlocked, value:{self._value}"
  532. if self._waiters:
  533. extra = f"{extra}, waiters:{len(self._waiters)}"
  534. return f"<{res[1:-1]} [{extra}]>"
  535. async def __aenter__(self):
  536. await self.acquire()
  537. # We have no use for the "as ..." clause in the with
  538. # statement for locks.
  539. return None
  540. async def __aexit__(self, exc_type, exc, tb):
  541. self.release()
  542. def get_max_value(self):
  543. return self._get_value_fn()
  544. def locked(self):
  545. """Returns True if semaphore cannot be acquired immediately."""
  546. return self._value >= self.get_max_value() or (
  547. any(not w.cancelled() for w in (self._waiters or ()))
  548. )
  549. async def acquire(self):
  550. """Acquire a semaphore.
  551. If the internal counter is larger than zero on entry,
  552. decrement it by one and return True immediately. If it is
  553. zero on entry, block, waiting until some other coroutine has
  554. called release() to make it larger than 0, and then return
  555. True.
  556. """
  557. if not self.locked():
  558. self._value += 1
  559. return True
  560. if self._waiters is None:
  561. self._waiters = collections.deque()
  562. fut = asyncio.Future()
  563. self._waiters.append(fut)
  564. # Finally block should be called before the CancelledError
  565. # handling as we don't want CancelledError to call
  566. # _wake_up_first() and attempt to wake up itself.
  567. try:
  568. try:
  569. await fut
  570. finally:
  571. self._waiters.remove(fut)
  572. except asyncio.CancelledError:
  573. if not fut.cancelled():
  574. self._value -= 1
  575. self._wake_up_next()
  576. raise
  577. if self._value < self.get_max_value():
  578. self._wake_up_next()
  579. return True
  580. def release(self):
  581. """Release a semaphore, incrementing the internal counter by one.
  582. When it was zero on entry and another coroutine is waiting for it to
  583. become larger than zero again, wake up that coroutine.
  584. """
  585. self._value -= 1
  586. self._wake_up_next()
  587. def _wake_up_next(self):
  588. """Wake up the first waiter that isn't done."""
  589. if not self._waiters:
  590. return
  591. for fut in self._waiters:
  592. if not fut.done():
  593. self._value += 1
  594. fut.set_result(True)
  595. return