| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091 |
- # Copyright (c) 2024, Tri Dao, Albert Gu.
- from typing import Optional
- import torch
- from torch import nn, Tensor
- from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn
- class Block(nn.Module):
- def __init__(
- self, dim, mixer_cls, mlp_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
- ):
- """
- Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
- This Block has a slightly different structure compared to a regular
- prenorm Transformer block.
- The standard block is: LN -> MHA/MLP -> Add.
- [Ref: https://arxiv.org/abs/2002.04745]
- Here we have: Add -> LN -> Mixer, returning both
- the hidden_states (output of the mixer) and the residual.
- This is purely for performance reasons, as we can fuse add and LayerNorm.
- The residual needs to be provided (except for the very first block).
- """
- super().__init__()
- self.residual_in_fp32 = residual_in_fp32
- self.fused_add_norm = fused_add_norm
- self.norm = norm_cls(dim)
- self.mixer = mixer_cls(dim)
- if mlp_cls is not nn.Identity:
- self.norm2 = norm_cls(dim)
- self.mlp = mlp_cls(dim)
- else:
- self.mlp = None
- if self.fused_add_norm:
- assert RMSNorm is not None, "RMSNorm import fails"
- assert isinstance(
- self.norm, (nn.LayerNorm, RMSNorm)
- ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
- def forward(
- self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None, **mixer_kwargs
- ):
- r"""Pass the input through the encoder layer.
- Args:
- hidden_states: the sequence to the encoder layer (required).
- residual: hidden_states = Mixer(LN(residual))
- """
- if not self.fused_add_norm:
- residual = (hidden_states + residual) if residual is not None else hidden_states
- hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
- if self.residual_in_fp32:
- residual = residual.to(torch.float32)
- else:
- hidden_states, residual = layer_norm_fn(
- hidden_states,
- self.norm.weight,
- self.norm.bias,
- residual=residual,
- prenorm=True,
- residual_in_fp32=self.residual_in_fp32,
- eps=self.norm.eps,
- is_rms_norm=isinstance(self.norm, RMSNorm)
- )
- hidden_states = self.mixer(hidden_states, inference_params=inference_params, **mixer_kwargs)
- if self.mlp is not None:
- if not self.fused_add_norm:
- residual = hidden_states + residual
- hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
- if self.residual_in_fp32:
- residual = residual.to(torch.float32)
- else:
- hidden_states, residual = layer_norm_fn(
- hidden_states,
- self.norm2.weight,
- self.norm2.bias,
- residual=residual,
- prenorm=True,
- residual_in_fp32=self.residual_in_fp32,
- eps=self.norm2.eps,
- is_rms_norm=isinstance(self.norm2, RMSNorm)
- )
- hidden_states = self.mlp(hidden_states)
- return hidden_states, residual
- def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
- return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|