| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102 |
- from __future__ import annotations
- import base64
- import hashlib
- import logging
- import mmap
- import time
- from typing import TYPE_CHECKING
- from typing_extensions import TypeAlias
- from wandb.sdk.lib.paths import StrPath
- if TYPE_CHECKING:
- import _hashlib # type: ignore[import-not-found]
- logger = logging.getLogger(__name__)
- # In the future, consider relying on pydantic to validate these types via e.g.
- # - Base64Str: https://docs.pydantic.dev/latest/api/types/#pydantic.types.Base64Str
- # - a custom EncodedStr + Encoder impl: https://docs.pydantic.dev/latest/api/types/#pydantic.types.EncodedStr
- #
- # Note that so long as we continue to support Pydantic v1, the options above will require a compatible shim/backport
- # implementation, since those types are not in Pydantic v1.
- ETag: TypeAlias = str
- HexMD5: TypeAlias = str
- B64MD5: TypeAlias = str
- def _md5(data: bytes = b"") -> _hashlib.HASH:
- """Allow FIPS-compliant md5 hash when supported."""
- return hashlib.md5(data, usedforsecurity=False)
- def md5_string(string: str) -> B64MD5:
- return _b64_from_hasher(_md5(string.encode("utf-8")))
- def _b64_from_hasher(hasher: _hashlib.HASH) -> B64MD5:
- return B64MD5(base64.b64encode(hasher.digest()).decode("ascii"))
- def b64_to_hex_id(string: B64MD5) -> HexMD5:
- return HexMD5(base64.standard_b64decode(string).hex())
- def hex_to_b64_id(encoded_string: str | bytes) -> B64MD5:
- if isinstance(encoded_string, bytes):
- encoded_string = encoded_string.decode("utf-8")
- as_str = bytes.fromhex(encoded_string)
- return B64MD5(base64.standard_b64encode(as_str).decode("utf-8"))
- def md5_file_b64(*paths: StrPath) -> B64MD5:
- start_time = time.monotonic()
- digest = _b64_from_hasher(_md5_file_hasher(*paths))
- hash_time_seconds = time.monotonic() - start_time
- if hash_time_seconds > 1.0:
- logger.debug(
- "Computed MD5 hash for file. paths=%s, hashTimeMs=%d",
- paths,
- int(hash_time_seconds * 1000),
- )
- return digest
- def md5_file_hex(*paths: StrPath) -> HexMD5:
- return HexMD5(_md5_file_hasher(*paths).hexdigest())
- _KB: int = 1_024
- _CHUNKSIZE: int = 128 * _KB
- """Chunk size (in bytes) for iteratively reading from file, if needed."""
- def _md5_file_hasher(*paths: StrPath) -> _hashlib.HASH:
- md5_hash = _md5()
- # Note: We use str paths (instead of pathlib.Path objs) for minor perf improvements.
- for path in sorted(map(str, paths)):
- with open(path, "rb") as f:
- try:
- with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_READ) as mview:
- md5_hash.update(mview)
- except OSError:
- # This occurs if the mmap-ed file is on a different/mounted filesystem,
- # so we'll fall back on a less performant implementation.
- # Note: At the time of implementation, the walrus operator `:=`
- # is avoided to maintain support for users on python 3.7.
- # Consider revisiting once 3.7 support is no longer needed.
- chunk = f.read(_CHUNKSIZE)
- while chunk:
- md5_hash.update(chunk)
- chunk = f.read(_CHUNKSIZE)
- except ValueError:
- # This occurs when mmap-ing an empty file, which can be skipped.
- # See: https://github.com/python/cpython/blob/986a4e1b6fcae7fe7a1d0a26aea446107dd58dd2/Modules/mmapmodule.c#L1589
- pass
- return md5_hash
|