handlers.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. """A Websocket Handler for emitting Jupyter server events.
  2. .. versionadded:: 2.0
  3. """
  4. from __future__ import annotations
  5. import json
  6. from datetime import datetime
  7. from typing import TYPE_CHECKING, Any, Optional, cast
  8. from jupyter_core.utils import ensure_async
  9. from tornado import web, websocket
  10. from jupyter_server.auth.decorator import authorized, ws_authenticated
  11. from jupyter_server.base.handlers import JupyterHandler
  12. from ...base.handlers import APIHandler
  13. AUTH_RESOURCE = "events"
  14. if TYPE_CHECKING:
  15. import jupyter_events.logger
  16. class SubscribeWebsocket(
  17. JupyterHandler,
  18. websocket.WebSocketHandler,
  19. ):
  20. """Websocket handler for subscribing to events"""
  21. auth_resource = AUTH_RESOURCE
  22. async def pre_get(self):
  23. """Handles authorization when
  24. attempting to subscribe to events emitted by
  25. Jupyter Server's eventbus.
  26. """
  27. user = self.current_user
  28. # authorize the user.
  29. authorized = await ensure_async(
  30. self.authorizer.is_authorized(self, user, "execute", "events")
  31. )
  32. if not authorized:
  33. raise web.HTTPError(403)
  34. @ws_authenticated
  35. async def get(self, *args, **kwargs):
  36. """Get an event socket."""
  37. await ensure_async(self.pre_get())
  38. res = super().get(*args, **kwargs)
  39. if res is not None:
  40. await res
  41. async def event_listener(
  42. self, logger: jupyter_events.logger.EventLogger, schema_id: str, data: dict[str, Any]
  43. ) -> None:
  44. """Write an event message."""
  45. capsule = dict(schema_id=schema_id, **data)
  46. self.write_message(json.dumps(capsule))
  47. def open(self):
  48. """Routes events that are emitted by Jupyter Server's
  49. EventBus to a WebSocket client in the browser.
  50. """
  51. self.event_logger.add_listener(listener=self.event_listener)
  52. def on_close(self):
  53. """Handle a socket close."""
  54. self.event_logger.remove_listener(listener=self.event_listener)
  55. def validate_model(
  56. data: dict[str, Any], registry: jupyter_events.schema_registry.SchemaRegistry
  57. ) -> None:
  58. """Validates for required fields in the JSON request body and verifies that
  59. a registered schema/version exists"""
  60. required_keys = {"schema_id", "version", "data"}
  61. for key in required_keys:
  62. if key not in data:
  63. message = f"Missing `{key}` in the JSON request body."
  64. raise Exception(message)
  65. schema_id = cast(str, data.get("schema_id"))
  66. # The case where a given schema_id isn't found,
  67. # jupyter_events raises a useful error, so there's no need to
  68. # handle that case here.
  69. schema = registry.get(schema_id)
  70. version = str(cast(str, data.get("version")))
  71. if schema.version != version:
  72. message = f"Unregistered version: {version!r}≠{schema.version!r} for `{schema_id}`"
  73. raise Exception(message)
  74. def get_timestamp(data: dict[str, Any]) -> Optional[datetime]:
  75. """Parses timestamp from the JSON request body"""
  76. try:
  77. if "timestamp" in data:
  78. timestamp = datetime.strptime(data["timestamp"], "%Y-%m-%dT%H:%M:%S%zZ")
  79. else:
  80. timestamp = None
  81. except Exception as e:
  82. raise web.HTTPError(
  83. 400,
  84. """Failed to parse timestamp from JSON request body,
  85. an ISO format datetime string with UTC offset is expected,
  86. for example, 2022-05-26T13:50:00+05:00Z""",
  87. ) from e
  88. return timestamp
  89. class EventHandler(APIHandler):
  90. """REST api handler for events"""
  91. auth_resource = AUTH_RESOURCE
  92. @web.authenticated
  93. @authorized
  94. async def post(self):
  95. """Emit an event."""
  96. payload = self.get_json_body()
  97. if payload is None:
  98. raise web.HTTPError(400, "No JSON data provided")
  99. try:
  100. validate_model(payload, self.event_logger.schemas)
  101. self.event_logger.emit(
  102. schema_id=cast(str, payload.get("schema_id")),
  103. data=cast("dict[str, Any]", payload.get("data")),
  104. timestamp_override=get_timestamp(payload),
  105. )
  106. self.set_status(204)
  107. self.finish()
  108. except Exception as e:
  109. # All known exceptions are raised by bad requests, e.g., bad
  110. # version, unregistered schema, invalid emission data payload, etc.
  111. raise web.HTTPError(400, str(e)) from e
  112. default_handlers = [
  113. (r"/api/events", EventHandler),
  114. (r"/api/events/subscribe", SubscribeWebsocket),
  115. ]