hashutil.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. from __future__ import annotations
  2. import base64
  3. import hashlib
  4. import logging
  5. import mmap
  6. import time
  7. from typing import TYPE_CHECKING
  8. from typing_extensions import TypeAlias
  9. from wandb.sdk.lib.paths import StrPath
  10. if TYPE_CHECKING:
  11. import _hashlib # type: ignore[import-not-found]
  12. logger = logging.getLogger(__name__)
  13. # In the future, consider relying on pydantic to validate these types via e.g.
  14. # - Base64Str: https://docs.pydantic.dev/latest/api/types/#pydantic.types.Base64Str
  15. # - a custom EncodedStr + Encoder impl: https://docs.pydantic.dev/latest/api/types/#pydantic.types.EncodedStr
  16. #
  17. # Note that so long as we continue to support Pydantic v1, the options above will require a compatible shim/backport
  18. # implementation, since those types are not in Pydantic v1.
  19. ETag: TypeAlias = str
  20. HexMD5: TypeAlias = str
  21. B64MD5: TypeAlias = str
  22. def _md5(data: bytes = b"") -> _hashlib.HASH:
  23. """Allow FIPS-compliant md5 hash when supported."""
  24. return hashlib.md5(data, usedforsecurity=False)
  25. def md5_string(string: str) -> B64MD5:
  26. return _b64_from_hasher(_md5(string.encode("utf-8")))
  27. def _b64_from_hasher(hasher: _hashlib.HASH) -> B64MD5:
  28. return B64MD5(base64.b64encode(hasher.digest()).decode("ascii"))
  29. def b64_to_hex_id(string: B64MD5) -> HexMD5:
  30. return HexMD5(base64.standard_b64decode(string).hex())
  31. def hex_to_b64_id(encoded_string: str | bytes) -> B64MD5:
  32. if isinstance(encoded_string, bytes):
  33. encoded_string = encoded_string.decode("utf-8")
  34. as_str = bytes.fromhex(encoded_string)
  35. return B64MD5(base64.standard_b64encode(as_str).decode("utf-8"))
  36. def md5_file_b64(*paths: StrPath) -> B64MD5:
  37. start_time = time.monotonic()
  38. digest = _b64_from_hasher(_md5_file_hasher(*paths))
  39. hash_time_seconds = time.monotonic() - start_time
  40. if hash_time_seconds > 1.0:
  41. logger.debug(
  42. "Computed MD5 hash for file. paths=%s, hashTimeMs=%d",
  43. paths,
  44. int(hash_time_seconds * 1000),
  45. )
  46. return digest
  47. def md5_file_hex(*paths: StrPath) -> HexMD5:
  48. return HexMD5(_md5_file_hasher(*paths).hexdigest())
  49. _KB: int = 1_024
  50. _CHUNKSIZE: int = 128 * _KB
  51. """Chunk size (in bytes) for iteratively reading from file, if needed."""
  52. def _md5_file_hasher(*paths: StrPath) -> _hashlib.HASH:
  53. md5_hash = _md5()
  54. # Note: We use str paths (instead of pathlib.Path objs) for minor perf improvements.
  55. for path in sorted(map(str, paths)):
  56. with open(path, "rb") as f:
  57. try:
  58. with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_READ) as mview:
  59. md5_hash.update(mview)
  60. except OSError:
  61. # This occurs if the mmap-ed file is on a different/mounted filesystem,
  62. # so we'll fall back on a less performant implementation.
  63. # Note: At the time of implementation, the walrus operator `:=`
  64. # is avoided to maintain support for users on python 3.7.
  65. # Consider revisiting once 3.7 support is no longer needed.
  66. chunk = f.read(_CHUNKSIZE)
  67. while chunk:
  68. md5_hash.update(chunk)
  69. chunk = f.read(_CHUNKSIZE)
  70. except ValueError:
  71. # This occurs when mmap-ing an empty file, which can be skipped.
  72. # See: https://github.com/python/cpython/blob/986a4e1b6fcae7fe7a1d0a26aea446107dd58dd2/Modules/mmapmodule.c#L1589
  73. pass
  74. return md5_hash