_validators.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. """Internal validation utilities that are specific to artifacts."""
  2. from __future__ import annotations
  3. import json
  4. import re
  5. from dataclasses import dataclass, field, replace
  6. from functools import singledispatch, wraps
  7. from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, TypeVar
  8. from pydantic.dataclasses import dataclass as pydantic_dataclass
  9. from typing_extensions import Concatenate, ParamSpec, Self
  10. from wandb._iterutils import always_list, unique_list
  11. from wandb._pydantic import from_json
  12. from wandb._strutils import nameof
  13. from wandb.util import json_friendly_val
  14. from .exceptions import ArtifactFinalizedError, ArtifactNotLoggedError
  15. if TYPE_CHECKING:
  16. from collections.abc import Iterable
  17. from typing import Final
  18. from wandb.sdk.artifacts.artifact import Artifact
  19. ArtifactT = TypeVar("ArtifactT", bound="Artifact")
  20. SelfT = TypeVar("SelfT")
  21. R = TypeVar("R")
  22. P = ParamSpec("P")
  23. REGISTRY_PREFIX: Final[str] = "wandb-registry-"
  24. MAX_ARTIFACT_METADATA_KEYS: Final[int] = 100
  25. NAME_MAXLEN: Final[int] = 128
  26. INVALID_ARTIFACT_NAME_CHARS: Final[frozenset[str]] = frozenset("/")
  27. INVALID_URL_CHARS: Final[frozenset[str]] = frozenset("/\\#?%:\r\n")
  28. ARTIFACT_SEP_CHARS: Final[frozenset[str]] = frozenset("/:")
  29. @dataclass
  30. class LinkArtifactFields:
  31. """Keep this list updated with fields where linked and source artifacts differ."""
  32. entity_name: str
  33. project_name: str
  34. name: str
  35. version: str
  36. aliases: list[str]
  37. # These fields shouldn't be user-editable, linked artifacts always have these values
  38. _is_link: Literal[True] = field(init=False, default=True)
  39. _linked_artifacts: list[Artifact] = field(init=False, default_factory=list)
  40. @property
  41. def is_link(self) -> bool:
  42. return self._is_link
  43. @property
  44. def linked_artifacts(self) -> list[Artifact]:
  45. return self._linked_artifacts
  46. def validate_artifact_name(name: str) -> str:
  47. """Validate the artifact name, returning it if successful.
  48. Raises:
  49. ValueError: If the artifact name is invalid.
  50. """
  51. if len(name) > NAME_MAXLEN:
  52. trunc_name = f"{name[:NAME_MAXLEN]} ..."
  53. raise ValueError(
  54. f"Artifact name is longer than {NAME_MAXLEN!r} characters: {trunc_name!r}"
  55. )
  56. if INVALID_ARTIFACT_NAME_CHARS.intersection(name):
  57. raise ValueError(
  58. "Artifact names must not contain any of the following characters: "
  59. f"{', '.join(sorted(INVALID_ARTIFACT_NAME_CHARS))}. Got: {name!r}"
  60. )
  61. return name
  62. def validate_project_name(name: str) -> str:
  63. """Validate a project name according to W&B rules.
  64. Return the original name if successful.
  65. Args:
  66. name: The project name string.
  67. Raises:
  68. ValueError: If the name is invalid (too long or contains invalid characters).
  69. """
  70. if not name:
  71. raise ValueError("Project name cannot be empty")
  72. if not (registry_name := name.removeprefix(REGISTRY_PREFIX)):
  73. raise ValueError("Registry name cannot be empty")
  74. if len(name) > NAME_MAXLEN:
  75. if registry_name != name:
  76. msg = f"Invalid registry name {registry_name!r}, must be {NAME_MAXLEN - len(REGISTRY_PREFIX)!r} characters or less"
  77. else:
  78. msg = f"Invalid project name {name!r}, must be {NAME_MAXLEN!r} characters or less"
  79. raise ValueError(msg)
  80. # Find the first occurrence of any invalid character
  81. if invalid_chars := set(INVALID_URL_CHARS).intersection(name):
  82. error_name = registry_name or name
  83. invalid_chars_repr = ", ".join(sorted(map(repr, invalid_chars)))
  84. raise ValueError(
  85. f"Invalid project/registry name {error_name!r}, cannot contain characters: {invalid_chars_repr!s}"
  86. )
  87. return name
  88. def validate_aliases(aliases: Iterable[str] | str) -> list[str]:
  89. """Validate the artifact aliases and return them as a list.
  90. Raises:
  91. ValueError: If any of the aliases contain invalid characters.
  92. """
  93. aliases_list = always_list(aliases)
  94. if any(ARTIFACT_SEP_CHARS.intersection(name) for name in aliases_list):
  95. invalid_chars = ", ".join(sorted(map(repr, ARTIFACT_SEP_CHARS)))
  96. raise ValueError(
  97. f"Aliases must not contain any of the following characters: {invalid_chars}"
  98. )
  99. return aliases_list
  100. def validate_artifact_types(types: Iterable[str] | str) -> list[str]:
  101. """Validate the artifact type names and return them as a list."""
  102. types_list = always_list(types)
  103. if any(ARTIFACT_SEP_CHARS.intersection(name) for name in types_list):
  104. invalid_chars = ", ".join(sorted(map(repr, ARTIFACT_SEP_CHARS)))
  105. raise ValueError(
  106. f"Artifact types must not contain any of the following characters: {invalid_chars}"
  107. )
  108. if any(len(name) > NAME_MAXLEN for name in types_list):
  109. raise ValueError(
  110. f"Artifact types must be less than or equal to {NAME_MAXLEN!r} characters"
  111. )
  112. return types_list
  113. TAG_REGEX: re.Pattern[str] = re.compile(r"^[-\w]+( +[-\w]+)*$")
  114. """Regex pattern for valid tag names."""
  115. def validate_tags(tags: Iterable[str] | str) -> list[str]:
  116. """Validate artifact tag names and return them as a deduped list.
  117. In the case of duplicates, keep the first tag and maintain the order of
  118. appearance.
  119. Raises:
  120. ValueError: If any of the tags contain invalid characters.
  121. """
  122. tags_list = unique_list(always_list(tags))
  123. if any(not TAG_REGEX.match(tag) for tag in tags_list):
  124. raise ValueError(
  125. "Invalid tag(s). "
  126. "Tags must only contain alphanumeric characters separated by hyphens, underscores, and/or spaces."
  127. )
  128. return tags_list
  129. RESERVED_ARTIFACT_TYPE_PREFIX: Final[str] = "wandb-"
  130. """Internal, reserved artifact type prefix."""
  131. RESERVED_ARTIFACT_NAME_PREFIX_BY_TYPE: Final[dict[str, str]] = {
  132. "job": "", # Empty prefix means ALL artifact names are reserved for this artifact type
  133. "run_table": "run-",
  134. "code": "source-",
  135. }
  136. """Lookup of internal, reserved `Artifact.name` prefixes by `Artifact.type`."""
  137. def validate_artifact_type(typ: str, name: str) -> str:
  138. """Validate the artifact type and return it as a string."""
  139. if (
  140. # Check if the artifact name is disallowed, based on the artifact type
  141. (
  142. # This check MUST be against `None`, since "" disallows ALL artifact names
  143. (bad_prefix := RESERVED_ARTIFACT_NAME_PREFIX_BY_TYPE.get(typ)) is not None
  144. and name.startswith(bad_prefix)
  145. )
  146. or
  147. # Check if the artifact type is disallowed
  148. typ.startswith(RESERVED_ARTIFACT_TYPE_PREFIX)
  149. ):
  150. raise ValueError(
  151. f"Artifact type {typ!r} is reserved for internal use. "
  152. "Please use a different type."
  153. )
  154. return typ
  155. @singledispatch
  156. def validate_metadata(metadata: dict[str, Any] | str | None) -> dict[str, Any]:
  157. """Validate the artifact metadata and return it as a dict."""
  158. raise TypeError(f"Cannot parse {type(metadata)} as artifact metadata")
  159. @validate_metadata.register(type(None))
  160. @validate_metadata.register(str)
  161. def _(metadata: str | None) -> dict[str, Any]:
  162. return validate_metadata(from_json(metadata)) if metadata else {}
  163. @validate_metadata.register(dict)
  164. def _(metadata: dict[str, Any]) -> dict[str, Any]:
  165. # NOTE: The backend doesn't currently allow JS-compatible `+/-Infinity` values.
  166. # Forbid them here to avoid surprises, but revisit if we add future backend support.
  167. # Note that prior behavior already converts `NaN` values to `None` (client-side).
  168. metadata = from_json(json.dumps(json_friendly_val(metadata), allow_nan=False))
  169. if len(metadata) > MAX_ARTIFACT_METADATA_KEYS:
  170. raise ValueError(
  171. f"Artifact must not have more than {MAX_ARTIFACT_METADATA_KEYS!r} metadata keys."
  172. )
  173. return metadata
  174. def validate_ttl_duration_seconds(ttl_duration_seconds: int) -> int | None:
  175. """Validate the `ttlDurationSeconds` value from a GraphQL response.
  176. A non-positive value indicates that TTL is DISABLED (-2), which we
  177. convert to `None`.
  178. """
  179. return ttl_duration_seconds if ttl_duration_seconds > 0 else None
  180. # ----------------------------------------------------------------------------
  181. MethodT = Callable[Concatenate[SelfT, P], R]
  182. """Generic type hint for an instance method, e.g. for use with decorators."""
  183. def ensure_logged(method: MethodT[ArtifactT, P, R]) -> MethodT[ArtifactT, P, R]:
  184. """Ensure an artifact method runs only if the artifact has been logged.
  185. If the method is called on an artifact that's not logged, `ArtifactNotLoggedError`
  186. is raised.
  187. """
  188. # For clarity, use the qualified (full) name of the method
  189. method_fullname = nameof(method)
  190. @wraps(method)
  191. def wrapper(self: ArtifactT, *args: P.args, **kwargs: P.kwargs) -> R:
  192. if self.is_draft():
  193. raise ArtifactNotLoggedError(fullname=method_fullname, obj=self)
  194. return method(self, *args, **kwargs)
  195. return wrapper
  196. def ensure_not_finalized(method: MethodT[ArtifactT, P, R]) -> MethodT[ArtifactT, P, R]:
  197. """Ensure an `Artifact` method runs only if the artifact is not finalized.
  198. If the method is called on an artifact that's not logged, `ArtifactFinalizedError`
  199. is raised.
  200. """
  201. # For clarity, use the qualified (full) name of the method
  202. method_fullname = nameof(method)
  203. @wraps(method)
  204. def wrapper(self: ArtifactT, *args: P.args, **kwargs: P.kwargs) -> R:
  205. if self._final:
  206. raise ArtifactFinalizedError(fullname=method_fullname, obj=self)
  207. return method(self, *args, **kwargs)
  208. return wrapper
  209. def is_artifact_registry_project(project: str) -> bool:
  210. return project.startswith(REGISTRY_PREFIX)
  211. def remove_registry_prefix(project: str) -> str:
  212. if not is_artifact_registry_project(project):
  213. raise ValueError(
  214. f"Project {project!r} is not a registry project. Must start with: {REGISTRY_PREFIX!r}"
  215. )
  216. return project.removeprefix(REGISTRY_PREFIX)
  217. @pydantic_dataclass
  218. class ArtifactPath:
  219. name: str
  220. """The collection or artifact version name."""
  221. project: Optional[str] = None # noqa: UP045
  222. """The project name."""
  223. prefix: Optional[str] = None # noqa: UP045
  224. """Typically the entity or org name."""
  225. @classmethod
  226. def from_str(cls, path: str) -> Self:
  227. """Instantiate by parsing a string artifact path.
  228. Raises:
  229. ValueError: If the string is not a valid artifact path.
  230. """
  231. # Separate the alias first, which may itself contain slashes.
  232. # If there's no alias, note that both sep and alias will be empty.
  233. collection_path, sep, alias = path.partition(":")
  234. prefix, project = None, None # defaults, if missing
  235. if len(parts := collection_path.split("/")) == 1:
  236. name = parts[0]
  237. elif len(parts) == 2:
  238. project, name = parts
  239. elif len(parts) == 3:
  240. prefix, project, name = parts
  241. else:
  242. raise ValueError(f"Invalid artifact path: {path!r}")
  243. return cls(prefix=prefix, project=project, name=f"{name}{sep}{alias}")
  244. def to_str(self) -> str:
  245. """Returns the slash-separated string representation of the path."""
  246. ordered_parts = (self.prefix, self.project, self.name)
  247. return "/".join(part for part in ordered_parts if part)
  248. def with_defaults(
  249. self,
  250. *,
  251. prefix: str | None = None,
  252. project: str | None = None,
  253. ) -> Self:
  254. """Returns a copy of this path with missing values set to the given defaults."""
  255. return replace(
  256. self,
  257. prefix=self.prefix or prefix,
  258. project=self.project or project,
  259. )
  260. def is_registry_path(self) -> bool:
  261. """Returns True if this path appears to be a registry path."""
  262. return bool((p := self.project) and is_artifact_registry_project(p))
  263. @pydantic_dataclass
  264. class FullArtifactPath(ArtifactPath):
  265. """Same as ArtifactPath, but with all parts required."""
  266. name: str
  267. project: str
  268. prefix: str