past_helper.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. # -------------------------------------------------------------------------
  2. # Copyright (c) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License. See License.txt in the project root for
  4. # license information.
  5. # --------------------------------------------------------------------------
  6. import logging
  7. import torch
  8. logger = logging.getLogger(__name__)
  9. class PastKeyValuesHelper:
  10. """Helper functions to process past key values for encoder-decoder model"""
  11. @staticmethod
  12. def get_past_names(num_layers, present: bool = False):
  13. past_self_names = []
  14. past_cross_names = []
  15. for i in range(num_layers):
  16. past_self_names.extend(
  17. [f"present_key_self_{i}", f"present_value_self_{i}"]
  18. if present
  19. else [f"past_key_self_{i}", f"past_value_self_{i}"]
  20. )
  21. past_cross_names.extend(
  22. [f"present_key_cross_{i}", f"present_value_cross_{i}"]
  23. if present
  24. else [f"past_key_cross_{i}", f"past_value_cross_{i}"]
  25. )
  26. return past_self_names + past_cross_names
  27. @staticmethod
  28. def group_by_self_or_cross(present_key_values):
  29. """Split present state from grouped by layer to grouped by self/cross attention.
  30. Before: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0), (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1), ...
  31. After: (past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ...), (past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...)
  32. """
  33. present_self = []
  34. present_cross = []
  35. for _i, present_layer_i in enumerate(present_key_values):
  36. assert len(present_layer_i) == 4, f"Expected to have four items. Got {len(present_layer_i)}"
  37. (
  38. present_key_self,
  39. present_value_self,
  40. present_key_cross,
  41. present_value_cross,
  42. ) = present_layer_i
  43. present_self.extend([present_key_self, present_value_self])
  44. present_cross.extend([present_key_cross, present_value_cross])
  45. return present_self, present_cross
  46. @staticmethod
  47. def group_by_layer(past, num_layers):
  48. """Reorder past state from grouped by self/cross attention to grouped by layer.
  49. Before: past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ..., past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...
  50. After: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0), (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1),
  51. """
  52. assert len(past) == 4 * num_layers
  53. return tuple(
  54. [
  55. past[2 * i],
  56. past[2 * i + 1],
  57. past[2 * num_layers + 2 * i],
  58. past[2 * num_layers + 2 * i + 1],
  59. ]
  60. for i in range(num_layers)
  61. )
  62. @staticmethod
  63. def back_group_by_layer(past_key_values: tuple[tuple[torch.Tensor]]):
  64. """Categorize present_key_values from self and cross attention to layer by layer.
  65. Reorder past state from grouped by self/cross attention to grouped by layer.
  66. Before: past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ...,
  67. past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...
  68. After: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0),
  69. (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1),
  70. Args:
  71. present_key_values: From past_key_values of a model (group by self and cross attention)
  72. Returns:
  73. past_tuples: present key and values grouped by layer.
  74. """
  75. past_tuples = ()
  76. half_idx = len(past_key_values) // 2
  77. for i in range(len(past_key_values) // 4):
  78. idx = 2 * i
  79. past_tuples += (
  80. (
  81. past_key_values[idx],
  82. past_key_values[idx + 1],
  83. past_key_values[half_idx + idx],
  84. past_key_values[half_idx + idx + 1],
  85. ),
  86. )
  87. return past_tuples
  88. @staticmethod
  89. def group_by_self_and_cross(present_key_values: tuple[torch.Tensor], concat: bool = False):
  90. """Categorize present_key_values into self and cross attention.
  91. Split present state from grouped by layer to grouped by self/cross attention.
  92. Before: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0),
  93. (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1), ...
  94. After: (past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ...),
  95. (past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...)
  96. Args:
  97. present_key_values: From past_key_values of a model (group by layer)
  98. concat: If concat self attention with cross attention key/value to return
  99. Returns:
  100. present_self (Tuple[torch.Tensor]): present key and values from self attention
  101. present_cross (Tuple[torch.Tensor]): present key and values from cross attention
  102. """
  103. present_self: list[torch.Tensor] = []
  104. present_cross: list[torch.Tensor] = []
  105. for _, present_layer_i in enumerate(present_key_values):
  106. assert len(present_layer_i) == 4, f"Expected to have four items. Got {len(present_layer_i)}"
  107. present_key_self, present_value_self, present_key_cross, present_value_cross = present_layer_i
  108. present_self.extend([present_key_self, present_value_self])
  109. present_cross.extend([present_key_cross, present_value_cross])
  110. if concat:
  111. return present_self + present_cross
  112. else:
  113. return present_self, present_cross
  114. @staticmethod
  115. def get_input_names(past_key_values: tuple[tuple[torch.Tensor]], encoder=True):
  116. """Process input names of model wrapper.
  117. Args:
  118. past_key_values: Consider `self` and `cross` past_key_values
  119. Returns:
  120. names (List[string]): input names
  121. """
  122. names = []
  123. num_layers = len(past_key_values) // 4 if encoder else len(past_key_values)
  124. prefix = "past_" if not encoder else "present_"
  125. for i in range(num_layers):
  126. names.extend([prefix + s for s in [f"key_self_{i}", f"value_self_{i}"]])
  127. for i in range(num_layers):
  128. names.extend([prefix + s for s in [f"key_cross_{i}", f"value_cross_{i}"]])
  129. return names