| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438 |
- """Events that trigger W&B Automations."""
- from __future__ import annotations
- from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, Union
- from pydantic import AfterValidator, Field
- from typing_extensions import get_args
- from wandb._pydantic import GQLBase, model_validator, pydantic_isinstance
- from wandb._strutils import nameof
- from ._filters import And, MongoLikeFilter
- from ._filters.expressions import FilterableField
- from ._filters.run_metrics import (
- MetricChangeFilter,
- MetricThresholdFilter,
- MetricVal,
- MetricZScoreFilter,
- )
- from ._filters.run_states import StateFilter, StateOperand
- from ._generated import FilterEventFields
- from ._validators import (
- JsonEncoded,
- LenientStrEnum,
- ensure_json,
- wrap_mutation_event_filter,
- wrap_run_event_run_filter,
- )
- from .actions import InputAction, InputActionTypes, SavedActionTypes
- from .scopes import ArtifactCollectionScope, AutomationScope, ProjectScope
- if TYPE_CHECKING:
- from .automations import NewAutomation
- # NOTE: Re-defined publicly with a more readable name for easier access
- class EventType(LenientStrEnum):
- """The type of event that triggers an automation."""
- # ---------------------------------------------------------------------------
- # Events triggered by GraphQL mutations
- UPDATE_ARTIFACT_ALIAS = "UPDATE_ARTIFACT_ALIAS" # NOTE: Avoid in new automations
- CREATE_ARTIFACT = "CREATE_ARTIFACT"
- ADD_ARTIFACT_ALIAS = "ADD_ARTIFACT_ALIAS"
- LINK_ARTIFACT = "LINK_MODEL"
- # Note: "LINK_MODEL" is the (legacy) value expected by the backend, but we
- # name it "LINK_ARTIFACT" here in the public API for clarity and consistency.
- # ---------------------------------------------------------------------------
- # Events triggered by Run conditions
- RUN_METRIC_THRESHOLD = "RUN_METRIC"
- RUN_METRIC_CHANGE = "RUN_METRIC_CHANGE"
- RUN_STATE = "RUN_STATE"
- RUN_METRIC_ZSCORE = "RUN_METRIC_ZSCORE"
- # ------------------------------------------------------------------------------
- # Saved types: for parsing response data from saved automations
- # Note: In GQL responses containing saved automation data, the filter is wrapped
- # in an extra `filter` key.
- class _WrappedSavedEventFilter(GQLBase): # from: TriggeringFilterEvent
- filter: JsonEncoded[MongoLikeFilter] = And()
- class _WrappedMetricThresholdFilter(GQLBase): # from: RunMetricFilter
- event_type: Annotated[
- Literal[EventType.RUN_METRIC_THRESHOLD],
- Field(exclude=True, repr=False),
- ] = EventType.RUN_METRIC_THRESHOLD
- threshold_filter: MetricThresholdFilter
- @model_validator(mode="before")
- @classmethod
- def _nest_inner_filter(cls, v: Any) -> Any:
- # Yeah, we've got a lot of nesting due to backend schema constraints.
- if pydantic_isinstance(v, MetricThresholdFilter):
- return cls(threshold_filter=v)
- return v
- class _WrappedMetricChangeFilter(GQLBase): # from: RunMetricFilter
- event_type: Annotated[
- Literal[EventType.RUN_METRIC_CHANGE],
- Field(exclude=True, repr=False),
- ] = EventType.RUN_METRIC_CHANGE
- change_filter: MetricChangeFilter
- @model_validator(mode="before")
- @classmethod
- def _nest_inner_filter(cls, v: Any) -> Any:
- # Yeah, we've got a lot of nesting due to backend schema constraints.
- if pydantic_isinstance(v, MetricChangeFilter):
- return cls(change_filter=v)
- return v
- class _WrappedMetricZScoreFilter(GQLBase): # from: RunMetricFilter
- event_type: Annotated[
- Literal[EventType.RUN_METRIC_ZSCORE],
- Field(exclude=True, repr=False),
- ] = EventType.RUN_METRIC_ZSCORE
- zscore_filter: MetricZScoreFilter
- @model_validator(mode="before")
- @classmethod
- def _nest_inner_filter(cls, v: Any) -> Any:
- if pydantic_isinstance(v, MetricZScoreFilter):
- return cls(zscore_filter=v)
- return v
- class RunMetricFilter(GQLBase): # from: TriggeringRunMetricEvent
- run: Annotated[
- JsonEncoded[MongoLikeFilter],
- AfterValidator(wrap_run_event_run_filter),
- Field(alias="run_filter"),
- ] = And()
- """Filters that must match any runs that will trigger this event."""
- metric: Annotated[
- Union[
- _WrappedMetricThresholdFilter,
- _WrappedMetricChangeFilter,
- _WrappedMetricZScoreFilter,
- ],
- Field(alias="run_metric_filter"),
- ]
- """Metric condition(s) that must be satisfied for this event to trigger."""
- # ------------------------------------------------------------------------------
- legacy_metric_filter: Annotated[
- Optional[JsonEncoded[MetricThresholdFilter]],
- Field(alias="metric_filter", deprecated=True),
- ] = None
- """Deprecated legacy field for defining run metric threshold events.
- For new automations, use the `metric` field (JSON alias `run_metric_filter`).
- """
- @model_validator(mode="before")
- @classmethod
- def _nest_metric_filter(cls, v: Any) -> Any:
- # If no run filter is given, automatically nest the metric filter and
- # let inner validators reshape further as needed.
- if pydantic_isinstance(
- v, (MetricThresholdFilter, MetricChangeFilter, MetricZScoreFilter)
- ):
- return cls(metric=v)
- return v
- class RunStateFilter(GQLBase): # from: TriggeringRunStateEvent
- """Represents a filter for triggering events based on changes in run states."""
- run: Annotated[
- JsonEncoded[MongoLikeFilter],
- AfterValidator(wrap_run_event_run_filter),
- Field(alias="run_filter"),
- ] = And()
- """Filters that must match any runs that will trigger this event."""
- state: Annotated[StateFilter, Field(alias="run_state_filter")]
- """Run state condition(s) that must be satisfied for this event to trigger."""
- @model_validator(mode="before")
- @classmethod
- def _nest_state_filter(cls, v: Any) -> Any:
- # If no run filter is given, automatically nest the state filter and
- # let inner validators reshape further as needed.
- if pydantic_isinstance(v, StateFilter):
- return cls(state=v)
- return v
- class SavedEvent(FilterEventFields): # from: FilterEventTriggeringCondition
- """A triggering event from a saved automation."""
- event_type: Annotated[EventType, Field(frozen=True)] # type: ignore[assignment]
- # We override the type of the `filter` field in order to enforce the expected
- # structure for the JSON data when validating and serializing.
- filter: JsonEncoded[
- Union[_WrappedSavedEventFilter, RunMetricFilter, RunStateFilter]
- ]
- """The condition(s) under which this event triggers an automation."""
- # ------------------------------------------------------------------------------
- # Input types: for creating or updating automations
- # Note: The GQL input for `eventFilter` does NOT wrap the filter in an extra
- # `filter` key, unlike the `eventFilter` in GQL responses for saved automations.
- class _BaseEventInput(GQLBase):
- event_type: EventType
- scope: AutomationScope
- """The scope of the event."""
- filter: JsonEncoded[Any]
- def then(self, action: InputAction) -> NewAutomation:
- """Define a new Automation in which this event triggers the given action."""
- from .automations import NewAutomation
- if isinstance(action, (InputActionTypes, SavedActionTypes)):
- return NewAutomation(event=self, action=action)
- raise TypeError(f"Expected a valid action, got: {nameof(type(action))!r}")
- def __rshift__(self, other: InputAction) -> NewAutomation:
- """Implement `event >> action` to define an automation."""
- return self.then(other)
- # ------------------------------------------------------------------------------
- # Events that trigger on specific mutations in the backend
- class _BaseMutationEventInput(_BaseEventInput):
- filter: Annotated[
- JsonEncoded[MongoLikeFilter],
- AfterValidator(wrap_mutation_event_filter),
- ] = And()
- """Additional conditions(s), if any, that are required for this event to trigger."""
- class OnLinkArtifact(_BaseMutationEventInput):
- """A new artifact is linked to a collection.
- Examples:
- Define an event that triggers when an artifact is linked to the
- collection "my-collection" with the alias "prod":
- ```python
- from wandb import Api
- from wandb.automations import OnLinkArtifact, ArtifactEvent
- api = Api()
- collection = api.artifact_collection(name="my-collection", type_name="model")
- event = OnLinkArtifact(
- scope=collection,
- filter=ArtifactEvent.alias.eq("prod"),
- )
- ```
- """
- event_type: Literal[EventType.LINK_ARTIFACT] = EventType.LINK_ARTIFACT
- class OnAddArtifactAlias(_BaseMutationEventInput):
- """A new alias is assigned to an artifact.
- Examples:
- Define an event that triggers whenever the alias "prod" is assigned to
- any artifact in the collection "my-collection":
- ```python
- from wandb import Api
- from wandb.automations import OnAddArtifactAlias, ArtifactEvent
- api = Api()
- collection = api.artifact_collection(name="my-collection", type_name="model")
- event = OnAddArtifactAlias(
- scope=collection,
- filter=ArtifactEvent.alias.eq("prod"),
- )
- ```
- """
- event_type: Literal[EventType.ADD_ARTIFACT_ALIAS] = EventType.ADD_ARTIFACT_ALIAS
- class OnCreateArtifact(_BaseMutationEventInput):
- """A new artifact is created.
- Examples:
- Define an event that triggers when a new artifact is created in the
- collection "my-collection":
- ```python
- from wandb import Api
- from wandb.automations import OnCreateArtifact
- api = Api()
- collection = api.artifact_collection(name="my-collection", type_name="model")
- event = OnCreateArtifact(scope=collection)
- ```
- """
- event_type: Literal[EventType.CREATE_ARTIFACT] = EventType.CREATE_ARTIFACT
- scope: ArtifactCollectionScope
- """The scope of the event: must be an artifact collection."""
- # ------------------------------------------------------------------------------
- # Events that trigger on run conditions
- class _BaseRunEventInput(_BaseEventInput):
- scope: ProjectScope
- """The scope of the event: must be a project."""
- class OnRunMetric(_BaseRunEventInput):
- """A run metric satisfies a user-defined condition.
- Examples:
- Define an event that triggers for any run in project "my-project" when
- the average of the last 5 values of metric "my-metric" exceeds 123.45:
- ```python
- from wandb import Api
- from wandb.automations import OnRunMetric, RunEvent
- api = Api()
- project = api.project(name="my-project")
- event = OnRunMetric(
- scope=project,
- filter=RunEvent.metric("my-metric").avg(5).gt(123.45),
- )
- ```
- """
- event_type: Literal[
- EventType.RUN_METRIC_THRESHOLD,
- EventType.RUN_METRIC_CHANGE,
- EventType.RUN_METRIC_ZSCORE,
- ]
- filter: JsonEncoded[RunMetricFilter]
- """Run and/or metric condition(s) that must be satisfied for this event to trigger."""
- @model_validator(mode="before")
- @classmethod
- def _infer_event_type(cls, data: Any) -> Any:
- """Infer the event type from the inner filter during validation.
- This supports both "threshold" and "change" metric filters, which can
- only be determined after parsing and validating the inner JSON data.
- """
- if isinstance(data, dict) and (raw_filter := data.get("filter")):
- # At this point, `raw_filter` may or may not be JSON-serialized
- parsed_filter = RunMetricFilter.model_validate_json(ensure_json(raw_filter))
- return {**data, "event_type": parsed_filter.metric.event_type}
- return data
- class OnRunState(_BaseRunEventInput):
- """A run state changes.
- Examples:
- Define an event that triggers for any run in project "my-project" when
- its state changes to "finished" (i.e. succeeded) or "failed":
- ```python
- from wandb import Api
- from wandb.automations import OnRunState
- api = Api()
- project = api.project(name="my-project")
- event = OnRunState(
- scope=project,
- filter=RunEvent.state.in_(["finished", "failed"]),
- )
- ```
- """
- event_type: Literal[EventType.RUN_STATE] = EventType.RUN_STATE
- filter: JsonEncoded[RunStateFilter]
- """Run state condition(s) that must be satisfied for this event to trigger."""
- # for type annotations
- InputEvent = Annotated[
- Union[
- OnLinkArtifact,
- OnAddArtifactAlias,
- OnCreateArtifact,
- OnRunMetric,
- OnRunState,
- ],
- Field(discriminator="event_type"),
- ]
- # for runtime type checks
- InputEventTypes: tuple[type, ...] = get_args(InputEvent.__origin__) # type: ignore[attr-defined]
- # ----------------------------------------------------------------------------
- class RunEvent:
- name = FilterableField(server_name="display_name")
- # `Run.name` is actually filtered on `Run.display_name` in the backend.
- # We can't reasonably expect users to know this a priori, so
- # automatically fix it here.
- state = StateOperand()
- @staticmethod
- def metric(name: str) -> MetricVal:
- """Define a metric filter condition."""
- return MetricVal(name=name)
- class ArtifactEvent:
- alias = FilterableField()
- MetricThresholdFilter.model_rebuild()
- RunMetricFilter.model_rebuild()
- _WrappedSavedEventFilter.model_rebuild()
- OnLinkArtifact.model_rebuild()
- OnAddArtifactAlias.model_rebuild()
- OnCreateArtifact.model_rebuild()
- OnRunMetric.model_rebuild()
- __all__ = [
- "EventType",
- *(nameof(cls) for cls in InputEventTypes),
- "RunEvent",
- "ArtifactEvent",
- "MetricThresholdFilter",
- "MetricChangeFilter",
- "MetricZScoreFilter",
- ]
|