events.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438
  1. """Events that trigger W&B Automations."""
  2. from __future__ import annotations
  3. from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, Union
  4. from pydantic import AfterValidator, Field
  5. from typing_extensions import get_args
  6. from wandb._pydantic import GQLBase, model_validator, pydantic_isinstance
  7. from wandb._strutils import nameof
  8. from ._filters import And, MongoLikeFilter
  9. from ._filters.expressions import FilterableField
  10. from ._filters.run_metrics import (
  11. MetricChangeFilter,
  12. MetricThresholdFilter,
  13. MetricVal,
  14. MetricZScoreFilter,
  15. )
  16. from ._filters.run_states import StateFilter, StateOperand
  17. from ._generated import FilterEventFields
  18. from ._validators import (
  19. JsonEncoded,
  20. LenientStrEnum,
  21. ensure_json,
  22. wrap_mutation_event_filter,
  23. wrap_run_event_run_filter,
  24. )
  25. from .actions import InputAction, InputActionTypes, SavedActionTypes
  26. from .scopes import ArtifactCollectionScope, AutomationScope, ProjectScope
  27. if TYPE_CHECKING:
  28. from .automations import NewAutomation
  29. # NOTE: Re-defined publicly with a more readable name for easier access
  30. class EventType(LenientStrEnum):
  31. """The type of event that triggers an automation."""
  32. # ---------------------------------------------------------------------------
  33. # Events triggered by GraphQL mutations
  34. UPDATE_ARTIFACT_ALIAS = "UPDATE_ARTIFACT_ALIAS" # NOTE: Avoid in new automations
  35. CREATE_ARTIFACT = "CREATE_ARTIFACT"
  36. ADD_ARTIFACT_ALIAS = "ADD_ARTIFACT_ALIAS"
  37. LINK_ARTIFACT = "LINK_MODEL"
  38. # Note: "LINK_MODEL" is the (legacy) value expected by the backend, but we
  39. # name it "LINK_ARTIFACT" here in the public API for clarity and consistency.
  40. # ---------------------------------------------------------------------------
  41. # Events triggered by Run conditions
  42. RUN_METRIC_THRESHOLD = "RUN_METRIC"
  43. RUN_METRIC_CHANGE = "RUN_METRIC_CHANGE"
  44. RUN_STATE = "RUN_STATE"
  45. RUN_METRIC_ZSCORE = "RUN_METRIC_ZSCORE"
  46. # ------------------------------------------------------------------------------
  47. # Saved types: for parsing response data from saved automations
  48. # Note: In GQL responses containing saved automation data, the filter is wrapped
  49. # in an extra `filter` key.
  50. class _WrappedSavedEventFilter(GQLBase): # from: TriggeringFilterEvent
  51. filter: JsonEncoded[MongoLikeFilter] = And()
  52. class _WrappedMetricThresholdFilter(GQLBase): # from: RunMetricFilter
  53. event_type: Annotated[
  54. Literal[EventType.RUN_METRIC_THRESHOLD],
  55. Field(exclude=True, repr=False),
  56. ] = EventType.RUN_METRIC_THRESHOLD
  57. threshold_filter: MetricThresholdFilter
  58. @model_validator(mode="before")
  59. @classmethod
  60. def _nest_inner_filter(cls, v: Any) -> Any:
  61. # Yeah, we've got a lot of nesting due to backend schema constraints.
  62. if pydantic_isinstance(v, MetricThresholdFilter):
  63. return cls(threshold_filter=v)
  64. return v
  65. class _WrappedMetricChangeFilter(GQLBase): # from: RunMetricFilter
  66. event_type: Annotated[
  67. Literal[EventType.RUN_METRIC_CHANGE],
  68. Field(exclude=True, repr=False),
  69. ] = EventType.RUN_METRIC_CHANGE
  70. change_filter: MetricChangeFilter
  71. @model_validator(mode="before")
  72. @classmethod
  73. def _nest_inner_filter(cls, v: Any) -> Any:
  74. # Yeah, we've got a lot of nesting due to backend schema constraints.
  75. if pydantic_isinstance(v, MetricChangeFilter):
  76. return cls(change_filter=v)
  77. return v
  78. class _WrappedMetricZScoreFilter(GQLBase): # from: RunMetricFilter
  79. event_type: Annotated[
  80. Literal[EventType.RUN_METRIC_ZSCORE],
  81. Field(exclude=True, repr=False),
  82. ] = EventType.RUN_METRIC_ZSCORE
  83. zscore_filter: MetricZScoreFilter
  84. @model_validator(mode="before")
  85. @classmethod
  86. def _nest_inner_filter(cls, v: Any) -> Any:
  87. if pydantic_isinstance(v, MetricZScoreFilter):
  88. return cls(zscore_filter=v)
  89. return v
  90. class RunMetricFilter(GQLBase): # from: TriggeringRunMetricEvent
  91. run: Annotated[
  92. JsonEncoded[MongoLikeFilter],
  93. AfterValidator(wrap_run_event_run_filter),
  94. Field(alias="run_filter"),
  95. ] = And()
  96. """Filters that must match any runs that will trigger this event."""
  97. metric: Annotated[
  98. Union[
  99. _WrappedMetricThresholdFilter,
  100. _WrappedMetricChangeFilter,
  101. _WrappedMetricZScoreFilter,
  102. ],
  103. Field(alias="run_metric_filter"),
  104. ]
  105. """Metric condition(s) that must be satisfied for this event to trigger."""
  106. # ------------------------------------------------------------------------------
  107. legacy_metric_filter: Annotated[
  108. Optional[JsonEncoded[MetricThresholdFilter]],
  109. Field(alias="metric_filter", deprecated=True),
  110. ] = None
  111. """Deprecated legacy field for defining run metric threshold events.
  112. For new automations, use the `metric` field (JSON alias `run_metric_filter`).
  113. """
  114. @model_validator(mode="before")
  115. @classmethod
  116. def _nest_metric_filter(cls, v: Any) -> Any:
  117. # If no run filter is given, automatically nest the metric filter and
  118. # let inner validators reshape further as needed.
  119. if pydantic_isinstance(
  120. v, (MetricThresholdFilter, MetricChangeFilter, MetricZScoreFilter)
  121. ):
  122. return cls(metric=v)
  123. return v
  124. class RunStateFilter(GQLBase): # from: TriggeringRunStateEvent
  125. """Represents a filter for triggering events based on changes in run states."""
  126. run: Annotated[
  127. JsonEncoded[MongoLikeFilter],
  128. AfterValidator(wrap_run_event_run_filter),
  129. Field(alias="run_filter"),
  130. ] = And()
  131. """Filters that must match any runs that will trigger this event."""
  132. state: Annotated[StateFilter, Field(alias="run_state_filter")]
  133. """Run state condition(s) that must be satisfied for this event to trigger."""
  134. @model_validator(mode="before")
  135. @classmethod
  136. def _nest_state_filter(cls, v: Any) -> Any:
  137. # If no run filter is given, automatically nest the state filter and
  138. # let inner validators reshape further as needed.
  139. if pydantic_isinstance(v, StateFilter):
  140. return cls(state=v)
  141. return v
  142. class SavedEvent(FilterEventFields): # from: FilterEventTriggeringCondition
  143. """A triggering event from a saved automation."""
  144. event_type: Annotated[EventType, Field(frozen=True)] # type: ignore[assignment]
  145. # We override the type of the `filter` field in order to enforce the expected
  146. # structure for the JSON data when validating and serializing.
  147. filter: JsonEncoded[
  148. Union[_WrappedSavedEventFilter, RunMetricFilter, RunStateFilter]
  149. ]
  150. """The condition(s) under which this event triggers an automation."""
  151. # ------------------------------------------------------------------------------
  152. # Input types: for creating or updating automations
  153. # Note: The GQL input for `eventFilter` does NOT wrap the filter in an extra
  154. # `filter` key, unlike the `eventFilter` in GQL responses for saved automations.
  155. class _BaseEventInput(GQLBase):
  156. event_type: EventType
  157. scope: AutomationScope
  158. """The scope of the event."""
  159. filter: JsonEncoded[Any]
  160. def then(self, action: InputAction) -> NewAutomation:
  161. """Define a new Automation in which this event triggers the given action."""
  162. from .automations import NewAutomation
  163. if isinstance(action, (InputActionTypes, SavedActionTypes)):
  164. return NewAutomation(event=self, action=action)
  165. raise TypeError(f"Expected a valid action, got: {nameof(type(action))!r}")
  166. def __rshift__(self, other: InputAction) -> NewAutomation:
  167. """Implement `event >> action` to define an automation."""
  168. return self.then(other)
  169. # ------------------------------------------------------------------------------
  170. # Events that trigger on specific mutations in the backend
  171. class _BaseMutationEventInput(_BaseEventInput):
  172. filter: Annotated[
  173. JsonEncoded[MongoLikeFilter],
  174. AfterValidator(wrap_mutation_event_filter),
  175. ] = And()
  176. """Additional conditions(s), if any, that are required for this event to trigger."""
  177. class OnLinkArtifact(_BaseMutationEventInput):
  178. """A new artifact is linked to a collection.
  179. Examples:
  180. Define an event that triggers when an artifact is linked to the
  181. collection "my-collection" with the alias "prod":
  182. ```python
  183. from wandb import Api
  184. from wandb.automations import OnLinkArtifact, ArtifactEvent
  185. api = Api()
  186. collection = api.artifact_collection(name="my-collection", type_name="model")
  187. event = OnLinkArtifact(
  188. scope=collection,
  189. filter=ArtifactEvent.alias.eq("prod"),
  190. )
  191. ```
  192. """
  193. event_type: Literal[EventType.LINK_ARTIFACT] = EventType.LINK_ARTIFACT
  194. class OnAddArtifactAlias(_BaseMutationEventInput):
  195. """A new alias is assigned to an artifact.
  196. Examples:
  197. Define an event that triggers whenever the alias "prod" is assigned to
  198. any artifact in the collection "my-collection":
  199. ```python
  200. from wandb import Api
  201. from wandb.automations import OnAddArtifactAlias, ArtifactEvent
  202. api = Api()
  203. collection = api.artifact_collection(name="my-collection", type_name="model")
  204. event = OnAddArtifactAlias(
  205. scope=collection,
  206. filter=ArtifactEvent.alias.eq("prod"),
  207. )
  208. ```
  209. """
  210. event_type: Literal[EventType.ADD_ARTIFACT_ALIAS] = EventType.ADD_ARTIFACT_ALIAS
  211. class OnCreateArtifact(_BaseMutationEventInput):
  212. """A new artifact is created.
  213. Examples:
  214. Define an event that triggers when a new artifact is created in the
  215. collection "my-collection":
  216. ```python
  217. from wandb import Api
  218. from wandb.automations import OnCreateArtifact
  219. api = Api()
  220. collection = api.artifact_collection(name="my-collection", type_name="model")
  221. event = OnCreateArtifact(scope=collection)
  222. ```
  223. """
  224. event_type: Literal[EventType.CREATE_ARTIFACT] = EventType.CREATE_ARTIFACT
  225. scope: ArtifactCollectionScope
  226. """The scope of the event: must be an artifact collection."""
  227. # ------------------------------------------------------------------------------
  228. # Events that trigger on run conditions
  229. class _BaseRunEventInput(_BaseEventInput):
  230. scope: ProjectScope
  231. """The scope of the event: must be a project."""
  232. class OnRunMetric(_BaseRunEventInput):
  233. """A run metric satisfies a user-defined condition.
  234. Examples:
  235. Define an event that triggers for any run in project "my-project" when
  236. the average of the last 5 values of metric "my-metric" exceeds 123.45:
  237. ```python
  238. from wandb import Api
  239. from wandb.automations import OnRunMetric, RunEvent
  240. api = Api()
  241. project = api.project(name="my-project")
  242. event = OnRunMetric(
  243. scope=project,
  244. filter=RunEvent.metric("my-metric").avg(5).gt(123.45),
  245. )
  246. ```
  247. """
  248. event_type: Literal[
  249. EventType.RUN_METRIC_THRESHOLD,
  250. EventType.RUN_METRIC_CHANGE,
  251. EventType.RUN_METRIC_ZSCORE,
  252. ]
  253. filter: JsonEncoded[RunMetricFilter]
  254. """Run and/or metric condition(s) that must be satisfied for this event to trigger."""
  255. @model_validator(mode="before")
  256. @classmethod
  257. def _infer_event_type(cls, data: Any) -> Any:
  258. """Infer the event type from the inner filter during validation.
  259. This supports both "threshold" and "change" metric filters, which can
  260. only be determined after parsing and validating the inner JSON data.
  261. """
  262. if isinstance(data, dict) and (raw_filter := data.get("filter")):
  263. # At this point, `raw_filter` may or may not be JSON-serialized
  264. parsed_filter = RunMetricFilter.model_validate_json(ensure_json(raw_filter))
  265. return {**data, "event_type": parsed_filter.metric.event_type}
  266. return data
  267. class OnRunState(_BaseRunEventInput):
  268. """A run state changes.
  269. Examples:
  270. Define an event that triggers for any run in project "my-project" when
  271. its state changes to "finished" (i.e. succeeded) or "failed":
  272. ```python
  273. from wandb import Api
  274. from wandb.automations import OnRunState
  275. api = Api()
  276. project = api.project(name="my-project")
  277. event = OnRunState(
  278. scope=project,
  279. filter=RunEvent.state.in_(["finished", "failed"]),
  280. )
  281. ```
  282. """
  283. event_type: Literal[EventType.RUN_STATE] = EventType.RUN_STATE
  284. filter: JsonEncoded[RunStateFilter]
  285. """Run state condition(s) that must be satisfied for this event to trigger."""
  286. # for type annotations
  287. InputEvent = Annotated[
  288. Union[
  289. OnLinkArtifact,
  290. OnAddArtifactAlias,
  291. OnCreateArtifact,
  292. OnRunMetric,
  293. OnRunState,
  294. ],
  295. Field(discriminator="event_type"),
  296. ]
  297. # for runtime type checks
  298. InputEventTypes: tuple[type, ...] = get_args(InputEvent.__origin__) # type: ignore[attr-defined]
  299. # ----------------------------------------------------------------------------
  300. class RunEvent:
  301. name = FilterableField(server_name="display_name")
  302. # `Run.name` is actually filtered on `Run.display_name` in the backend.
  303. # We can't reasonably expect users to know this a priori, so
  304. # automatically fix it here.
  305. state = StateOperand()
  306. @staticmethod
  307. def metric(name: str) -> MetricVal:
  308. """Define a metric filter condition."""
  309. return MetricVal(name=name)
  310. class ArtifactEvent:
  311. alias = FilterableField()
  312. MetricThresholdFilter.model_rebuild()
  313. RunMetricFilter.model_rebuild()
  314. _WrappedSavedEventFilter.model_rebuild()
  315. OnLinkArtifact.model_rebuild()
  316. OnAddArtifactAlias.model_rebuild()
  317. OnCreateArtifact.model_rebuild()
  318. OnRunMetric.model_rebuild()
  319. __all__ = [
  320. "EventType",
  321. *(nameof(cls) for cls in InputEventTypes),
  322. "RunEvent",
  323. "ArtifactEvent",
  324. "MetricThresholdFilter",
  325. "MetricChangeFilter",
  326. "MetricZScoreFilter",
  327. ]