| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149 |
- # -------------------------------------------------------------------------
- # Copyright (c) Microsoft Corporation. All rights reserved.
- # Licensed under the MIT License. See License.txt in the project root for
- # license information.
- # --------------------------------------------------------------------------
- import logging
- import torch
- logger = logging.getLogger(__name__)
- class PastKeyValuesHelper:
- """Helper functions to process past key values for encoder-decoder model"""
- @staticmethod
- def get_past_names(num_layers, present: bool = False):
- past_self_names = []
- past_cross_names = []
- for i in range(num_layers):
- past_self_names.extend(
- [f"present_key_self_{i}", f"present_value_self_{i}"]
- if present
- else [f"past_key_self_{i}", f"past_value_self_{i}"]
- )
- past_cross_names.extend(
- [f"present_key_cross_{i}", f"present_value_cross_{i}"]
- if present
- else [f"past_key_cross_{i}", f"past_value_cross_{i}"]
- )
- return past_self_names + past_cross_names
- @staticmethod
- def group_by_self_or_cross(present_key_values):
- """Split present state from grouped by layer to grouped by self/cross attention.
- Before: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0), (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1), ...
- After: (past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ...), (past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...)
- """
- present_self = []
- present_cross = []
- for _i, present_layer_i in enumerate(present_key_values):
- assert len(present_layer_i) == 4, f"Expected to have four items. Got {len(present_layer_i)}"
- (
- present_key_self,
- present_value_self,
- present_key_cross,
- present_value_cross,
- ) = present_layer_i
- present_self.extend([present_key_self, present_value_self])
- present_cross.extend([present_key_cross, present_value_cross])
- return present_self, present_cross
- @staticmethod
- def group_by_layer(past, num_layers):
- """Reorder past state from grouped by self/cross attention to grouped by layer.
- Before: past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ..., past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...
- After: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0), (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1),
- """
- assert len(past) == 4 * num_layers
- return tuple(
- [
- past[2 * i],
- past[2 * i + 1],
- past[2 * num_layers + 2 * i],
- past[2 * num_layers + 2 * i + 1],
- ]
- for i in range(num_layers)
- )
- @staticmethod
- def back_group_by_layer(past_key_values: tuple[tuple[torch.Tensor]]):
- """Categorize present_key_values from self and cross attention to layer by layer.
- Reorder past state from grouped by self/cross attention to grouped by layer.
- Before: past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ...,
- past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...
- After: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0),
- (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1),
- Args:
- present_key_values: From past_key_values of a model (group by self and cross attention)
- Returns:
- past_tuples: present key and values grouped by layer.
- """
- past_tuples = ()
- half_idx = len(past_key_values) // 2
- for i in range(len(past_key_values) // 4):
- idx = 2 * i
- past_tuples += (
- (
- past_key_values[idx],
- past_key_values[idx + 1],
- past_key_values[half_idx + idx],
- past_key_values[half_idx + idx + 1],
- ),
- )
- return past_tuples
- @staticmethod
- def group_by_self_and_cross(present_key_values: tuple[torch.Tensor], concat: bool = False):
- """Categorize present_key_values into self and cross attention.
- Split present state from grouped by layer to grouped by self/cross attention.
- Before: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0),
- (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1), ...
- After: (past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ...),
- (past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...)
- Args:
- present_key_values: From past_key_values of a model (group by layer)
- concat: If concat self attention with cross attention key/value to return
- Returns:
- present_self (Tuple[torch.Tensor]): present key and values from self attention
- present_cross (Tuple[torch.Tensor]): present key and values from cross attention
- """
- present_self: list[torch.Tensor] = []
- present_cross: list[torch.Tensor] = []
- for _, present_layer_i in enumerate(present_key_values):
- assert len(present_layer_i) == 4, f"Expected to have four items. Got {len(present_layer_i)}"
- present_key_self, present_value_self, present_key_cross, present_value_cross = present_layer_i
- present_self.extend([present_key_self, present_value_self])
- present_cross.extend([present_key_cross, present_value_cross])
- if concat:
- return present_self + present_cross
- else:
- return present_self, present_cross
- @staticmethod
- def get_input_names(past_key_values: tuple[tuple[torch.Tensor]], encoder=True):
- """Process input names of model wrapper.
- Args:
- past_key_values: Consider `self` and `cross` past_key_values
- Returns:
- names (List[string]): input names
- """
- names = []
- num_layers = len(past_key_values) // 4 if encoder else len(past_key_values)
- prefix = "past_" if not encoder else "present_"
- for i in range(num_layers):
- names.extend([prefix + s for s in [f"key_self_{i}", f"value_self_{i}"]])
- for i in range(num_layers):
- names.extend([prefix + s for s in [f"key_cross_{i}", f"value_cross_{i}"]])
- return names
|