| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349 |
- """Internal validation utilities that are specific to artifacts."""
- from __future__ import annotations
- import json
- import re
- from dataclasses import dataclass, field, replace
- from functools import singledispatch, wraps
- from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, TypeVar
- from pydantic.dataclasses import dataclass as pydantic_dataclass
- from typing_extensions import Concatenate, ParamSpec, Self
- from wandb._iterutils import always_list, unique_list
- from wandb._pydantic import from_json
- from wandb._strutils import nameof
- from wandb.util import json_friendly_val
- from .exceptions import ArtifactFinalizedError, ArtifactNotLoggedError
- if TYPE_CHECKING:
- from collections.abc import Iterable
- from typing import Final
- from wandb.sdk.artifacts.artifact import Artifact
- ArtifactT = TypeVar("ArtifactT", bound="Artifact")
- SelfT = TypeVar("SelfT")
- R = TypeVar("R")
- P = ParamSpec("P")
- REGISTRY_PREFIX: Final[str] = "wandb-registry-"
- MAX_ARTIFACT_METADATA_KEYS: Final[int] = 100
- NAME_MAXLEN: Final[int] = 128
- INVALID_ARTIFACT_NAME_CHARS: Final[frozenset[str]] = frozenset("/")
- INVALID_URL_CHARS: Final[frozenset[str]] = frozenset("/\\#?%:\r\n")
- ARTIFACT_SEP_CHARS: Final[frozenset[str]] = frozenset("/:")
- @dataclass
- class LinkArtifactFields:
- """Keep this list updated with fields where linked and source artifacts differ."""
- entity_name: str
- project_name: str
- name: str
- version: str
- aliases: list[str]
- # These fields shouldn't be user-editable, linked artifacts always have these values
- _is_link: Literal[True] = field(init=False, default=True)
- _linked_artifacts: list[Artifact] = field(init=False, default_factory=list)
- @property
- def is_link(self) -> bool:
- return self._is_link
- @property
- def linked_artifacts(self) -> list[Artifact]:
- return self._linked_artifacts
- def validate_artifact_name(name: str) -> str:
- """Validate the artifact name, returning it if successful.
- Raises:
- ValueError: If the artifact name is invalid.
- """
- if len(name) > NAME_MAXLEN:
- trunc_name = f"{name[:NAME_MAXLEN]} ..."
- raise ValueError(
- f"Artifact name is longer than {NAME_MAXLEN!r} characters: {trunc_name!r}"
- )
- if INVALID_ARTIFACT_NAME_CHARS.intersection(name):
- raise ValueError(
- "Artifact names must not contain any of the following characters: "
- f"{', '.join(sorted(INVALID_ARTIFACT_NAME_CHARS))}. Got: {name!r}"
- )
- return name
- def validate_project_name(name: str) -> str:
- """Validate a project name according to W&B rules.
- Return the original name if successful.
- Args:
- name: The project name string.
- Raises:
- ValueError: If the name is invalid (too long or contains invalid characters).
- """
- if not name:
- raise ValueError("Project name cannot be empty")
- if not (registry_name := name.removeprefix(REGISTRY_PREFIX)):
- raise ValueError("Registry name cannot be empty")
- if len(name) > NAME_MAXLEN:
- if registry_name != name:
- msg = f"Invalid registry name {registry_name!r}, must be {NAME_MAXLEN - len(REGISTRY_PREFIX)!r} characters or less"
- else:
- msg = f"Invalid project name {name!r}, must be {NAME_MAXLEN!r} characters or less"
- raise ValueError(msg)
- # Find the first occurrence of any invalid character
- if invalid_chars := set(INVALID_URL_CHARS).intersection(name):
- error_name = registry_name or name
- invalid_chars_repr = ", ".join(sorted(map(repr, invalid_chars)))
- raise ValueError(
- f"Invalid project/registry name {error_name!r}, cannot contain characters: {invalid_chars_repr!s}"
- )
- return name
- def validate_aliases(aliases: Iterable[str] | str) -> list[str]:
- """Validate the artifact aliases and return them as a list.
- Raises:
- ValueError: If any of the aliases contain invalid characters.
- """
- aliases_list = always_list(aliases)
- if any(ARTIFACT_SEP_CHARS.intersection(name) for name in aliases_list):
- invalid_chars = ", ".join(sorted(map(repr, ARTIFACT_SEP_CHARS)))
- raise ValueError(
- f"Aliases must not contain any of the following characters: {invalid_chars}"
- )
- return aliases_list
- def validate_artifact_types(types: Iterable[str] | str) -> list[str]:
- """Validate the artifact type names and return them as a list."""
- types_list = always_list(types)
- if any(ARTIFACT_SEP_CHARS.intersection(name) for name in types_list):
- invalid_chars = ", ".join(sorted(map(repr, ARTIFACT_SEP_CHARS)))
- raise ValueError(
- f"Artifact types must not contain any of the following characters: {invalid_chars}"
- )
- if any(len(name) > NAME_MAXLEN for name in types_list):
- raise ValueError(
- f"Artifact types must be less than or equal to {NAME_MAXLEN!r} characters"
- )
- return types_list
- TAG_REGEX: re.Pattern[str] = re.compile(r"^[-\w]+( +[-\w]+)*$")
- """Regex pattern for valid tag names."""
- def validate_tags(tags: Iterable[str] | str) -> list[str]:
- """Validate artifact tag names and return them as a deduped list.
- In the case of duplicates, keep the first tag and maintain the order of
- appearance.
- Raises:
- ValueError: If any of the tags contain invalid characters.
- """
- tags_list = unique_list(always_list(tags))
- if any(not TAG_REGEX.match(tag) for tag in tags_list):
- raise ValueError(
- "Invalid tag(s). "
- "Tags must only contain alphanumeric characters separated by hyphens, underscores, and/or spaces."
- )
- return tags_list
- RESERVED_ARTIFACT_TYPE_PREFIX: Final[str] = "wandb-"
- """Internal, reserved artifact type prefix."""
- RESERVED_ARTIFACT_NAME_PREFIX_BY_TYPE: Final[dict[str, str]] = {
- "job": "", # Empty prefix means ALL artifact names are reserved for this artifact type
- "run_table": "run-",
- "code": "source-",
- }
- """Lookup of internal, reserved `Artifact.name` prefixes by `Artifact.type`."""
- def validate_artifact_type(typ: str, name: str) -> str:
- """Validate the artifact type and return it as a string."""
- if (
- # Check if the artifact name is disallowed, based on the artifact type
- (
- # This check MUST be against `None`, since "" disallows ALL artifact names
- (bad_prefix := RESERVED_ARTIFACT_NAME_PREFIX_BY_TYPE.get(typ)) is not None
- and name.startswith(bad_prefix)
- )
- or
- # Check if the artifact type is disallowed
- typ.startswith(RESERVED_ARTIFACT_TYPE_PREFIX)
- ):
- raise ValueError(
- f"Artifact type {typ!r} is reserved for internal use. "
- "Please use a different type."
- )
- return typ
- @singledispatch
- def validate_metadata(metadata: dict[str, Any] | str | None) -> dict[str, Any]:
- """Validate the artifact metadata and return it as a dict."""
- raise TypeError(f"Cannot parse {type(metadata)} as artifact metadata")
- @validate_metadata.register(type(None))
- @validate_metadata.register(str)
- def _(metadata: str | None) -> dict[str, Any]:
- return validate_metadata(from_json(metadata)) if metadata else {}
- @validate_metadata.register(dict)
- def _(metadata: dict[str, Any]) -> dict[str, Any]:
- # NOTE: The backend doesn't currently allow JS-compatible `+/-Infinity` values.
- # Forbid them here to avoid surprises, but revisit if we add future backend support.
- # Note that prior behavior already converts `NaN` values to `None` (client-side).
- metadata = from_json(json.dumps(json_friendly_val(metadata), allow_nan=False))
- if len(metadata) > MAX_ARTIFACT_METADATA_KEYS:
- raise ValueError(
- f"Artifact must not have more than {MAX_ARTIFACT_METADATA_KEYS!r} metadata keys."
- )
- return metadata
- def validate_ttl_duration_seconds(ttl_duration_seconds: int) -> int | None:
- """Validate the `ttlDurationSeconds` value from a GraphQL response.
- A non-positive value indicates that TTL is DISABLED (-2), which we
- convert to `None`.
- """
- return ttl_duration_seconds if ttl_duration_seconds > 0 else None
- # ----------------------------------------------------------------------------
- MethodT = Callable[Concatenate[SelfT, P], R]
- """Generic type hint for an instance method, e.g. for use with decorators."""
- def ensure_logged(method: MethodT[ArtifactT, P, R]) -> MethodT[ArtifactT, P, R]:
- """Ensure an artifact method runs only if the artifact has been logged.
- If the method is called on an artifact that's not logged, `ArtifactNotLoggedError`
- is raised.
- """
- # For clarity, use the qualified (full) name of the method
- method_fullname = nameof(method)
- @wraps(method)
- def wrapper(self: ArtifactT, *args: P.args, **kwargs: P.kwargs) -> R:
- if self.is_draft():
- raise ArtifactNotLoggedError(fullname=method_fullname, obj=self)
- return method(self, *args, **kwargs)
- return wrapper
- def ensure_not_finalized(method: MethodT[ArtifactT, P, R]) -> MethodT[ArtifactT, P, R]:
- """Ensure an `Artifact` method runs only if the artifact is not finalized.
- If the method is called on an artifact that's not logged, `ArtifactFinalizedError`
- is raised.
- """
- # For clarity, use the qualified (full) name of the method
- method_fullname = nameof(method)
- @wraps(method)
- def wrapper(self: ArtifactT, *args: P.args, **kwargs: P.kwargs) -> R:
- if self._final:
- raise ArtifactFinalizedError(fullname=method_fullname, obj=self)
- return method(self, *args, **kwargs)
- return wrapper
- def is_artifact_registry_project(project: str) -> bool:
- return project.startswith(REGISTRY_PREFIX)
- def remove_registry_prefix(project: str) -> str:
- if not is_artifact_registry_project(project):
- raise ValueError(
- f"Project {project!r} is not a registry project. Must start with: {REGISTRY_PREFIX!r}"
- )
- return project.removeprefix(REGISTRY_PREFIX)
- @pydantic_dataclass
- class ArtifactPath:
- name: str
- """The collection or artifact version name."""
- project: Optional[str] = None # noqa: UP045
- """The project name."""
- prefix: Optional[str] = None # noqa: UP045
- """Typically the entity or org name."""
- @classmethod
- def from_str(cls, path: str) -> Self:
- """Instantiate by parsing a string artifact path.
- Raises:
- ValueError: If the string is not a valid artifact path.
- """
- # Separate the alias first, which may itself contain slashes.
- # If there's no alias, note that both sep and alias will be empty.
- collection_path, sep, alias = path.partition(":")
- prefix, project = None, None # defaults, if missing
- if len(parts := collection_path.split("/")) == 1:
- name = parts[0]
- elif len(parts) == 2:
- project, name = parts
- elif len(parts) == 3:
- prefix, project, name = parts
- else:
- raise ValueError(f"Invalid artifact path: {path!r}")
- return cls(prefix=prefix, project=project, name=f"{name}{sep}{alias}")
- def to_str(self) -> str:
- """Returns the slash-separated string representation of the path."""
- ordered_parts = (self.prefix, self.project, self.name)
- return "/".join(part for part in ordered_parts if part)
- def with_defaults(
- self,
- *,
- prefix: str | None = None,
- project: str | None = None,
- ) -> Self:
- """Returns a copy of this path with missing values set to the given defaults."""
- return replace(
- self,
- prefix=self.prefix or prefix,
- project=self.project or project,
- )
- def is_registry_path(self) -> bool:
- """Returns True if this path appears to be a registry path."""
- return bool((p := self.project) and is_artifact_registry_project(p))
- @pydantic_dataclass
- class FullArtifactPath(ArtifactPath):
- """Same as ArtifactPath, but with all parts required."""
- name: str
- project: str
- prefix: str
|