tracing_helper.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574
  1. import importlib
  2. import inspect
  3. import logging
  4. import os
  5. from contextlib import contextmanager
  6. from functools import wraps
  7. from inspect import Parameter
  8. from types import ModuleType
  9. from typing import (
  10. Any,
  11. Callable,
  12. Dict,
  13. Generator,
  14. List,
  15. MutableMapping,
  16. Optional,
  17. Sequence,
  18. Union,
  19. cast,
  20. )
  21. import ray
  22. import ray._private.worker
  23. from ray._private.inspect_util import (
  24. is_class_method,
  25. is_function_or_method,
  26. is_static_method,
  27. )
  28. from ray.runtime_context import get_runtime_context
  29. logger = logging.getLogger(__name__)
  30. class _OpenTelemetryProxy:
  31. """
  32. This proxy makes it possible for tracing to be disabled when opentelemetry
  33. is not installed on the cluster, but is installed locally.
  34. The check for `opentelemetry`'s existence must happen where the functions
  35. are executed because `opentelemetry` may be present where the functions
  36. are pickled. This can happen when `ray[full]` is installed locally by `ray`
  37. (no extra dependencies) is installed on the cluster.
  38. """
  39. allowed_functions = {"trace", "context", "propagate", "Context"}
  40. def __getattr__(self, name):
  41. if name in _OpenTelemetryProxy.allowed_functions:
  42. return getattr(self, f"_{name}")()
  43. else:
  44. raise AttributeError(f"Attribute does not exist: {name}")
  45. def _trace(self):
  46. return self._try_import("opentelemetry.trace")
  47. def _context(self):
  48. return self._try_import("opentelemetry.context")
  49. def _propagate(self):
  50. return self._try_import("opentelemetry.propagate")
  51. def _Context(self):
  52. context = self._context()
  53. if context:
  54. return context.context.Context
  55. else:
  56. return None
  57. def try_all(self):
  58. self._trace()
  59. self._context()
  60. self._propagate()
  61. self._Context()
  62. def _try_import(self, module):
  63. try:
  64. return importlib.import_module(module)
  65. except ImportError:
  66. if _is_tracing_enabled():
  67. raise ImportError(
  68. "Install OpenTelemetry with "
  69. "'pip install opentelemetry-api==1.34.1 opentelemetry-sdk==1.34.1 opentelemetry-exporter-otlp==1.34.1' "
  70. "to enable tracing. See the Ray documentation for details: "
  71. "https://docs.ray.io/en/latest/ray-observability/user-guides/ray-tracing.html#installation"
  72. )
  73. _global_is_tracing_enabled = False
  74. _opentelemetry = None
  75. def _is_tracing_enabled() -> bool:
  76. """Checks environment variable feature flag to see if tracing is turned on.
  77. Tracing is off by default."""
  78. return _global_is_tracing_enabled
  79. def _enable_tracing():
  80. global _global_is_tracing_enabled, _opentelemetry
  81. _global_is_tracing_enabled = True
  82. _opentelemetry = _OpenTelemetryProxy()
  83. _opentelemetry.try_all()
  84. def _sort_params_list(params_list: List[Parameter]):
  85. """Given a list of Parameters, if a kwargs Parameter exists,
  86. move it to the end of the list."""
  87. for i, param in enumerate(params_list):
  88. if param.kind == Parameter.VAR_KEYWORD:
  89. params_list.append(params_list.pop(i))
  90. break
  91. return params_list
  92. def _add_param_to_signature(function: Callable, new_param: Parameter):
  93. """Add additional Parameter to function signature."""
  94. old_sig = inspect.signature(function)
  95. old_sig_list_repr = list(old_sig.parameters.values())
  96. # If new_param is already in signature, do not add it again.
  97. if any(param.name == new_param.name for param in old_sig_list_repr):
  98. return old_sig
  99. new_params = _sort_params_list(old_sig_list_repr + [new_param])
  100. new_sig = old_sig.replace(parameters=new_params)
  101. return new_sig
  102. class _ImportFromStringError(Exception):
  103. pass
  104. def _import_from_string(import_str: Union[ModuleType, str]) -> ModuleType:
  105. """Given a string that is in format "<module>:<attribute>",
  106. import the attribute."""
  107. if not isinstance(import_str, str):
  108. return import_str
  109. module_str, _, attrs_str = import_str.partition(":")
  110. if not module_str or not attrs_str:
  111. message = (
  112. 'Import string "{import_str}" must be in format' '"<module>:<attribute>".'
  113. )
  114. raise _ImportFromStringError(message.format(import_str=import_str))
  115. try:
  116. module = importlib.import_module(module_str)
  117. except ImportError as exc:
  118. if exc.name != module_str:
  119. raise exc from None
  120. message = 'Could not import module "{module_str}".'
  121. raise _ImportFromStringError(message.format(module_str=module_str))
  122. instance = module
  123. try:
  124. for attr_str in attrs_str.split("."):
  125. instance = getattr(instance, attr_str)
  126. except AttributeError:
  127. message = 'Attribute "{attrs_str}" not found in module "{module_str}".'
  128. raise _ImportFromStringError(
  129. message.format(attrs_str=attrs_str, module_str=module_str)
  130. )
  131. return instance
  132. class _DictPropagator:
  133. def inject_current_context() -> Dict[Any, Any]:
  134. """Inject trace context into otel propagator."""
  135. context_dict: Dict[Any, Any] = {}
  136. _opentelemetry.propagate.inject(context_dict)
  137. return context_dict
  138. def extract(context_dict: Dict[Any, Any]) -> "_opentelemetry.Context":
  139. """Given a trace context, extract as a Context."""
  140. return cast(
  141. _opentelemetry.Context, _opentelemetry.propagate.extract(context_dict)
  142. )
  143. @contextmanager
  144. def _use_context(
  145. parent_context: "_opentelemetry.Context",
  146. ) -> Generator[None, None, None]:
  147. """Uses the Ray trace context for the span."""
  148. if parent_context is not None:
  149. new_context = parent_context
  150. else:
  151. new_context = _opentelemetry.Context()
  152. token = _opentelemetry.context.attach(new_context)
  153. try:
  154. yield
  155. finally:
  156. _opentelemetry.context.detach(token)
  157. def _function_hydrate_span_args(function_name: str):
  158. """Get the Attributes of the function that will be reported as attributes
  159. in the trace."""
  160. runtime_context = get_runtime_context()
  161. span_args = {
  162. "ray.remote": "function",
  163. "ray.function": function_name,
  164. "ray.pid": str(os.getpid()),
  165. "ray.job_id": runtime_context.get_job_id(),
  166. "ray.node_id": runtime_context.get_node_id(),
  167. }
  168. # We only get task ID for workers
  169. if ray._private.worker.global_worker.mode == ray._private.worker.WORKER_MODE:
  170. task_id = runtime_context.get_task_id()
  171. if task_id:
  172. span_args["ray.task_id"] = task_id
  173. worker_id = getattr(ray._private.worker.global_worker, "worker_id", None)
  174. if worker_id:
  175. span_args["ray.worker_id"] = worker_id.hex()
  176. return span_args
  177. def _function_span_producer_name(func: Callable[..., Any]) -> str:
  178. """Returns the function span name that has span kind of producer."""
  179. return f"{func} ray.remote"
  180. def _function_span_consumer_name(func: Callable[..., Any]) -> str:
  181. """Returns the function span name that has span kind of consumer."""
  182. return f"{func} ray.remote_worker"
  183. def _actor_hydrate_span_args(
  184. class_: Union[str, Callable[..., Any]],
  185. method: Union[str, Callable[..., Any]],
  186. ):
  187. """Get the Attributes of the actor that will be reported as attributes
  188. in the trace."""
  189. if callable(class_):
  190. class_ = class_.__name__
  191. if callable(method):
  192. method = method.__name__
  193. runtime_context = get_runtime_context()
  194. span_args = {
  195. "ray.remote": "actor",
  196. "ray.actor_class": class_,
  197. "ray.actor_method": method,
  198. "ray.function": f"{class_}.{method}",
  199. "ray.pid": str(os.getpid()),
  200. "ray.job_id": runtime_context.get_job_id(),
  201. "ray.node_id": runtime_context.get_node_id(),
  202. }
  203. # We only get actor ID for workers
  204. if ray._private.worker.global_worker.mode == ray._private.worker.WORKER_MODE:
  205. actor_id = runtime_context.get_actor_id()
  206. if actor_id:
  207. span_args["ray.actor_id"] = actor_id
  208. worker_id = getattr(ray._private.worker.global_worker, "worker_id", None)
  209. if worker_id:
  210. span_args["ray.worker_id"] = worker_id.hex()
  211. return span_args
  212. def _actor_span_producer_name(
  213. class_: Union[str, Callable[..., Any]],
  214. method: Union[str, Callable[..., Any]],
  215. ) -> str:
  216. """Returns the actor span name that has span kind of producer."""
  217. if not isinstance(class_, str):
  218. class_ = class_.__name__
  219. if not isinstance(method, str):
  220. method = method.__name__
  221. return f"{class_}.{method} ray.remote"
  222. def _actor_span_consumer_name(
  223. class_: Union[str, Callable[..., Any]],
  224. method: Union[str, Callable[..., Any]],
  225. ) -> str:
  226. """Returns the actor span name that has span kind of consumer."""
  227. if not isinstance(class_, str):
  228. class_ = class_.__name__
  229. if not isinstance(method, str):
  230. method = method.__name__
  231. return f"{class_}.{method} ray.remote_worker"
  232. def _tracing_task_invocation(method):
  233. """Trace the execution of a remote task. Inject
  234. the current span context into kwargs for propagation."""
  235. @wraps(method)
  236. def _invocation_remote_span(
  237. self,
  238. args: Any = None, # from tracing
  239. kwargs: MutableMapping[Any, Any] = None, # from tracing
  240. *_args: Any, # from Ray
  241. **_kwargs: Any, # from Ray
  242. ) -> Any:
  243. # If tracing feature flag is not on, perform a no-op.
  244. # Tracing doesn't work for cross lang yet.
  245. if not _is_tracing_enabled() or self._is_cross_language:
  246. if kwargs is not None:
  247. assert "_ray_trace_ctx" not in kwargs
  248. return method(self, args, kwargs, *_args, **_kwargs)
  249. assert "_ray_trace_ctx" not in kwargs
  250. tracer = _opentelemetry.trace.get_tracer(__name__)
  251. with tracer.start_as_current_span(
  252. _function_span_producer_name(self._function_name),
  253. kind=_opentelemetry.trace.SpanKind.PRODUCER,
  254. attributes=_function_hydrate_span_args(self._function_name),
  255. ):
  256. # Inject a _ray_trace_ctx as a dictionary
  257. kwargs["_ray_trace_ctx"] = _DictPropagator.inject_current_context()
  258. return method(self, args, kwargs, *_args, **_kwargs)
  259. return _invocation_remote_span
  260. def _inject_tracing_into_function(function):
  261. """Wrap the function argument passed to RemoteFunction's __init__ so that
  262. future execution of that function will include tracing.
  263. Use the provided trace context from kwargs.
  264. """
  265. if not _is_tracing_enabled():
  266. return function
  267. function.__signature__ = _add_param_to_signature(
  268. function,
  269. inspect.Parameter(
  270. "_ray_trace_ctx", inspect.Parameter.KEYWORD_ONLY, default=None
  271. ),
  272. )
  273. @wraps(function)
  274. def _function_with_tracing(
  275. *args: Any,
  276. _ray_trace_ctx: Optional[Dict[str, Any]] = None,
  277. **kwargs: Any,
  278. ) -> Any:
  279. if _ray_trace_ctx is None:
  280. return function(*args, **kwargs)
  281. tracer = _opentelemetry.trace.get_tracer(__name__)
  282. function_name = function.__module__ + "." + function.__name__
  283. # Retrieves the context from the _ray_trace_ctx dictionary we injected
  284. with _use_context(
  285. _DictPropagator.extract(_ray_trace_ctx)
  286. ), tracer.start_as_current_span(
  287. _function_span_consumer_name(function_name),
  288. kind=_opentelemetry.trace.SpanKind.CONSUMER,
  289. attributes=_function_hydrate_span_args(function_name),
  290. ):
  291. return function(*args, **kwargs)
  292. return _function_with_tracing
  293. def _tracing_actor_creation(method):
  294. """Trace the creation of an actor. Inject
  295. the current span context into kwargs for propagation."""
  296. @wraps(method)
  297. def _invocation_actor_class_remote_span(
  298. self,
  299. args: Any = tuple(), # from tracing
  300. kwargs: MutableMapping[Any, Any] = None, # from tracing
  301. *_args: Any, # from Ray
  302. **_kwargs: Any, # from Ray
  303. ):
  304. if kwargs is None:
  305. kwargs = {}
  306. # If tracing feature flag is not on, perform a no-op
  307. if not _is_tracing_enabled():
  308. assert "_ray_trace_ctx" not in kwargs
  309. return method(self, args, kwargs, *_args, **_kwargs)
  310. class_name = self.__ray_metadata__.class_name
  311. method_name = "__init__"
  312. assert "_ray_trace_ctx" not in _kwargs
  313. tracer = _opentelemetry.trace.get_tracer(__name__)
  314. with tracer.start_as_current_span(
  315. name=_actor_span_producer_name(class_name, method_name),
  316. kind=_opentelemetry.trace.SpanKind.PRODUCER,
  317. attributes=_actor_hydrate_span_args(class_name, method_name),
  318. ) as span:
  319. # Inject a _ray_trace_ctx as a dictionary
  320. kwargs["_ray_trace_ctx"] = _DictPropagator.inject_current_context()
  321. result = method(self, args, kwargs, *_args, **_kwargs)
  322. span.set_attribute("ray.actor_id", result._ray_actor_id.hex())
  323. return result
  324. return _invocation_actor_class_remote_span
  325. def _tracing_actor_method_invocation(method):
  326. """Trace the invocation of an actor method."""
  327. @wraps(method)
  328. def _start_span(
  329. self,
  330. args: Sequence[Any] = None,
  331. kwargs: MutableMapping[Any, Any] = None,
  332. *_args: Any,
  333. **_kwargs: Any,
  334. ) -> Any:
  335. # If tracing feature flag is not on, perform a no-op
  336. if not _is_tracing_enabled() or self._actor._ray_is_cross_language:
  337. if kwargs is not None:
  338. assert "_ray_trace_ctx" not in kwargs
  339. return method(self, args, kwargs, *_args, **_kwargs)
  340. class_name = self._actor._ray_actor_creation_function_descriptor.class_name
  341. method_name = self._method_name
  342. assert "_ray_trace_ctx" not in _kwargs
  343. tracer = _opentelemetry.trace.get_tracer(__name__)
  344. with tracer.start_as_current_span(
  345. name=_actor_span_producer_name(class_name, method_name),
  346. kind=_opentelemetry.trace.SpanKind.PRODUCER,
  347. attributes=_actor_hydrate_span_args(class_name, method_name),
  348. ) as span:
  349. # Inject a _ray_trace_ctx as a dictionary
  350. kwargs["_ray_trace_ctx"] = _DictPropagator.inject_current_context()
  351. span.set_attribute("ray.actor_id", self._actor._ray_actor_id.hex())
  352. return method(self, args, kwargs, *_args, **_kwargs)
  353. return _start_span
  354. def _inject_tracing_into_class(_cls):
  355. """Given a class that will be made into an actor,
  356. inject tracing into all of the methods."""
  357. def span_wrapper(method: Callable[..., Any]) -> Any:
  358. def _resume_span(
  359. self: Any,
  360. *_args: Any,
  361. _ray_trace_ctx: Optional[Dict[str, Any]] = None,
  362. **_kwargs: Any,
  363. ) -> Any:
  364. """
  365. Wrap the user's function with a function that
  366. will extract the trace context
  367. """
  368. # If tracing feature flag is not on, perform a no-op
  369. if not _is_tracing_enabled() or _ray_trace_ctx is None:
  370. return method(self, *_args, **_kwargs)
  371. tracer: _opentelemetry.trace.Tracer = _opentelemetry.trace.get_tracer(
  372. __name__
  373. )
  374. # Retrieves the context from the _ray_trace_ctx dictionary we
  375. # injected.
  376. with _use_context(
  377. _DictPropagator.extract(_ray_trace_ctx)
  378. ), tracer.start_as_current_span(
  379. _actor_span_consumer_name(self.__class__.__name__, method),
  380. kind=_opentelemetry.trace.SpanKind.CONSUMER,
  381. attributes=_actor_hydrate_span_args(self.__class__.__name__, method),
  382. ):
  383. return method(self, *_args, **_kwargs)
  384. return _resume_span
  385. def async_span_wrapper(method: Callable[..., Any]) -> Any:
  386. async def _resume_span(
  387. self: Any,
  388. *_args: Any,
  389. _ray_trace_ctx: Optional[Dict[str, Any]] = None,
  390. **_kwargs: Any,
  391. ) -> Any:
  392. """
  393. Wrap the user's function with a function that
  394. will extract the trace context
  395. """
  396. # If tracing feature flag is not on, perform a no-op
  397. if not _is_tracing_enabled() or _ray_trace_ctx is None:
  398. return await method(self, *_args, **_kwargs)
  399. tracer = _opentelemetry.trace.get_tracer(__name__)
  400. # Retrieves the context from the _ray_trace_ctx dictionary we
  401. # injected, or starts a new context
  402. with _use_context(
  403. _DictPropagator.extract(_ray_trace_ctx)
  404. ), tracer.start_as_current_span(
  405. _actor_span_consumer_name(self.__class__.__name__, method.__name__),
  406. kind=_opentelemetry.trace.SpanKind.CONSUMER,
  407. attributes=_actor_hydrate_span_args(
  408. self.__class__.__name__, method.__name__
  409. ),
  410. ):
  411. return await method(self, *_args, **_kwargs)
  412. return _resume_span
  413. methods = inspect.getmembers(_cls, is_function_or_method)
  414. for name, method in methods:
  415. # Skip tracing for staticmethod or classmethod, because these method
  416. # might not be called directly by remote calls. Additionally, they are
  417. # tricky to get wrapped and unwrapped.
  418. if is_static_method(_cls, name) or is_class_method(method):
  419. continue
  420. if inspect.isgeneratorfunction(method) or inspect.isasyncgenfunction(method):
  421. # Right now, this method somehow changes the signature of the method
  422. # when they are generator.
  423. # TODO(sang): Fix it.
  424. continue
  425. # Don't decorate the __del__ magic method.
  426. # It's because the __del__ can be called after Python
  427. # modules are garbage colleted, which means the modules
  428. # used for the decorator (e.g., `span_wrapper`) may not be
  429. # available. For example, it is not guranteed that
  430. # `_is_tracing_enabled` is available when `__del__` is called.
  431. # Tracing `__del__` is also not very useful.
  432. # https://joekuan.wordpress.com/2015/06/30/python-3-__del__-method-and-imported-modules/ # noqa
  433. if name == "__del__":
  434. continue
  435. # If the method is already wrapped, we still need to set __signature__
  436. # on the deeply unwrapped original. This is because cloudpickle doesn't
  437. # preserve __signature__ attributes, and _ActorClassMethodMetadata.create
  438. # uses inspect.unwrap which goes all the way to the original method.
  439. unwrapped_method = inspect.unwrap(method)
  440. # Add _ray_trace_ctx to the UNWRAPPED method's signature.
  441. # This ensures inspect.unwrap() will find the signature.
  442. # Note: We always set the signature, even if it was already set by a
  443. # previous call, because the signature might have been lost during
  444. # serialization/deserialization.
  445. unwrapped_method.__signature__ = _add_param_to_signature(
  446. unwrapped_method,
  447. inspect.Parameter(
  448. "_ray_trace_ctx", inspect.Parameter.KEYWORD_ONLY, default=None
  449. ),
  450. )
  451. # If method was already wrapped by tracing (e.g., preserved through
  452. # cloudpickle), don't re-wrap it. We use a custom marker attribute
  453. # instead of __wrapped__ because __wrapped__ could be from any
  454. # decorator, not just tracing.
  455. if getattr(method, "__ray_tracing_wrapped__", False):
  456. continue
  457. if inspect.iscoroutinefunction(method):
  458. # If the method was async, swap out sync wrapper into async
  459. wrapped_method = wraps(method)(async_span_wrapper(method))
  460. else:
  461. wrapped_method = wraps(method)(span_wrapper(method))
  462. # Mark the wrapped method so we don't re-wrap it if this class
  463. # is processed again (e.g., after cloudpickle round-trip).
  464. wrapped_method.__ray_tracing_wrapped__ = True
  465. setattr(_cls, name, wrapped_method)
  466. return _cls