registry.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. from __future__ import annotations
  2. from typing import TYPE_CHECKING, Any, Optional
  3. from pydantic import Field
  4. from typing_extensions import Self
  5. from wandb._pydantic import GQLId, field_validator
  6. from wandb._strutils import nameof
  7. from wandb.apis.public.registries._freezable_list import AddOnlyArtifactTypesList
  8. from wandb.apis.public.registries._utils import Visibility
  9. from .base_model import ArtifactsBase
  10. if TYPE_CHECKING:
  11. from wandb.sdk.artifacts._generated import RegistryFragment
  12. class RegistryData(ArtifactsBase):
  13. """Transport-free model for local `Registry` data.
  14. For now, this is separated from the public `Registry` class
  15. to more easily ensure continuity in the public `Registry` API.
  16. """
  17. id: GQLId = Field(frozen=True)
  18. """The unique, encoded ID for this registry."""
  19. created_at: str = Field(frozen=True)
  20. """When this registry was created."""
  21. updated_at: Optional[str] = Field(frozen=True)
  22. """When this registry was last updated."""
  23. organization: str = Field(frozen=True)
  24. """The organization of the registry."""
  25. entity: str = Field(frozen=True)
  26. """The organization entity of the registry."""
  27. name: str = Field(min_length=1) # Disallow empty strings
  28. """The name of the registry without the `wandb-registry-` project prefix."""
  29. description: Optional[str] = None
  30. """The description, if any, of the registry."""
  31. allow_all_artifact_types: bool
  32. """Whether all artifact types are allowed in the registry."""
  33. artifact_types: AddOnlyArtifactTypesList = Field(
  34. default_factory=AddOnlyArtifactTypesList
  35. )
  36. """The artifact types allowed in the registry.
  37. The meaning of this list depends on `allow_all_artifact_types`:
  38. - If True: `artifact_types` are the previously saved or currently used
  39. types in the registry.
  40. - If False: `artifact_types` are the only allowed artifact types in the
  41. registry.
  42. """
  43. visibility: Visibility = Field(alias="access")
  44. """The visibility of the registry."""
  45. @property
  46. def full_name(self) -> str:
  47. """The project name with the expected `wandb-registry-` prefix."""
  48. from wandb.sdk.artifacts._validators import REGISTRY_PREFIX
  49. return f"{REGISTRY_PREFIX}{self.name}"
  50. @field_validator("artifact_types", mode="plain")
  51. def _validate_artifact_types(cls, v: Any) -> AddOnlyArtifactTypesList:
  52. """Coerce `artifact_types` to an AddOnlyArtifactTypesList."""
  53. from wandb.sdk.artifacts._generated.fragments import (
  54. RegistryFragmentArtifactTypes,
  55. )
  56. if isinstance(v, RegistryFragmentArtifactTypes):
  57. # This is a GQL connection object, so we need to extract the node names
  58. return AddOnlyArtifactTypesList(e.node.name for e in v.edges if e.node)
  59. # By default, assume we were passed an iterable of strings
  60. return AddOnlyArtifactTypesList(v)
  61. @classmethod
  62. def from_fragment(cls, obj: RegistryFragment) -> Self:
  63. from wandb.sdk.artifacts._validators import remove_registry_prefix
  64. if not obj.entity.organization:
  65. raise ValueError(
  66. f"Unable to parse registry organization from {nameof(type(obj))!r} object"
  67. )
  68. return cls(
  69. id=obj.id,
  70. created_at=obj.created_at,
  71. updated_at=obj.updated_at,
  72. organization=obj.entity.organization.name,
  73. entity=obj.entity.name,
  74. name=remove_registry_prefix(obj.name),
  75. description=obj.description,
  76. allow_all_artifact_types=obj.allow_all_artifact_types,
  77. artifact_types=obj.artifact_types,
  78. visibility=obj.access,
  79. )