| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297 |
- """Definition of the config object used by the Launch agent."""
- from __future__ import annotations
- from enum import Enum
- # ValidationError is imported for exception type checking purposes only.
- from pydantic import ( # type: ignore
- BaseModel,
- Field,
- ValidationError,
- root_validator,
- validator,
- )
- import wandb
- from wandb.sdk.launch.utils import (
- AZURE_BLOB_REGEX,
- AZURE_CONTAINER_REGISTRY_URI_REGEX,
- ELASTIC_CONTAINER_REGISTRY_URI_REGEX,
- GCP_ARTIFACT_REGISTRY_URI_REGEX,
- GCS_URI_RE,
- S3_URI_RE,
- )
- __all__ = [
- "ValidationError",
- "AgentConfig",
- ]
- class EnvironmentType(str, Enum):
- """Enum of valid environment types."""
- aws = "aws"
- gcp = "gcp"
- azure = "azure"
- class RegistryType(str, Enum):
- """Enum of valid registry types."""
- ecr = "ecr"
- acr = "acr"
- gcr = "gcr"
- class BuilderType(str, Enum):
- """Enum of valid builder types."""
- docker = "docker"
- kaniko = "kaniko"
- noop = "noop"
- class TargetPlatform(str, Enum):
- """Enum of valid target platforms."""
- linux_amd64 = "linux/amd64"
- linux_arm64 = "linux/arm64"
- class RegistryConfig(BaseModel):
- """Configuration for registry block.
- Note that we don't forbid extra fields here because:
- - We want to allow all fields supported by each registry
- - We will perform validation on the registry object itself later
- - Registry block is being deprecated in favor of destination field in builder
- """
- type: RegistryType | None = Field(
- None,
- description="The type of registry to use.",
- )
- uri: str | None = Field(
- None,
- description="The URI of the registry.",
- )
- @validator("uri") # type: ignore
- @classmethod
- def validate_uri(cls, uri: str) -> str:
- return validate_registry_uri(uri)
- class EnvironmentConfig(BaseModel):
- """Configuration for the environment block."""
- type: EnvironmentType | None = Field(
- None,
- description="The type of environment to use.",
- )
- region: str | None = Field(..., description="The region to use.")
- class Config:
- extra = "allow"
- @root_validator(pre=True) # type: ignore
- @classmethod
- def check_extra_fields(cls, values: dict) -> dict:
- """Check for extra fields and print a warning."""
- for key in values:
- if key not in ["type", "region"]:
- wandb.termwarn(
- f"Unrecognized field {key} in environment block. Please check your config file."
- )
- return values
- class BuilderConfig(BaseModel):
- type: BuilderType | None = Field(
- None,
- description="The type of builder to use.",
- )
- destination: str | None = Field(
- None,
- description="The destination to use for the built image. If not provided, "
- "the image will be pushed to the registry.",
- )
- platform: TargetPlatform | None = Field(
- None,
- description="The platform to use for the built image. If not provided, "
- "the platform will be detected automatically.",
- )
- build_context_store: str | None = Field(
- None,
- description="The build context store to use. Required for kaniko builds.",
- alias="build-context-store",
- )
- build_job_name: str | None = Field(
- "wandb-launch-container-build",
- description="Name prefix of the build job.",
- alias="build-job-name",
- )
- secret_name: str | None = Field(
- None,
- description="The name of the secret to use for the build job.",
- alias="secret-name",
- )
- secret_key: str | None = Field(
- None,
- description="The key of the secret to use for the build job.",
- alias="secret-key",
- )
- kaniko_image: str | None = Field(
- "gcr.io/kaniko-project/executor:latest",
- description="The image to use for the kaniko executor.",
- alias="kaniko-image",
- )
- @validator("build_context_store") # type: ignore
- @classmethod
- def validate_build_context_store(
- cls, build_context_store: str | None
- ) -> str | None:
- """Validate that the build context store is a valid container registry URI."""
- if build_context_store is None:
- return None
- for regex in [
- S3_URI_RE,
- GCS_URI_RE,
- AZURE_BLOB_REGEX,
- ]:
- if regex.match(build_context_store):
- return build_context_store
- raise ValueError(
- "Invalid build context store. Build context store must be a URI for an "
- "S3 bucket, GCS bucket, or Azure blob."
- )
- @root_validator(pre=True) # type: ignore
- @classmethod
- def validate_docker(cls, values: dict) -> dict:
- """Right now there are no required fields for docker builds."""
- return values
- @validator("destination") # type: ignore
- @classmethod
- def validate_destination(cls, destination: str | None) -> str | None:
- """Validate that the destination is a valid container registry URI."""
- if destination is None:
- return None
- return validate_registry_uri(destination)
- class AgentConfig(BaseModel):
- """Configuration for the Launch agent."""
- queues: list[str] = Field(
- default=[],
- description="The queues to use for this agent.",
- )
- entity: str | None = Field(
- description="The W&B entity to use for this agent.",
- )
- max_jobs: int | None = Field(
- 1,
- description="The maximum number of jobs to run concurrently.",
- )
- max_schedulers: int | None = Field(
- 1,
- description="The maximum number of sweep schedulers to run concurrently.",
- )
- secure_mode: bool | None = Field(
- False,
- description="Whether to use secure mode for this agent. If True, the "
- "agent will reject runs that attempt to override the entrypoint or image.",
- )
- registry: RegistryConfig | None = Field(
- None,
- description="The registry to use.",
- )
- environment: EnvironmentConfig | None = Field(
- None,
- description="The environment to use.",
- )
- builder: BuilderConfig | None = Field(
- None,
- description="The builder to use.",
- )
- verbosity: int | None = Field(
- 0,
- description="How verbose to print, 0 = default, 1 = verbose, 2 = very verbose",
- )
- stopped_run_timeout: int | None = Field(
- 60,
- description="How many seconds to wait after receiving the stop command before forcibly cancelling a run.",
- )
- class Config:
- extra = "forbid"
- def validate_registry_uri(uri: str) -> str:
- """Validate that the registry URI is a valid container registry URI.
- The URI should resolve to an image name in a container registry. The recognized
- formats are for ECR, ACR, and GCP Artifact Registry. If the URI does not match
- any of these formats, a warning is printed indicating the registry type is not
- recognized and the agent can't guarantee that images can be pushed.
- If the format is recognized but does not resolve to an image name, an
- error is raised. For example, if the URI is an ECR URI but does not include
- an image name or includes a tag as well as an image name, an error is raised.
- """
- tag_msg = (
- "Destination for built images may not include a tag, but the URI provided "
- "includes the suffix '{tag}'. Please remove the tag and try again. The agent "
- "will automatically tag each image with a unique hash of the source code."
- )
- if uri.startswith("https://"):
- uri = uri[8:]
- match = GCP_ARTIFACT_REGISTRY_URI_REGEX.match(uri)
- if match:
- if match.group("tag"):
- raise ValueError(tag_msg.format(tag=match.group("tag")))
- if not match.group("image_name"):
- raise ValueError(
- "An image name must be specified in the URI for a GCP Artifact Registry. "
- "Please provide a uri with the format "
- "'https://<region>-docker.pkg.dev/<project>/<repository>/<image>'."
- )
- return uri
- match = AZURE_CONTAINER_REGISTRY_URI_REGEX.match(uri)
- if match:
- if match.group("tag"):
- raise ValueError(tag_msg.format(tag=match.group("tag")))
- if not match.group("repository"):
- raise ValueError(
- "A repository name must be specified in the URI for an "
- "Azure Container Registry. Please provide a uri with the format "
- "'https://<registry-name>.azurecr.io/<repository>'."
- )
- return uri
- match = ELASTIC_CONTAINER_REGISTRY_URI_REGEX.match(uri)
- if match:
- if match.group("tag"):
- raise ValueError(tag_msg.format(tag=match.group("tag")))
- if not match.group("repository"):
- raise ValueError(
- "A repository name must be specified in the URI for an "
- "Elastic Container Registry. Please provide a uri with the format "
- "'https://<account-id>.dkr.ecr.<region>.amazonaws.com/<repository>'."
- )
- return uri
- wandb.termwarn(
- f"Unable to recognize registry type in URI {uri}. You are responsible "
- "for ensuring the agent can push images to this registry."
- )
- return uri
|