block.py 3.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. # Copyright (c) 2024, Tri Dao, Albert Gu.
  2. from typing import Optional
  3. import torch
  4. from torch import nn, Tensor
  5. from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn
  6. class Block(nn.Module):
  7. def __init__(
  8. self, dim, mixer_cls, mlp_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
  9. ):
  10. """
  11. Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
  12. This Block has a slightly different structure compared to a regular
  13. prenorm Transformer block.
  14. The standard block is: LN -> MHA/MLP -> Add.
  15. [Ref: https://arxiv.org/abs/2002.04745]
  16. Here we have: Add -> LN -> Mixer, returning both
  17. the hidden_states (output of the mixer) and the residual.
  18. This is purely for performance reasons, as we can fuse add and LayerNorm.
  19. The residual needs to be provided (except for the very first block).
  20. """
  21. super().__init__()
  22. self.residual_in_fp32 = residual_in_fp32
  23. self.fused_add_norm = fused_add_norm
  24. self.norm = norm_cls(dim)
  25. self.mixer = mixer_cls(dim)
  26. if mlp_cls is not nn.Identity:
  27. self.norm2 = norm_cls(dim)
  28. self.mlp = mlp_cls(dim)
  29. else:
  30. self.mlp = None
  31. if self.fused_add_norm:
  32. assert RMSNorm is not None, "RMSNorm import fails"
  33. assert isinstance(
  34. self.norm, (nn.LayerNorm, RMSNorm)
  35. ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
  36. def forward(
  37. self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None, **mixer_kwargs
  38. ):
  39. r"""Pass the input through the encoder layer.
  40. Args:
  41. hidden_states: the sequence to the encoder layer (required).
  42. residual: hidden_states = Mixer(LN(residual))
  43. """
  44. if not self.fused_add_norm:
  45. residual = (hidden_states + residual) if residual is not None else hidden_states
  46. hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
  47. if self.residual_in_fp32:
  48. residual = residual.to(torch.float32)
  49. else:
  50. hidden_states, residual = layer_norm_fn(
  51. hidden_states,
  52. self.norm.weight,
  53. self.norm.bias,
  54. residual=residual,
  55. prenorm=True,
  56. residual_in_fp32=self.residual_in_fp32,
  57. eps=self.norm.eps,
  58. is_rms_norm=isinstance(self.norm, RMSNorm)
  59. )
  60. hidden_states = self.mixer(hidden_states, inference_params=inference_params, **mixer_kwargs)
  61. if self.mlp is not None:
  62. if not self.fused_add_norm:
  63. residual = hidden_states + residual
  64. hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
  65. if self.residual_in_fp32:
  66. residual = residual.to(torch.float32)
  67. else:
  68. hidden_states, residual = layer_norm_fn(
  69. hidden_states,
  70. self.norm2.weight,
  71. self.norm2.bias,
  72. residual=residual,
  73. prenorm=True,
  74. residual_in_fp32=self.residual_in_fp32,
  75. eps=self.norm2.eps,
  76. is_rms_norm=isinstance(self.norm2, RMSNorm)
  77. )
  78. hidden_states = self.mlp(hidden_states)
  79. return hidden_states, residual
  80. def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
  81. return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)