| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041 |
- from __future__ import annotations
- import colorsys
- import contextlib
- import dataclasses
- import enum
- import importlib
- import importlib.util
- import itertools
- import json
- import logging
- import math
- import numbers
- import os
- import pathlib
- import platform
- import queue
- import random
- import re
- import secrets
- import shlex
- import socket
- import string
- import sys
- import tarfile
- import tempfile
- import threading
- import time
- import types
- import urllib
- from collections.abc import Iterable, Mapping, Sequence
- from dataclasses import asdict, is_dataclass
- from datetime import date, datetime, timedelta
- from gzip import GzipFile
- from importlib import import_module
- from sys import getsizeof
- from types import ModuleType
- from typing import IO, TYPE_CHECKING, Callable, TextIO, Union
- from typing_extensions import Any, Generator, TypeGuard, TypeVar, deprecated
- import wandb
- import wandb.env
- from wandb.errors import (
- AuthenticationError,
- CommError,
- UsageError,
- WandbCoreNotAvailableError,
- )
- from wandb.errors.term import terminput
- from wandb.sdk.lib import filesystem, runid
- from wandb.sdk.lib.json_util import dump, dumps
- from wandb.sdk.lib.paths import FilePathStr, StrPath
- if TYPE_CHECKING:
- from requests import Response
- from wandb.sdk.artifacts.artifact import Artifact
- CheckRetryFnType = Callable[[Exception], Union[bool, timedelta]]
- T = TypeVar("T")
- logger = logging.getLogger(__name__)
- _not_importable = set()
- LAUNCH_JOB_ARTIFACT_SLOT_NAME = "_wandb_job"
- MAX_LINE_BYTES = (10 << 20) - (100 << 10) # imposed by back end
- IS_GIT = os.path.exists(os.path.join(os.path.dirname(__file__), "..", ".git"))
- # From https://docs.docker.com/engine/reference/commandline/tag/
- # "Name components may contain lowercase letters, digits and separators.
- # A separator is defined as a period, one or two underscores, or one or more dashes.
- # A name component may not start or end with a separator."
- DOCKER_IMAGE_NAME_SEPARATOR = "(?:__|[._]|[-]+)"
- RE_DOCKER_IMAGE_NAME_SEPARATOR_START = re.compile("^" + DOCKER_IMAGE_NAME_SEPARATOR)
- RE_DOCKER_IMAGE_NAME_SEPARATOR_END = re.compile(DOCKER_IMAGE_NAME_SEPARATOR + "$")
- RE_DOCKER_IMAGE_NAME_SEPARATOR_REPEAT = re.compile(DOCKER_IMAGE_NAME_SEPARATOR + "{2,}")
- RE_DOCKER_IMAGE_NAME_CHARS = re.compile(r"[^a-z0-9._\-]")
- POW_10_BYTES = [
- ("B", 10**0),
- ("KB", 10**3),
- ("MB", 10**6),
- ("GB", 10**9),
- ("TB", 10**12),
- ("PB", 10**15),
- ("EB", 10**18),
- ]
- POW_2_BYTES = [
- ("B", 2**0),
- ("KiB", 2**10),
- ("MiB", 2**20),
- ("GiB", 2**30),
- ("TiB", 2**40),
- ("PiB", 2**50),
- ("EiB", 2**60),
- ]
- def vendor_setup() -> Callable:
- """Create a function that restores user paths after vendor imports.
- This enables us to use the vendor directory for packages we don't depend on. Call
- the returned function after imports are complete. If you don't you may modify the
- user's path which is never good.
- Usage:
- ```python
- reset_path = vendor_setup()
- # do any vendor imports...
- reset_path()
- ```
- """
- original_path = [directory for directory in sys.path]
- def reset_import_path() -> None:
- sys.path = original_path
- parent_dir = os.path.abspath(os.path.dirname(__file__))
- vendor_dir = os.path.join(parent_dir, "vendor")
- vendor_packages = (
- "gql-0.2.0",
- "graphql-core-1.1",
- "watchdog_0_9_0",
- "promise-2.3.0",
- )
- package_dirs = [os.path.join(vendor_dir, p) for p in vendor_packages]
- for p in [vendor_dir] + package_dirs:
- if p not in sys.path:
- sys.path.insert(1, p)
- return reset_import_path
- def vendor_import(name: str) -> Any:
- reset_path = vendor_setup()
- module = import_module(name)
- reset_path()
- return module
- class LazyModuleState:
- def __init__(self, module: types.ModuleType) -> None:
- self.module = module
- self.load_started = False
- self.lock = threading.RLock()
- def load(self) -> None:
- with self.lock:
- if self.load_started:
- return
- self.load_started = True
- assert self.module.__spec__ is not None
- assert self.module.__spec__.loader is not None
- self.module.__spec__.loader.exec_module(self.module)
- self.module.__class__ = types.ModuleType
- # Set the submodule as an attribute on the parent module
- # This enables access to the submodule via normal attribute access.
- parent, _, child = self.module.__name__.rpartition(".")
- if parent:
- parent_module = sys.modules[parent]
- setattr(parent_module, child, self.module)
- class LazyModule(types.ModuleType):
- def __getattribute__(self, name: str) -> Any:
- state = object.__getattribute__(self, "__lazy_module_state__")
- state.load()
- return object.__getattribute__(self, name)
- def __setattr__(self, name: str, value: Any) -> None:
- state = object.__getattribute__(self, "__lazy_module_state__")
- state.load()
- object.__setattr__(self, name, value)
- def __delattr__(self, name: str) -> None:
- state = object.__getattribute__(self, "__lazy_module_state__")
- state.load()
- object.__delattr__(self, name)
- def import_module_lazy(name: str) -> types.ModuleType:
- """Import a module lazily, only when it is used.
- Inspired by importlib.util.LazyLoader, but improved so that the module loading is
- thread-safe. Circular dependency between modules can lead to a deadlock if the two
- modules are loaded from different threads.
- :param (str) name: Dot-separated module path. E.g., 'scipy.stats'.
- """
- try:
- return sys.modules[name]
- except KeyError:
- spec = importlib.util.find_spec(name)
- if spec is None:
- raise ModuleNotFoundError
- module = importlib.util.module_from_spec(spec)
- module.__lazy_module_state__ = LazyModuleState(module) # type: ignore
- module.__class__ = LazyModule
- sys.modules[name] = module
- return module
- def get_module(
- name: str,
- required: str | None = None,
- lazy: bool = True,
- ) -> Any:
- """Return module or None. Absolute import is required.
- :param (str) name: Dot-separated module path. E.g., 'scipy.stats'.
- :param (str) required: A string to raise a ValueError if missing
- :param (bool) lazy: If True, return a lazy loader for the module.
- :return: (module|None) If import succeeds, the module will be returned.
- """
- if name not in _not_importable:
- try:
- if not lazy:
- return import_module(name)
- else:
- return import_module_lazy(name)
- except Exception:
- _not_importable.add(name)
- msg = f"Error importing optional module {name}"
- if required:
- logger.exception(msg)
- if required and name in _not_importable:
- raise wandb.Error(required)
- def get_optional_module(name) -> importlib.ModuleInterface | None: # type: ignore
- return get_module(name)
- np = get_module("numpy")
- pd_available = False
- pandas_spec = importlib.util.find_spec("pandas")
- if pandas_spec is not None:
- pd_available = True
- # TODO: Revisit these limits
- VALUE_BYTES_LIMIT = 100000
- @deprecated("Read the `app_url` setting from the appropriate Settings object.")
- def app_url(api_url: str) -> str:
- """Returns the URL for the W&B UI without a trailing slash."""
- if app_url := wandb.env.get_app_url():
- return str(app_url.strip("/"))
- return api_to_app_url(api_url)
- def api_to_app_url(api_url: str) -> str:
- """Convert the API URL to an app (UI) URL.
- Unlike the deprecated `app_url()`, this is a pure function: it does
- not consult environment variables.
- """
- if "://api.wandb.test" in api_url:
- # dev mode
- return api_url.replace("://api.", "://app.").strip("/")
- elif "://api.wandb." in api_url:
- # cloud
- return api_url.replace("://api.", "://").strip("/")
- elif "://api." in api_url:
- # onprem cloud
- return api_url.replace("://api.", "://app.").strip("/")
- # wandb/local
- return api_url
- def get_full_typename(o: Any) -> Any:
- """Determine types based on type names.
- Avoids needing to to import (and therefore depend on) PyTorch, TensorFlow, etc.
- """
- instance_name = o.__class__.__module__ + "." + o.__class__.__name__
- if instance_name in ["builtins.module", "__builtin__.module"]:
- return o.__name__
- else:
- return instance_name
- def get_h5_typename(o: Any) -> Any:
- typename = get_full_typename(o)
- if is_tf_tensor_typename(typename):
- return "tensorflow.Tensor"
- elif is_pytorch_tensor_typename(typename):
- return "torch.Tensor"
- else:
- return o.__class__.__module__.split(".")[0] + "." + o.__class__.__name__
- def is_uri(string: str) -> bool:
- parsed_uri = urllib.parse.urlparse(string)
- return len(parsed_uri.scheme) > 0
- def local_file_uri_to_path(uri: str) -> str:
- """Convert URI to local filesystem path.
- No-op if the uri does not have the expected scheme.
- """
- path = urllib.parse.urlparse(uri).path if uri.startswith("file:") else uri
- return urllib.request.url2pathname(path)
- def get_local_path_or_none(path_or_uri: str) -> str | None:
- """Return path if local, None otherwise.
- Return None if the argument is a local path (not a scheme or file:///). Otherwise
- return `path_or_uri`.
- """
- parsed_uri = urllib.parse.urlparse(path_or_uri)
- if (
- len(parsed_uri.scheme) == 0
- or parsed_uri.scheme == "file"
- and len(parsed_uri.netloc) == 0
- ):
- return local_file_uri_to_path(path_or_uri)
- else:
- return None
- def check_windows_valid_filename(path: int | str) -> bool:
- r"""Verify that the given path does not contain any invalid characters for a Windows filename.
- Windows filenames cannot contain the following characters:
- < > : " \ / | ? *
- For more details, refer to the official documentation:
- https://learn.microsoft.com/en-us/windows/win32/fileio/naming-a-file#naming-conventions
- Args:
- path: The file path to check, which can be either an integer or a string.
- Returns:
- bool: True if the path does not contain any invalid characters, False otherwise.
- """
- return not bool(re.search(r'[<>:"\\?*]', path)) # type: ignore
- def make_file_path_upload_safe(path: str) -> str:
- r"""Makes the provide path safe for file upload.
- The filename is made safe by:
- 1. Removing any leading slashes to prevent writing to absolute paths
- 2. Replacing '.' and '..' with underscores to prevent directory traversal attacks
- Raises:
- ValueError: If running on Windows and the key contains invalid filename characters
- (\, :, *, ?, ", <, >, |)
- """
- sys_platform = platform.system()
- if sys_platform == "Windows" and not check_windows_valid_filename(path):
- raise ValueError(
- f"Path {path} is invalid. Please remove invalid filename characters"
- r' (\, :, *, ?, ", <, >, |)'
- )
- # On Windows, convert forward slashes to backslashes.
- # This ensures that the key is a valid filename on Windows.
- if sys_platform == "Windows":
- path = str(path).replace("/", os.sep)
- # Avoid writing to absolute paths by striping any leading slashes.
- # The key has already been validated for windows operating systems in util.check_windows_valid_filename
- # This ensures the key does not contain invalid characters for windows, such as '\' or ':'.
- # So we can check only for '/' in the key.
- path = path.lstrip(os.sep)
- # Avoid directory traversal by replacing dots with underscores.
- paths = path.split(os.sep)
- safe_paths = [
- p.replace(".", "_") if p in (os.curdir, os.pardir) else p for p in paths
- ]
- # Recombine the key into a relative path.
- return os.sep.join(safe_paths)
- def make_tarfile(
- output_filename: str,
- source_dir: str,
- archive_name: str,
- custom_filter: Callable | None = None,
- ) -> None:
- # Helper for filtering out modification timestamps
- def _filter_timestamps(tar_info: tarfile.TarInfo) -> tarfile.TarInfo | None:
- tar_info.mtime = 0
- return tar_info if custom_filter is None else custom_filter(tar_info)
- descriptor, unzipped_filename = tempfile.mkstemp()
- try:
- with tarfile.open(unzipped_filename, "w") as tar:
- tar.add(source_dir, arcname=archive_name, filter=_filter_timestamps)
- # When gzipping the tar, don't include the tar's filename or modification time in the
- # zipped archive (see https://docs.python.org/3/library/gzip.html#gzip.GzipFile)
- with (
- open(output_filename, "wb") as out_file,
- GzipFile(filename="", fileobj=out_file, mode="wb", mtime=0) as gzipped_tar,
- open(unzipped_filename, "rb") as tar_file,
- ):
- gzipped_tar.write(tar_file.read())
- finally:
- os.close(descriptor)
- os.remove(unzipped_filename)
- def is_tf_tensor(obj: Any) -> bool:
- import tensorflow # type: ignore
- return isinstance(obj, tensorflow.Tensor)
- def is_tf_tensor_typename(typename: str) -> bool:
- return typename.startswith("tensorflow.") and (
- "Tensor" in typename or "Variable" in typename
- )
- def is_tf_eager_tensor_typename(typename: str) -> bool:
- return typename.startswith("tensorflow.") and ("EagerTensor" in typename)
- def is_pytorch_tensor(obj: Any) -> bool:
- import torch # type: ignore
- return isinstance(obj, torch.Tensor)
- def is_pytorch_tensor_typename(typename: str) -> bool:
- return typename.startswith("torch.") and (
- "Tensor" in typename or "Variable" in typename
- )
- def is_jax_tensor_typename(typename: str) -> bool:
- return typename.startswith("jaxlib.") and "Array" in typename
- def get_jax_tensor(obj: Any) -> Any:
- import jax # type: ignore
- return jax.device_get(obj)
- def is_fastai_tensor_typename(typename: str) -> bool:
- return typename.startswith("fastai.") and ("Tensor" in typename)
- def is_pandas_data_frame_typename(typename: str) -> bool:
- return typename.startswith("pandas.") and "DataFrame" in typename
- def is_matplotlib_typename(typename: str) -> bool:
- return typename.startswith("matplotlib.")
- def is_plotly_typename(typename: str) -> bool:
- return typename.startswith("plotly.")
- def is_plotly_figure_typename(typename: str) -> bool:
- return typename.startswith("plotly.") and typename.endswith(".Figure")
- def is_numpy_array(obj: Any) -> bool:
- return np and isinstance(obj, np.ndarray)
- def is_pandas_data_frame(obj: Any) -> bool:
- if pd_available:
- import pandas as pd
- return isinstance(obj, pd.DataFrame)
- else:
- return is_pandas_data_frame_typename(get_full_typename(obj))
- def ensure_matplotlib_figure(obj: Any) -> Any:
- """Extract the current figure from a matplotlib object.
- Return the object itself if it's a figure.
- Raises ValueError if the object can't be converted.
- """
- import matplotlib # type: ignore
- from matplotlib.figure import Figure # type: ignore
- # there are combinations of plotly and matplotlib versions that don't work well together,
- # this patches matplotlib to add a removed method that plotly assumes exists
- from matplotlib.spines import Spine # type: ignore
- def is_frame_like(self: Any) -> bool:
- """Return True if directly on axes frame.
- This is useful for determining if a spine is the edge of an
- old style MPL plot. If so, this function will return True.
- """
- position = self._position or ("outward", 0.0)
- if isinstance(position, str):
- if position == "center":
- position = ("axes", 0.5)
- elif position == "zero":
- position = ("data", 0)
- if len(position) != 2:
- raise ValueError("position should be 2-tuple")
- position_type, amount = position # type: ignore
- return bool(position_type == "outward" and amount == 0)
- Spine.is_frame_like = is_frame_like
- if obj == matplotlib.pyplot:
- obj = obj.gcf()
- elif (not isinstance(obj, Figure)) and hasattr(obj, "figure"):
- obj = obj.figure
- # Some matplotlib objects have a figure function
- if not isinstance(obj, Figure):
- raise ValueError(
- "Only matplotlib.pyplot or matplotlib.pyplot.Figure objects are accepted."
- )
- return obj
- def matplotlib_to_plotly(obj: Any) -> Any:
- obj = ensure_matplotlib_figure(obj)
- tools = get_module(
- "plotly.tools",
- required=(
- "plotly is required to log interactive plots, install with: "
- "`pip install plotly` or convert the plot to an image with `wandb.Image(plt)`"
- ),
- )
- return tools.mpl_to_plotly(obj)
- def matplotlib_contains_images(obj: Any) -> bool:
- obj = ensure_matplotlib_figure(obj)
- return any(len(ax.images) > 0 for ax in obj.axes)
- def _numpy_generic_convert(obj: Any) -> Any:
- obj = obj.item()
- if isinstance(obj, float) and math.isnan(obj):
- obj = None
- elif isinstance(obj, np.generic) and (
- obj.dtype.kind == "f" or obj.dtype == "bfloat16"
- ):
- # obj is a numpy float with precision greater than that of native python float
- # (i.e., float96 or float128) or it is of custom type such as bfloat16.
- # in these cases, obj.item() does not return a native
- # python float (in the first case - to avoid loss of precision,
- # so we need to explicitly cast this down to a 64bit float)
- obj = float(obj)
- return obj
- def _sanitize_numpy_keys(
- d: dict,
- visited: dict[int, dict] | None = None,
- ) -> tuple[dict, bool]:
- """Returns a dictionary where all NumPy keys are converted.
- Args:
- d: The dictionary to sanitize.
- Returns:
- A sanitized dictionary, and a boolean indicating whether anything was
- changed.
- """
- out: dict[Any, Any] = dict()
- converted = False
- # Work with recursive dictionaries: if a dictionary has already been
- # converted, reuse its converted value to retain the recursive structure
- # of the input.
- if visited is None:
- visited = {id(d): out}
- elif id(d) in visited:
- return visited[id(d)], False
- visited[id(d)] = out
- for key, value in d.items():
- if isinstance(value, dict):
- value, converted_value = _sanitize_numpy_keys(value, visited)
- converted |= converted_value
- if isinstance(key, np.generic):
- key = _numpy_generic_convert(key)
- converted = True
- out[key] = value
- return out, converted
- def json_friendly( # noqa: C901
- obj: Any,
- ) -> tuple[Any, bool] | tuple[None | str | float, bool]:
- """Convert an object into something that's more becoming of JSON."""
- converted = True
- typename = get_full_typename(obj)
- if is_tf_eager_tensor_typename(typename):
- obj = obj.numpy()
- elif is_tf_tensor_typename(typename):
- try:
- obj = obj.eval()
- except RuntimeError:
- obj = obj.numpy()
- elif is_pytorch_tensor_typename(typename) or is_fastai_tensor_typename(typename):
- try:
- if obj.requires_grad:
- obj = obj.detach()
- except AttributeError:
- pass # before 0.4 is only present on variables
- try:
- obj = obj.data
- except RuntimeError:
- pass # happens for Tensors before 0.4
- if obj.size():
- obj = obj.cpu().detach().numpy()
- else:
- return obj.item(), True
- elif is_jax_tensor_typename(typename):
- obj = get_jax_tensor(obj)
- if is_numpy_array(obj):
- if obj.size == 1:
- obj = obj.flatten()[0]
- elif obj.size <= 32:
- obj = obj.tolist()
- elif np and isinstance(obj, np.generic):
- obj = _numpy_generic_convert(obj)
- elif isinstance(obj, bytes):
- obj = obj.decode("utf-8")
- elif isinstance(obj, (datetime, date)):
- obj = obj.isoformat()
- elif callable(obj):
- obj = (
- f"{obj.__module__}.{obj.__qualname__}"
- if hasattr(obj, "__qualname__") and hasattr(obj, "__module__")
- else str(obj)
- )
- elif isinstance(obj, float) and math.isnan(obj):
- obj = None
- elif isinstance(obj, dict) and np:
- obj, converted = _sanitize_numpy_keys(obj)
- elif isinstance(obj, set):
- # set is not json serializable, so we convert it to tuple
- obj = tuple(obj)
- elif isinstance(obj, enum.Enum):
- obj = obj.name
- else:
- converted = False
- if getsizeof(obj) > VALUE_BYTES_LIMIT:
- wandb.termwarn(
- f"Serializing object of type {type(obj).__name__} that is {getsizeof(obj)} bytes"
- )
- return obj, converted
- def json_friendly_val(val: Any) -> Any:
- """Make any value (including dict, slice, sequence, dataclass) JSON friendly."""
- converted: dict | list
- if isinstance(val, dict):
- converted = {}
- for key, value in val.items():
- converted[key] = json_friendly_val(value)
- return converted
- if isinstance(val, slice):
- converted = dict(
- slice_start=val.start, slice_step=val.step, slice_stop=val.stop
- )
- return converted
- val, _ = json_friendly(val)
- if isinstance(val, Sequence) and not isinstance(val, str):
- converted = []
- for value in val:
- converted.append(json_friendly_val(value))
- return converted
- if is_dataclass(val) and not isinstance(val, type):
- converted = asdict(val)
- return json_friendly_val(converted)
- else:
- if val.__class__.__module__ not in ("builtins", "__builtin__"):
- val = str(val)
- return val
- def alias_is_version_index(alias: str) -> bool:
- return len(alias) >= 2 and alias[0] == "v" and alias[1:].isnumeric()
- def convert_plots(obj: Any) -> Any:
- if is_matplotlib_typename(get_full_typename(obj)):
- tools = get_module(
- "plotly.tools",
- required=(
- "plotly is required to log interactive plots, install with: "
- "`pip install plotly` or convert the plot to an image with `wandb.Image(plt)`"
- ),
- )
- obj = tools.mpl_to_plotly(obj)
- if is_plotly_typename(get_full_typename(obj)):
- return {"_type": "plotly", "plot": obj.to_plotly_json()}
- else:
- return obj
- def maybe_compress_history(obj: Any) -> tuple[Any, bool]:
- if np and isinstance(obj, np.ndarray) and obj.size > 32:
- return wandb.Histogram(obj, num_bins=32).to_json(), True
- else:
- return obj, False
- def maybe_compress_summary(obj: Any, h5_typename: str) -> tuple[Any, bool]:
- if np and isinstance(obj, np.ndarray) and obj.size > 32:
- return (
- {
- "_type": h5_typename, # may not be ndarray
- "var": np.var(obj).item(),
- "mean": np.mean(obj).item(),
- "min": np.amin(obj).item(),
- "max": np.amax(obj).item(),
- "10%": np.percentile(obj, 10),
- "25%": np.percentile(obj, 25),
- "75%": np.percentile(obj, 75),
- "90%": np.percentile(obj, 90),
- "size": obj.size,
- },
- True,
- )
- else:
- return obj, False
- def launch_browser(attempt_launch_browser: bool = True) -> bool:
- """Decide if we should launch a browser."""
- _display_variables = ["DISPLAY", "WAYLAND_DISPLAY", "MIR_SOCKET"]
- _webbrowser_names_blocklist = ["www-browser", "lynx", "links", "elinks", "w3m"]
- import webbrowser
- launch_browser = attempt_launch_browser
- if launch_browser:
- if "linux" in sys.platform and not any(
- os.getenv(var) for var in _display_variables
- ):
- launch_browser = False
- try:
- browser = webbrowser.get()
- if hasattr(browser, "name") and browser.name in _webbrowser_names_blocklist:
- launch_browser = False
- except webbrowser.Error:
- launch_browser = False
- return launch_browser
- def generate_id(length: int = 8) -> str:
- # Do not use this; use wandb.sdk.lib.runid.generate_id instead.
- # This is kept only for legacy code.
- return runid.generate_id(length)
- def parse_tfjob_config() -> Any:
- """Attempt to parse TFJob config, returning False if it can't find it."""
- if os.getenv("TF_CONFIG"):
- try:
- return json.loads(os.environ["TF_CONFIG"])
- except ValueError:
- return False
- else:
- return False
- class WandBJSONEncoder(json.JSONEncoder):
- """A JSON Encoder that handles some extra types."""
- def default(self, obj: Any) -> Any:
- if hasattr(obj, "json_encode"):
- return obj.json_encode()
- # if hasattr(obj, 'to_json'):
- # return obj.to_json()
- tmp_obj, converted = json_friendly(obj)
- if converted:
- return tmp_obj
- return json.JSONEncoder.default(self, obj)
- class WandBJSONEncoderOld(json.JSONEncoder):
- """A JSON Encoder that handles some extra types."""
- def default(self, obj: Any) -> Any:
- tmp_obj, converted = json_friendly(obj)
- tmp_obj, compressed = maybe_compress_summary(tmp_obj, get_h5_typename(obj))
- if converted:
- return tmp_obj
- return json.JSONEncoder.default(self, tmp_obj)
- class WandBHistoryJSONEncoder(json.JSONEncoder):
- """A JSON Encoder that handles some extra types.
- This encoder turns numpy like objects with a size > 32 into histograms.
- """
- def default(self, obj: Any) -> Any:
- obj, converted = json_friendly(obj)
- obj, compressed = maybe_compress_history(obj)
- if converted:
- return obj
- return json.JSONEncoder.default(self, obj)
- class JSONEncoderUncompressed(json.JSONEncoder):
- """A JSON Encoder that handles some extra types.
- This encoder turns numpy like objects with a size > 32 into histograms.
- """
- def default(self, obj: Any) -> Any:
- if is_numpy_array(obj):
- return obj.tolist()
- elif np and isinstance(obj, np.number):
- return obj.item()
- elif np and isinstance(obj, np.generic):
- obj = obj.item()
- return json.JSONEncoder.default(self, obj)
- def json_dump_safer(obj: Any, fp: IO[str], **kwargs: Any) -> None:
- """Convert obj to json, with some extra encodable types."""
- return dump(obj, fp, cls=WandBJSONEncoder, **kwargs)
- def json_dumps_safer(obj: Any, **kwargs: Any) -> str:
- """Convert obj to json, with some extra encodable types."""
- return dumps(obj, cls=WandBJSONEncoder, **kwargs)
- # This is used for dumping raw json into files
- def json_dump_uncompressed(obj: Any, fp: IO[str], **kwargs: Any) -> None:
- """Convert obj to json, with some extra encodable types."""
- return dump(obj, fp, cls=JSONEncoderUncompressed, **kwargs)
- def json_dumps_safer_history(obj: Any, **kwargs: Any) -> str:
- """Convert obj to json, with some extra encodable types, including histograms."""
- return dumps(obj, cls=WandBHistoryJSONEncoder, **kwargs)
- def make_json_if_not_number(
- v: int | float | str | Mapping | Sequence,
- ) -> int | float | str:
- """If v is not a basic type convert it to json."""
- if isinstance(v, (float, int)):
- return v
- return json_dumps_safer(v)
- def make_safe_for_json(obj: Any) -> Any:
- """Replace invalid json floats with strings. Also converts to lists and dicts."""
- if isinstance(obj, Mapping):
- return {k: make_safe_for_json(v) for k, v in obj.items()}
- elif isinstance(obj, str):
- # str's are Sequence, so we need to short-circuit
- return obj
- elif isinstance(obj, Sequence):
- return [make_safe_for_json(v) for v in obj]
- elif isinstance(obj, float):
- # W&B backend and UI handle these strings
- if obj != obj: # standard way to check for NaN
- return "NaN"
- elif obj == float("+inf"):
- return "Infinity"
- elif obj == float("-inf"):
- return "-Infinity"
- return obj
- def no_retry_4xx(e: Exception) -> bool:
- from requests import HTTPError
- if not isinstance(e, HTTPError):
- return True
- assert e.response is not None
- if not (400 <= e.response.status_code < 500) or e.response.status_code == 429:
- return True
- body = json.loads(e.response.content)
- raise UsageError(body["errors"][0]["message"])
- def parse_backend_error_messages(response: Response) -> list[str]:
- """Returns error messages stored in a backend response.
- If the response is not in an expected format, an empty list is returned.
- Args:
- response: A response to an HTTP request to the W&B server.
- """
- from requests import JSONDecodeError
- try:
- data = response.json()
- except JSONDecodeError:
- return []
- if not isinstance(data, dict):
- return []
- # Backend error values are returned in one of two ways:
- # - A string containing the error message
- # - A JSON object with a "message" field that is a string
- def get_message(error: Any) -> str | None:
- if isinstance(error, str):
- return error
- elif (
- isinstance(error, dict)
- and (message := error.get("message"))
- and isinstance(message, str)
- ):
- return message
- else:
- return None
- # The response can contain an "error" field with a single error
- # or an "errors" field with a list of errors.
- if error := data.get("error"):
- message = get_message(error)
- return [message] if message else []
- elif (errors := data.get("errors")) and isinstance(errors, list):
- messages: list[str] = []
- for error in errors:
- message = get_message(error)
- if message:
- messages.append(message)
- return messages
- else:
- return []
- def no_retry_auth(e: Any) -> bool:
- from requests import HTTPError
- if hasattr(e, "exception"):
- e = e.exception
- if not isinstance(e, HTTPError):
- return True
- if e.response is None:
- return True
- # Don't retry bad request errors; raise immediately
- if e.response.status_code in (400, 409):
- return False
- # Retry all non-forbidden/unauthorized/not-found errors.
- if e.response.status_code not in (401, 403, 404):
- return True
- # Crash with more informational message on forbidden/unauthorized errors.
- # UnauthorizedError
- if e.response.status_code == 401:
- raise AuthenticationError(
- "The API key you provided is either invalid or missing. "
- f"If the `{wandb.env.API_KEY}` environment variable is set, make sure it is correct. "
- "Otherwise, to resolve this issue, you may try running the 'wandb login --relogin' command. "
- "If you are using a local server, make sure that you're using the correct hostname. "
- "If you're not sure, you can try logging in again using the 'wandb login --relogin --host [hostname]' command."
- f"(Error {e.response.status_code}: {e.response.reason})"
- )
- # ForbiddenError
- if e.response.status_code == 403:
- if wandb.run:
- raise CommError(f"Permission denied to access {wandb.run.path}")
- else:
- raise CommError(
- "It appears that you do not have permission to access the requested resource. "
- "Please reach out to the project owner to grant you access. "
- "If you have the correct permissions, verify that there are no issues with your networking setup."
- f"(Error {e.response.status_code}: {e.response.reason})"
- )
- # NotFoundError
- if e.response.status_code == 404:
- # If error message is empty, raise a more generic NotFoundError message.
- if parse_backend_error_messages(e.response):
- return False
- else:
- raise LookupError(
- f"Failed to find resource. Please make sure you have the correct resource path. "
- f"(Error {e.response.status_code}: {e.response.reason})"
- )
- return False
- def check_retry_conflict(e: Any) -> bool | None:
- """Check if the exception is a conflict type so it can be retried.
- Returns:
- True - Should retry this operation
- False - Should not retry this operation
- None - No decision, let someone else decide
- """
- from requests import HTTPError
- if hasattr(e, "exception"):
- e = e.exception
- if (
- isinstance(e, HTTPError)
- and e.response is not None
- and e.response.status_code == 409
- ):
- return True
- return None
- def check_retry_conflict_or_gone(e: Any) -> bool | None:
- """Check if the exception is a conflict or gone type, so it can be retried or not.
- Returns:
- True - Should retry this operation
- False - Should not retry this operation
- None - No decision, let someone else decide
- """
- from requests import HTTPError
- if hasattr(e, "exception"):
- e = e.exception
- if isinstance(e, HTTPError) and e.response is not None:
- if e.response.status_code == 409:
- return True
- if e.response.status_code == 410:
- return False
- return None
- def make_check_retry_fn(
- fallback_retry_fn: CheckRetryFnType,
- check_fn: Callable[[Exception], bool | None],
- check_timedelta: timedelta | None = None,
- ) -> CheckRetryFnType:
- """Return a check_retry_fn which can be used by lib.Retry().
- Args:
- fallback_fn: Use this function if check_fn didn't decide if a retry should happen.
- check_fn: Function which returns bool if retry should happen or None if unsure.
- check_timedelta: Optional retry timeout if we check_fn matches the exception
- """
- def check_retry_fn(e: Exception) -> bool | timedelta:
- check = check_fn(e)
- if check is None:
- return fallback_retry_fn(e)
- if check is False:
- return False
- if check_timedelta:
- return check_timedelta
- return True
- return check_retry_fn
- def find_runner(program: str) -> None | list | list[str]:
- """Return a command that will run program.
- Args:
- program: The string name of the program to try to run.
- Returns:
- commandline list of strings to run the program (eg. with subprocess.call()) or None
- """
- if os.path.isfile(program) and not os.access(program, os.X_OK):
- # program is a path to a non-executable file
- try:
- opened = open(program)
- except OSError: # PermissionError doesn't exist in 2.7
- return None
- first_line = opened.readline().strip()
- if first_line.startswith("#!"):
- return shlex.split(first_line[2:])
- if program.endswith(".py"):
- return [sys.executable]
- return None
- def downsample(values: Sequence, target_length: int) -> list:
- """Downsample 1d values to target_length, including start and end.
- Algorithm just rounds index down.
- Values can be any sequence, including a generator.
- """
- if not target_length > 1:
- raise UsageError("target_length must be > 1")
- values = list(values)
- if len(values) < target_length:
- return values
- ratio = float(len(values) - 1) / (target_length - 1)
- result = []
- for i in range(target_length):
- result.append(values[int(i * ratio)])
- return result
- def has_num(dictionary: Mapping, key: Any) -> bool:
- return key in dictionary and isinstance(dictionary[key], numbers.Number)
- def docker_image_regex(image: str) -> Any:
- """Regex match for valid docker image names."""
- if image:
- return re.match(
- r"^(?:(?=[^:\/]{1,253})(?!-)[a-zA-Z0-9-]{1,63}(?<!-)(?:\.(?!-)[a-zA-Z0-9-]{1,63}(?<!-))*(?::[0-9]{1,5})?/)?((?![._-])(?:[a-z0-9._-]*)(?<![._-])(?:/(?![._-])[a-z0-9._-]*(?<![._-]))*)(?::(?![.-])[a-zA-Z0-9_.-]{1,128})?$",
- image,
- )
- return None
- def image_from_docker_args(args: list[str]) -> str | None:
- """Scan docker run args and attempt to find the most likely docker image argument.
- It excludes any arguments that start with a dash, and the argument after it if it
- isn't a boolean switch. This can be improved, we currently fallback gracefully when
- this fails.
- """
- bool_args = [
- "-t",
- "--tty",
- "--rm",
- "--privileged",
- "--oom-kill-disable",
- "--no-healthcheck",
- "-i",
- "--interactive",
- "--init",
- "--help",
- "--detach",
- "-d",
- "--sig-proxy",
- "-it",
- "-itd",
- ]
- last_flag = -2
- last_arg = ""
- possible_images = []
- if len(args) > 0 and args[0] == "run":
- args.pop(0)
- for i, arg in enumerate(args):
- if arg.startswith("-"):
- last_flag = i
- last_arg = arg
- elif "@sha256:" in arg:
- # Because our regex doesn't match digests
- possible_images.append(arg)
- elif docker_image_regex(arg):
- if last_flag == i - 2:
- possible_images.append(arg)
- elif "=" in last_arg:
- possible_images.append(arg)
- elif last_arg in bool_args and last_flag == i - 1:
- possible_images.append(arg)
- most_likely = None
- for img in possible_images:
- if ":" in img or "@" in img or "/" in img:
- most_likely = img
- break
- if most_likely is None and len(possible_images) > 0:
- most_likely = possible_images[0]
- return most_likely
- def load_yaml(file: Any) -> Any:
- import yaml
- return yaml.safe_load(file)
- def image_id_from_k8s() -> str | None:
- """Ping the k8s metadata service for the image id.
- Specify the KUBERNETES_NAMESPACE environment variable if your pods are not in the
- default namespace:
- - name: KUBERNETES_NAMESPACE valueFrom:
- fieldRef:
- fieldPath: metadata.namespace
- """
- token_path = "/var/run/secrets/kubernetes.io/serviceaccount/token"
- if not os.path.exists(token_path):
- return None
- try:
- with open(token_path) as token_file:
- token = token_file.read()
- except FileNotFoundError:
- logger.warning(f"Token file not found at {token_path}.")
- return None
- except PermissionError as e:
- current_uid = os.getuid()
- warning = (
- f"Unable to read the token file at {token_path} due to permission error ({e})."
- f"The current user id is {current_uid}. "
- "Consider changing the securityContext to run the container as the current user."
- )
- logger.warning(warning)
- wandb.termwarn(warning)
- return None
- if not token:
- return None
- import requests
- k8s_server = "https://{}:{}/api/v1/namespaces/{}/pods/{}".format(
- os.getenv("KUBERNETES_SERVICE_HOST"),
- os.getenv("KUBERNETES_PORT_443_TCP_PORT"),
- os.getenv("KUBERNETES_NAMESPACE", "default"),
- os.getenv("HOSTNAME"),
- )
- try:
- res = requests.get(
- k8s_server,
- verify="/var/run/secrets/kubernetes.io/serviceaccount/ca.crt",
- timeout=3,
- headers={"Authorization": f"Bearer {token}"},
- )
- res.raise_for_status()
- except requests.RequestException:
- return None
- try:
- return str( # noqa: B005
- res.json()["status"]["containerStatuses"][0]["imageID"]
- ).strip("docker-pullable://")
- except (ValueError, KeyError, IndexError):
- logger.exception("Error checking kubernetes for image id")
- return None
- def async_call(target: Callable, timeout: int | float | None = None) -> Callable:
- """Wrap a method to run in the background with an optional timeout.
- Returns a new method that will call the original with any args, waiting for upto
- timeout seconds. This new method blocks on the original and returns the result or
- None if timeout was reached, along with the thread. You can check thread.is_alive()
- to determine if a timeout was reached. If an exception is thrown in the thread, we
- reraise it.
- """
- q: queue.Queue = queue.Queue()
- def wrapped_target(q: queue.Queue, *args: Any, **kwargs: Any) -> Any:
- try:
- q.put(target(*args, **kwargs))
- except Exception as e:
- q.put(e)
- def wrapper(
- *args: Any, **kwargs: Any
- ) -> tuple[Exception, threading.Thread] | tuple[None, threading.Thread]:
- thread = threading.Thread(
- target=wrapped_target, args=(q,) + args, kwargs=kwargs
- )
- thread.daemon = True
- thread.start()
- try:
- result = q.get(True, timeout)
- except queue.Empty:
- return None, thread
- if isinstance(result, Exception):
- raise result.with_traceback(sys.exc_info()[2])
- return result, thread
- return wrapper
- def read_many_from_queue(
- q: queue.Queue, max_items: int, queue_timeout: int | float
- ) -> list:
- try:
- item = q.get(True, queue_timeout)
- except queue.Empty:
- return []
- items = [item]
- for _ in range(max_items):
- try:
- item = q.get_nowait()
- except queue.Empty:
- return items
- items.append(item)
- return items
- def stopwatch_now() -> float:
- """Get a time value for interval comparisons.
- When possible it is a monotonic clock to prevent backwards time issues.
- """
- return time.monotonic()
- def class_colors(class_count: int) -> list[list[int]]:
- # make class 0 black, and the rest equally spaced fully saturated hues
- return [[0, 0, 0]] + [
- colorsys.hsv_to_rgb(i / (class_count - 1.0), 1.0, 1.0) # type: ignore
- for i in range(class_count - 1)
- ]
- def prompt_choices(
- choices: Sequence[str],
- input_timeout: float | None = None,
- ) -> str:
- """Prompt the user to choose from a list of options.
- If exactly one choice is given, it is returned immediately.
- Raises:
- TimeoutError: if input_timeout is specified and expires.
- NotATerminalError: if the output device is not capable.
- KeyboardInterrupt: if the user aborts by pressing Ctrl+C.
- """
- if len(choices) == 1:
- return choices[0]
- for i, choice_str in enumerate(choices):
- wandb.termlog(f"({i + 1}) {choice_str}")
- while True:
- choice = terminput("Enter your choice: ", timeout=input_timeout)
- # If the user presses enter without typing anything, try again.
- if not choice:
- continue
- idx = -1
- with contextlib.suppress(ValueError):
- idx = int(choice) - 1
- if idx < 0 or idx > len(choices) - 1:
- wandb.termwarn("Invalid choice")
- continue
- result = choices[idx]
- wandb.termlog(f"You chose {result!r}")
- return result
- def guess_data_type(shape: Sequence[int], risky: bool = False) -> str | None:
- """Infer the type of data based on the shape of the tensors.
- Args:
- shape (Sequence[int]): The shape of the data
- risky(bool): some guesses are more likely to be wrong.
- """
- # (samples,) or (samples,logits)
- if len(shape) in (1, 2):
- return "label"
- # Assume image mask like fashion mnist: (no color channel)
- # This is risky because RNNs often have 3 dim tensors: batch, time, channels
- if risky and len(shape) == 3:
- return "image"
- if len(shape) == 4:
- if shape[-1] in (1, 3, 4):
- # (samples, height, width, Y \ RGB \ RGBA)
- return "image"
- else:
- # (samples, height, width, logits)
- return "segmentation_mask"
- return None
- def download_file_from_url(
- dest_path: str, source_url: str, api_key: str | None = None
- ) -> None:
- import requests
- auth = ("api", api_key or "")
- response = requests.get(
- source_url,
- auth=auth,
- stream=True,
- timeout=5,
- )
- response.raise_for_status()
- if os.sep in dest_path:
- filesystem.mkdir_exists_ok(os.path.dirname(dest_path))
- with fsync_open(dest_path, "wb") as file:
- for data in response.iter_content(chunk_size=1024):
- file.write(data)
- def download_file_into_memory(source_url: str, api_key: str | None = None) -> bytes:
- import requests
- auth = ("api", api_key or "")
- response = requests.get(
- source_url,
- auth=auth,
- stream=True,
- timeout=5,
- )
- response.raise_for_status()
- return response.content
- def isatty(ob: IO) -> bool:
- return hasattr(ob, "isatty") and ob.isatty()
- def to_human_size(size: int, units: list[tuple[str, Any]] | None = None) -> str:
- units = units or POW_10_BYTES
- unit, value = units[0]
- factor = round(float(size) / value, 1)
- return (
- f"{factor}{unit}"
- if factor < 1024 or len(units) == 1
- else to_human_size(size, units[1:])
- )
- def from_human_size(size: str, units: list[tuple[str, Any]] | None = None) -> int:
- units = units or POW_10_BYTES
- units_dict = {unit.upper(): value for (unit, value) in units}
- regex = re.compile(
- r"(\d+\.?\d*)\s*({})?".format("|".join(units_dict.keys())), re.IGNORECASE
- )
- match = re.match(regex, size)
- if not match:
- raise ValueError("size must be of the form `10`, `10B` or `10 B`.")
- factor, unit = (
- float(match.group(1)),
- units_dict[match.group(2).upper()] if match.group(2) else 1,
- )
- return int(factor * unit)
- def auto_project_name(program: str | None) -> str:
- # if we're in git, set project name to git repo name + relative path within repo
- from wandb.sdk.lib.gitlib import GitRepo
- root_dir = GitRepo().root_dir
- if root_dir is None:
- return "uncategorized"
- # On windows, GitRepo returns paths in unix style, but os.path is windows
- # style. Coerce here.
- root_dir = to_native_slash_path(root_dir)
- repo_name = os.path.basename(root_dir)
- if program is None:
- return str(repo_name)
- if not os.path.isabs(program):
- program = os.path.join(os.curdir, program)
- prog_dir = os.path.dirname(os.path.abspath(program))
- if not prog_dir.startswith(root_dir):
- return str(repo_name)
- project = repo_name
- sub_path = os.path.relpath(prog_dir, root_dir)
- if sub_path != ".":
- project += "-" + sub_path
- return str(project.replace(os.sep, "_"))
- # TODO(hugh): Deprecate version here and use wandb/sdk/lib/paths.py
- def to_forward_slash_path(path: str) -> str:
- if platform.system() == "Windows":
- path = path.replace("\\", "/")
- return path
- # TODO(hugh): Deprecate version here and use wandb/sdk/lib/paths.py
- def to_native_slash_path(path: str) -> FilePathStr:
- return FilePathStr(path.replace("/", os.sep))
- def check_and_warn_old(files: list[str]) -> bool:
- if "wandb-metadata.json" in files:
- wandb.termwarn("These runs were logged with a previous version of wandb.")
- wandb.termwarn(
- "Run pip install wandb<0.10.0 to get the old library and sync your runs."
- )
- return True
- return False
- class ImportMetaHook:
- def __init__(self) -> None:
- self.modules: dict[str, ModuleType] = dict()
- self.on_import: dict[str, list] = dict()
- def add(self, fullname: str, on_import: Callable) -> None:
- self.on_import.setdefault(fullname, []).append(on_import)
- def install(self) -> None:
- sys.meta_path.insert(0, self) # type: ignore
- def uninstall(self) -> None:
- sys.meta_path.remove(self) # type: ignore
- def find_module(
- self, fullname: str, path: str | None = None
- ) -> ImportMetaHook | None:
- if fullname in self.on_import:
- return self
- return None
- def load_module(self, fullname: str) -> ModuleType:
- self.uninstall()
- mod = importlib.import_module(fullname)
- self.install()
- self.modules[fullname] = mod
- on_imports = self.on_import.get(fullname)
- if on_imports:
- for f in on_imports:
- f()
- return mod
- def get_modules(self) -> tuple[str, ...]:
- return tuple(self.modules)
- def get_module(self, module: str) -> ModuleType:
- return self.modules[module]
- _import_hook: ImportMetaHook | None = None
- def add_import_hook(fullname: str, on_import: Callable) -> None:
- global _import_hook
- if _import_hook is None:
- _import_hook = ImportMetaHook()
- _import_hook.install()
- _import_hook.add(fullname, on_import)
- def host_from_path(path: str | None) -> str:
- """Return the host of the path."""
- url = urllib.parse.urlparse(path)
- return str(url.netloc)
- def uri_from_path(path: str | None) -> str:
- """Return the URI of the path."""
- url = urllib.parse.urlparse(path)
- uri = url.path if url.path[0] != "/" else url.path[1:]
- return str(uri)
- def is_unicode_safe(stream: TextIO) -> bool:
- """Return True if the stream supports UTF-8."""
- encoding = getattr(stream, "encoding", None)
- return encoding.lower() in {"utf-8", "utf_8"} if encoding else False
- def rand_alphanumeric(
- length: int = 8, rand: ModuleType | random.Random | None = None
- ) -> str:
- wandb.termerror("rand_alphanumeric is deprecated, use 'secrets.token_hex'")
- rand = rand or random
- return "".join(rand.choice("0123456789ABCDEF") for _ in range(length))
- @contextlib.contextmanager
- def fsync_open(
- path: StrPath, mode: str = "w", encoding: str | None = None
- ) -> Generator[IO[Any], None, None]:
- """Open a path for I/O and guarantee that the file is flushed and synced."""
- with open(path, mode, encoding=encoding) as f:
- yield f
- f.flush()
- os.fsync(f.fileno())
- def _is_kaggle() -> bool:
- return (
- os.getenv("KAGGLE_KERNEL_RUN_TYPE") is not None
- or "kaggle_environments" in sys.modules
- )
- def _has_internet() -> bool:
- """Returns whether we have internet access.
- Checks for internet access by attempting to open a DNS connection to
- Google's root servers.
- """
- try:
- s = socket.create_connection(("8.8.8.8", 53), 0.5)
- s.close()
- except OSError:
- return False
- return True
- def _is_likely_kaggle() -> bool:
- # Telemetry to mark first runs from Kagglers.
- return (
- _is_kaggle()
- or os.path.exists(
- os.path.expanduser(os.path.join("~", ".kaggle", "kaggle.json"))
- )
- or "kaggle" in sys.modules
- )
- def _is_databricks() -> bool:
- # check if we are running inside a databricks notebook by
- # inspecting sys.modules, searching for dbutils and verifying that
- # it has the appropriate structure
- if "dbutils" in sys.modules:
- dbutils = sys.modules["dbutils"]
- if hasattr(dbutils, "shell"):
- shell = dbutils.shell
- if hasattr(shell, "sc"):
- sc = shell.sc
- if hasattr(sc, "appName"):
- return bool(sc.appName == "Databricks Shell")
- return False
- def _is_py_requirements_or_dockerfile(path: str) -> bool:
- file = os.path.basename(path)
- return (
- file.endswith(".py")
- or file.startswith("Dockerfile")
- or file == "requirements.txt"
- )
- def artifact_to_json(artifact: Artifact) -> dict[str, Any]:
- return {
- "_type": "artifactVersion",
- "_version": "v0",
- "id": artifact.id,
- "version": artifact.source_version,
- "sequenceName": artifact.source_name.split(":")[0],
- "usedAs": artifact.use_as,
- }
- def check_dict_contains_nested_artifact(d: dict, nested: bool = False) -> bool:
- for item in d.values():
- if isinstance(item, dict):
- contains_artifacts = check_dict_contains_nested_artifact(item, True)
- if contains_artifacts:
- return True
- elif (isinstance(item, wandb.Artifact) or _is_artifact_string(item)) and nested:
- return True
- return False
- def load_json_yaml_dict(config: str) -> Any:
- import yaml
- ext = os.path.splitext(config)[-1]
- if ext == ".json":
- with open(config) as f:
- return json.load(f)
- elif ext == ".yaml":
- with open(config) as f:
- return yaml.safe_load(f)
- else:
- try:
- return json.loads(config)
- except ValueError:
- return None
- def _parse_entity_project_item(path: str) -> tuple:
- """Parse paths with the following formats: {item}, {project}/{item}, & {entity}/{project}/{item}.
- Args:
- path: `str`, input path; must be between 0 and 3 in length.
- Returns:
- tuple of length 3 - (item, project, entity)
- Example:
- alias, project, entity = _parse_entity_project_item("myproj/mymodel:best")
- assert entity == ""
- assert project == "myproj"
- assert alias == "mymodel:best"
- """
- words = path.split("/")
- if len(words) > 3:
- raise ValueError(
- "Invalid path: must be str the form {item}, {project}/{item}, or {entity}/{project}/{item}"
- )
- padded_words = [""] * (3 - len(words)) + words
- return tuple(reversed(padded_words))
- def _resolve_aliases(aliases: str | Iterable[str] | None) -> list[str]:
- """Add the 'latest' alias and ensure that all aliases are unique.
- Takes in `aliases` which can be None, str, or List[str] and returns list[str].
- Ensures that "latest" is always present in the returned list.
- Args:
- aliases: `aliases: str | Iterable[str] | None`
- Returns:
- list[str], with "latest" always present.
- Usage:
- ```python
- aliases = _resolve_aliases(["best", "dev"])
- assert aliases == ["best", "dev", "latest"]
- aliases = _resolve_aliases("boom")
- assert aliases == ["boom", "latest"]
- ```
- """
- aliases = aliases or ["latest"]
- if isinstance(aliases, str):
- aliases = [aliases]
- try:
- return list(set(aliases) | {"latest"})
- except TypeError as exc:
- raise ValueError("`aliases` must be Iterable or None") from exc
- def _is_artifact_object(v: Any) -> TypeGuard[wandb.Artifact]:
- return isinstance(v, wandb.Artifact)
- def _is_artifact_string(v: Any) -> TypeGuard[str]:
- return isinstance(v, str) and v.startswith("wandb-artifact://")
- def _is_artifact_version_weave_dict(v: Any) -> TypeGuard[dict]:
- return isinstance(v, dict) and v.get("_type") == "artifactVersion"
- def _is_artifact_representation(v: Any) -> bool:
- return (
- _is_artifact_object(v)
- or _is_artifact_string(v)
- or _is_artifact_version_weave_dict(v)
- )
- def parse_artifact_string(v: str) -> tuple[str, str | None, bool]:
- if not v.startswith("wandb-artifact://"):
- raise ValueError(f"Invalid artifact string: {v}")
- parsed_v = v[len("wandb-artifact://") :]
- base_uri = None
- url_info = urllib.parse.urlparse(parsed_v)
- if url_info.scheme != "":
- base_uri = f"{url_info.scheme}://{url_info.netloc}"
- parts = url_info.path.split("/")[1:]
- else:
- parts = parsed_v.split("/")
- if parts[0] == "_id":
- # for now can't fetch paths but this will be supported in the future
- # when we allow passing typed media objects, this can be extended
- # to include paths
- return parts[1], base_uri, True
- if len(parts) < 3:
- raise ValueError(f"Invalid artifact string: {v}")
- # for now can't fetch paths but this will be supported in the future
- # when we allow passing typed media objects, this can be extended
- # to include paths
- entity, project, name_and_alias_or_version = parts[:3]
- return f"{entity}/{project}/{name_and_alias_or_version}", base_uri, False
- def _get_max_cli_version() -> str | None:
- max_cli_version = wandb.api.max_cli_version()
- return str(max_cli_version) if max_cli_version is not None else None
- def ensure_text(
- string: str | bytes, encoding: str = "utf-8", errors: str = "strict"
- ) -> str:
- """Coerce s to str."""
- if isinstance(string, bytes):
- return string.decode(encoding, errors)
- elif isinstance(string, str):
- return string
- else:
- raise TypeError(f"not expecting type {type(string)!r}")
- def make_artifact_name_safe(name: str) -> str:
- """Make an artifact name safe for use in artifacts."""
- # artifact names may only contain alphanumeric characters, dashes, underscores, and dots.
- cleaned = re.sub(r"[^a-zA-Z0-9_\-.]", "_", name)
- if len(cleaned) <= 128:
- return cleaned
- # truncate with dots in the middle using regex
- return re.sub(r"(^.{63}).*(.{63}$)", r"\g<1>..\g<2>", cleaned)
- def make_docker_image_name_safe(name: str) -> str:
- """Make a docker image name safe for use in artifacts."""
- safe_chars = RE_DOCKER_IMAGE_NAME_CHARS.sub("__", name.lower())
- deduped = RE_DOCKER_IMAGE_NAME_SEPARATOR_REPEAT.sub("__", safe_chars)
- trimmed_start = RE_DOCKER_IMAGE_NAME_SEPARATOR_START.sub("", deduped)
- trimmed = RE_DOCKER_IMAGE_NAME_SEPARATOR_END.sub("", trimmed_start)
- return trimmed if trimmed else "image"
- def merge_dicts(
- source: dict[str, Any],
- destination: dict[str, Any],
- ) -> dict[str, Any]:
- """Recursively merge two dictionaries.
- This mutates the destination and its nested dictionaries and lists.
- Instances of `dict` are recursively merged and instances of `list`
- are appended to the destination. If the destination type is not
- `dict` or `list`, respectively, the key is overwritten with the
- source value.
- For all other types, the source value overwrites the destination value.
- """
- for key, value in source.items():
- if isinstance(value, dict):
- node = destination.get(key)
- if isinstance(node, dict):
- merge_dicts(value, node)
- else:
- destination[key] = value
- elif isinstance(value, list):
- dest_value = destination.get(key)
- if isinstance(dest_value, list):
- dest_value.extend(value)
- else:
- destination[key] = value
- else:
- destination[key] = value
- return destination
- def coalesce(*arg: Any) -> Any:
- """Return the first non-none value in the list of arguments.
- Similar to ?? in C#.
- """
- return next((a for a in arg if a is not None), None)
- def recursive_cast_dictlike_to_dict(d: dict[str, Any]) -> dict[str, Any]:
- for k, v in d.items():
- if isinstance(v, dict):
- recursive_cast_dictlike_to_dict(v)
- elif hasattr(v, "keys"):
- d[k] = dict(v)
- recursive_cast_dictlike_to_dict(d[k])
- return d
- def remove_keys_with_none_values(d: dict[str, Any] | Any) -> dict[str, Any] | Any:
- # otherwise iterrows will create a bunch of ugly charts
- if not isinstance(d, dict):
- return d
- if isinstance(d, dict):
- new_dict = {}
- for k, v in d.items():
- new_v = remove_keys_with_none_values(v)
- if new_v is not None and not (isinstance(new_v, dict) and len(new_v) == 0):
- new_dict[k] = new_v
- return new_dict if new_dict else None
- def batched(n: int, iterable: Iterable[T]) -> Generator[list[T], None, None]:
- i = iter(iterable)
- batch = list(itertools.islice(i, n))
- while batch:
- yield batch
- batch = list(itertools.islice(i, n))
- def random_string(length: int = 12) -> str:
- """Generate a random string of a given length.
- :param length: Length of the string to generate.
- :return: Random string.
- """
- return "".join(
- secrets.choice(string.ascii_lowercase + string.digits) for _ in range(length)
- )
- def sample_with_exponential_decay_weights(
- xs: Iterable | Iterable[Iterable],
- ys: Iterable[Iterable],
- keys: Iterable | None = None,
- sample_size: int = 1500,
- ) -> tuple[list, list, list | None]:
- """Sample from a list of lists with weights that decay exponentially.
- May be used with the wandb.plot.line_series function.
- """
- xs_array = np.array(xs)
- ys_array = np.array(ys)
- keys_array = np.array(keys) if keys else None
- weights = np.exp(-np.arange(len(xs_array)) / len(xs_array))
- weights /= np.sum(weights)
- sampled_indices = np.random.choice(len(xs_array), size=sample_size, p=weights)
- sampled_xs = xs_array[sampled_indices].tolist()
- sampled_ys = ys_array[sampled_indices].tolist()
- sampled_keys = keys_array[sampled_indices].tolist() if keys_array else None
- return sampled_xs, sampled_ys, sampled_keys
- @dataclasses.dataclass(frozen=True)
- class InstalledDistribution:
- """An installed distribution.
- Attributes:
- key: The distribution name as it would be imported.
- version: The distribution's version string.
- """
- key: str
- version: str
- def working_set() -> Iterable[InstalledDistribution]:
- """Return the working set of installed distributions."""
- from importlib.metadata import distributions
- for d in distributions():
- with contextlib.suppress(KeyError, UnicodeDecodeError, TypeError):
- # In some distributions, the "Name" attribute may not be present,
- # or the metadata itself may be None or malformed, which can raise
- # KeyError, UnicodeDecodeError, or TypeError.
- # For additional context, see: https://github.com/python/importlib_metadata/issues/371.
- yield InstalledDistribution(key=d.metadata["Name"], version=d.version)
- def get_core_path() -> str:
- """Returns the path to the wandb-core binary.
- Returns:
- str: The path to the wandb-core package.
- Raises:
- WandbCoreNotAvailableError: If wandb-core was not built for the current system.
- """
- bin_path = pathlib.Path(__file__).parent / "bin" / "wandb-core"
- if not bin_path.exists():
- raise WandbCoreNotAvailableError(
- f"File not found: {bin_path}."
- " Please contact support at support@wandb.com."
- f" Your platform is: {platform.platform()}."
- )
- return str(bin_path)
- def time_string_to_seconds(time_str: str) -> int:
- """Parse a time period string and return seconds.
- Args:
- time_str: Time period string like "10s", "5m", "8h", "8d", "6M", "1y"
- Accepted values are:
- - s (seconds)
- - m (minutes)
- - h (hours)
- - d (days)
- - M (months)
- - y (years)
- Returns:
- Number of seconds in the time period
- Raises:
- ValueError: If the format is invalid
- Examples:
- >>> parse_time_period("10s")
- 10
- >>> parse_time_period("5m")
- 300
- >>> parse_time_period("8d")
- 691200
- >>> parse_time_period("6M")
- 15552000
- >>> parse_time_period("1y")
- 31536000
- """
- import re
- if not time_str:
- return 0
- # Extract number and unit
- match = re.match(r"^(\d+)([smhdMy])$", time_str)
- if not match:
- raise ValueError(
- f"Invalid time period format: {time_str}. "
- "Expected format: <number><unit> where unit is s (seconds), "
- "m (minutes), h (hours), d (days), M (months), or y (years)"
- )
- amount = int(match.group(1))
- unit = match.group(2)
- # Convert to seconds
- conversions = {
- "s": 1, # seconds
- "m": 60, # minutes
- "h": 3600, # hours
- "d": 86400, # days
- "M": 2592000, # months (30 days)
- "y": 31536000, # years (365 days)
- }
- return amount * conversions[unit]
|