from __future__ import annotations import re import sys from contextlib import AbstractContextManager from types import TracebackType from typing import TYPE_CHECKING import wandb from wandb.proto.wandb_telemetry_pb2 import Imports as TelemetryImports from wandb.proto.wandb_telemetry_pb2 import TelemetryRecord # avoid cycle, use string type reference if TYPE_CHECKING: from .. import wandb_run _LABEL_TOKEN: str = "@wandbcode{" class _TelemetryObject: _run: wandb_run.Run | None _obj: TelemetryRecord def __init__( self, run: wandb_run.Run | None = None, obj: TelemetryRecord | None = None, ) -> None: self._run = run or wandb.run self._obj = obj or TelemetryRecord() def __enter__(self) -> TelemetryRecord: return self._obj def __exit__( self, exctype: type[BaseException] | None, excinst: BaseException | None, exctb: TracebackType | None, ) -> None: if not self._run: return self._run._telemetry_callback(self._obj) def context( run: wandb_run.Run | None = None, obj: TelemetryRecord | None = None ) -> AbstractContextManager[TelemetryRecord]: return _TelemetryObject(run=run, obj=obj) MATCH_RE = re.compile(r"(?P[a-zA-Z0-9_-]+)[,}](?P.*)") def _parse_label_lines(lines: list[str]) -> dict[str, str]: seen = False ret = {} for line in lines: idx = line.find(_LABEL_TOKEN) if idx < 0: # Stop parsing on first non token line after match if seen: break continue seen = True label_str = line[idx + len(_LABEL_TOKEN) :] # match identifier (first token without key=value syntax (optional) # Note: Parse is fairly permissive as it does not enforce strict syntax r = MATCH_RE.match(label_str) if r: ret["code"] = r.group("code").replace("-", "_") label_str = r.group("rest") # match rest of tokens on one line tokens = re.findall( r'([a-zA-Z0-9_]+)\s*=\s*("[a-zA-Z0-9_-]*"|[a-zA-Z0-9_-]*)[,}]', label_str ) for k, v in tokens: ret[k] = v.strip('"').replace("-", "_") return ret def list_telemetry_imports(only_imported: bool = False) -> set[str]: import_telemetry_set = { desc.name for desc in TelemetryImports.DESCRIPTOR.fields if desc.type == desc.TYPE_BOOL } if only_imported: imported_modules_set = set(sys.modules) return imported_modules_set.intersection(import_telemetry_set) return import_telemetry_set __all__ = [ "TelemetryImports", "TelemetryRecord", "context", "list_telemetry_imports", ]