actions.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. """Actions that are triggered by W&B Automations."""
  2. from __future__ import annotations
  3. from typing import Annotated, Any, Literal, Optional, Union
  4. from pydantic import BeforeValidator, Field
  5. from typing_extensions import Self, TypeVar, get_args
  6. from wandb._pydantic import GQLBase, GQLId
  7. from wandb._strutils import nameof
  8. from ._generated import (
  9. AlertSeverity,
  10. GenericWebhookActionFields,
  11. GenericWebhookActionInput,
  12. NoOpActionFields,
  13. NoOpTriggeredActionInput,
  14. NotificationActionFields,
  15. NotificationActionInput,
  16. QueueJobActionFields,
  17. )
  18. from ._validators import (
  19. JsonEncoded,
  20. LenientStrEnum,
  21. default_if_none,
  22. parse_input_action,
  23. parse_saved_action,
  24. upper_if_str,
  25. )
  26. from .integrations import SlackIntegration, WebhookIntegration
  27. T = TypeVar("T")
  28. # NOTE: Name shortened for readability and defined publicly for easier access
  29. class ActionType(LenientStrEnum):
  30. """The type of action triggered by an automation."""
  31. QUEUE_JOB = "QUEUE_JOB" # NOTE: Deprecated for creation
  32. NOTIFICATION = "NOTIFICATION"
  33. GENERIC_WEBHOOK = "GENERIC_WEBHOOK"
  34. NO_OP = "NO_OP"
  35. # ------------------------------------------------------------------------------
  36. # Saved types: for parsing response data from saved automations
  37. # NOTE: `QueueJobActionInput` for defining a Launch job is deprecated,
  38. # so while we allow parsing it from previously saved Automations, we deliberately
  39. # don't currently expose it in the API for creating automations.
  40. class SavedLaunchJobAction(QueueJobActionFields):
  41. action_type: Literal[ActionType.QUEUE_JOB] = ActionType.QUEUE_JOB
  42. # FIXME: Find a better place to put these OR a better way to handle the
  43. # conversion from `InputAction` -> `SavedAction`.
  44. #
  45. # Necessary placeholder class defs for converting:
  46. # - `SendNotification -> SavedNotificationAction`
  47. # - `SendWebhook -> SavedWebhookAction`
  48. #
  49. # The "input" types (`Send{Notification,Webhook}`) will only have an `integration_id`,
  50. # and we don't want/need to fetch the other `{Slack,Webhook}Integration` fields if
  51. # we can avoid it.
  52. class _SlackIntegrationStub(GQLBase):
  53. typename__: Annotated[
  54. Literal["SlackIntegration"],
  55. Field(alias="__typename", frozen=True, repr=False),
  56. ] = "SlackIntegration"
  57. id: GQLId
  58. class _WebhookIntegrationStub(GQLBase):
  59. typename__: Annotated[
  60. Literal["GenericWebhookIntegration"],
  61. Field(alias="__typename", frozen=True, repr=False),
  62. ] = "GenericWebhookIntegration"
  63. id: GQLId
  64. class SavedNotificationAction(NotificationActionFields, frozen=False):
  65. action_type: Literal[ActionType.NOTIFICATION] = ActionType.NOTIFICATION
  66. integration: _SlackIntegrationStub
  67. title: Optional[str]
  68. message: Optional[str]
  69. severity: Optional[AlertSeverity]
  70. class SavedWebhookAction(GenericWebhookActionFields, frozen=False):
  71. action_type: Literal[ActionType.GENERIC_WEBHOOK] = ActionType.GENERIC_WEBHOOK
  72. integration: _WebhookIntegrationStub
  73. # We override the type of the `requestPayload` field since the original GraphQL
  74. # schema (and generated class) effectively defines it as a string, when we know
  75. # and need to anticipate the expected structure of the JSON-serialized data.
  76. request_payload: Optional[JsonEncoded[dict[str, Any]]] = None # type: ignore[assignment]
  77. class SavedNoOpAction(NoOpActionFields, frozen=False):
  78. action_type: Literal[ActionType.NO_OP] = ActionType.NO_OP
  79. no_op: Annotated[
  80. bool,
  81. BeforeValidator(default_if_none),
  82. Field(repr=False, frozen=True),
  83. ] = True
  84. """Placeholder field, only needed to conform to schema requirements.
  85. There should never be a need to set this field explicitly, as its value is ignored.
  86. """
  87. # for type annotations
  88. SavedAction = Annotated[
  89. Union[
  90. SavedLaunchJobAction,
  91. SavedNotificationAction,
  92. SavedWebhookAction,
  93. SavedNoOpAction,
  94. ],
  95. BeforeValidator(parse_saved_action),
  96. Field(discriminator="typename__"),
  97. ]
  98. # for runtime type checks
  99. SavedActionTypes: tuple[type, ...] = get_args(SavedAction.__origin__) # type: ignore[attr-defined]
  100. # ------------------------------------------------------------------------------
  101. # Input types: for creating or updating automations
  102. class _BaseActionInput(GQLBase):
  103. action_type: Annotated[ActionType, Field(frozen=True)]
  104. """The kind of action to be triggered."""
  105. class SendNotification(_BaseActionInput, NotificationActionInput):
  106. """Defines an automation action that sends a (Slack) notification."""
  107. action_type: Literal[ActionType.NOTIFICATION] = ActionType.NOTIFICATION
  108. integration_id: GQLId
  109. """The ID of the Slack integration that will be used to send the notification."""
  110. # Note: Validation aliases preserve continuity with the prior `wandb.alert()` API.
  111. title: str = ""
  112. """The title of the sent notification."""
  113. message: Annotated[str, Field(validation_alias="text")] = ""
  114. """The message body of the sent notification."""
  115. severity: Annotated[
  116. AlertSeverity,
  117. BeforeValidator(upper_if_str), # Be helpful by ensuring uppercase strings
  118. Field(validation_alias="level"),
  119. ] = AlertSeverity.INFO
  120. """The severity (`INFO`, `WARN`, `ERROR`) of the sent notification."""
  121. @classmethod
  122. def from_integration(
  123. cls,
  124. integration: SlackIntegration,
  125. *,
  126. title: str = "",
  127. text: str = "",
  128. level: AlertSeverity = AlertSeverity.INFO,
  129. ) -> Self:
  130. """Define a notification action that sends to the given (Slack) integration."""
  131. return cls(
  132. integration_id=integration.id, title=title, message=text, severity=level
  133. )
  134. class SendWebhook(_BaseActionInput, GenericWebhookActionInput):
  135. """Defines an automation action that sends a webhook request."""
  136. action_type: Literal[ActionType.GENERIC_WEBHOOK] = ActionType.GENERIC_WEBHOOK
  137. integration_id: GQLId
  138. """The ID of the webhook integration that will be used to send the request."""
  139. # overrides the generated field type to parse/serialize JSON strings
  140. request_payload: Optional[JsonEncoded[dict[str, Any]]] = Field( # type: ignore[assignment]
  141. default=None, alias="requestPayload"
  142. )
  143. """The payload, possibly with template variables, to send in the webhook request."""
  144. @classmethod
  145. def from_integration(
  146. cls,
  147. integration: WebhookIntegration,
  148. *,
  149. payload: Optional[JsonEncoded[dict[str, Any]]] = None,
  150. ) -> Self:
  151. """Define a webhook action that sends to the given (webhook) integration."""
  152. return cls(integration_id=integration.id, request_payload=payload)
  153. class DoNothing(_BaseActionInput, NoOpTriggeredActionInput, frozen=True):
  154. """Defines an automation action that intentionally does nothing."""
  155. action_type: Literal[ActionType.NO_OP] = ActionType.NO_OP
  156. no_op: Annotated[bool, BeforeValidator(default_if_none)] = True
  157. """Placeholder field which exists only to satisfy backend schema requirements.
  158. There should never be a need to set this field explicitly, as its value is ignored.
  159. """
  160. # for type annotations
  161. InputAction = Annotated[
  162. Union[
  163. SendNotification,
  164. SendWebhook,
  165. DoNothing,
  166. ],
  167. BeforeValidator(parse_input_action),
  168. Field(discriminator="action_type"),
  169. ]
  170. # for runtime type checks
  171. InputActionTypes: tuple[type, ...] = get_args(InputAction.__origin__) # type: ignore[attr-defined]
  172. __all__ = [
  173. "ActionType",
  174. *(nameof(cls) for cls in InputActionTypes),
  175. ]