telemetry.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. from __future__ import annotations
  2. import re
  3. import sys
  4. from contextlib import AbstractContextManager
  5. from types import TracebackType
  6. from typing import TYPE_CHECKING
  7. import wandb
  8. from wandb.proto.wandb_telemetry_pb2 import Imports as TelemetryImports
  9. from wandb.proto.wandb_telemetry_pb2 import TelemetryRecord
  10. # avoid cycle, use string type reference
  11. if TYPE_CHECKING:
  12. from .. import wandb_run
  13. _LABEL_TOKEN: str = "@wandbcode{"
  14. class _TelemetryObject:
  15. _run: wandb_run.Run | None
  16. _obj: TelemetryRecord
  17. def __init__(
  18. self,
  19. run: wandb_run.Run | None = None,
  20. obj: TelemetryRecord | None = None,
  21. ) -> None:
  22. self._run = run or wandb.run
  23. self._obj = obj or TelemetryRecord()
  24. def __enter__(self) -> TelemetryRecord:
  25. return self._obj
  26. def __exit__(
  27. self,
  28. exctype: type[BaseException] | None,
  29. excinst: BaseException | None,
  30. exctb: TracebackType | None,
  31. ) -> None:
  32. if not self._run:
  33. return
  34. self._run._telemetry_callback(self._obj)
  35. def context(
  36. run: wandb_run.Run | None = None, obj: TelemetryRecord | None = None
  37. ) -> AbstractContextManager[TelemetryRecord]:
  38. return _TelemetryObject(run=run, obj=obj)
  39. MATCH_RE = re.compile(r"(?P<code>[a-zA-Z0-9_-]+)[,}](?P<rest>.*)")
  40. def _parse_label_lines(lines: list[str]) -> dict[str, str]:
  41. seen = False
  42. ret = {}
  43. for line in lines:
  44. idx = line.find(_LABEL_TOKEN)
  45. if idx < 0:
  46. # Stop parsing on first non token line after match
  47. if seen:
  48. break
  49. continue
  50. seen = True
  51. label_str = line[idx + len(_LABEL_TOKEN) :]
  52. # match identifier (first token without key=value syntax (optional)
  53. # Note: Parse is fairly permissive as it does not enforce strict syntax
  54. r = MATCH_RE.match(label_str)
  55. if r:
  56. ret["code"] = r.group("code").replace("-", "_")
  57. label_str = r.group("rest")
  58. # match rest of tokens on one line
  59. tokens = re.findall(
  60. r'([a-zA-Z0-9_]+)\s*=\s*("[a-zA-Z0-9_-]*"|[a-zA-Z0-9_-]*)[,}]', label_str
  61. )
  62. for k, v in tokens:
  63. ret[k] = v.strip('"').replace("-", "_")
  64. return ret
  65. def list_telemetry_imports(only_imported: bool = False) -> set[str]:
  66. import_telemetry_set = {
  67. desc.name
  68. for desc in TelemetryImports.DESCRIPTOR.fields
  69. if desc.type == desc.TYPE_BOOL
  70. }
  71. if only_imported:
  72. imported_modules_set = set(sys.modules)
  73. return imported_modules_set.intersection(import_telemetry_set)
  74. return import_telemetry_set
  75. __all__ = [
  76. "TelemetryImports",
  77. "TelemetryRecord",
  78. "context",
  79. "list_telemetry_imports",
  80. ]