_utils.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. # mypy: allow-untyped-defs
  2. import logging
  3. from contextlib import contextmanager
  4. from typing import cast
  5. logger = logging.getLogger(__name__)
  6. @contextmanager
  7. def _group_membership_management(store, name, is_join):
  8. token_key = "RpcGroupManagementToken"
  9. join_or_leave = "join" if is_join else "leave"
  10. my_token = f"Token_for_{name}_{join_or_leave}"
  11. while True:
  12. # Retrieve token from store to signal start of rank join/leave critical section
  13. returned = store.compare_set(token_key, "", my_token).decode()
  14. if returned == my_token:
  15. # Yield to the function this context manager wraps
  16. yield
  17. # Finished, now exit and release token
  18. # Update from store to signal end of rank join/leave critical section
  19. store.set(token_key, "")
  20. # Other will wait for this token to be set before they execute
  21. store.set(my_token, "Done")
  22. break
  23. else:
  24. # Store will wait for the token to be released
  25. try:
  26. store.wait([returned])
  27. except RuntimeError:
  28. logger.error(
  29. "Group membership token %s timed out waiting for %s to be released.",
  30. my_token,
  31. returned,
  32. )
  33. raise
  34. def _update_group_membership(worker_info, my_devices, reverse_device_map, is_join):
  35. from . import api, TensorPipeAgent
  36. agent = cast(TensorPipeAgent, api._get_current_rpc_agent())
  37. ret = agent._update_group_membership(
  38. worker_info, my_devices, reverse_device_map, is_join
  39. )
  40. return ret