| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280 |
- # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import logging
- import re
- import shutil
- import sys
- from collections import OrderedDict, defaultdict
- from dataclasses import dataclass
- from typing import Any
- _DIGIT_RX = re.compile(r"(?<=\.)(\d+)(?=\.|$)") # numbers between dots or at the end
- def _pattern_of(key: str) -> str:
- """Replace every dot-delimited integer with '*' to get the structure."""
- return _DIGIT_RX.sub("*", key)
- def _fmt_indices(values: list[int], cutoff=10) -> str:
- """Format a list of ints as single number, {a, ..., b}, or first...last."""
- if len(values) == 1:
- return str(values[0])
- values = sorted(values)
- if len(values) > cutoff:
- return f"{values[0]}...{values[-1]}"
- return ", ".join(map(str, values))
- def update_key_name(mapping: dict[str, Any]) -> dict[str, Any]:
- """
- Merge keys like 'layers.0.x', 'layers.1.x' into 'layers.{0, 1}.x'
- BUT only merge together keys that have the exact same value.
- Returns a new dict {merged_key: value}.
- """
- # (pattern, value) -> list[set[int]] (per-star index values)
- not_mapping = False
- if not isinstance(mapping, dict):
- mapping = {k: k for k in mapping}
- not_mapping = True
- bucket: dict[str, list[set[int] | Any]] = defaultdict(list)
- for key, val in mapping.items():
- digs = _DIGIT_RX.findall(key)
- patt = _pattern_of(key)
- for i, d in enumerate(digs):
- if len(bucket[patt]) <= i:
- bucket[patt].append(set())
- bucket[patt][i].add(int(d))
- bucket[patt].append(val)
- out_items = {}
- for patt, values in bucket.items():
- sets, val = values[:-1], values[-1]
- parts = patt.split("*") # stars are between parts
- final = parts[0]
- for i in range(1, len(parts)):
- if i - 1 < len(sets) and sets[i - 1]:
- insert = _fmt_indices(sorted(sets[i - 1]))
- if len(sets[i - 1]) > 1:
- final += "{" + insert + "}"
- else:
- final += insert
- else:
- final += "*"
- final += parts[i]
- out_items[final] = val
- out = OrderedDict(out_items)
- if not_mapping:
- return out.keys()
- return out
- _ansi_re = re.compile(r"\x1b\[[0-9;]*m")
- def _strip_ansi(s: str) -> str:
- return _ansi_re.sub("", str(s))
- def _pad(text, width):
- t = str(text)
- pad = max(0, width - len(_strip_ansi(t)))
- return t + " " * pad
- def _make_table(rows, headers):
- # compute display widths while ignoring ANSI codes
- cols = list(zip(*([headers] + rows))) if rows else [headers]
- widths = [max(len(_strip_ansi(x)) for x in col) for col in cols]
- header_line = " | ".join(_pad(h, w) for h, w in zip(headers, widths))
- sep_line = "-+-".join("-" * w for w in widths)
- body = [" | ".join(_pad(c, w) for c, w in zip(r, widths)) for r in rows]
- return "\n".join([header_line, sep_line] + body)
- PALETTE = {
- "reset": "[0m",
- "red": "[31m",
- "yellow": "[33m",
- "orange": "[38;5;208m",
- "purple": "[35m",
- "bold": "[1m",
- "italic": "[3m",
- "dim": "[2m",
- }
- def _style(s, color):
- """Return color/style-formatted input `s` if `sys.stdout` is interactive, e.g. connected to a terminal."""
- if sys.stdout.isatty():
- return f"{PALETTE[color]}{s}{PALETTE['reset']}"
- else:
- return s
- def _get_terminal_width(default=80):
- try:
- return shutil.get_terminal_size().columns
- except Exception:
- return default
- @dataclass
- class LoadStateDictInfo:
- """
- Mutable container for state-dict loading results and diagnostics. Each entry in this structure is mutable,
- and will usually be mutated in-place during the loading pipeline.
- Attributes:
- missing_keys (`set[str]`):
- Keys that are missing from the loaded checkpoints but expected in the model's architecture.
- unexpected_keys (`set[str]`):
- Keys that are found in the checkpoints, but not expected in the model's architecture.
- mismatched_keys (`set[tuple[str, tuple[int], tuple[int]]]`):
- Keys that are found in the checkpoints and are expected in the model's architecture, but with a different shape.
- error_msgs ( `list[str]`):
- Some potential error messages.
- conversion_errors (`dict[str, str]`):
- Errors happening during the on-the-fly weight conversion process.
- """
- missing_keys: set[str]
- unexpected_keys: set[str]
- mismatched_keys: set[tuple[str, tuple[int], tuple[int]]]
- error_msgs: list[str]
- conversion_errors: dict[str, str]
- def missing_and_mismatched(self):
- """Return all effective missing keys, including `missing` and `mismatched` keys."""
- return self.missing_keys | {k[0] for k in self.mismatched_keys}
- def to_dict(self):
- # Does not include the `conversion_errors` to be coherent with legacy reporting in the tests
- return {
- "missing_keys": self.missing_keys,
- "unexpected_keys": self.unexpected_keys,
- "mismatched_keys": self.mismatched_keys,
- "error_msgs": self.error_msgs,
- }
- def create_loading_report(self) -> str | None:
- """Generate the minimal table of a loading report."""
- term_w = _get_terminal_width()
- rows = []
- tips = "\n\nNotes:"
- if self.unexpected_keys:
- tips += f"\n- {_style('UNEXPECTED:', 'orange')}\t" + _style(
- "can be ignored when loading from different task/architecture; not ok if you expect identical arch.",
- "italic",
- )
- for k in update_key_name(self.unexpected_keys):
- status = _style("UNEXPECTED", "orange")
- rows.append([k, status, "", ""])
- if self.missing_keys:
- tips += f"\n- {_style('MISSING:', 'red')}\t" + _style(
- "those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.",
- "italic",
- )
- for k in update_key_name(self.missing_keys):
- status = _style("MISSING", "red")
- rows.append([k, status, ""])
- if self.mismatched_keys:
- tips += f"\n- {_style('MISMATCH:', 'yellow')}\t" + _style(
- "ckpt weights were loaded, but they did not match the original empty weight shapes.", "italic"
- )
- iterator = {a: (b, c) for a, b, c in self.mismatched_keys}
- for key, (shape_ckpt, shape_model) in update_key_name(iterator).items():
- status = _style("MISMATCH", "yellow")
- data = [
- key,
- status,
- f"Reinit due to size mismatch - ckpt: {str(shape_ckpt)} vs model:{str(shape_model)}",
- ]
- rows.append(data)
- if self.conversion_errors:
- tips += f"\n- {_style('CONVERSION:', 'purple')}\t" + _style(
- "originate from the conversion scheme", "italic"
- )
- for k, v in update_key_name(self.conversion_errors).items():
- status = _style("CONVERSION", "purple")
- _details = f"\n\n{v}\n\n"
- rows.append([k, status, _details])
- # If nothing is wrong, return None
- if len(rows) == 0:
- return None
- headers = ["Key", "Status"]
- if term_w > 200:
- headers += ["Details"]
- else:
- headers += ["", ""]
- table = _make_table(rows, headers=headers)
- report = table + tips
- return report
- def log_state_dict_report(
- model,
- pretrained_model_name_or_path: str,
- ignore_mismatched_sizes: bool,
- loading_info: LoadStateDictInfo,
- logger: logging.Logger | None = None,
- ):
- """
- Log a readable report about state_dict loading issues.
- This version is terminal-size aware: for very small terminals it falls back to a compact
- Key | Status view so output doesn't wrap badly.
- """
- if logger is None:
- logger = logging.getLogger(__name__)
- # Re-raise errors early if needed
- if loading_info.error_msgs:
- error_msg = "\n\t".join(loading_info.error_msgs)
- if "size mismatch" in error_msg:
- error_msg += (
- "\n\tYou may consider adding `ignore_mismatched_sizes=True` to `from_pretrained(...)` if appropriate."
- )
- raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
- # Create the report table
- report = loading_info.create_loading_report()
- if report is None:
- return
- prelude = f"{PALETTE['bold']}{model.__class__.__name__} LOAD REPORT{PALETTE['reset']} from: {pretrained_model_name_or_path}\n"
- # Log the report as warning
- logger.warning(prelude + report)
- # Re-raise in those case, after the report
- if loading_info.conversion_errors:
- raise RuntimeError(
- "We encountered some issues during automatic conversion of the weights. For details look at the `CONVERSION` entries of "
- "the above report!"
- )
- if not ignore_mismatched_sizes and loading_info.mismatched_keys:
- raise RuntimeError(
- "You set `ignore_mismatched_sizes` to `False`, thus raising an error. For details look at the above report!"
- )
|