loading_report.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. import re
  16. import shutil
  17. import sys
  18. from collections import OrderedDict, defaultdict
  19. from dataclasses import dataclass
  20. from typing import Any
  21. _DIGIT_RX = re.compile(r"(?<=\.)(\d+)(?=\.|$)") # numbers between dots or at the end
  22. def _pattern_of(key: str) -> str:
  23. """Replace every dot-delimited integer with '*' to get the structure."""
  24. return _DIGIT_RX.sub("*", key)
  25. def _fmt_indices(values: list[int], cutoff=10) -> str:
  26. """Format a list of ints as single number, {a, ..., b}, or first...last."""
  27. if len(values) == 1:
  28. return str(values[0])
  29. values = sorted(values)
  30. if len(values) > cutoff:
  31. return f"{values[0]}...{values[-1]}"
  32. return ", ".join(map(str, values))
  33. def update_key_name(mapping: dict[str, Any]) -> dict[str, Any]:
  34. """
  35. Merge keys like 'layers.0.x', 'layers.1.x' into 'layers.{0, 1}.x'
  36. BUT only merge together keys that have the exact same value.
  37. Returns a new dict {merged_key: value}.
  38. """
  39. # (pattern, value) -> list[set[int]] (per-star index values)
  40. not_mapping = False
  41. if not isinstance(mapping, dict):
  42. mapping = {k: k for k in mapping}
  43. not_mapping = True
  44. bucket: dict[str, list[set[int] | Any]] = defaultdict(list)
  45. for key, val in mapping.items():
  46. digs = _DIGIT_RX.findall(key)
  47. patt = _pattern_of(key)
  48. for i, d in enumerate(digs):
  49. if len(bucket[patt]) <= i:
  50. bucket[patt].append(set())
  51. bucket[patt][i].add(int(d))
  52. bucket[patt].append(val)
  53. out_items = {}
  54. for patt, values in bucket.items():
  55. sets, val = values[:-1], values[-1]
  56. parts = patt.split("*") # stars are between parts
  57. final = parts[0]
  58. for i in range(1, len(parts)):
  59. if i - 1 < len(sets) and sets[i - 1]:
  60. insert = _fmt_indices(sorted(sets[i - 1]))
  61. if len(sets[i - 1]) > 1:
  62. final += "{" + insert + "}"
  63. else:
  64. final += insert
  65. else:
  66. final += "*"
  67. final += parts[i]
  68. out_items[final] = val
  69. out = OrderedDict(out_items)
  70. if not_mapping:
  71. return out.keys()
  72. return out
  73. _ansi_re = re.compile(r"\x1b\[[0-9;]*m")
  74. def _strip_ansi(s: str) -> str:
  75. return _ansi_re.sub("", str(s))
  76. def _pad(text, width):
  77. t = str(text)
  78. pad = max(0, width - len(_strip_ansi(t)))
  79. return t + " " * pad
  80. def _make_table(rows, headers):
  81. # compute display widths while ignoring ANSI codes
  82. cols = list(zip(*([headers] + rows))) if rows else [headers]
  83. widths = [max(len(_strip_ansi(x)) for x in col) for col in cols]
  84. header_line = " | ".join(_pad(h, w) for h, w in zip(headers, widths))
  85. sep_line = "-+-".join("-" * w for w in widths)
  86. body = [" | ".join(_pad(c, w) for c, w in zip(r, widths)) for r in rows]
  87. return "\n".join([header_line, sep_line] + body)
  88. PALETTE = {
  89. "reset": "",
  90. "red": "",
  91. "yellow": "",
  92. "orange": "",
  93. "purple": "",
  94. "bold": "",
  95. "italic": "",
  96. "dim": "",
  97. }
  98. def _style(s, color):
  99. """Return color/style-formatted input `s` if `sys.stdout` is interactive, e.g. connected to a terminal."""
  100. if sys.stdout.isatty():
  101. return f"{PALETTE[color]}{s}{PALETTE['reset']}"
  102. else:
  103. return s
  104. def _get_terminal_width(default=80):
  105. try:
  106. return shutil.get_terminal_size().columns
  107. except Exception:
  108. return default
  109. @dataclass
  110. class LoadStateDictInfo:
  111. """
  112. Mutable container for state-dict loading results and diagnostics. Each entry in this structure is mutable,
  113. and will usually be mutated in-place during the loading pipeline.
  114. Attributes:
  115. missing_keys (`set[str]`):
  116. Keys that are missing from the loaded checkpoints but expected in the model's architecture.
  117. unexpected_keys (`set[str]`):
  118. Keys that are found in the checkpoints, but not expected in the model's architecture.
  119. mismatched_keys (`set[tuple[str, tuple[int], tuple[int]]]`):
  120. Keys that are found in the checkpoints and are expected in the model's architecture, but with a different shape.
  121. error_msgs ( `list[str]`):
  122. Some potential error messages.
  123. conversion_errors (`dict[str, str]`):
  124. Errors happening during the on-the-fly weight conversion process.
  125. """
  126. missing_keys: set[str]
  127. unexpected_keys: set[str]
  128. mismatched_keys: set[tuple[str, tuple[int], tuple[int]]]
  129. error_msgs: list[str]
  130. conversion_errors: dict[str, str]
  131. def missing_and_mismatched(self):
  132. """Return all effective missing keys, including `missing` and `mismatched` keys."""
  133. return self.missing_keys | {k[0] for k in self.mismatched_keys}
  134. def to_dict(self):
  135. # Does not include the `conversion_errors` to be coherent with legacy reporting in the tests
  136. return {
  137. "missing_keys": self.missing_keys,
  138. "unexpected_keys": self.unexpected_keys,
  139. "mismatched_keys": self.mismatched_keys,
  140. "error_msgs": self.error_msgs,
  141. }
  142. def create_loading_report(self) -> str | None:
  143. """Generate the minimal table of a loading report."""
  144. term_w = _get_terminal_width()
  145. rows = []
  146. tips = "\n\nNotes:"
  147. if self.unexpected_keys:
  148. tips += f"\n- {_style('UNEXPECTED:', 'orange')}\t" + _style(
  149. "can be ignored when loading from different task/architecture; not ok if you expect identical arch.",
  150. "italic",
  151. )
  152. for k in update_key_name(self.unexpected_keys):
  153. status = _style("UNEXPECTED", "orange")
  154. rows.append([k, status, "", ""])
  155. if self.missing_keys:
  156. tips += f"\n- {_style('MISSING:', 'red')}\t" + _style(
  157. "those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.",
  158. "italic",
  159. )
  160. for k in update_key_name(self.missing_keys):
  161. status = _style("MISSING", "red")
  162. rows.append([k, status, ""])
  163. if self.mismatched_keys:
  164. tips += f"\n- {_style('MISMATCH:', 'yellow')}\t" + _style(
  165. "ckpt weights were loaded, but they did not match the original empty weight shapes.", "italic"
  166. )
  167. iterator = {a: (b, c) for a, b, c in self.mismatched_keys}
  168. for key, (shape_ckpt, shape_model) in update_key_name(iterator).items():
  169. status = _style("MISMATCH", "yellow")
  170. data = [
  171. key,
  172. status,
  173. f"Reinit due to size mismatch - ckpt: {str(shape_ckpt)} vs model:{str(shape_model)}",
  174. ]
  175. rows.append(data)
  176. if self.conversion_errors:
  177. tips += f"\n- {_style('CONVERSION:', 'purple')}\t" + _style(
  178. "originate from the conversion scheme", "italic"
  179. )
  180. for k, v in update_key_name(self.conversion_errors).items():
  181. status = _style("CONVERSION", "purple")
  182. _details = f"\n\n{v}\n\n"
  183. rows.append([k, status, _details])
  184. # If nothing is wrong, return None
  185. if len(rows) == 0:
  186. return None
  187. headers = ["Key", "Status"]
  188. if term_w > 200:
  189. headers += ["Details"]
  190. else:
  191. headers += ["", ""]
  192. table = _make_table(rows, headers=headers)
  193. report = table + tips
  194. return report
  195. def log_state_dict_report(
  196. model,
  197. pretrained_model_name_or_path: str,
  198. ignore_mismatched_sizes: bool,
  199. loading_info: LoadStateDictInfo,
  200. logger: logging.Logger | None = None,
  201. ):
  202. """
  203. Log a readable report about state_dict loading issues.
  204. This version is terminal-size aware: for very small terminals it falls back to a compact
  205. Key | Status view so output doesn't wrap badly.
  206. """
  207. if logger is None:
  208. logger = logging.getLogger(__name__)
  209. # Re-raise errors early if needed
  210. if loading_info.error_msgs:
  211. error_msg = "\n\t".join(loading_info.error_msgs)
  212. if "size mismatch" in error_msg:
  213. error_msg += (
  214. "\n\tYou may consider adding `ignore_mismatched_sizes=True` to `from_pretrained(...)` if appropriate."
  215. )
  216. raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
  217. # Create the report table
  218. report = loading_info.create_loading_report()
  219. if report is None:
  220. return
  221. prelude = f"{PALETTE['bold']}{model.__class__.__name__} LOAD REPORT{PALETTE['reset']} from: {pretrained_model_name_or_path}\n"
  222. # Log the report as warning
  223. logger.warning(prelude + report)
  224. # Re-raise in those case, after the report
  225. if loading_info.conversion_errors:
  226. raise RuntimeError(
  227. "We encountered some issues during automatic conversion of the weights. For details look at the `CONVERSION` entries of "
  228. "the above report!"
  229. )
  230. if not ignore_mismatched_sizes and loading_info.mismatched_keys:
  231. raise RuntimeError(
  232. "You set `ignore_mismatched_sizes` to `False`, thus raising an error. For details look at the above report!"
  233. )