| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574 |
- """
- [1] - Attention Is All You Need - Vaswani, Jones, Shazeer, Parmar,
- Uszkoreit, Gomez, Kaiser - Google Brain/Research, U Toronto - 2017.
- https://arxiv.org/pdf/1706.03762.pdf
- [2] - Stabilizing Transformers for Reinforcement Learning - E. Parisotto
- et al. - DeepMind - 2019. https://arxiv.org/pdf/1910.06764.pdf
- [3] - Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context.
- Z. Dai, Z. Yang, et al. - Carnegie Mellon U - 2019.
- https://www.aclweb.org/anthology/P19-1285.pdf
- """
- from typing import Any, Dict, Optional, Union
- import gymnasium as gym
- import numpy as np
- import tree # pip install dm_tree
- from gymnasium.spaces import Box, Discrete, MultiDiscrete
- from ray._common.deprecation import deprecation_warning
- from ray.rllib.models.modelv2 import ModelV2
- from ray.rllib.models.tf.layers import (
- GRUGate,
- RelativeMultiHeadAttention,
- SkipConnection,
- )
- from ray.rllib.models.tf.recurrent_net import RecurrentNetwork
- from ray.rllib.models.tf.tf_modelv2 import TFModelV2
- from ray.rllib.policy.sample_batch import SampleBatch
- from ray.rllib.policy.view_requirement import ViewRequirement
- from ray.rllib.utils.annotations import OldAPIStack, override
- from ray.rllib.utils.framework import try_import_tf
- from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
- from ray.rllib.utils.tf_utils import flatten_inputs_to_1d_tensor, one_hot
- from ray.rllib.utils.typing import List, ModelConfigDict, TensorType
- from ray.util import log_once
- tf1, tf, tfv = try_import_tf()
- @OldAPIStack
- class PositionwiseFeedforward(tf.keras.layers.Layer if tf else object):
- """A 2x linear layer with ReLU activation in between described in [1].
- Each timestep coming from the attention head will be passed through this
- layer separately.
- """
- def __init__(
- self,
- out_dim: int,
- hidden_dim: int,
- output_activation: Optional[Any] = None,
- **kwargs,
- ):
- super().__init__(**kwargs)
- self._hidden_layer = tf.keras.layers.Dense(
- hidden_dim,
- activation=tf.nn.relu,
- )
- self._output_layer = tf.keras.layers.Dense(
- out_dim, activation=output_activation
- )
- if log_once("positionwise_feedforward_tf"):
- deprecation_warning(
- old="rllib.models.tf.attention_net.PositionwiseFeedforward",
- )
- def call(self, inputs: TensorType, **kwargs) -> TensorType:
- del kwargs
- output = self._hidden_layer(inputs)
- return self._output_layer(output)
- @OldAPIStack
- class TrXLNet(RecurrentNetwork):
- """A TrXL net Model described in [1]."""
- def __init__(
- self,
- observation_space: gym.spaces.Space,
- action_space: gym.spaces.Space,
- num_outputs: int,
- model_config: ModelConfigDict,
- name: str,
- num_transformer_units: int,
- attention_dim: int,
- num_heads: int,
- head_dim: int,
- position_wise_mlp_dim: int,
- ):
- """Initializes a TrXLNet object.
- Args:
- num_transformer_units: The number of Transformer repeats to
- use (denoted L in [2]).
- attention_dim: The input and output dimensions of one
- Transformer unit.
- num_heads: The number of attention heads to use in parallel.
- Denoted as `H` in [3].
- head_dim: The dimension of a single(!) attention head within
- a multi-head attention unit. Denoted as `d` in [3].
- position_wise_mlp_dim: The dimension of the hidden layer
- within the position-wise MLP (after the multi-head attention
- block within one Transformer unit). This is the size of the
- first of the two layers within the PositionwiseFeedforward. The
- second layer always has size=`attention_dim`.
- """
- if log_once("trxl_net_tf"):
- deprecation_warning(
- old="rllib.models.tf.attention_net.TrXLNet",
- )
- super().__init__(
- observation_space, action_space, num_outputs, model_config, name
- )
- self.num_transformer_units = num_transformer_units
- self.attention_dim = attention_dim
- self.num_heads = num_heads
- self.head_dim = head_dim
- self.max_seq_len = model_config["max_seq_len"]
- self.obs_dim = observation_space.shape[0]
- inputs = tf.keras.layers.Input(
- shape=(self.max_seq_len, self.obs_dim), name="inputs"
- )
- E_out = tf.keras.layers.Dense(attention_dim)(inputs)
- for _ in range(self.num_transformer_units):
- MHA_out = SkipConnection(
- RelativeMultiHeadAttention(
- out_dim=attention_dim,
- num_heads=num_heads,
- head_dim=head_dim,
- input_layernorm=False,
- output_activation=None,
- ),
- fan_in_layer=None,
- )(E_out)
- E_out = SkipConnection(
- PositionwiseFeedforward(attention_dim, position_wise_mlp_dim)
- )(MHA_out)
- E_out = tf.keras.layers.LayerNormalization(axis=-1)(E_out)
- # Postprocess TrXL output with another hidden layer and compute values.
- logits = tf.keras.layers.Dense(
- self.num_outputs, activation=tf.keras.activations.linear, name="logits"
- )(E_out)
- self.base_model = tf.keras.models.Model([inputs], [logits])
- @override(RecurrentNetwork)
- def forward_rnn(
- self, inputs: TensorType, state: List[TensorType], seq_lens: TensorType
- ) -> (TensorType, List[TensorType]):
- # To make Attention work with current RLlib's ModelV2 API:
- # We assume `state` is the history of L recent observations (all
- # concatenated into one tensor) and append the current inputs to the
- # end and only keep the most recent (up to `max_seq_len`). This allows
- # us to deal with timestep-wise inference and full sequence training
- # within the same logic.
- observations = state[0]
- observations = tf.concat((observations, inputs), axis=1)[:, -self.max_seq_len :]
- logits = self.base_model([observations])
- T = tf.shape(inputs)[1] # Length of input segment (time).
- logits = logits[:, -T:]
- return logits, [observations]
- @override(RecurrentNetwork)
- def get_initial_state(self) -> List[np.ndarray]:
- # State is the T last observations concat'd together into one Tensor.
- # Plus all Transformer blocks' E(l) outputs concat'd together (up to
- # tau timesteps).
- return [np.zeros((self.max_seq_len, self.obs_dim), np.float32)]
- class GTrXLNet(RecurrentNetwork):
- """A GTrXL net Model described in [2].
- This is still in an experimental phase.
- Can be used as a drop-in replacement for LSTMs in PPO and IMPALA.
- To use this network as a replacement for an RNN, configure your Algorithm
- as follows:
- Examples:
- >> config["model"]["custom_model"] = GTrXLNet
- >> config["model"]["max_seq_len"] = 10
- >> config["model"]["custom_model_config"] = {
- >> num_transformer_units=1,
- >> attention_dim=32,
- >> num_heads=2,
- >> memory_inference=100,
- >> memory_training=50,
- >> etc..
- >> }
- """
- def __init__(
- self,
- observation_space: gym.spaces.Space,
- action_space: gym.spaces.Space,
- num_outputs: Optional[int],
- model_config: ModelConfigDict,
- name: str,
- *,
- num_transformer_units: int = 1,
- attention_dim: int = 64,
- num_heads: int = 2,
- memory_inference: int = 50,
- memory_training: int = 50,
- head_dim: int = 32,
- position_wise_mlp_dim: int = 32,
- init_gru_gate_bias: float = 2.0,
- ):
- """Initializes a GTrXLNet instance.
- Args:
- num_transformer_units: The number of Transformer repeats to
- use (denoted L in [2]).
- attention_dim: The input and output dimensions of one
- Transformer unit.
- num_heads: The number of attention heads to use in parallel.
- Denoted as `H` in [3].
- memory_inference: The number of timesteps to concat (time
- axis) and feed into the next transformer unit as inference
- input. The first transformer unit will receive this number of
- past observations (plus the current one), instead.
- memory_training: The number of timesteps to concat (time
- axis) and feed into the next transformer unit as training
- input (plus the actual input sequence of len=max_seq_len).
- The first transformer unit will receive this number of
- past observations (plus the input sequence), instead.
- head_dim: The dimension of a single(!) attention head within
- a multi-head attention unit. Denoted as `d` in [3].
- position_wise_mlp_dim: The dimension of the hidden layer
- within the position-wise MLP (after the multi-head attention
- block within one Transformer unit). This is the size of the
- first of the two layers within the PositionwiseFeedforward. The
- second layer always has size=`attention_dim`.
- init_gru_gate_bias: Initial bias values for the GRU gates
- (two GRUs per Transformer unit, one after the MHA, one after
- the position-wise MLP).
- """
- super().__init__(
- observation_space, action_space, num_outputs, model_config, name
- )
- self.num_transformer_units = num_transformer_units
- self.attention_dim = attention_dim
- self.num_heads = num_heads
- self.memory_inference = memory_inference
- self.memory_training = memory_training
- self.head_dim = head_dim
- self.max_seq_len = model_config["max_seq_len"]
- self.obs_dim = observation_space.shape[0]
- # Raw observation input (plus (None) time axis).
- input_layer = tf.keras.layers.Input(shape=(None, self.obs_dim), name="inputs")
- memory_ins = [
- tf.keras.layers.Input(
- shape=(None, self.attention_dim),
- dtype=tf.float32,
- name="memory_in_{}".format(i),
- )
- for i in range(self.num_transformer_units)
- ]
- # Map observation dim to input/output transformer (attention) dim.
- E_out = tf.keras.layers.Dense(self.attention_dim)(input_layer)
- # Output, collected and concat'd to build the internal, tau-len
- # Memory units used for additional contextual information.
- memory_outs = [E_out]
- # 2) Create L Transformer blocks according to [2].
- for i in range(self.num_transformer_units):
- # RelativeMultiHeadAttention part.
- MHA_out = SkipConnection(
- RelativeMultiHeadAttention(
- out_dim=self.attention_dim,
- num_heads=num_heads,
- head_dim=head_dim,
- input_layernorm=True,
- output_activation=tf.nn.relu,
- ),
- fan_in_layer=GRUGate(init_gru_gate_bias),
- name="mha_{}".format(i + 1),
- )(E_out, memory=memory_ins[i])
- # Position-wise MLP part.
- E_out = SkipConnection(
- tf.keras.Sequential(
- (
- tf.keras.layers.LayerNormalization(axis=-1),
- PositionwiseFeedforward(
- out_dim=self.attention_dim,
- hidden_dim=position_wise_mlp_dim,
- output_activation=tf.nn.relu,
- ),
- )
- ),
- fan_in_layer=GRUGate(init_gru_gate_bias),
- name="pos_wise_mlp_{}".format(i + 1),
- )(MHA_out)
- # Output of position-wise MLP == E(l-1), which is concat'd
- # to the current Mem block (M(l-1)) to yield E~(l-1), which is then
- # used by the next transformer block.
- memory_outs.append(E_out)
- self._logits = None
- self._value_out = None
- # Postprocess TrXL output with another hidden layer and compute values.
- if num_outputs is not None:
- self._logits = tf.keras.layers.Dense(
- self.num_outputs, activation=None, name="logits"
- )(E_out)
- values_out = tf.keras.layers.Dense(1, activation=None, name="values")(E_out)
- outs = [self._logits, values_out]
- else:
- outs = [E_out]
- self.num_outputs = self.attention_dim
- self.trxl_model = tf.keras.Model(
- inputs=[input_layer] + memory_ins, outputs=outs + memory_outs[:-1]
- )
- self.trxl_model.summary()
- # __sphinx_doc_begin__
- # Setup trajectory views (`memory-inference` x past memory outs).
- for i in range(self.num_transformer_units):
- space = Box(-1.0, 1.0, shape=(self.attention_dim,))
- self.view_requirements["state_in_{}".format(i)] = ViewRequirement(
- "state_out_{}".format(i),
- shift="-{}:-1".format(self.memory_inference),
- # Repeat the incoming state every max-seq-len times.
- batch_repeat_value=self.max_seq_len,
- space=space,
- )
- self.view_requirements["state_out_{}".format(i)] = ViewRequirement(
- space=space, used_for_training=False
- )
- # __sphinx_doc_end__
- @override(ModelV2)
- def forward(
- self, input_dict, state: List[TensorType], seq_lens: TensorType
- ) -> (TensorType, List[TensorType]):
- assert seq_lens is not None
- # Add the time dim to observations.
- B = tf.shape(seq_lens)[0]
- observations = input_dict[SampleBatch.OBS]
- shape = tf.shape(observations)
- T = shape[0] // B
- observations = tf.reshape(observations, tf.concat([[-1, T], shape[1:]], axis=0))
- all_out = self.trxl_model([observations] + state)
- if self._logits is not None:
- out = tf.reshape(all_out[0], [-1, self.num_outputs])
- self._value_out = all_out[1]
- memory_outs = all_out[2:]
- else:
- out = tf.reshape(all_out[0], [-1, self.attention_dim])
- memory_outs = all_out[1:]
- return out, [tf.reshape(m, [-1, self.attention_dim]) for m in memory_outs]
- @override(RecurrentNetwork)
- def get_initial_state(self) -> List[np.ndarray]:
- return [
- tf.zeros(self.view_requirements["state_in_{}".format(i)].space.shape)
- for i in range(self.num_transformer_units)
- ]
- @override(ModelV2)
- def value_function(self) -> TensorType:
- return tf.reshape(self._value_out, [-1])
- class AttentionWrapper(TFModelV2):
- """GTrXL wrapper serving as interface for ModelV2s that set use_attention."""
- def __init__(
- self,
- obs_space: gym.spaces.Space,
- action_space: gym.spaces.Space,
- num_outputs: int,
- model_config: ModelConfigDict,
- name: str,
- ):
- if log_once("attention_wrapper_tf_deprecation"):
- deprecation_warning(
- old="ray.rllib.models.tf.attention_net.AttentionWrapper"
- )
- super().__init__(obs_space, action_space, None, model_config, name)
- self.use_n_prev_actions = model_config["attention_use_n_prev_actions"]
- self.use_n_prev_rewards = model_config["attention_use_n_prev_rewards"]
- self.action_space_struct = get_base_struct_from_space(self.action_space)
- self.action_dim = 0
- for space in tree.flatten(self.action_space_struct):
- if isinstance(space, Discrete):
- self.action_dim += space.n
- elif isinstance(space, MultiDiscrete):
- self.action_dim += np.sum(space.nvec)
- elif space.shape is not None:
- self.action_dim += int(np.prod(space.shape))
- else:
- self.action_dim += int(len(space))
- # Add prev-action/reward nodes to input to LSTM.
- if self.use_n_prev_actions:
- self.num_outputs += self.use_n_prev_actions * self.action_dim
- if self.use_n_prev_rewards:
- self.num_outputs += self.use_n_prev_rewards
- cfg = model_config
- self.attention_dim = cfg["attention_dim"]
- if self.num_outputs is not None:
- in_space = gym.spaces.Box(
- float("-inf"), float("inf"), shape=(self.num_outputs,), dtype=np.float32
- )
- else:
- in_space = obs_space
- # Construct GTrXL sub-module w/ num_outputs=None (so it does not
- # create a logits/value output; we'll do this ourselves in this wrapper
- # here).
- self.gtrxl = GTrXLNet(
- in_space,
- action_space,
- None,
- model_config,
- "gtrxl",
- num_transformer_units=cfg["attention_num_transformer_units"],
- attention_dim=self.attention_dim,
- num_heads=cfg["attention_num_heads"],
- head_dim=cfg["attention_head_dim"],
- memory_inference=cfg["attention_memory_inference"],
- memory_training=cfg["attention_memory_training"],
- position_wise_mlp_dim=cfg["attention_position_wise_mlp_dim"],
- init_gru_gate_bias=cfg["attention_init_gru_gate_bias"],
- )
- # `self.num_outputs` right now is the number of nodes coming from the
- # attention net.
- input_ = tf.keras.layers.Input(shape=(self.gtrxl.num_outputs,))
- # Set final num_outputs to correct value (depending on action space).
- self.num_outputs = num_outputs
- # Postprocess GTrXL output with another hidden layer and compute
- # values.
- out = tf.keras.layers.Dense(self.num_outputs, activation=None)(input_)
- self._logits_branch = tf.keras.models.Model([input_], [out])
- out = tf.keras.layers.Dense(1, activation=None)(input_)
- self._value_branch = tf.keras.models.Model([input_], [out])
- self.view_requirements = self.gtrxl.view_requirements
- self.view_requirements["obs"].space = self.obs_space
- # Add prev-a/r to this model's view, if required.
- if self.use_n_prev_actions:
- self.view_requirements[SampleBatch.PREV_ACTIONS] = ViewRequirement(
- SampleBatch.ACTIONS,
- space=self.action_space,
- shift="-{}:-1".format(self.use_n_prev_actions),
- )
- if self.use_n_prev_rewards:
- self.view_requirements[SampleBatch.PREV_REWARDS] = ViewRequirement(
- SampleBatch.REWARDS, shift="-{}:-1".format(self.use_n_prev_rewards)
- )
- @override(RecurrentNetwork)
- def forward(
- self,
- input_dict: Dict[str, TensorType],
- state: List[TensorType],
- seq_lens: TensorType,
- ) -> (TensorType, List[TensorType]):
- assert seq_lens is not None
- # Push obs through "unwrapped" net's `forward()` first.
- wrapped_out, _ = self._wrapped_forward(input_dict, [], None)
- # Concat. prev-action/reward if required.
- prev_a_r = []
- # Prev actions.
- if self.use_n_prev_actions:
- prev_n_actions = input_dict[SampleBatch.PREV_ACTIONS]
- # If actions are not processed yet (in their original form as
- # have been sent to environment):
- # Flatten/one-hot into 1D array.
- if self.model_config["_disable_action_flattening"]:
- # Merge prev n actions into flat tensor.
- flat = flatten_inputs_to_1d_tensor(
- prev_n_actions,
- spaces_struct=self.action_space_struct,
- time_axis=True,
- )
- # Fold time-axis into flattened data.
- flat = tf.reshape(flat, [tf.shape(flat)[0], -1])
- prev_a_r.append(flat)
- # If actions are already flattened (but not one-hot'd yet!),
- # one-hot discrete/multi-discrete actions here and concatenate the
- # n most recent actions together.
- else:
- if isinstance(self.action_space, Discrete):
- for i in range(self.use_n_prev_actions):
- prev_a_r.append(
- one_hot(prev_n_actions[:, i], self.action_space)
- )
- elif isinstance(self.action_space, MultiDiscrete):
- for i in range(
- 0, self.use_n_prev_actions, self.action_space.shape[0]
- ):
- prev_a_r.append(
- one_hot(
- tf.cast(
- prev_n_actions[
- :, i : i + self.action_space.shape[0]
- ],
- tf.float32,
- ),
- space=self.action_space,
- )
- )
- else:
- prev_a_r.append(
- tf.reshape(
- tf.cast(prev_n_actions, tf.float32),
- [-1, self.use_n_prev_actions * self.action_dim],
- )
- )
- # Prev rewards.
- if self.use_n_prev_rewards:
- prev_a_r.append(
- tf.reshape(
- tf.cast(input_dict[SampleBatch.PREV_REWARDS], tf.float32),
- [-1, self.use_n_prev_rewards],
- )
- )
- # Concat prev. actions + rewards to the "main" input.
- if prev_a_r:
- wrapped_out = tf.concat([wrapped_out] + prev_a_r, axis=1)
- # Then through our GTrXL.
- input_dict["obs_flat"] = input_dict["obs"] = wrapped_out
- self._features, memory_outs = self.gtrxl(input_dict, state, seq_lens)
- model_out = self._logits_branch(self._features)
- return model_out, memory_outs
- @override(ModelV2)
- def value_function(self) -> TensorType:
- assert self._features is not None, "Must call forward() first!"
- return tf.reshape(self._value_branch(self._features), [-1])
- @override(ModelV2)
- def get_initial_state(self) -> Union[List[np.ndarray], List[TensorType]]:
- return [
- np.zeros(self.gtrxl.view_requirements["state_in_{}".format(i)].space.shape)
- for i in range(self.gtrxl.num_transformer_units)
- ]
|