filenames.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. from __future__ import annotations
  2. import os
  3. from collections.abc import Generator
  4. from typing import Callable
  5. WANDB_DIRS = ("wandb", ".wandb")
  6. CONFIG_FNAME = "config.yaml"
  7. OUTPUT_FNAME = "output.log"
  8. DIFF_FNAME = "diff.patch"
  9. SUMMARY_FNAME = "wandb-summary.json"
  10. METADATA_FNAME = "wandb-metadata.json"
  11. REQUIREMENTS_FNAME = "requirements.txt"
  12. HISTORY_FNAME = "wandb-history.jsonl"
  13. EVENTS_FNAME = "wandb-events.jsonl"
  14. JOBSPEC_FNAME = "wandb-jobspec.json"
  15. CONDA_ENVIRONMENTS_FNAME = "conda-environment.yaml"
  16. def is_wandb_file(name: str) -> bool:
  17. return name in (
  18. METADATA_FNAME,
  19. CONFIG_FNAME,
  20. REQUIREMENTS_FNAME,
  21. OUTPUT_FNAME,
  22. DIFF_FNAME,
  23. CONDA_ENVIRONMENTS_FNAME,
  24. ) or name.startswith("wandb")
  25. def filtered_dir(
  26. root: str,
  27. include_fn: Callable[[str, str], bool] | Callable[[str], bool],
  28. exclude_fn: Callable[[str, str], bool] | Callable[[str], bool],
  29. ) -> Generator[str, None, None]:
  30. """Simple generator to walk a directory."""
  31. import inspect
  32. # compatibility with old API, which didn't pass root
  33. def _include_fn(path: str, root: str) -> bool:
  34. return (
  35. include_fn(path, root) # type: ignore
  36. if len(inspect.signature(include_fn).parameters) == 2
  37. else include_fn(path) # type: ignore
  38. )
  39. def _exclude_fn(path: str, root: str) -> bool:
  40. return (
  41. exclude_fn(path, root) # type: ignore
  42. if len(inspect.signature(exclude_fn).parameters) == 2
  43. else exclude_fn(path) # type: ignore
  44. )
  45. for dirpath, _, files in os.walk(root):
  46. for fname in files:
  47. file_path = os.path.join(dirpath, fname)
  48. if _include_fn(file_path, root) and not _exclude_fn(file_path, root):
  49. yield file_path
  50. def exclude_wandb_fn(path: str, root: str) -> bool:
  51. return any(
  52. os.path.relpath(path, root).startswith(wandb_dir + os.sep)
  53. for wandb_dir in WANDB_DIRS
  54. )