| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256 |
- # mypy: allow-untyped-defs
- import copy
- import warnings
- from collections.abc import Callable
- from typing import Any
- import torch
- import torch.nn.functional as F
- from torch import Tensor
- from torch.nn.init import xavier_uniform_
- from .activation import MultiheadAttention
- from .container import ModuleList
- from .dropout import Dropout
- from .linear import Linear
- from .module import Module
- from .normalization import LayerNorm
- __all__ = [
- "Transformer",
- "TransformerEncoder",
- "TransformerDecoder",
- "TransformerEncoderLayer",
- "TransformerDecoderLayer",
- ]
- def _generate_square_subsequent_mask(
- sz: int,
- device: torch.device | None = None,
- dtype: torch.dtype | None = None,
- ) -> Tensor:
- r"""Generate a square causal mask for the sequence.
- The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0).
- """
- return torch.triu(
- torch.full((sz, sz), float("-inf"), dtype=dtype, device=device),
- diagonal=1,
- )
- def _get_seq_len(src: Tensor, batch_first: bool) -> int | None:
- if src.is_nested:
- return None
- else:
- src_size = src.size()
- if len(src_size) == 2:
- # unbatched: S, E
- return src_size[0]
- else:
- # batched: B, S, E if batch_first else S, B, E
- seq_len_pos = 1 if batch_first else 0
- return src_size[seq_len_pos]
- class Transformer(Module):
- r"""A basic transformer layer.
- This Transformer layer implements the original Transformer architecture described
- in the `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_ paper. The
- intent of this layer is as a reference implementation for foundational understanding
- and thus it contains only limited features relative to newer Transformer architectures.
- Given the fast pace of innovation in transformer-like architectures, we recommend
- exploring this `tutorial <https://pytorch.org/tutorials/intermediate/transformer_building_blocks.html>`_
- to build an efficient transformer layer from building blocks in core or using higher
- level libraries from the `PyTorch Ecosystem <https://landscape.pytorch.org/>`_.
- Args:
- d_model: the number of expected features in the encoder/decoder inputs (default=512).
- nhead: the number of heads in the multiheadattention models (default=8).
- num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
- num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
- dim_feedforward: the dimension of the feedforward network model (default=2048).
- dropout: the dropout value (default=0.1).
- activation: the activation function of encoder/decoder intermediate layer, can be a string
- ("relu" or "gelu") or a unary callable. Default: relu
- custom_encoder: custom encoder (default=None).
- custom_decoder: custom decoder (default=None).
- layer_norm_eps: the eps value in layer normalization components (default=1e-5).
- batch_first: If ``True``, then the input and output tensors are provided
- as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
- norm_first: if ``True``, encoder and decoder layers will perform LayerNorms before
- other attention and feedforward operations, otherwise after. Default: ``False`` (after).
- bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive
- bias. Default: ``True``.
- Examples:
- >>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
- >>> src = torch.rand((10, 32, 512))
- >>> tgt = torch.rand((20, 32, 512))
- >>> out = transformer_model(src, tgt)
- Note: A full example to apply nn.Transformer module for the word language model is available in
- https://github.com/pytorch/examples/tree/master/word_language_model
- """
- def __init__(
- self,
- d_model: int = 512,
- nhead: int = 8,
- num_encoder_layers: int = 6,
- num_decoder_layers: int = 6,
- dim_feedforward: int = 2048,
- dropout: float = 0.1,
- activation: str | Callable[[Tensor], Tensor] = F.relu,
- custom_encoder: Any | None = None,
- custom_decoder: Any | None = None,
- layer_norm_eps: float = 1e-5,
- batch_first: bool = False,
- norm_first: bool = False,
- bias: bool = True,
- device=None,
- dtype=None,
- ) -> None:
- factory_kwargs = {"device": device, "dtype": dtype}
- super().__init__()
- torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
- if custom_encoder is not None:
- self.encoder = custom_encoder
- else:
- encoder_layer = TransformerEncoderLayer(
- d_model,
- nhead,
- dim_feedforward,
- dropout,
- activation,
- layer_norm_eps,
- batch_first,
- norm_first,
- bias,
- **factory_kwargs,
- )
- encoder_norm = LayerNorm(
- d_model,
- eps=layer_norm_eps,
- bias=bias,
- # pyrefly: ignore [bad-argument-type]
- **factory_kwargs,
- )
- self.encoder = TransformerEncoder(
- encoder_layer, num_encoder_layers, encoder_norm
- )
- if custom_decoder is not None:
- self.decoder = custom_decoder
- else:
- decoder_layer = TransformerDecoderLayer(
- d_model,
- nhead,
- dim_feedforward,
- dropout,
- activation,
- layer_norm_eps,
- batch_first,
- norm_first,
- bias,
- **factory_kwargs,
- )
- decoder_norm = LayerNorm(
- d_model,
- eps=layer_norm_eps,
- bias=bias,
- # pyrefly: ignore [bad-argument-type]
- **factory_kwargs,
- )
- self.decoder = TransformerDecoder(
- decoder_layer, num_decoder_layers, decoder_norm
- )
- self._reset_parameters()
- self.d_model = d_model
- self.nhead = nhead
- self.batch_first = batch_first
- def forward(
- self,
- src: Tensor,
- tgt: Tensor,
- src_mask: Tensor | None = None,
- tgt_mask: Tensor | None = None,
- memory_mask: Tensor | None = None,
- src_key_padding_mask: Tensor | None = None,
- tgt_key_padding_mask: Tensor | None = None,
- memory_key_padding_mask: Tensor | None = None,
- src_is_causal: bool | None = None,
- tgt_is_causal: bool | None = None,
- memory_is_causal: bool = False,
- ) -> Tensor:
- r"""Take in and process masked source/target sequences.
- .. note::
- If a boolean tensor is provided for any of the [src/tgt/memory]_mask arguments, positions with a ``True`` value are
- not allowed to participate in the attention,
- which is the opposite of the definition for :attr:`attn_mask`
- in :func:`torch.nn.functional.scaled_dot_product_attention`.
- Args:
- src: the sequence to the encoder (required).
- tgt: the sequence to the decoder (required).
- src_mask: the additive mask for the src sequence (optional).
- tgt_mask: the additive mask for the tgt sequence (optional).
- memory_mask: the additive mask for the encoder output (optional).
- src_key_padding_mask: the Tensor mask for src keys per batch (optional).
- tgt_key_padding_mask: the Tensor mask for tgt keys per batch (optional).
- memory_key_padding_mask: the Tensor mask for memory keys per batch (optional).
- src_is_causal: If specified, applies a causal mask as ``src_mask``.
- Default: ``None``; try to detect a causal mask.
- Warning:
- ``src_is_causal`` provides a hint that ``src_mask`` is
- the causal mask. Providing incorrect hints can result in
- incorrect execution, including forward and backward
- compatibility.
- tgt_is_causal: If specified, applies a causal mask as ``tgt_mask``.
- Default: ``None``; try to detect a causal mask.
- Warning:
- ``tgt_is_causal`` provides a hint that ``tgt_mask`` is
- the causal mask. Providing incorrect hints can result in
- incorrect execution, including forward and backward
- compatibility.
- memory_is_causal: If specified, applies a causal mask as
- ``memory_mask``.
- Default: ``False``.
- Warning:
- ``memory_is_causal`` provides a hint that
- ``memory_mask`` is the causal mask. Providing incorrect
- hints can result in incorrect execution, including
- forward and backward compatibility.
- Shape:
- - src: :math:`(S, E)` for unbatched input, :math:`(S, N, E)` if `batch_first=False` or
- `(N, S, E)` if `batch_first=True`.
- - tgt: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or
- `(N, T, E)` if `batch_first=True`.
- - src_mask: :math:`(S, S)` or :math:`(N\cdot\text{num\_heads}, S, S)`.
- - tgt_mask: :math:`(T, T)` or :math:`(N\cdot\text{num\_heads}, T, T)`.
- - memory_mask: :math:`(T, S)`.
- - src_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`.
- - tgt_key_padding_mask: :math:`(T)` for unbatched input otherwise :math:`(N, T)`.
- - memory_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`.
- Note: [src/tgt/memory]_mask ensures that position :math:`i` is allowed to attend the unmasked
- positions. If a BoolTensor is provided, positions with ``True``
- are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
- is provided, it will be added to the attention weight.
- [src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by
- the attention. If a BoolTensor is provided, the positions with the
- value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
- - output: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or
- `(N, T, E)` if `batch_first=True`.
- Note: Due to the multi-head attention architecture in the transformer model,
- the output sequence length of a transformer is same as the input sequence
- (i.e. target) length of the decoder.
- where :math:`S` is the source sequence length, :math:`T` is the target sequence length, :math:`N` is the
- batch size, :math:`E` is the feature number
- Examples:
- >>> # xdoctest: +SKIP
- >>> output = transformer_model(
- ... src, tgt, src_mask=src_mask, tgt_mask=tgt_mask
- ... )
- """
- is_batched = src.dim() == 3
- if not self.batch_first and src.size(1) != tgt.size(1) and is_batched:
- raise RuntimeError("the batch number of src and tgt must be equal")
- elif self.batch_first and src.size(0) != tgt.size(0) and is_batched:
- raise RuntimeError("the batch number of src and tgt must be equal")
- if src.size(-1) != self.d_model or tgt.size(-1) != self.d_model:
- raise RuntimeError(
- "the feature number of src and tgt must be equal to d_model"
- )
- memory = self.encoder(
- src,
- mask=src_mask,
- src_key_padding_mask=src_key_padding_mask,
- is_causal=src_is_causal,
- )
- output = self.decoder(
- tgt,
- memory,
- tgt_mask=tgt_mask,
- memory_mask=memory_mask,
- tgt_key_padding_mask=tgt_key_padding_mask,
- memory_key_padding_mask=memory_key_padding_mask,
- tgt_is_causal=tgt_is_causal,
- memory_is_causal=memory_is_causal,
- )
- return output
- @staticmethod
- def generate_square_subsequent_mask(
- sz: int,
- device: torch.device | None = None,
- dtype: torch.dtype | None = None,
- ) -> Tensor:
- r"""Generate a square causal mask for the sequence.
- The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0).
- """
- return _generate_square_subsequent_mask(sz, dtype=dtype, device=device)
- def _reset_parameters(self) -> None:
- r"""Initiate parameters in the transformer model."""
- for p in self.parameters():
- if p.dim() > 1:
- xavier_uniform_(p)
- class TransformerEncoder(Module):
- r"""TransformerEncoder is a stack of N encoder layers.
- This TransformerEncoder layer implements the original architecture described
- in the `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_ paper. The
- intent of this layer is as a reference implementation for foundational understanding
- and thus it contains only limited features relative to newer Transformer architectures.
- Given the fast pace of innovation in transformer-like architectures, we recommend
- exploring this `tutorial <https://pytorch.org/tutorials/intermediate/transformer_building_blocks.html>`_
- to build efficient layers from building blocks in core or using higher
- level libraries from the `PyTorch Ecosystem <https://landscape.pytorch.org/>`_.
- .. warning::
- All layers in the TransformerEncoder are initialized with the same parameters.
- It is recommended to manually initialize the layers after creating the TransformerEncoder instance.
- Args:
- encoder_layer: an instance of the TransformerEncoderLayer() class (required).
- num_layers: the number of sub-encoder-layers in the encoder (required).
- norm: the layer normalization component (optional).
- enable_nested_tensor: if True, input will automatically convert to nested tensor
- (and convert back on output). This will improve the overall performance of
- TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
- Examples:
- >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
- >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
- >>> src = torch.rand(10, 32, 512)
- >>> out = transformer_encoder(src)
- """
- __constants__ = ["norm"]
- def __init__(
- self,
- encoder_layer: "TransformerEncoderLayer",
- num_layers: int,
- norm: Module | None = None,
- enable_nested_tensor: bool = True,
- mask_check: bool = True,
- ) -> None:
- super().__init__()
- torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
- self.layers = _get_clones(encoder_layer, num_layers)
- self.num_layers = num_layers
- self.norm = norm
- # this attribute saves the value providedat object construction
- self.enable_nested_tensor = enable_nested_tensor
- # this attribute controls whether nested tensors are used
- self.use_nested_tensor = enable_nested_tensor
- self.mask_check = mask_check
- enc_layer = "encoder_layer"
- why_not_sparsity_fast_path = ""
- if not isinstance(encoder_layer, torch.nn.TransformerEncoderLayer):
- why_not_sparsity_fast_path = f"{enc_layer} was not TransformerEncoderLayer"
- elif encoder_layer.norm_first:
- why_not_sparsity_fast_path = f"{enc_layer}.norm_first was True"
- elif not encoder_layer.self_attn.batch_first:
- why_not_sparsity_fast_path = (
- f"{enc_layer}.self_attn.batch_first was not True"
- + "(use batch_first for better inference performance)"
- )
- elif not encoder_layer.self_attn._qkv_same_embed_dim:
- why_not_sparsity_fast_path = (
- f"{enc_layer}.self_attn._qkv_same_embed_dim was not True"
- )
- elif encoder_layer.self_attn.in_proj_bias is None:
- why_not_sparsity_fast_path = f"{enc_layer}.self_attn was passed bias=False"
- elif not encoder_layer.activation_relu_or_gelu:
- why_not_sparsity_fast_path = (
- f"{enc_layer}.activation_relu_or_gelu was not True"
- )
- elif encoder_layer.norm1.eps != encoder_layer.norm2.eps:
- why_not_sparsity_fast_path = (
- f"{enc_layer}.norm1.eps was not equal to {enc_layer}.norm2.eps"
- )
- elif encoder_layer.self_attn.num_heads % 2 == 1:
- why_not_sparsity_fast_path = f"{enc_layer}.self_attn.num_heads is odd"
- if enable_nested_tensor and why_not_sparsity_fast_path:
- warnings.warn(
- f"enable_nested_tensor is True, but self.use_nested_tensor is False because {why_not_sparsity_fast_path}",
- stacklevel=2,
- )
- self.use_nested_tensor = False
- def forward(
- self,
- src: Tensor,
- mask: Tensor | None = None,
- src_key_padding_mask: Tensor | None = None,
- is_causal: bool | None = None,
- ) -> Tensor:
- r"""Pass the input through the encoder layers in turn.
- Args:
- src: the sequence to the encoder (required).
- mask: the mask for the src sequence (optional).
- src_key_padding_mask: the mask for the src keys per batch (optional).
- is_causal: If specified, applies a causal mask as ``mask``.
- Default: ``None``; try to detect a causal mask.
- Warning:
- ``is_causal`` provides a hint that ``mask`` is the
- causal mask. Providing incorrect hints can result in
- incorrect execution, including forward and backward
- compatibility.
- Shape:
- see the docs in :class:`~torch.nn.Transformer`.
- """
- src_key_padding_mask = F._canonical_mask(
- mask=src_key_padding_mask,
- mask_name="src_key_padding_mask",
- other_type=F._none_or_dtype(mask),
- other_name="mask",
- target_type=src.dtype,
- )
- mask = F._canonical_mask(
- mask=mask,
- mask_name="mask",
- other_type=None,
- other_name="",
- target_type=src.dtype,
- check_other=False,
- )
- output = src
- convert_to_nested = False
- first_layer = self.layers[0]
- src_key_padding_mask_for_layers = src_key_padding_mask
- why_not_sparsity_fast_path = ""
- str_first_layer = "self.layers[0]"
- batch_first = first_layer.self_attn.batch_first
- is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled()
- do_mask_check = getattr(self, "mask_check", True)
- if not is_fastpath_enabled:
- why_not_sparsity_fast_path = (
- "torch.backends.mha.get_fastpath_enabled() was not True"
- )
- elif not hasattr(self, "use_nested_tensor"):
- why_not_sparsity_fast_path = "use_nested_tensor attribute not present"
- elif not self.use_nested_tensor:
- why_not_sparsity_fast_path = (
- "self.use_nested_tensor (set in init) was not True"
- )
- elif first_layer.training:
- why_not_sparsity_fast_path = f"{str_first_layer} was in training mode"
- elif src.dim() != 3:
- why_not_sparsity_fast_path = (
- f"input not batched; expected src.dim() of 3 but got {src.dim()}"
- )
- elif src_key_padding_mask is None:
- why_not_sparsity_fast_path = "src_key_padding_mask was None"
- # This check avoids a call to torch._nested_tensor_from_mask_left_aligned() that
- # breaks in torch.compile.
- elif do_mask_check and torch.compiler.is_compiling():
- why_not_sparsity_fast_path = (
- "mask_check enabled with torch.compile or torch.export"
- )
- elif do_mask_check and not torch._nested_tensor_from_mask_left_aligned(
- src, src_key_padding_mask.logical_not()
- ):
- why_not_sparsity_fast_path = "mask_check enabled, and src and src_key_padding_mask was not left aligned"
- elif output.is_nested:
- why_not_sparsity_fast_path = "NestedTensor input is not supported"
- elif mask is not None:
- why_not_sparsity_fast_path = (
- "src_key_padding_mask and mask were both supplied"
- )
- elif torch.is_autocast_enabled():
- why_not_sparsity_fast_path = "autocast is enabled"
- if not why_not_sparsity_fast_path:
- tensor_args = (
- src,
- first_layer.self_attn.in_proj_weight,
- first_layer.self_attn.in_proj_bias,
- first_layer.self_attn.out_proj.weight,
- first_layer.self_attn.out_proj.bias,
- first_layer.norm1.weight,
- first_layer.norm1.bias,
- first_layer.norm2.weight,
- first_layer.norm2.bias,
- first_layer.linear1.weight,
- first_layer.linear1.bias,
- first_layer.linear2.weight,
- first_layer.linear2.bias,
- )
- _supported_device_type = [
- "cpu",
- "cuda",
- "xpu",
- torch.utils.backend_registration._privateuse1_backend_name,
- ]
- if torch.overrides.has_torch_function(tensor_args):
- why_not_sparsity_fast_path = "some Tensor argument has_torch_function"
- elif src.device.type not in _supported_device_type:
- why_not_sparsity_fast_path = (
- f"src device is neither one of {_supported_device_type}"
- )
- elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args):
- why_not_sparsity_fast_path = (
- "grad is enabled and at least one of query or the "
- "input/output projection weights or biases requires_grad"
- )
- if (not why_not_sparsity_fast_path) and (src_key_padding_mask is not None):
- convert_to_nested = True
- output = torch._nested_tensor_from_mask(
- output, src_key_padding_mask.logical_not(), mask_check=False
- )
- src_key_padding_mask_for_layers = None
- seq_len = _get_seq_len(src, batch_first)
- is_causal = _detect_is_causal_mask(mask, is_causal, seq_len)
- for mod in self.layers:
- output = mod(
- output,
- src_mask=mask,
- is_causal=is_causal,
- src_key_padding_mask=src_key_padding_mask_for_layers,
- )
- if convert_to_nested:
- output = output.to_padded_tensor(0.0, src.size())
- if self.norm is not None:
- output = self.norm(output)
- return output
- class TransformerDecoder(Module):
- r"""TransformerDecoder is a stack of N decoder layers.
- This TransformerDecoder layer implements the original architecture described
- in the `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_ paper. The
- intent of this layer is as a reference implementation for foundational understanding
- and thus it contains only limited features relative to newer Transformer architectures.
- Given the fast pace of innovation in transformer-like architectures, we recommend
- exploring this `tutorial <https://pytorch.org/tutorials/intermediate/transformer_building_blocks.html>`_
- to build efficient layers from building blocks in core or using higher
- level libraries from the `PyTorch Ecosystem <https://landscape.pytorch.org/>`_.
- .. warning::
- All layers in the TransformerDecoder are initialized with the same parameters.
- It is recommended to manually initialize the layers after creating the TransformerDecoder instance.
- Args:
- decoder_layer: an instance of the TransformerDecoderLayer() class (required).
- num_layers: the number of sub-decoder-layers in the decoder (required).
- norm: the layer normalization component (optional).
- Examples:
- >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
- >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
- >>> memory = torch.rand(10, 32, 512)
- >>> tgt = torch.rand(20, 32, 512)
- >>> out = transformer_decoder(tgt, memory)
- """
- __constants__ = ["norm"]
- def __init__(
- self,
- decoder_layer: "TransformerDecoderLayer",
- num_layers: int,
- norm: Module | None = None,
- ) -> None:
- super().__init__()
- torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
- self.layers = _get_clones(decoder_layer, num_layers)
- self.num_layers = num_layers
- self.norm = norm
- def forward(
- self,
- tgt: Tensor,
- memory: Tensor,
- tgt_mask: Tensor | None = None,
- memory_mask: Tensor | None = None,
- tgt_key_padding_mask: Tensor | None = None,
- memory_key_padding_mask: Tensor | None = None,
- tgt_is_causal: bool | None = None,
- memory_is_causal: bool = False,
- ) -> Tensor:
- r"""Pass the inputs (and mask) through the decoder layer in turn.
- Args:
- tgt: the sequence to the decoder (required).
- memory: the sequence from the last layer of the encoder (required).
- tgt_mask: the mask for the tgt sequence (optional).
- memory_mask: the mask for the memory sequence (optional).
- tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
- memory_key_padding_mask: the mask for the memory keys per batch (optional).
- tgt_is_causal: If specified, applies a causal mask as ``tgt mask``.
- Default: ``None``; try to detect a causal mask.
- Warning:
- ``tgt_is_causal`` provides a hint that ``tgt_mask`` is
- the causal mask. Providing incorrect hints can result in
- incorrect execution, including forward and backward
- compatibility.
- memory_is_causal: If specified, applies a causal mask as
- ``memory mask``.
- Default: ``False``.
- Warning:
- ``memory_is_causal`` provides a hint that
- ``memory_mask`` is the causal mask. Providing incorrect
- hints can result in incorrect execution, including
- forward and backward compatibility.
- Shape:
- see the docs in :class:`~torch.nn.Transformer`.
- """
- output = tgt
- seq_len = _get_seq_len(tgt, self.layers[0].self_attn.batch_first)
- tgt_is_causal = _detect_is_causal_mask(tgt_mask, tgt_is_causal, seq_len)
- for mod in self.layers:
- output = mod(
- output,
- memory,
- tgt_mask=tgt_mask,
- memory_mask=memory_mask,
- tgt_key_padding_mask=tgt_key_padding_mask,
- memory_key_padding_mask=memory_key_padding_mask,
- tgt_is_causal=tgt_is_causal,
- memory_is_causal=memory_is_causal,
- )
- if self.norm is not None:
- output = self.norm(output)
- return output
- class TransformerEncoderLayer(Module):
- r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
- This TransformerEncoderLayer implements the original architecture described
- in the `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_ paper. The
- intent of this layer is as a reference implementation for foundational understanding
- and thus it contains only limited features relative to newer Transformer architectures.
- Given the fast pace of innovation in transformer-like architectures, we recommend
- exploring this `tutorial <https://pytorch.org/tutorials/intermediate/transformer_building_blocks.html>`_
- to build efficient layers from building blocks in core or using higher
- level libraries from the `PyTorch Ecosystem <https://landscape.pytorch.org/>`_.
- TransformerEncoderLayer can handle either traditional torch.tensor inputs,
- or Nested Tensor inputs. Derived classes are expected to similarly accept
- both input formats. (Not all combinations of inputs are currently
- supported by TransformerEncoderLayer while Nested Tensor is in prototype
- state.)
- If you are implementing a custom layer, you may derive it either from
- the Module or TransformerEncoderLayer class. If your custom layer
- supports both torch.Tensors and Nested Tensors inputs, make its
- implementation a derived class of TransformerEncoderLayer. If your custom
- Layer supports only torch.Tensor inputs, derive its implementation from
- Module.
- Args:
- d_model: the number of expected features in the input (required).
- nhead: the number of heads in the multiheadattention models (required).
- dim_feedforward: the dimension of the feedforward network model (default=2048).
- dropout: the dropout value (default=0.1).
- activation: the activation function of the intermediate layer, can be a string
- ("relu" or "gelu") or a unary callable. Default: relu
- layer_norm_eps: the eps value in layer normalization components (default=1e-5).
- batch_first: If ``True``, then the input and output tensors are provided
- as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
- norm_first: if ``True``, layer norm is done prior to attention and feedforward
- operations, respectively. Otherwise it's done after. Default: ``False`` (after).
- bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive
- bias. Default: ``True``.
- Examples:
- >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
- >>> src = torch.rand(10, 32, 512)
- >>> out = encoder_layer(src)
- Alternatively, when ``batch_first`` is ``True``:
- >>> encoder_layer = nn.TransformerEncoderLayer(
- ... d_model=512, nhead=8, batch_first=True
- ... )
- >>> src = torch.rand(32, 10, 512)
- >>> out = encoder_layer(src)
- Fast path:
- forward() will use a special optimized implementation described in
- `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`_ if all of the following
- conditions are met:
- - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor
- argument ``requires_grad``
- - training is disabled (using ``.eval()``)
- - batch_first is ``True`` and the input is batched (i.e., ``src.dim() == 3``)
- - activation is one of: ``"relu"``, ``"gelu"``, ``torch.functional.relu``, or ``torch.functional.gelu``
- - at most one of ``src_mask`` and ``src_key_padding_mask`` is passed
- - if src is a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_, neither ``src_mask``
- nor ``src_key_padding_mask`` is passed
- - the two ``LayerNorm`` instances have a consistent ``eps`` value (this will naturally be the case
- unless the caller has manually modified one without modifying the other)
- If the optimized implementation is in use, a
- `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be
- passed for ``src`` to represent padding more efficiently than using a padding
- mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ will be
- returned, and an additional speedup proportional to the fraction of the input that
- is padding can be expected.
- .. _`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`:
- https://arxiv.org/abs/2205.14135
- """
- __constants__ = ["norm_first"]
- def __init__(
- self,
- d_model: int,
- nhead: int,
- dim_feedforward: int = 2048,
- dropout: float = 0.1,
- activation: str | Callable[[Tensor], Tensor] = F.relu,
- layer_norm_eps: float = 1e-5,
- batch_first: bool = False,
- norm_first: bool = False,
- bias: bool = True,
- device=None,
- dtype=None,
- ) -> None:
- factory_kwargs = {"device": device, "dtype": dtype}
- super().__init__()
- self.self_attn = MultiheadAttention(
- d_model,
- nhead,
- dropout=dropout,
- bias=bias,
- batch_first=batch_first,
- **factory_kwargs,
- )
- # Implementation of Feedforward model
- self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs)
- self.dropout = Dropout(dropout)
- self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
- self.norm_first = norm_first
- # pyrefly: ignore [bad-argument-type]
- self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
- # pyrefly: ignore [bad-argument-type]
- self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
- self.dropout1 = Dropout(dropout)
- self.dropout2 = Dropout(dropout)
- # Legacy string support for activation function.
- if isinstance(activation, str):
- activation = _get_activation_fn(activation)
- # We can't test self.activation in forward() in TorchScript,
- # so stash some information about it instead.
- if activation is F.relu or isinstance(activation, torch.nn.ReLU):
- self.activation_relu_or_gelu = 1
- elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
- self.activation_relu_or_gelu = 2
- else:
- self.activation_relu_or_gelu = 0
- self.activation = activation
- def __setstate__(self, state):
- super().__setstate__(state)
- if not hasattr(self, "activation"):
- self.activation = F.relu
- def forward(
- self,
- src: Tensor,
- src_mask: Tensor | None = None,
- src_key_padding_mask: Tensor | None = None,
- is_causal: bool = False,
- ) -> Tensor:
- r"""Pass the input through the encoder layer.
- Args:
- src: the sequence to the encoder layer (required).
- src_mask: the mask for the src sequence (optional).
- src_key_padding_mask: the mask for the src keys per batch (optional).
- is_causal: If specified, applies a causal mask as ``src mask``.
- Default: ``False``.
- Warning:
- ``is_causal`` provides a hint that ``src_mask`` is the
- causal mask. Providing incorrect hints can result in
- incorrect execution, including forward and backward
- compatibility.
- Shape:
- see the docs in :class:`~torch.nn.Transformer`.
- """
- src_key_padding_mask = F._canonical_mask(
- mask=src_key_padding_mask,
- mask_name="src_key_padding_mask",
- other_type=F._none_or_dtype(src_mask),
- other_name="src_mask",
- target_type=src.dtype,
- )
- src_mask = F._canonical_mask(
- mask=src_mask,
- mask_name="src_mask",
- other_type=None,
- other_name="",
- target_type=src.dtype,
- check_other=False,
- )
- is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled()
- why_not_sparsity_fast_path = ""
- if not is_fastpath_enabled:
- why_not_sparsity_fast_path = (
- "torch.backends.mha.get_fastpath_enabled() was not True"
- )
- elif src.dim() != 3:
- why_not_sparsity_fast_path = (
- f"input not batched; expected src.dim() of 3 but got {src.dim()}"
- )
- elif self.training:
- why_not_sparsity_fast_path = "training is enabled"
- elif not self.self_attn.batch_first:
- why_not_sparsity_fast_path = "self_attn.batch_first was not True"
- elif self.self_attn.in_proj_bias is None:
- why_not_sparsity_fast_path = "self_attn was passed bias=False"
- elif not self.self_attn._qkv_same_embed_dim:
- why_not_sparsity_fast_path = "self_attn._qkv_same_embed_dim was not True"
- elif not self.activation_relu_or_gelu:
- why_not_sparsity_fast_path = "activation_relu_or_gelu was not True"
- elif self.norm1.eps != self.norm2.eps:
- why_not_sparsity_fast_path = "norm1.eps is not equal to norm2.eps"
- elif src.is_nested and (
- src_key_padding_mask is not None or src_mask is not None
- ):
- why_not_sparsity_fast_path = "neither src_key_padding_mask nor src_mask are not supported with NestedTensor input"
- elif self.self_attn.num_heads % 2 == 1:
- why_not_sparsity_fast_path = "num_head is odd"
- elif torch.is_autocast_enabled():
- why_not_sparsity_fast_path = "autocast is enabled"
- elif any(
- len(getattr(m, "_forward_hooks", {}))
- + len(getattr(m, "_forward_pre_hooks", {}))
- for m in self.modules()
- ):
- why_not_sparsity_fast_path = "forward pre-/hooks are attached to the module"
- if not why_not_sparsity_fast_path:
- tensor_args = (
- src,
- self.self_attn.in_proj_weight,
- self.self_attn.in_proj_bias,
- self.self_attn.out_proj.weight,
- self.self_attn.out_proj.bias,
- self.norm1.weight,
- self.norm1.bias,
- self.norm2.weight,
- self.norm2.bias,
- self.linear1.weight,
- self.linear1.bias,
- self.linear2.weight,
- self.linear2.bias,
- )
- # We have to use list comprehensions below because TorchScript does not support
- # generator expressions.
- _supported_device_type = [
- "cpu",
- "cuda",
- "xpu",
- torch.utils.backend_registration._privateuse1_backend_name,
- ]
- if torch.overrides.has_torch_function(tensor_args):
- why_not_sparsity_fast_path = "some Tensor argument has_torch_function"
- elif not all(
- (x.device.type in _supported_device_type) for x in tensor_args
- ):
- why_not_sparsity_fast_path = (
- "some Tensor argument's device is neither one of "
- f"{_supported_device_type}"
- )
- elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args):
- why_not_sparsity_fast_path = (
- "grad is enabled and at least one of query or the "
- "input/output projection weights or biases requires_grad"
- )
- if not why_not_sparsity_fast_path:
- merged_mask, mask_type = self.self_attn.merge_masks(
- src_mask, src_key_padding_mask, src
- )
- return torch._transformer_encoder_layer_fwd(
- src,
- self.self_attn.embed_dim,
- self.self_attn.num_heads,
- self.self_attn.in_proj_weight,
- self.self_attn.in_proj_bias,
- self.self_attn.out_proj.weight,
- self.self_attn.out_proj.bias,
- self.activation_relu_or_gelu == 2,
- self.norm_first,
- self.norm1.eps,
- self.norm1.weight,
- self.norm1.bias,
- self.norm2.weight,
- self.norm2.bias,
- self.linear1.weight,
- self.linear1.bias,
- self.linear2.weight,
- self.linear2.bias,
- merged_mask,
- mask_type,
- )
- # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
- x = src
- if self.norm_first:
- x = x + self._sa_block(
- self.norm1(x), src_mask, src_key_padding_mask, is_causal=is_causal
- )
- x = x + self._ff_block(self.norm2(x))
- else:
- x = self.norm1(
- x
- + self._sa_block(x, src_mask, src_key_padding_mask, is_causal=is_causal)
- )
- x = self.norm2(x + self._ff_block(x))
- return x
- # self-attention block
- def _sa_block(
- self,
- x: Tensor,
- attn_mask: Tensor | None,
- key_padding_mask: Tensor | None,
- is_causal: bool = False,
- ) -> Tensor:
- x = self.self_attn(
- x,
- x,
- x,
- attn_mask=attn_mask,
- key_padding_mask=key_padding_mask,
- need_weights=False,
- is_causal=is_causal,
- )[0]
- return self.dropout1(x)
- # feed forward block
- def _ff_block(self, x: Tensor) -> Tensor:
- x = self.linear2(self.dropout(self.activation(self.linear1(x))))
- return self.dropout2(x)
- class TransformerDecoderLayer(Module):
- r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
- This TransformerDecoderLayer implements the original architecture described
- in the `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_ paper. The
- intent of this layer is as a reference implementation for foundational understanding
- and thus it contains only limited features relative to newer Transformer architectures.
- Given the fast pace of innovation in transformer-like architectures, we recommend
- exploring this `tutorial <https://pytorch.org/tutorials/intermediate/transformer_building_blocks.html>`_
- to build efficient layers from building blocks in core or using higher
- level libraries from the `PyTorch Ecosystem <https://landscape.pytorch.org/>`_.
- Args:
- d_model: the number of expected features in the input (required).
- nhead: the number of heads in the multiheadattention models (required).
- dim_feedforward: the dimension of the feedforward network model (default=2048).
- dropout: the dropout value (default=0.1).
- activation: the activation function of the intermediate layer, can be a string
- ("relu" or "gelu") or a unary callable. Default: relu
- layer_norm_eps: the eps value in layer normalization components (default=1e-5).
- batch_first: If ``True``, then the input and output tensors are provided
- as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
- norm_first: if ``True``, layer norm is done prior to self attention, multihead
- attention and feedforward operations, respectively. Otherwise it's done after.
- Default: ``False`` (after).
- bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive
- bias. Default: ``True``.
- Examples:
- >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
- >>> memory = torch.rand(10, 32, 512)
- >>> tgt = torch.rand(20, 32, 512)
- >>> out = decoder_layer(tgt, memory)
- Alternatively, when ``batch_first`` is ``True``:
- >>> decoder_layer = nn.TransformerDecoderLayer(
- ... d_model=512, nhead=8, batch_first=True
- ... )
- >>> memory = torch.rand(32, 10, 512)
- >>> tgt = torch.rand(32, 20, 512)
- >>> out = decoder_layer(tgt, memory)
- """
- __constants__ = ["norm_first"]
- def __init__(
- self,
- d_model: int,
- nhead: int,
- dim_feedforward: int = 2048,
- dropout: float = 0.1,
- activation: str | Callable[[Tensor], Tensor] = F.relu,
- layer_norm_eps: float = 1e-5,
- batch_first: bool = False,
- norm_first: bool = False,
- bias: bool = True,
- device=None,
- dtype=None,
- ) -> None:
- factory_kwargs = {"device": device, "dtype": dtype}
- super().__init__()
- self.self_attn = MultiheadAttention(
- d_model,
- nhead,
- dropout=dropout,
- batch_first=batch_first,
- bias=bias,
- **factory_kwargs,
- )
- self.multihead_attn = MultiheadAttention(
- d_model,
- nhead,
- dropout=dropout,
- batch_first=batch_first,
- bias=bias,
- **factory_kwargs,
- )
- # Implementation of Feedforward model
- self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs)
- self.dropout = Dropout(dropout)
- self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
- self.norm_first = norm_first
- # pyrefly: ignore [bad-argument-type]
- self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
- # pyrefly: ignore [bad-argument-type]
- self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
- # pyrefly: ignore [bad-argument-type]
- self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
- self.dropout1 = Dropout(dropout)
- self.dropout2 = Dropout(dropout)
- self.dropout3 = Dropout(dropout)
- # Legacy string support for activation function.
- if isinstance(activation, str):
- self.activation = _get_activation_fn(activation)
- else:
- self.activation = activation
- def __setstate__(self, state):
- if "activation" not in state:
- state["activation"] = F.relu
- super().__setstate__(state)
- def forward(
- self,
- tgt: Tensor,
- memory: Tensor,
- tgt_mask: Tensor | None = None,
- memory_mask: Tensor | None = None,
- tgt_key_padding_mask: Tensor | None = None,
- memory_key_padding_mask: Tensor | None = None,
- tgt_is_causal: bool = False,
- memory_is_causal: bool = False,
- ) -> Tensor:
- r"""Pass the inputs (and mask) through the decoder layer.
- Args:
- tgt: the sequence to the decoder layer (required).
- memory: the sequence from the last layer of the encoder (required).
- tgt_mask: the mask for the tgt sequence (optional).
- memory_mask: the mask for the memory sequence (optional).
- tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
- memory_key_padding_mask: the mask for the memory keys per batch (optional).
- tgt_is_causal: If specified, applies a causal mask as ``tgt mask``.
- Default: ``False``.
- Warning:
- ``tgt_is_causal`` provides a hint that ``tgt_mask`` is
- the causal mask. Providing incorrect hints can result in
- incorrect execution, including forward and backward
- compatibility.
- memory_is_causal: If specified, applies a causal mask as
- ``memory mask``.
- Default: ``False``.
- Warning:
- ``memory_is_causal`` provides a hint that
- ``memory_mask`` is the causal mask. Providing incorrect
- hints can result in incorrect execution, including
- forward and backward compatibility.
- Shape:
- see the docs in :class:`~torch.nn.Transformer`.
- """
- # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
- x = tgt
- if self.norm_first:
- x = x + self._sa_block(
- self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal
- )
- x = x + self._mha_block(
- self.norm2(x),
- memory,
- memory_mask,
- memory_key_padding_mask,
- memory_is_causal,
- )
- x = x + self._ff_block(self.norm3(x))
- else:
- x = self.norm1(
- x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal)
- )
- x = self.norm2(
- x
- + self._mha_block(
- x, memory, memory_mask, memory_key_padding_mask, memory_is_causal
- )
- )
- x = self.norm3(x + self._ff_block(x))
- return x
- # self-attention block
- def _sa_block(
- self,
- x: Tensor,
- attn_mask: Tensor | None,
- key_padding_mask: Tensor | None,
- is_causal: bool = False,
- ) -> Tensor:
- x = self.self_attn(
- x,
- x,
- x,
- attn_mask=attn_mask,
- key_padding_mask=key_padding_mask,
- is_causal=is_causal,
- need_weights=False,
- )[0]
- return self.dropout1(x)
- # multihead attention block
- def _mha_block(
- self,
- x: Tensor,
- mem: Tensor,
- attn_mask: Tensor | None,
- key_padding_mask: Tensor | None,
- is_causal: bool = False,
- ) -> Tensor:
- x = self.multihead_attn(
- x,
- mem,
- mem,
- attn_mask=attn_mask,
- key_padding_mask=key_padding_mask,
- is_causal=is_causal,
- need_weights=False,
- )[0]
- return self.dropout2(x)
- # feed forward block
- def _ff_block(self, x: Tensor) -> Tensor:
- x = self.linear2(self.dropout(self.activation(self.linear1(x))))
- return self.dropout3(x)
- def _get_clones(module, N):
- # FIXME: copy.deepcopy() is not defined on nn.module
- return ModuleList([copy.deepcopy(module) for i in range(N)])
- def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
- if activation == "relu":
- return F.relu
- elif activation == "gelu":
- return F.gelu
- raise RuntimeError(f"activation should be relu/gelu, not {activation}")
- def _detect_is_causal_mask(
- mask: Tensor | None,
- is_causal: bool | None = None,
- size: int | None = None,
- ) -> bool:
- """Return whether the given attention mask is causal.
- Warning:
- If ``is_causal`` is not ``None``, its value will be returned as is. If a
- user supplies an incorrect ``is_causal`` hint,
- ``is_causal=False`` when the mask is in fact a causal attention.mask
- may lead to reduced performance relative to what would be achievable
- with ``is_causal=True``;
- ``is_causal=True`` when the mask is in fact not a causal attention.mask
- may lead to incorrect and unpredictable execution - in some scenarios,
- a causal mask may be applied based on the hint, in other execution
- scenarios the specified mask may be used. The choice may not appear
- to be deterministic, in that a number of factors like alignment,
- hardware SKU, etc influence the decision whether to use a mask or
- rely on the hint.
- ``size`` if not None, check whether the mask is a causal mask of the provided size
- Otherwise, checks for any causal mask.
- """
- # Prevent type refinement
- make_causal = is_causal is True
- if is_causal is None and mask is not None:
- sz = size if size is not None else mask.size(-2)
- causal_comparison = _generate_square_subsequent_mask(
- sz, device=mask.device, dtype=mask.dtype
- )
- # Do not use `torch.equal` so we handle batched masks by
- # broadcasting the comparison.
- if mask.size() == causal_comparison.size():
- make_causal = bool((mask == causal_comparison).all())
- else:
- make_causal = False
- return make_causal
|