| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287 |
- from __future__ import annotations
- from collections.abc import Collection
- from typing import Annotated, Any, Final, Optional, Protocol, TypedDict, Union
- from pydantic import Field
- from typing_extensions import Self, Unpack
- from wandb._pydantic import GQLId, GQLInput, computed_field, model_validator, to_json
- from ._filters import MongoLikeFilter
- from ._generated import (
- CreateFilterTriggerInput,
- QueueJobActionInput,
- TriggeredActionConfig,
- UpdateFilterTriggerInput,
- )
- from ._validators import parse_input_action
- from .actions import (
- ActionType,
- DoNothing,
- InputAction,
- SavedAction,
- SendNotification,
- SendWebhook,
- )
- from .automations import Automation, NewAutomation
- from .events import (
- EventType,
- InputEvent,
- RunMetricFilter,
- SavedEvent,
- _WrappedSavedEventFilter,
- )
- from .scopes import AutomationScope, ScopeType
- INVALID_INPUT_EVENTS: Final[Collection[EventType]] = (EventType.UPDATE_ARTIFACT_ALIAS,)
- """Event types that should NOT be allowed as new values on new or edited automations.
- While we forbid new/edited automations from assigning these event types,
- they're defined so that we can still parse existing automations that may use them.
- """
- INVALID_INPUT_ACTIONS: Final[Collection[ActionType]] = (ActionType.QUEUE_JOB,)
- """Action types that should NOT be allowed as new values on new or edited automations.
- While we forbid new/edited automations from assigning these action types,
- they're defined so that we can still parse existing automations that may use them.
- """
- ALWAYS_SUPPORTED_EVENTS: Final[Collection[EventType]] = frozenset(
- {
- EventType.CREATE_ARTIFACT,
- EventType.LINK_ARTIFACT,
- EventType.ADD_ARTIFACT_ALIAS,
- }
- )
- """Event types that should be supported by all current, non-EOL server versions."""
- ALWAYS_SUPPORTED_ACTIONS: Final[Collection[ActionType]] = frozenset(
- {
- ActionType.NOTIFICATION,
- ActionType.GENERIC_WEBHOOK,
- }
- )
- """Action types that should be supported by all current, non-EOL server versions."""
- class HasId(Protocol):
- id: str
- def extract_id(obj: HasId | str) -> str:
- return obj.id if hasattr(obj, "id") else obj
- # ---------------------------------------------------------------------------
- ACTION_CONFIG_KEYS: dict[ActionType, str] = {
- ActionType.NOTIFICATION: "notification_action_input",
- ActionType.GENERIC_WEBHOOK: "generic_webhook_action_input",
- ActionType.NO_OP: "no_op_action_input",
- ActionType.QUEUE_JOB: "queue_job_action_input",
- }
- class InputActionConfig(TriggeredActionConfig):
- """Prepares action configuration data for saving an automation."""
- # NOTE: `QueueJobActionInput` for defining a Launch job is deprecated,
- # so while it's allowed here to update EXISTING mutations, we don't
- # currently expose it through the public API to create NEW automations.
- queue_job_action_input: Optional[QueueJobActionInput] = None
- notification_action_input: Optional[SendNotification] = None
- generic_webhook_action_input: Optional[SendWebhook] = None
- no_op_action_input: Optional[DoNothing] = None
- def prepare_action_config_input(obj: SavedAction | InputAction) -> dict[str, Any]:
- """Nests the action input under the correct key for `TriggeredActionConfig`.
- This is necessary to conform to the schemas for:
- - `CreateFilterTriggerInput`
- - `UpdateFilterTriggerInput`
- """
- # Delegate to inner validators to convert SavedAction -> InputAction types, if needed.
- obj = parse_input_action(obj)
- return InputActionConfig(**{ACTION_CONFIG_KEYS[obj.action_type]: obj}).model_dump()
- def prepare_event_filter_input(
- obj: _WrappedSavedEventFilter | MongoLikeFilter | RunMetricFilter,
- ) -> str:
- """Unnests (if needed) and serializes an `EventFilter` input to JSON.
- This is necessary to conform to the schemas for:
- - `CreateFilterTriggerInput`
- - `UpdateFilterTriggerInput`
- """
- # Input event filters are nested one level deeper than saved event filters.
- # Note that this is NOT the case for run/run metric filters.
- #
- # Yes, this is confusing. It's also necessary to conform to under-the-hood
- # schemas and logic in the backend.
- if isinstance(obj, _WrappedSavedEventFilter):
- return to_json(obj.filter)
- return to_json(obj)
- class WriteAutomationsKwargs(TypedDict, total=False):
- """Keyword arguments that can be passed to create or update an automation."""
- name: str
- description: str
- enabled: bool
- scope: AutomationScope
- event: InputEvent
- action: InputAction
- class ValidatedCreateInput(GQLInput, extra="forbid", frozen=True):
- """Validated automation parameters, prepared for creating a new automation.
- Note: Users should never need to instantiate this class directly.
- """
- name: str
- description: Optional[str] = None
- enabled: bool = True
- # ------------------------------------------------------------------------------
- # Set on instantiation, but used to derive other fields and deliberately
- # EXCLUDED from the final GraphQL request vars
- event: Annotated[InputEvent, Field(exclude=True)]
- action: Annotated[InputAction, Field(exclude=True)]
- # ------------------------------------------------------------------------------
- # Derived fields to match the input schemas
- @computed_field
- def scope_type(self) -> ScopeType:
- return self.event.scope.scope_type
- @computed_field
- def scope_id(self) -> GQLId:
- return self.event.scope.id
- @computed_field
- def triggering_event_type(self) -> EventType:
- return self.event.event_type
- @computed_field
- def event_filter(self) -> str:
- return prepare_event_filter_input(self.event.filter)
- @computed_field
- def triggered_action_type(self) -> ActionType:
- return self.action.action_type
- @computed_field
- def triggered_action_config(self) -> dict[str, Any]:
- return prepare_action_config_input(self.action)
- # ------------------------------------------------------------------------------
- # Custom validation
- @model_validator(mode="after")
- def _forbid_legacy_event_types(self) -> Self:
- if (type_ := self.event.event_type) in INVALID_INPUT_EVENTS:
- raise ValueError(f"{type_!r} events cannot be assigned to automations.")
- return self
- @model_validator(mode="after")
- def _forbid_legacy_action_types(self) -> Self:
- if (type_ := self.action.action_type) in INVALID_INPUT_ACTIONS:
- raise ValueError(f"{type_!r} actions cannot be assigned to automations.")
- return self
- class ValidatedUpdateInput(GQLInput, extra="ignore", frozen=True):
- """Validated automation parameters, prepared for updating an existing automation.
- Accepts both InputEvent/InputAction (user-supplied for the update) and
- SavedEvent/SavedAction (carried over from the existing saved automation).
- This avoids the coercion bug where routing through Automation(event: SavedEvent)
- silently drops InputEvent filters.
- Uses extra="ignore" (rather than "forbid") because dict(Automation) includes
- fields like typename__, created_at, updated_at that are not relevant for the
- update payload.
- """
- id: GQLId
- name: Optional[str] = None
- description: Optional[str] = None
- enabled: Optional[bool] = None
- event: Annotated[Union[InputEvent, SavedEvent], Field(exclude=True)]
- action: Annotated[Union[InputAction, SavedAction], Field(exclude=True)]
- scope: Annotated[AutomationScope, Field(exclude=True)]
- # --------------------------------------------------------------------------
- # Derived fields to match the input schemas
- @computed_field
- def scope_type(self) -> ScopeType:
- return self.scope.scope_type
- @computed_field
- def scope_id(self) -> GQLId:
- return self.scope.id
- @computed_field
- def triggering_event_type(self) -> EventType:
- return self.event.event_type
- @computed_field
- def event_filter(self) -> str:
- return prepare_event_filter_input(self.event.filter)
- @computed_field
- def triggered_action_type(self) -> ActionType:
- return self.action.action_type
- @computed_field
- def triggered_action_config(self) -> dict[str, Any]:
- return prepare_action_config_input(self.action)
- # --------------------------------------------------------------------------
- # Custom validation
- @model_validator(mode="after")
- def _forbid_legacy_event_types(self) -> Self:
- if (type_ := self.event.event_type) in INVALID_INPUT_EVENTS:
- raise ValueError(f"{type_!r} events cannot be assigned to automations.")
- return self
- @model_validator(mode="after")
- def _forbid_legacy_action_types(self) -> Self:
- if (type_ := self.action.action_type) in INVALID_INPUT_ACTIONS:
- raise ValueError(f"{type_!r} actions cannot be assigned to automations.")
- return self
- def prepare_to_create(
- obj: NewAutomation | None = None,
- /,
- **kwargs: Unpack[WriteAutomationsKwargs],
- ) -> CreateFilterTriggerInput:
- """Prepares the payload to create an automation in a GraphQL request."""
- # Validate all input variables, and prepare as expected by the GraphQL request.
- # - if an object is provided, override its fields with any keyword args
- # - otherwise, instantiate from the keyword args
- obj_dict = {**obj.model_dump(), **kwargs} if obj else kwargs # type: ignore[typeddict-item]
- vobj = ValidatedCreateInput(**obj_dict)
- return CreateFilterTriggerInput.model_validate(vobj)
- def prepare_to_update(
- obj: Automation | None = None,
- /,
- **kwargs: Unpack[WriteAutomationsKwargs],
- ) -> UpdateFilterTriggerInput:
- """Prepares the payload to update an automation in a GraphQL request."""
- # Validate all input variables, and prepare as expected by the GraphQL request.
- # - if an object is provided, override its fields with any keyword args
- # - otherwise, instantiate from the keyword args
- obj_dict = dict(obj or {}) | kwargs
- vobj = ValidatedUpdateInput(**obj_dict)
- return UpdateFilterTriggerInput.model_validate(vobj)
|