from __future__ import annotations import base64 import hashlib import logging import mmap import time from typing import TYPE_CHECKING from typing_extensions import TypeAlias from wandb.sdk.lib.paths import StrPath if TYPE_CHECKING: import _hashlib # type: ignore[import-not-found] logger = logging.getLogger(__name__) # In the future, consider relying on pydantic to validate these types via e.g. # - Base64Str: https://docs.pydantic.dev/latest/api/types/#pydantic.types.Base64Str # - a custom EncodedStr + Encoder impl: https://docs.pydantic.dev/latest/api/types/#pydantic.types.EncodedStr # # Note that so long as we continue to support Pydantic v1, the options above will require a compatible shim/backport # implementation, since those types are not in Pydantic v1. ETag: TypeAlias = str HexMD5: TypeAlias = str B64MD5: TypeAlias = str def _md5(data: bytes = b"") -> _hashlib.HASH: """Allow FIPS-compliant md5 hash when supported.""" return hashlib.md5(data, usedforsecurity=False) def md5_string(string: str) -> B64MD5: return _b64_from_hasher(_md5(string.encode("utf-8"))) def _b64_from_hasher(hasher: _hashlib.HASH) -> B64MD5: return B64MD5(base64.b64encode(hasher.digest()).decode("ascii")) def b64_to_hex_id(string: B64MD5) -> HexMD5: return HexMD5(base64.standard_b64decode(string).hex()) def hex_to_b64_id(encoded_string: str | bytes) -> B64MD5: if isinstance(encoded_string, bytes): encoded_string = encoded_string.decode("utf-8") as_str = bytes.fromhex(encoded_string) return B64MD5(base64.standard_b64encode(as_str).decode("utf-8")) def md5_file_b64(*paths: StrPath) -> B64MD5: start_time = time.monotonic() digest = _b64_from_hasher(_md5_file_hasher(*paths)) hash_time_seconds = time.monotonic() - start_time if hash_time_seconds > 1.0: logger.debug( "Computed MD5 hash for file. paths=%s, hashTimeMs=%d", paths, int(hash_time_seconds * 1000), ) return digest def md5_file_hex(*paths: StrPath) -> HexMD5: return HexMD5(_md5_file_hasher(*paths).hexdigest()) _KB: int = 1_024 _CHUNKSIZE: int = 128 * _KB """Chunk size (in bytes) for iteratively reading from file, if needed.""" def _md5_file_hasher(*paths: StrPath) -> _hashlib.HASH: md5_hash = _md5() # Note: We use str paths (instead of pathlib.Path objs) for minor perf improvements. for path in sorted(map(str, paths)): with open(path, "rb") as f: try: with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_READ) as mview: md5_hash.update(mview) except OSError: # This occurs if the mmap-ed file is on a different/mounted filesystem, # so we'll fall back on a less performant implementation. # Note: At the time of implementation, the walrus operator `:=` # is avoided to maintain support for users on python 3.7. # Consider revisiting once 3.7 support is no longer needed. chunk = f.read(_CHUNKSIZE) while chunk: md5_hash.update(chunk) chunk = f.read(_CHUNKSIZE) except ValueError: # This occurs when mmap-ing an empty file, which can be skipped. # See: https://github.com/python/cpython/blob/986a4e1b6fcae7fe7a1d0a26aea446107dd58dd2/Modules/mmapmodule.c#L1589 pass return md5_hash