auto_init_hook.py 793 B

1234567891011121314151617181920212223242526272829303132
  1. import os
  2. import threading
  3. from functools import wraps
  4. import ray
  5. auto_init_lock = threading.Lock()
  6. enable_auto_connect = os.environ.get("RAY_ENABLE_AUTO_CONNECT", "") != "0"
  7. def auto_init_ray():
  8. if enable_auto_connect and not ray.is_initialized():
  9. with auto_init_lock:
  10. if not ray.is_initialized():
  11. ray.init()
  12. def wrap_auto_init(fn):
  13. @wraps(fn)
  14. def auto_init_wrapper(*args, **kwargs):
  15. auto_init_ray()
  16. return fn(*args, **kwargs)
  17. return auto_init_wrapper
  18. def wrap_auto_init_for_all_apis(api_names):
  19. """Wrap public APIs with automatic ray.init."""
  20. for api_name in api_names:
  21. api = getattr(ray, api_name, None)
  22. assert api is not None, api_name
  23. setattr(ray, api_name, wrap_auto_init(api))