_utils.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. # SPDX-License-Identifier: MIT
  2. from __future__ import annotations
  3. import platform
  4. import sys
  5. from dataclasses import dataclass
  6. from typing import Any
  7. from .exceptions import InvalidHashError, UnsupportedParametersError
  8. from .low_level import Type
  9. NoneType = type(None)
  10. def _check_types(**kw: Any) -> str | None:
  11. """
  12. Check each ``name: (value, types)`` in *kw*.
  13. Returns a human-readable string of all violations or `None``.
  14. """
  15. errors = []
  16. for name, (value, types) in kw.items():
  17. if not isinstance(value, types):
  18. if isinstance(types, tuple):
  19. types = ", or ".join(t.__name__ for t in types)
  20. else:
  21. types = types.__name__
  22. errors.append(
  23. f"'{name}' must be a {types} (got {type(value).__name__})"
  24. )
  25. if errors != []:
  26. return ", ".join(errors) + "."
  27. return None
  28. def _is_wasm() -> bool:
  29. return sys.platform == "emscripten" or platform.machine() in [
  30. "wasm32",
  31. "wasm64",
  32. ]
  33. def _decoded_str_len(length: int) -> int:
  34. """
  35. Compute how long an encoded string of length *l* becomes.
  36. """
  37. rem = length % 4
  38. if rem == 3:
  39. last_group_len = 2
  40. elif rem == 2:
  41. last_group_len = 1
  42. else:
  43. last_group_len = 0
  44. return length // 4 * 3 + last_group_len
  45. @dataclass
  46. class Parameters:
  47. """
  48. Argon2 hash parameters.
  49. See :doc:`parameters` on how to pick them.
  50. Attributes:
  51. type: Hash type.
  52. version: Argon2 version.
  53. salt_len: Length of the salt in bytes.
  54. hash_len: Length of the hash in bytes.
  55. time_cost: Time cost in iterations.
  56. memory_cost: Memory cost in kibibytes.
  57. parallelism: Number of parallel threads.
  58. .. versionadded:: 18.2.0
  59. """
  60. type: Type
  61. version: int
  62. salt_len: int
  63. hash_len: int
  64. time_cost: int
  65. memory_cost: int
  66. parallelism: int
  67. __slots__ = (
  68. "hash_len",
  69. "memory_cost",
  70. "parallelism",
  71. "salt_len",
  72. "time_cost",
  73. "type",
  74. "version",
  75. )
  76. _NAME_TO_TYPE = {"argon2id": Type.ID, "argon2i": Type.I, "argon2d": Type.D}
  77. _REQUIRED_KEYS = sorted(("v", "m", "t", "p"))
  78. def extract_parameters(hash: str) -> Parameters:
  79. """
  80. Extract parameters from an encoded *hash*.
  81. Args:
  82. hash: An encoded Argon2 hash string.
  83. Returns:
  84. The parameters used to create the hash.
  85. .. versionadded:: 18.2.0
  86. """
  87. parts = hash.split("$")
  88. # Backwards compatibility for Argon v1.2 hashes
  89. if len(parts) == 5:
  90. parts.insert(2, "v=18")
  91. if len(parts) != 6:
  92. raise InvalidHashError
  93. if parts[0]:
  94. raise InvalidHashError
  95. try:
  96. type = _NAME_TO_TYPE[parts[1]]
  97. kvs = {
  98. k: int(v)
  99. for k, v in (
  100. s.split("=") for s in [parts[2], *parts[3].split(",")]
  101. )
  102. }
  103. except Exception: # noqa: BLE001
  104. raise InvalidHashError from None
  105. if sorted(kvs.keys()) != _REQUIRED_KEYS:
  106. raise InvalidHashError
  107. return Parameters(
  108. type=type,
  109. salt_len=_decoded_str_len(len(parts[4])),
  110. hash_len=_decoded_str_len(len(parts[5])),
  111. version=kvs["v"],
  112. time_cost=kvs["t"],
  113. memory_cost=kvs["m"],
  114. parallelism=kvs["p"],
  115. )
  116. def validate_params_for_platform(params: Parameters) -> None:
  117. """
  118. Validate *params* against current platform.
  119. Args:
  120. params: Parameters to be validated
  121. Returns:
  122. None
  123. """
  124. if _is_wasm() and params.parallelism != 1:
  125. msg = "In WebAssembly environments `parallelism` must be 1."
  126. raise UnsupportedParametersError(msg)