modeling_layers.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. # Copyright 2025 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from functools import partial
  15. import torch
  16. import torch.nn as nn
  17. from .cache_utils import Cache
  18. from .modeling_outputs import (
  19. BaseModelOutputWithPast,
  20. QuestionAnsweringModelOutput,
  21. SequenceClassifierOutputWithPast,
  22. TokenClassifierOutput,
  23. )
  24. from .models.auto import AutoModel
  25. from .processing_utils import Unpack
  26. from .utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  27. logger = logging.get_logger(__name__)
  28. class GradientCheckpointingLayer(nn.Module):
  29. """Base class for layers with gradient checkpointing.
  30. This class enables gradient checkpointing functionality for a layer. By default, gradient checkpointing is disabled
  31. (`gradient_checkpointing = False`). When `model.set_gradient_checkpointing()` is called, gradient checkpointing is
  32. enabled by setting `gradient_checkpointing = True` and assigning a checkpointing function to `_gradient_checkpointing_func`.
  33. Important:
  34. When using gradient checkpointing with `use_reentrant=True`, inputs that require gradients (e.g. hidden states)
  35. must be passed as positional arguments (`*args`) rather than keyword arguments to properly propagate gradients.
  36. Example:
  37. ```python
  38. >>> # Correct - hidden_states passed as positional arg
  39. >>> out = self.layer(hidden_states, attention_mask=attention_mask)
  40. >>> # Incorrect - hidden_states passed as keyword arg
  41. >>> out = self.layer(hidden_states=hidden_states, attention_mask=attention_mask)
  42. ```
  43. """
  44. gradient_checkpointing = False
  45. def __call__(self, *args, **kwargs):
  46. if self.gradient_checkpointing and self.training:
  47. do_warn = False
  48. layer_name = self.__class__.__name__
  49. message = f"Caching is incompatible with gradient checkpointing in {layer_name}. Setting"
  50. if "use_cache" in kwargs and kwargs["use_cache"]:
  51. kwargs["use_cache"] = False
  52. message += " `use_cache=False`,"
  53. do_warn = True
  54. # different names for the same thing in different layers
  55. # TODO cyril: this one without `S` can be removed after deprecation cycle
  56. if "past_key_value" in kwargs and kwargs["past_key_value"] is not None:
  57. kwargs["past_key_value"] = None
  58. message += " `past_key_value=None`,"
  59. do_warn = True
  60. if "past_key_values" in kwargs and kwargs["past_key_values"] is not None:
  61. kwargs["past_key_values"] = None
  62. message += " `past_key_values=None`,"
  63. do_warn = True
  64. if "layer_past" in kwargs and kwargs["layer_past"] is not None:
  65. kwargs["layer_past"] = None
  66. message += " `layer_past=None`,"
  67. do_warn = True
  68. # warn if anything was changed
  69. if do_warn:
  70. message = message.rstrip(",") + "."
  71. logger.warning_once(message)
  72. return self._gradient_checkpointing_func(partial(super().__call__, **kwargs), *args)
  73. return super().__call__(*args, **kwargs)
  74. @auto_docstring
  75. class GenericForSequenceClassification:
  76. base_model_prefix = "model"
  77. def __init__(self, config):
  78. super().__init__(config)
  79. self.num_labels = config.num_labels
  80. # Similar to `self.model = AutoModel.from_config(config)` but allows to change the base model name if needed in the child class
  81. setattr(self, self.base_model_prefix, AutoModel.from_config(config))
  82. self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
  83. # Initialize weights and apply final processing
  84. self.post_init()
  85. @can_return_tuple
  86. @auto_docstring
  87. def forward(
  88. self,
  89. input_ids: torch.LongTensor | None = None,
  90. attention_mask: torch.Tensor | None = None,
  91. position_ids: torch.LongTensor | None = None,
  92. past_key_values: Cache | None = None,
  93. inputs_embeds: torch.FloatTensor | None = None,
  94. labels: torch.LongTensor | None = None,
  95. use_cache: bool | None = None,
  96. **kwargs: Unpack[TransformersKwargs],
  97. ) -> SequenceClassifierOutputWithPast:
  98. transformer_outputs: BaseModelOutputWithPast = getattr(self, self.base_model_prefix)(
  99. input_ids,
  100. attention_mask=attention_mask,
  101. position_ids=position_ids,
  102. past_key_values=past_key_values,
  103. inputs_embeds=inputs_embeds,
  104. use_cache=use_cache,
  105. **kwargs,
  106. )
  107. hidden_states = transformer_outputs.last_hidden_state
  108. logits = self.score(hidden_states)
  109. if input_ids is not None:
  110. batch_size = input_ids.shape[0]
  111. else:
  112. batch_size = inputs_embeds.shape[0]
  113. if self.config.pad_token_id is None and batch_size != 1:
  114. raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
  115. if self.config.pad_token_id is None:
  116. last_non_pad_token = -1
  117. elif input_ids is not None:
  118. # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
  119. non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
  120. token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
  121. last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
  122. else:
  123. last_non_pad_token = -1
  124. logger.warning_once(
  125. f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
  126. "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
  127. )
  128. pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
  129. loss = None
  130. if labels is not None:
  131. loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
  132. return SequenceClassifierOutputWithPast(
  133. loss=loss,
  134. logits=pooled_logits,
  135. past_key_values=transformer_outputs.past_key_values,
  136. hidden_states=transformer_outputs.hidden_states,
  137. attentions=transformer_outputs.attentions,
  138. )
  139. @auto_docstring
  140. class GenericForQuestionAnswering:
  141. base_model_prefix = "model"
  142. def __init__(self, config):
  143. super().__init__(config)
  144. # Similar to `self.model = AutoModel.from_config(config)` but allows to change the base model name if needed in the child class
  145. setattr(self, self.base_model_prefix, AutoModel.from_config(config))
  146. self.qa_outputs = nn.Linear(config.hidden_size, 2)
  147. # Initialize weights and apply final processing
  148. self.post_init()
  149. def get_input_embeddings(self):
  150. return getattr(self, self.base_model_prefix).embed_tokens
  151. def set_input_embeddings(self, value):
  152. getattr(self, self.base_model_prefix).embed_tokens = value
  153. @can_return_tuple
  154. @auto_docstring
  155. def forward(
  156. self,
  157. input_ids: torch.LongTensor | None = None,
  158. attention_mask: torch.Tensor | None = None,
  159. position_ids: torch.LongTensor | None = None,
  160. past_key_values: Cache | None = None,
  161. inputs_embeds: torch.FloatTensor | None = None,
  162. start_positions: torch.LongTensor | None = None,
  163. end_positions: torch.LongTensor | None = None,
  164. **kwargs: Unpack[TransformersKwargs],
  165. ) -> QuestionAnsweringModelOutput:
  166. outputs: BaseModelOutputWithPast = getattr(self, self.base_model_prefix)(
  167. input_ids,
  168. attention_mask=attention_mask,
  169. position_ids=position_ids,
  170. past_key_values=past_key_values,
  171. inputs_embeds=inputs_embeds,
  172. **kwargs,
  173. )
  174. sequence_output = outputs.last_hidden_state
  175. logits = self.qa_outputs(sequence_output)
  176. start_logits, end_logits = logits.split(1, dim=-1)
  177. start_logits = start_logits.squeeze(-1).contiguous()
  178. end_logits = end_logits.squeeze(-1).contiguous()
  179. loss = None
  180. if start_positions is not None and end_positions is not None:
  181. loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
  182. return QuestionAnsweringModelOutput(
  183. loss=loss,
  184. start_logits=start_logits,
  185. end_logits=end_logits,
  186. hidden_states=outputs.hidden_states,
  187. attentions=outputs.attentions,
  188. )
  189. @auto_docstring
  190. class GenericForTokenClassification:
  191. base_model_prefix = "model"
  192. def __init__(self, config):
  193. super().__init__(config)
  194. self.num_labels = config.num_labels
  195. # Similar to `self.model = AutoModel.from_config(config)` but allows to change the base model name if needed in the child class
  196. setattr(self, self.base_model_prefix, AutoModel.from_config(config))
  197. if getattr(config, "classifier_dropout", None) is not None:
  198. classifier_dropout = config.classifier_dropout
  199. elif getattr(config, "hidden_dropout", None) is not None:
  200. classifier_dropout = config.hidden_dropout
  201. else:
  202. classifier_dropout = 0.1
  203. self.dropout = nn.Dropout(classifier_dropout)
  204. self.score = nn.Linear(config.hidden_size, config.num_labels)
  205. # Initialize weights and apply final processing
  206. self.post_init()
  207. @can_return_tuple
  208. @auto_docstring
  209. def forward(
  210. self,
  211. input_ids: torch.LongTensor | None = None,
  212. attention_mask: torch.Tensor | None = None,
  213. position_ids: torch.LongTensor | None = None,
  214. past_key_values: Cache | None = None,
  215. inputs_embeds: torch.FloatTensor | None = None,
  216. labels: torch.LongTensor | None = None,
  217. use_cache: bool | None = None,
  218. **kwargs: Unpack[TransformersKwargs],
  219. ) -> TokenClassifierOutput:
  220. outputs: BaseModelOutputWithPast = getattr(self, self.base_model_prefix)(
  221. input_ids,
  222. attention_mask=attention_mask,
  223. position_ids=position_ids,
  224. past_key_values=past_key_values,
  225. inputs_embeds=inputs_embeds,
  226. use_cache=use_cache,
  227. **kwargs,
  228. )
  229. sequence_output = outputs.last_hidden_state
  230. sequence_output = self.dropout(sequence_output)
  231. logits = self.score(sequence_output)
  232. loss = None
  233. if labels is not None:
  234. loss = self.loss_function(logits, labels, self.config)
  235. return TokenClassifierOutput(
  236. loss=loss,
  237. logits=logits,
  238. hidden_states=outputs.hidden_states,
  239. attentions=outputs.attentions,
  240. )