_typing.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. # Copyright 2026 The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Typing helpers shared across the Transformers library."""
  15. from __future__ import annotations
  16. import logging
  17. from collections.abc import Mapping, MutableMapping
  18. from typing import TYPE_CHECKING, Any, Protocol, TypeAlias
  19. if TYPE_CHECKING:
  20. import torch
  21. from .cache_utils import Cache
  22. # A few helpful type aliases
  23. Level: TypeAlias = int
  24. ExcInfo: TypeAlias = (
  25. None
  26. | bool
  27. | BaseException
  28. | tuple[type[BaseException], BaseException, object] # traceback is `types.TracebackType`, but keep generic here
  29. )
  30. class TransformersLogger(Protocol):
  31. # ---- Core Logger identity / configuration ----
  32. name: str
  33. level: int
  34. parent: logging.Logger | None
  35. propagate: bool
  36. disabled: bool
  37. handlers: list[logging.Handler]
  38. # Exists on Logger; default is True. (Not heavily used, but is part of API.)
  39. raiseExceptions: bool
  40. # ---- Standard methods ----
  41. def setLevel(self, level: Level) -> None: ...
  42. def isEnabledFor(self, level: Level) -> bool: ...
  43. def getEffectiveLevel(self) -> int: ...
  44. def getChild(self, suffix: str) -> logging.Logger: ...
  45. def addHandler(self, hdlr: logging.Handler) -> None: ...
  46. def removeHandler(self, hdlr: logging.Handler) -> None: ...
  47. def hasHandlers(self) -> bool: ...
  48. # ---- Logging calls ----
  49. def debug(self, msg: object, *args: object, **kwargs: object) -> None: ...
  50. def info(self, msg: object, *args: object, **kwargs: object) -> None: ...
  51. def warning(self, msg: object, *args: object, **kwargs: object) -> None: ...
  52. def warn(self, msg: object, *args: object, **kwargs: object) -> None: ...
  53. def error(self, msg: object, *args: object, **kwargs: object) -> None: ...
  54. def exception(self, msg: object, *args: object, exc_info: ExcInfo = True, **kwargs: object) -> None: ...
  55. def critical(self, msg: object, *args: object, **kwargs: object) -> None: ...
  56. def fatal(self, msg: object, *args: object, **kwargs: object) -> None: ...
  57. # The lowest-level primitive
  58. def log(self, level: Level, msg: object, *args: object, **kwargs: object) -> None: ...
  59. # ---- Record-level / formatting ----
  60. def makeRecord(
  61. self,
  62. name: str,
  63. level: Level,
  64. fn: str,
  65. lno: int,
  66. msg: object,
  67. args: tuple[object, ...] | Mapping[str, object],
  68. exc_info: ExcInfo,
  69. func: str | None = None,
  70. extra: Mapping[str, object] | None = None,
  71. sinfo: str | None = None,
  72. ) -> logging.LogRecord: ...
  73. def handle(self, record: logging.LogRecord) -> None: ...
  74. def findCaller(
  75. self,
  76. stack_info: bool = False,
  77. stacklevel: int = 1,
  78. ) -> tuple[str, int, str, str | None]: ...
  79. def callHandlers(self, record: logging.LogRecord) -> None: ...
  80. def getMessage(self) -> str: ... # NOTE: actually on LogRecord; included rarely; safe to omit if you want
  81. def _log(
  82. self,
  83. level: Level,
  84. msg: object,
  85. args: tuple[object, ...] | Mapping[str, object],
  86. exc_info: ExcInfo = None,
  87. extra: Mapping[str, object] | None = None,
  88. stack_info: bool = False,
  89. stacklevel: int = 1,
  90. ) -> None: ...
  91. # ---- Filters ----
  92. def addFilter(self, filt: logging.Filter) -> None: ...
  93. def removeFilter(self, filt: logging.Filter) -> None: ...
  94. @property
  95. def filters(self) -> list[logging.Filter]: ...
  96. def filter(self, record: logging.LogRecord) -> bool: ...
  97. # ---- Convenience helpers ----
  98. def setFormatter(self, fmt: logging.Formatter) -> None: ... # mostly on handlers; present on adapters sometimes
  99. def debugStack(self, msg: object, *args: object, **kwargs: object) -> None: ... # not std; safe no-op if absent
  100. # ---- stdlib dictConfig-friendly / extra storage ----
  101. # Logger has `manager` and can have arbitrary attributes; Protocol can't express arbitrary attrs,
  102. # but we can at least include `__dict__` to make "extra attributes" less painful.
  103. __dict__: MutableMapping[str, Any]
  104. # ---- Transformers logger specific methods ----
  105. def warning_advice(self, msg: object, *args: object, **kwargs: object) -> None: ...
  106. def warning_once(self, msg: object, *args: object, **kwargs: object) -> None: ...
  107. def info_once(self, msg: object, *args: object, **kwargs: object) -> None: ...
  108. class GenerativePreTrainedModel(Protocol):
  109. """Protocol for the model interface that GenerationMixin expects.
  110. GenerationMixin is designed to be mixed into PreTrainedModel subclasses. This Protocol documents the
  111. attributes and methods the mixin relies on from its host class. It is *not* used at runtime — its
  112. purpose is to help the ``ty`` type checker resolve ``self.<attr>`` accesses inside the mixin.
  113. """
  114. config: Any # PretrainedConfig — kept as Any to avoid circular imports
  115. device: torch.device
  116. dtype: torch.dtype
  117. main_input_name: str
  118. base_model_prefix: str
  119. _is_stateful: bool
  120. hf_quantizer: Any
  121. encoder: Any
  122. hf_device_map: dict[str, Any]
  123. _cache: Cache
  124. generation_config: Any # GenerationConfig
  125. def __getattr__(self, name: str) -> Any: ...
  126. def forward(self, *args: Any, **kwargs: Any) -> Any: ...
  127. def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
  128. def can_generate(self) -> bool: ...
  129. def get_encoder(self) -> Any: ...
  130. def get_output_embeddings(self) -> Any: ...
  131. def get_input_embeddings(self) -> Any: ...
  132. def set_output_embeddings(self, value: Any) -> None: ...
  133. def set_input_embeddings(self, value: Any) -> None: ...
  134. def get_compiled_call(self, compile_config: Any) -> Any: ...
  135. def set_experts_implementation(self, *args: Any, **kwargs: Any) -> Any: ...
  136. def _supports_logits_to_keep(self) -> bool: ...
  137. class WhisperGenerationConfigLike(Protocol):
  138. """Protocol for Whisper-specific generation config fields accessed in generation internals."""
  139. no_timestamps_token_id: int