validation.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import filecmp
  2. import logging
  3. import os
  4. import requests
  5. import wandb
  6. logger = logging.getLogger(__name__)
  7. logger.setLevel(logging.INFO)
  8. def _compare_artifact_manifests(
  9. src_art: wandb.Artifact, dst_art: wandb.Artifact
  10. ) -> list:
  11. problems = []
  12. if isinstance(dst_art, wandb.CommError):
  13. return ["commError"]
  14. if src_art.digest != dst_art.digest:
  15. problems.append(f"digest mismatch {src_art.digest=}, {dst_art.digest=}")
  16. for name, src_entry in src_art.manifest.entries.items():
  17. dst_entry = dst_art.manifest.entries.get(name)
  18. if dst_entry is None:
  19. problems.append(f"missing manifest entry {name=}, {src_entry=}")
  20. continue
  21. for attr in ["path", "digest", "size"]:
  22. if getattr(src_entry, attr) != getattr(dst_entry, attr):
  23. problems.append(
  24. f"manifest entry mismatch {attr=}, {getattr(src_entry, attr)=}, {getattr(dst_entry, attr)=}"
  25. )
  26. return problems
  27. def _compare_artifact_dirs(src_dir, dst_dir) -> list:
  28. def compare(src_dir, dst_dir):
  29. comparison = filecmp.dircmp(src_dir, dst_dir)
  30. differences = {
  31. "left_only": comparison.left_only,
  32. "right_only": comparison.right_only,
  33. "diff_files": comparison.diff_files,
  34. "subdir_differences": {},
  35. }
  36. # Recursively find differences in subdirectories
  37. for subdir in comparison.subdirs:
  38. subdir_src = os.path.join(src_dir, subdir)
  39. subdir_dst = os.path.join(dst_dir, subdir)
  40. subdir_differences = compare(subdir_src, subdir_dst)
  41. # If there are differences, add them to the result
  42. if subdir_differences and any(subdir_differences.values()):
  43. differences["subdir_differences"][subdir] = subdir_differences
  44. if all(not diff for diff in differences.values()):
  45. return None
  46. return differences
  47. return compare(src_dir, dst_dir)
  48. def _check_entries_are_downloadable(art):
  49. entries = _collect_entries(art)
  50. return all(_check_entry_is_downloable(entry) for entry in entries)
  51. def _collect_entries(art):
  52. has_next_page = True
  53. cursor = None
  54. entries = []
  55. while has_next_page:
  56. attrs = art._fetch_file_urls(cursor)
  57. has_next_page = attrs["pageInfo"]["hasNextPage"]
  58. cursor = attrs["pageInfo"]["endCursor"]
  59. for edge in attrs["edges"]:
  60. name = edge["node"]["name"]
  61. entry = art.get_entry(name)
  62. entry._download_url = edge["node"]["directUrl"]
  63. entries.append(entry)
  64. return entries
  65. def _check_entry_is_downloable(entry):
  66. url = entry._download_url
  67. expected_size = entry.size
  68. try:
  69. resp = requests.head(url, allow_redirects=True)
  70. except Exception:
  71. logger.exception(f"Problem validating {entry=}")
  72. return False
  73. if resp.status_code != 200:
  74. return False
  75. actual_size = resp.headers.get("content-length", -1)
  76. actual_size = int(actual_size)
  77. return expected_size == actual_size