pytorch_utils.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. # Copyright 2022 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import annotations
  15. import inspect
  16. from collections.abc import Callable
  17. from functools import lru_cache, wraps
  18. import torch
  19. from safetensors.torch import storage_ptr, storage_size
  20. from torch import nn
  21. from .utils import (
  22. is_torch_greater_or_equal,
  23. is_torch_xla_available,
  24. is_torchdynamo_compiling,
  25. logging,
  26. )
  27. ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
  28. logger = logging.get_logger(__name__)
  29. is_torch_greater_or_equal_than_2_8 = is_torch_greater_or_equal("2.8", accept_dev=True)
  30. is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True)
  31. # For backwards compatibility (e.g. some remote codes on Hub using those variables).
  32. is_torch_greater_or_equal_than_2_4 = is_torch_greater_or_equal("2.4", accept_dev=True)
  33. is_torch_greater_or_equal_than_2_3 = is_torch_greater_or_equal("2.3", accept_dev=True)
  34. is_torch_greater_or_equal_than_2_2 = is_torch_greater_or_equal("2.2", accept_dev=True)
  35. is_torch_greater_or_equal_than_2_1 = is_torch_greater_or_equal("2.1", accept_dev=True)
  36. is_torch_greater_or_equal_than_2_0 = is_torch_greater_or_equal("2.0", accept_dev=True)
  37. is_torch_greater_or_equal_than_1_13 = is_torch_greater_or_equal("1.13", accept_dev=True)
  38. is_torch_greater_or_equal_than_1_12 = is_torch_greater_or_equal("1.12", accept_dev=True)
  39. # Cache this result has it's a C FFI call which can be pretty time-consuming
  40. _torch_distributed_available = torch.distributed.is_available()
  41. def softmax_backward_data(parent, grad_output, output):
  42. """
  43. A function that calls the internal `_softmax_backward_data` PyTorch method and that adjusts the arguments according
  44. to the torch version detected.
  45. """
  46. from torch import _softmax_backward_data
  47. return _softmax_backward_data(grad_output, output, parent.dim, output.dtype)
  48. def prune_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: int = 0) -> nn.Linear:
  49. """
  50. Prune a linear layer to keep only entries in index.
  51. Used to remove heads.
  52. Args:
  53. layer (`torch.nn.Linear`): The layer to prune.
  54. index (`torch.LongTensor`): The indices to keep in the layer.
  55. dim (`int`, *optional*, defaults to 0): The dimension on which to keep the indices.
  56. Returns:
  57. `torch.nn.Linear`: The pruned layer as a new layer with `requires_grad=True`.
  58. """
  59. index = index.to(layer.weight.device)
  60. W = layer.weight.index_select(dim, index).detach().clone()
  61. if layer.bias is not None:
  62. if dim == 1:
  63. b = layer.bias.detach().clone()
  64. else:
  65. b = layer.bias[index].detach().clone()
  66. new_size = list(layer.weight.size())
  67. new_size[dim] = len(index)
  68. new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
  69. new_layer.weight.requires_grad = False
  70. new_layer.weight.copy_(W.contiguous())
  71. new_layer.weight.requires_grad = True
  72. if layer.bias is not None:
  73. new_layer.bias.requires_grad = False
  74. new_layer.bias.copy_(b.contiguous())
  75. new_layer.bias.requires_grad = True
  76. return new_layer
  77. class Conv1D(nn.Module):
  78. """
  79. 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
  80. Basically works like a linear layer but the weights are transposed.
  81. Args:
  82. nf (`int`): The number of output features.
  83. nx (`int`): The number of input features.
  84. """
  85. def __init__(self, nf, nx):
  86. super().__init__()
  87. self.nf = nf
  88. self.nx = nx
  89. self.weight = nn.Parameter(torch.empty(nx, nf))
  90. self.bias = nn.Parameter(torch.zeros(nf))
  91. nn.init.normal_(self.weight, std=0.02)
  92. def __repr__(self) -> str:
  93. return "Conv1D(nf={nf}, nx={nx})".format(**self.__dict__)
  94. def forward(self, x):
  95. size_out = x.size()[:-1] + (self.nf,)
  96. x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
  97. x = x.view(size_out)
  98. return x
  99. def apply_chunking_to_forward(
  100. forward_fn: Callable[..., torch.Tensor],
  101. chunk_size: int,
  102. chunk_dim: int,
  103. *input_tensors,
  104. ) -> torch.Tensor:
  105. """
  106. This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension
  107. `chunk_dim`. It then applies a layer `forward_fn` to each chunk independently to save memory.
  108. If the `forward_fn` is independent across the `chunk_dim` this function will yield the same result as directly
  109. applying `forward_fn` to `input_tensors`.
  110. Args:
  111. forward_fn (`Callable[..., torch.Tensor]`):
  112. The forward function of the model.
  113. chunk_size (`int`):
  114. The chunk size of a chunked tensor: `num_chunks = len(input_tensors[0]) / chunk_size`.
  115. chunk_dim (`int`):
  116. The dimension over which the `input_tensors` should be chunked.
  117. input_tensors (`tuple[torch.Tensor]`):
  118. The input tensors of `forward_fn` which will be chunked
  119. Returns:
  120. `torch.Tensor`: A tensor with the same shape as the `forward_fn` would have given if applied`.
  121. Examples:
  122. ```python
  123. # rename the usual forward() fn to forward_chunk()
  124. def forward_chunk(self, hidden_states):
  125. hidden_states = self.decoder(hidden_states)
  126. return hidden_states
  127. # implement a chunked forward function
  128. def forward(self, hidden_states):
  129. return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states)
  130. ```"""
  131. assert len(input_tensors) > 0, f"{input_tensors} has to be a tuple/list of tensors"
  132. # inspect.signature exist since python 3.5 and is a python method -> no problem with backward compatibility
  133. num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters)
  134. if num_args_in_forward_chunk_fn != len(input_tensors):
  135. raise ValueError(
  136. f"forward_chunk_fn expects {num_args_in_forward_chunk_fn} arguments, but only {len(input_tensors)} input "
  137. "tensors are given"
  138. )
  139. if chunk_size > 0:
  140. tensor_shape = input_tensors[0].shape[chunk_dim]
  141. for input_tensor in input_tensors:
  142. if input_tensor.shape[chunk_dim] != tensor_shape:
  143. raise ValueError(
  144. f"All input tenors have to be of the same shape: {tensor_shape}, "
  145. f"found shape {input_tensor.shape[chunk_dim]}"
  146. )
  147. if input_tensors[0].shape[chunk_dim] % chunk_size != 0:
  148. raise ValueError(
  149. f"The dimension to be chunked {input_tensors[0].shape[chunk_dim]} has to be a multiple of the chunk "
  150. f"size {chunk_size}"
  151. )
  152. num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
  153. # chunk input tensor into tuples
  154. input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors)
  155. # apply forward fn to every tuple
  156. output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks))
  157. # concatenate output at same dimension
  158. return torch.cat(output_chunks, dim=chunk_dim)
  159. return forward_fn(*input_tensors)
  160. def meshgrid(*tensors: torch.Tensor | list[torch.Tensor], indexing: str | None = None) -> tuple[torch.Tensor, ...]:
  161. """
  162. Wrapper around torch.meshgrid to avoid warning messages about the introduced `indexing` argument.
  163. Reference: https://pytorch.org/docs/1.13/generated/torch.meshgrid.html
  164. """
  165. return torch.meshgrid(*tensors, indexing=indexing)
  166. def id_tensor_storage(tensor: torch.Tensor) -> tuple[torch.device, int, int]:
  167. """
  168. Unique identifier to a tensor storage. Multiple different tensors can share the same underlying storage. For
  169. example, "meta" tensors all share the same storage, and thus their identifier will all be equal. This identifier is
  170. guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
  171. non-overlapping lifetimes may have the same id.
  172. """
  173. if _torch_distributed_available and is_torch_greater_or_equal("2.5"):
  174. from torch.distributed.tensor import DTensor
  175. if isinstance(tensor, DTensor):
  176. local_tensor = tensor.to_local()
  177. return tensor.device, local_tensor.storage().data_ptr(), tensor.nbytes
  178. if tensor.device.type == "xla" and is_torch_xla_available():
  179. # NOTE: xla tensors dont have storage
  180. # use some other unique id to distinguish.
  181. # this is a XLA tensor, it must be created using torch_xla's
  182. # device. So the following import is safe:
  183. import torch_xla
  184. unique_id = torch_xla._XLAC._xla_get_tensor_id(tensor)
  185. else:
  186. unique_id = storage_ptr(tensor)
  187. return tensor.device, unique_id, storage_size(tensor)
  188. @wraps(lru_cache)
  189. def compile_compatible_method_lru_cache(*lru_args, **lru_kwargs):
  190. """
  191. LRU cache decorator from standard functools library, but with a workaround to disable
  192. caching when torchdynamo is compiling. Expected to work with class methods.
  193. """
  194. def decorator(func):
  195. func_with_cache = lru_cache(*lru_args, **lru_kwargs)(func)
  196. @wraps(func)
  197. def wrapper(*args, **kwargs):
  198. if is_torchdynamo_compiling():
  199. return func(*args, **kwargs)
  200. else:
  201. return func_with_cache(*args, **kwargs)
  202. return wrapper
  203. return decorator