spark_driver.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. import sentry_sdk
  2. from sentry_sdk.integrations import Integration
  3. from sentry_sdk.utils import capture_internal_exceptions, ensure_integration_enabled
  4. from typing import TYPE_CHECKING
  5. if TYPE_CHECKING:
  6. from typing import Any
  7. from typing import Optional
  8. from sentry_sdk._types import Event, Hint
  9. from pyspark import SparkContext
  10. class SparkIntegration(Integration):
  11. identifier = "spark"
  12. @staticmethod
  13. def setup_once() -> None:
  14. _setup_sentry_tracing()
  15. def _set_app_properties() -> None:
  16. """
  17. Set properties in driver that propagate to worker processes, allowing for workers to have access to those properties.
  18. This allows worker integration to have access to app_name and application_id.
  19. """
  20. from pyspark import SparkContext
  21. spark_context = SparkContext._active_spark_context
  22. if spark_context:
  23. spark_context.setLocalProperty(
  24. "sentry_app_name",
  25. spark_context.appName,
  26. )
  27. spark_context.setLocalProperty(
  28. "sentry_application_id",
  29. spark_context.applicationId,
  30. )
  31. def _start_sentry_listener(sc: "SparkContext") -> None:
  32. """
  33. Start java gateway server to add custom `SparkListener`
  34. """
  35. from pyspark.java_gateway import ensure_callback_server_started
  36. gw = sc._gateway
  37. ensure_callback_server_started(gw)
  38. listener = SentryListener()
  39. sc._jsc.sc().addSparkListener(listener)
  40. def _add_event_processor(sc: "SparkContext") -> None:
  41. scope = sentry_sdk.get_isolation_scope()
  42. @scope.add_event_processor
  43. def process_event(event: "Event", hint: "Hint") -> "Optional[Event]":
  44. with capture_internal_exceptions():
  45. if sentry_sdk.get_client().get_integration(SparkIntegration) is None:
  46. return event
  47. if sc._active_spark_context is None:
  48. return event
  49. event.setdefault("user", {}).setdefault("id", sc.sparkUser())
  50. event.setdefault("tags", {}).setdefault(
  51. "executor.id", sc._conf.get("spark.executor.id")
  52. )
  53. event["tags"].setdefault(
  54. "spark-submit.deployMode",
  55. sc._conf.get("spark.submit.deployMode"),
  56. )
  57. event["tags"].setdefault("driver.host", sc._conf.get("spark.driver.host"))
  58. event["tags"].setdefault("driver.port", sc._conf.get("spark.driver.port"))
  59. event["tags"].setdefault("spark_version", sc.version)
  60. event["tags"].setdefault("app_name", sc.appName)
  61. event["tags"].setdefault("application_id", sc.applicationId)
  62. event["tags"].setdefault("master", sc.master)
  63. event["tags"].setdefault("spark_home", sc.sparkHome)
  64. event.setdefault("extra", {}).setdefault("web_url", sc.uiWebUrl)
  65. return event
  66. def _activate_integration(sc: "SparkContext") -> None:
  67. _start_sentry_listener(sc)
  68. _set_app_properties()
  69. _add_event_processor(sc)
  70. def _patch_spark_context_init() -> None:
  71. from pyspark import SparkContext
  72. spark_context_init = SparkContext._do_init
  73. @ensure_integration_enabled(SparkIntegration, spark_context_init)
  74. def _sentry_patched_spark_context_init(
  75. self: "SparkContext", *args: "Any", **kwargs: "Any"
  76. ) -> "Optional[Any]":
  77. rv = spark_context_init(self, *args, **kwargs)
  78. _activate_integration(self)
  79. return rv
  80. SparkContext._do_init = _sentry_patched_spark_context_init
  81. def _setup_sentry_tracing() -> None:
  82. from pyspark import SparkContext
  83. if SparkContext._active_spark_context is not None:
  84. _activate_integration(SparkContext._active_spark_context)
  85. return
  86. _patch_spark_context_init()
  87. class SparkListener:
  88. def onApplicationEnd(self, applicationEnd: "Any") -> None: # noqa: N802,N803
  89. pass
  90. def onApplicationStart(self, applicationStart: "Any") -> None: # noqa: N802,N803
  91. pass
  92. def onBlockManagerAdded(self, blockManagerAdded: "Any") -> None: # noqa: N802,N803
  93. pass
  94. def onBlockManagerRemoved(self, blockManagerRemoved: "Any") -> None: # noqa: N802,N803
  95. pass
  96. def onBlockUpdated(self, blockUpdated: "Any") -> None: # noqa: N802,N803
  97. pass
  98. def onEnvironmentUpdate(self, environmentUpdate: "Any") -> None: # noqa: N802,N803
  99. pass
  100. def onExecutorAdded(self, executorAdded: "Any") -> None: # noqa: N802,N803
  101. pass
  102. def onExecutorBlacklisted(self, executorBlacklisted: "Any") -> None: # noqa: N802,N803
  103. pass
  104. def onExecutorBlacklistedForStage( # noqa: N802
  105. self,
  106. executorBlacklistedForStage: "Any", # noqa: N803
  107. ) -> None:
  108. pass
  109. def onExecutorMetricsUpdate(self, executorMetricsUpdate: "Any") -> None: # noqa: N802,N803
  110. pass
  111. def onExecutorRemoved(self, executorRemoved: "Any") -> None: # noqa: N802,N803
  112. pass
  113. def onJobEnd(self, jobEnd: "Any") -> None: # noqa: N802,N803
  114. pass
  115. def onJobStart(self, jobStart: "Any") -> None: # noqa: N802,N803
  116. pass
  117. def onNodeBlacklisted(self, nodeBlacklisted: "Any") -> None: # noqa: N802,N803
  118. pass
  119. def onNodeBlacklistedForStage(self, nodeBlacklistedForStage: "Any") -> None: # noqa: N802,N803
  120. pass
  121. def onNodeUnblacklisted(self, nodeUnblacklisted: "Any") -> None: # noqa: N802,N803
  122. pass
  123. def onOtherEvent(self, event: "Any") -> None: # noqa: N802,N803
  124. pass
  125. def onSpeculativeTaskSubmitted(self, speculativeTask: "Any") -> None: # noqa: N802,N803
  126. pass
  127. def onStageCompleted(self, stageCompleted: "Any") -> None: # noqa: N802,N803
  128. pass
  129. def onStageSubmitted(self, stageSubmitted: "Any") -> None: # noqa: N802,N803
  130. pass
  131. def onTaskEnd(self, taskEnd: "Any") -> None: # noqa: N802,N803
  132. pass
  133. def onTaskGettingResult(self, taskGettingResult: "Any") -> None: # noqa: N802,N803
  134. pass
  135. def onTaskStart(self, taskStart: "Any") -> None: # noqa: N802,N803
  136. pass
  137. def onUnpersistRDD(self, unpersistRDD: "Any") -> None: # noqa: N802,N803
  138. pass
  139. class Java:
  140. implements = ["org.apache.spark.scheduler.SparkListenerInterface"]
  141. class SentryListener(SparkListener):
  142. def _add_breadcrumb(
  143. self,
  144. level: str,
  145. message: str,
  146. data: "Optional[dict[str, Any]]" = None,
  147. ) -> None:
  148. sentry_sdk.get_isolation_scope().add_breadcrumb(
  149. level=level, message=message, data=data
  150. )
  151. def onJobStart(self, jobStart: "Any") -> None: # noqa: N802,N803
  152. sentry_sdk.get_isolation_scope().clear_breadcrumbs()
  153. message = "Job {} Started".format(jobStart.jobId())
  154. self._add_breadcrumb(level="info", message=message)
  155. _set_app_properties()
  156. def onJobEnd(self, jobEnd: "Any") -> None: # noqa: N802,N803
  157. level = ""
  158. message = ""
  159. data = {"result": jobEnd.jobResult().toString()}
  160. if jobEnd.jobResult().toString() == "JobSucceeded":
  161. level = "info"
  162. message = "Job {} Ended".format(jobEnd.jobId())
  163. else:
  164. level = "warning"
  165. message = "Job {} Failed".format(jobEnd.jobId())
  166. self._add_breadcrumb(level=level, message=message, data=data)
  167. def onStageSubmitted(self, stageSubmitted: "Any") -> None: # noqa: N802,N803
  168. stage_info = stageSubmitted.stageInfo()
  169. message = "Stage {} Submitted".format(stage_info.stageId())
  170. data = {"name": stage_info.name()}
  171. attempt_id = _get_attempt_id(stage_info)
  172. if attempt_id is not None:
  173. data["attemptId"] = attempt_id
  174. self._add_breadcrumb(level="info", message=message, data=data)
  175. _set_app_properties()
  176. def onStageCompleted(self, stageCompleted: "Any") -> None: # noqa: N802,N803
  177. from py4j.protocol import Py4JJavaError # type: ignore
  178. stage_info = stageCompleted.stageInfo()
  179. message = ""
  180. level = ""
  181. data = {"name": stage_info.name()}
  182. attempt_id = _get_attempt_id(stage_info)
  183. if attempt_id is not None:
  184. data["attemptId"] = attempt_id
  185. # Have to Try Except because stageInfo.failureReason() is typed with Scala Option
  186. try:
  187. data["reason"] = stage_info.failureReason().get()
  188. message = "Stage {} Failed".format(stage_info.stageId())
  189. level = "warning"
  190. except Py4JJavaError:
  191. message = "Stage {} Completed".format(stage_info.stageId())
  192. level = "info"
  193. self._add_breadcrumb(level=level, message=message, data=data)
  194. def _get_attempt_id(stage_info: "Any") -> "Optional[int]":
  195. try:
  196. return stage_info.attemptId()
  197. except Exception:
  198. pass
  199. try:
  200. return stage_info.attemptNumber()
  201. except Exception:
  202. pass
  203. return None