attention_net.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574
  1. """
  2. [1] - Attention Is All You Need - Vaswani, Jones, Shazeer, Parmar,
  3. Uszkoreit, Gomez, Kaiser - Google Brain/Research, U Toronto - 2017.
  4. https://arxiv.org/pdf/1706.03762.pdf
  5. [2] - Stabilizing Transformers for Reinforcement Learning - E. Parisotto
  6. et al. - DeepMind - 2019. https://arxiv.org/pdf/1910.06764.pdf
  7. [3] - Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context.
  8. Z. Dai, Z. Yang, et al. - Carnegie Mellon U - 2019.
  9. https://www.aclweb.org/anthology/P19-1285.pdf
  10. """
  11. from typing import Any, Dict, Optional, Union
  12. import gymnasium as gym
  13. import numpy as np
  14. import tree # pip install dm_tree
  15. from gymnasium.spaces import Box, Discrete, MultiDiscrete
  16. from ray._common.deprecation import deprecation_warning
  17. from ray.rllib.models.modelv2 import ModelV2
  18. from ray.rllib.models.tf.layers import (
  19. GRUGate,
  20. RelativeMultiHeadAttention,
  21. SkipConnection,
  22. )
  23. from ray.rllib.models.tf.recurrent_net import RecurrentNetwork
  24. from ray.rllib.models.tf.tf_modelv2 import TFModelV2
  25. from ray.rllib.policy.sample_batch import SampleBatch
  26. from ray.rllib.policy.view_requirement import ViewRequirement
  27. from ray.rllib.utils.annotations import OldAPIStack, override
  28. from ray.rllib.utils.framework import try_import_tf
  29. from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
  30. from ray.rllib.utils.tf_utils import flatten_inputs_to_1d_tensor, one_hot
  31. from ray.rllib.utils.typing import List, ModelConfigDict, TensorType
  32. from ray.util import log_once
  33. tf1, tf, tfv = try_import_tf()
  34. @OldAPIStack
  35. class PositionwiseFeedforward(tf.keras.layers.Layer if tf else object):
  36. """A 2x linear layer with ReLU activation in between described in [1].
  37. Each timestep coming from the attention head will be passed through this
  38. layer separately.
  39. """
  40. def __init__(
  41. self,
  42. out_dim: int,
  43. hidden_dim: int,
  44. output_activation: Optional[Any] = None,
  45. **kwargs,
  46. ):
  47. super().__init__(**kwargs)
  48. self._hidden_layer = tf.keras.layers.Dense(
  49. hidden_dim,
  50. activation=tf.nn.relu,
  51. )
  52. self._output_layer = tf.keras.layers.Dense(
  53. out_dim, activation=output_activation
  54. )
  55. if log_once("positionwise_feedforward_tf"):
  56. deprecation_warning(
  57. old="rllib.models.tf.attention_net.PositionwiseFeedforward",
  58. )
  59. def call(self, inputs: TensorType, **kwargs) -> TensorType:
  60. del kwargs
  61. output = self._hidden_layer(inputs)
  62. return self._output_layer(output)
  63. @OldAPIStack
  64. class TrXLNet(RecurrentNetwork):
  65. """A TrXL net Model described in [1]."""
  66. def __init__(
  67. self,
  68. observation_space: gym.spaces.Space,
  69. action_space: gym.spaces.Space,
  70. num_outputs: int,
  71. model_config: ModelConfigDict,
  72. name: str,
  73. num_transformer_units: int,
  74. attention_dim: int,
  75. num_heads: int,
  76. head_dim: int,
  77. position_wise_mlp_dim: int,
  78. ):
  79. """Initializes a TrXLNet object.
  80. Args:
  81. num_transformer_units: The number of Transformer repeats to
  82. use (denoted L in [2]).
  83. attention_dim: The input and output dimensions of one
  84. Transformer unit.
  85. num_heads: The number of attention heads to use in parallel.
  86. Denoted as `H` in [3].
  87. head_dim: The dimension of a single(!) attention head within
  88. a multi-head attention unit. Denoted as `d` in [3].
  89. position_wise_mlp_dim: The dimension of the hidden layer
  90. within the position-wise MLP (after the multi-head attention
  91. block within one Transformer unit). This is the size of the
  92. first of the two layers within the PositionwiseFeedforward. The
  93. second layer always has size=`attention_dim`.
  94. """
  95. if log_once("trxl_net_tf"):
  96. deprecation_warning(
  97. old="rllib.models.tf.attention_net.TrXLNet",
  98. )
  99. super().__init__(
  100. observation_space, action_space, num_outputs, model_config, name
  101. )
  102. self.num_transformer_units = num_transformer_units
  103. self.attention_dim = attention_dim
  104. self.num_heads = num_heads
  105. self.head_dim = head_dim
  106. self.max_seq_len = model_config["max_seq_len"]
  107. self.obs_dim = observation_space.shape[0]
  108. inputs = tf.keras.layers.Input(
  109. shape=(self.max_seq_len, self.obs_dim), name="inputs"
  110. )
  111. E_out = tf.keras.layers.Dense(attention_dim)(inputs)
  112. for _ in range(self.num_transformer_units):
  113. MHA_out = SkipConnection(
  114. RelativeMultiHeadAttention(
  115. out_dim=attention_dim,
  116. num_heads=num_heads,
  117. head_dim=head_dim,
  118. input_layernorm=False,
  119. output_activation=None,
  120. ),
  121. fan_in_layer=None,
  122. )(E_out)
  123. E_out = SkipConnection(
  124. PositionwiseFeedforward(attention_dim, position_wise_mlp_dim)
  125. )(MHA_out)
  126. E_out = tf.keras.layers.LayerNormalization(axis=-1)(E_out)
  127. # Postprocess TrXL output with another hidden layer and compute values.
  128. logits = tf.keras.layers.Dense(
  129. self.num_outputs, activation=tf.keras.activations.linear, name="logits"
  130. )(E_out)
  131. self.base_model = tf.keras.models.Model([inputs], [logits])
  132. @override(RecurrentNetwork)
  133. def forward_rnn(
  134. self, inputs: TensorType, state: List[TensorType], seq_lens: TensorType
  135. ) -> (TensorType, List[TensorType]):
  136. # To make Attention work with current RLlib's ModelV2 API:
  137. # We assume `state` is the history of L recent observations (all
  138. # concatenated into one tensor) and append the current inputs to the
  139. # end and only keep the most recent (up to `max_seq_len`). This allows
  140. # us to deal with timestep-wise inference and full sequence training
  141. # within the same logic.
  142. observations = state[0]
  143. observations = tf.concat((observations, inputs), axis=1)[:, -self.max_seq_len :]
  144. logits = self.base_model([observations])
  145. T = tf.shape(inputs)[1] # Length of input segment (time).
  146. logits = logits[:, -T:]
  147. return logits, [observations]
  148. @override(RecurrentNetwork)
  149. def get_initial_state(self) -> List[np.ndarray]:
  150. # State is the T last observations concat'd together into one Tensor.
  151. # Plus all Transformer blocks' E(l) outputs concat'd together (up to
  152. # tau timesteps).
  153. return [np.zeros((self.max_seq_len, self.obs_dim), np.float32)]
  154. class GTrXLNet(RecurrentNetwork):
  155. """A GTrXL net Model described in [2].
  156. This is still in an experimental phase.
  157. Can be used as a drop-in replacement for LSTMs in PPO and IMPALA.
  158. To use this network as a replacement for an RNN, configure your Algorithm
  159. as follows:
  160. Examples:
  161. >> config["model"]["custom_model"] = GTrXLNet
  162. >> config["model"]["max_seq_len"] = 10
  163. >> config["model"]["custom_model_config"] = {
  164. >> num_transformer_units=1,
  165. >> attention_dim=32,
  166. >> num_heads=2,
  167. >> memory_inference=100,
  168. >> memory_training=50,
  169. >> etc..
  170. >> }
  171. """
  172. def __init__(
  173. self,
  174. observation_space: gym.spaces.Space,
  175. action_space: gym.spaces.Space,
  176. num_outputs: Optional[int],
  177. model_config: ModelConfigDict,
  178. name: str,
  179. *,
  180. num_transformer_units: int = 1,
  181. attention_dim: int = 64,
  182. num_heads: int = 2,
  183. memory_inference: int = 50,
  184. memory_training: int = 50,
  185. head_dim: int = 32,
  186. position_wise_mlp_dim: int = 32,
  187. init_gru_gate_bias: float = 2.0,
  188. ):
  189. """Initializes a GTrXLNet instance.
  190. Args:
  191. num_transformer_units: The number of Transformer repeats to
  192. use (denoted L in [2]).
  193. attention_dim: The input and output dimensions of one
  194. Transformer unit.
  195. num_heads: The number of attention heads to use in parallel.
  196. Denoted as `H` in [3].
  197. memory_inference: The number of timesteps to concat (time
  198. axis) and feed into the next transformer unit as inference
  199. input. The first transformer unit will receive this number of
  200. past observations (plus the current one), instead.
  201. memory_training: The number of timesteps to concat (time
  202. axis) and feed into the next transformer unit as training
  203. input (plus the actual input sequence of len=max_seq_len).
  204. The first transformer unit will receive this number of
  205. past observations (plus the input sequence), instead.
  206. head_dim: The dimension of a single(!) attention head within
  207. a multi-head attention unit. Denoted as `d` in [3].
  208. position_wise_mlp_dim: The dimension of the hidden layer
  209. within the position-wise MLP (after the multi-head attention
  210. block within one Transformer unit). This is the size of the
  211. first of the two layers within the PositionwiseFeedforward. The
  212. second layer always has size=`attention_dim`.
  213. init_gru_gate_bias: Initial bias values for the GRU gates
  214. (two GRUs per Transformer unit, one after the MHA, one after
  215. the position-wise MLP).
  216. """
  217. super().__init__(
  218. observation_space, action_space, num_outputs, model_config, name
  219. )
  220. self.num_transformer_units = num_transformer_units
  221. self.attention_dim = attention_dim
  222. self.num_heads = num_heads
  223. self.memory_inference = memory_inference
  224. self.memory_training = memory_training
  225. self.head_dim = head_dim
  226. self.max_seq_len = model_config["max_seq_len"]
  227. self.obs_dim = observation_space.shape[0]
  228. # Raw observation input (plus (None) time axis).
  229. input_layer = tf.keras.layers.Input(shape=(None, self.obs_dim), name="inputs")
  230. memory_ins = [
  231. tf.keras.layers.Input(
  232. shape=(None, self.attention_dim),
  233. dtype=tf.float32,
  234. name="memory_in_{}".format(i),
  235. )
  236. for i in range(self.num_transformer_units)
  237. ]
  238. # Map observation dim to input/output transformer (attention) dim.
  239. E_out = tf.keras.layers.Dense(self.attention_dim)(input_layer)
  240. # Output, collected and concat'd to build the internal, tau-len
  241. # Memory units used for additional contextual information.
  242. memory_outs = [E_out]
  243. # 2) Create L Transformer blocks according to [2].
  244. for i in range(self.num_transformer_units):
  245. # RelativeMultiHeadAttention part.
  246. MHA_out = SkipConnection(
  247. RelativeMultiHeadAttention(
  248. out_dim=self.attention_dim,
  249. num_heads=num_heads,
  250. head_dim=head_dim,
  251. input_layernorm=True,
  252. output_activation=tf.nn.relu,
  253. ),
  254. fan_in_layer=GRUGate(init_gru_gate_bias),
  255. name="mha_{}".format(i + 1),
  256. )(E_out, memory=memory_ins[i])
  257. # Position-wise MLP part.
  258. E_out = SkipConnection(
  259. tf.keras.Sequential(
  260. (
  261. tf.keras.layers.LayerNormalization(axis=-1),
  262. PositionwiseFeedforward(
  263. out_dim=self.attention_dim,
  264. hidden_dim=position_wise_mlp_dim,
  265. output_activation=tf.nn.relu,
  266. ),
  267. )
  268. ),
  269. fan_in_layer=GRUGate(init_gru_gate_bias),
  270. name="pos_wise_mlp_{}".format(i + 1),
  271. )(MHA_out)
  272. # Output of position-wise MLP == E(l-1), which is concat'd
  273. # to the current Mem block (M(l-1)) to yield E~(l-1), which is then
  274. # used by the next transformer block.
  275. memory_outs.append(E_out)
  276. self._logits = None
  277. self._value_out = None
  278. # Postprocess TrXL output with another hidden layer and compute values.
  279. if num_outputs is not None:
  280. self._logits = tf.keras.layers.Dense(
  281. self.num_outputs, activation=None, name="logits"
  282. )(E_out)
  283. values_out = tf.keras.layers.Dense(1, activation=None, name="values")(E_out)
  284. outs = [self._logits, values_out]
  285. else:
  286. outs = [E_out]
  287. self.num_outputs = self.attention_dim
  288. self.trxl_model = tf.keras.Model(
  289. inputs=[input_layer] + memory_ins, outputs=outs + memory_outs[:-1]
  290. )
  291. self.trxl_model.summary()
  292. # __sphinx_doc_begin__
  293. # Setup trajectory views (`memory-inference` x past memory outs).
  294. for i in range(self.num_transformer_units):
  295. space = Box(-1.0, 1.0, shape=(self.attention_dim,))
  296. self.view_requirements["state_in_{}".format(i)] = ViewRequirement(
  297. "state_out_{}".format(i),
  298. shift="-{}:-1".format(self.memory_inference),
  299. # Repeat the incoming state every max-seq-len times.
  300. batch_repeat_value=self.max_seq_len,
  301. space=space,
  302. )
  303. self.view_requirements["state_out_{}".format(i)] = ViewRequirement(
  304. space=space, used_for_training=False
  305. )
  306. # __sphinx_doc_end__
  307. @override(ModelV2)
  308. def forward(
  309. self, input_dict, state: List[TensorType], seq_lens: TensorType
  310. ) -> (TensorType, List[TensorType]):
  311. assert seq_lens is not None
  312. # Add the time dim to observations.
  313. B = tf.shape(seq_lens)[0]
  314. observations = input_dict[SampleBatch.OBS]
  315. shape = tf.shape(observations)
  316. T = shape[0] // B
  317. observations = tf.reshape(observations, tf.concat([[-1, T], shape[1:]], axis=0))
  318. all_out = self.trxl_model([observations] + state)
  319. if self._logits is not None:
  320. out = tf.reshape(all_out[0], [-1, self.num_outputs])
  321. self._value_out = all_out[1]
  322. memory_outs = all_out[2:]
  323. else:
  324. out = tf.reshape(all_out[0], [-1, self.attention_dim])
  325. memory_outs = all_out[1:]
  326. return out, [tf.reshape(m, [-1, self.attention_dim]) for m in memory_outs]
  327. @override(RecurrentNetwork)
  328. def get_initial_state(self) -> List[np.ndarray]:
  329. return [
  330. tf.zeros(self.view_requirements["state_in_{}".format(i)].space.shape)
  331. for i in range(self.num_transformer_units)
  332. ]
  333. @override(ModelV2)
  334. def value_function(self) -> TensorType:
  335. return tf.reshape(self._value_out, [-1])
  336. class AttentionWrapper(TFModelV2):
  337. """GTrXL wrapper serving as interface for ModelV2s that set use_attention."""
  338. def __init__(
  339. self,
  340. obs_space: gym.spaces.Space,
  341. action_space: gym.spaces.Space,
  342. num_outputs: int,
  343. model_config: ModelConfigDict,
  344. name: str,
  345. ):
  346. if log_once("attention_wrapper_tf_deprecation"):
  347. deprecation_warning(
  348. old="ray.rllib.models.tf.attention_net.AttentionWrapper"
  349. )
  350. super().__init__(obs_space, action_space, None, model_config, name)
  351. self.use_n_prev_actions = model_config["attention_use_n_prev_actions"]
  352. self.use_n_prev_rewards = model_config["attention_use_n_prev_rewards"]
  353. self.action_space_struct = get_base_struct_from_space(self.action_space)
  354. self.action_dim = 0
  355. for space in tree.flatten(self.action_space_struct):
  356. if isinstance(space, Discrete):
  357. self.action_dim += space.n
  358. elif isinstance(space, MultiDiscrete):
  359. self.action_dim += np.sum(space.nvec)
  360. elif space.shape is not None:
  361. self.action_dim += int(np.prod(space.shape))
  362. else:
  363. self.action_dim += int(len(space))
  364. # Add prev-action/reward nodes to input to LSTM.
  365. if self.use_n_prev_actions:
  366. self.num_outputs += self.use_n_prev_actions * self.action_dim
  367. if self.use_n_prev_rewards:
  368. self.num_outputs += self.use_n_prev_rewards
  369. cfg = model_config
  370. self.attention_dim = cfg["attention_dim"]
  371. if self.num_outputs is not None:
  372. in_space = gym.spaces.Box(
  373. float("-inf"), float("inf"), shape=(self.num_outputs,), dtype=np.float32
  374. )
  375. else:
  376. in_space = obs_space
  377. # Construct GTrXL sub-module w/ num_outputs=None (so it does not
  378. # create a logits/value output; we'll do this ourselves in this wrapper
  379. # here).
  380. self.gtrxl = GTrXLNet(
  381. in_space,
  382. action_space,
  383. None,
  384. model_config,
  385. "gtrxl",
  386. num_transformer_units=cfg["attention_num_transformer_units"],
  387. attention_dim=self.attention_dim,
  388. num_heads=cfg["attention_num_heads"],
  389. head_dim=cfg["attention_head_dim"],
  390. memory_inference=cfg["attention_memory_inference"],
  391. memory_training=cfg["attention_memory_training"],
  392. position_wise_mlp_dim=cfg["attention_position_wise_mlp_dim"],
  393. init_gru_gate_bias=cfg["attention_init_gru_gate_bias"],
  394. )
  395. # `self.num_outputs` right now is the number of nodes coming from the
  396. # attention net.
  397. input_ = tf.keras.layers.Input(shape=(self.gtrxl.num_outputs,))
  398. # Set final num_outputs to correct value (depending on action space).
  399. self.num_outputs = num_outputs
  400. # Postprocess GTrXL output with another hidden layer and compute
  401. # values.
  402. out = tf.keras.layers.Dense(self.num_outputs, activation=None)(input_)
  403. self._logits_branch = tf.keras.models.Model([input_], [out])
  404. out = tf.keras.layers.Dense(1, activation=None)(input_)
  405. self._value_branch = tf.keras.models.Model([input_], [out])
  406. self.view_requirements = self.gtrxl.view_requirements
  407. self.view_requirements["obs"].space = self.obs_space
  408. # Add prev-a/r to this model's view, if required.
  409. if self.use_n_prev_actions:
  410. self.view_requirements[SampleBatch.PREV_ACTIONS] = ViewRequirement(
  411. SampleBatch.ACTIONS,
  412. space=self.action_space,
  413. shift="-{}:-1".format(self.use_n_prev_actions),
  414. )
  415. if self.use_n_prev_rewards:
  416. self.view_requirements[SampleBatch.PREV_REWARDS] = ViewRequirement(
  417. SampleBatch.REWARDS, shift="-{}:-1".format(self.use_n_prev_rewards)
  418. )
  419. @override(RecurrentNetwork)
  420. def forward(
  421. self,
  422. input_dict: Dict[str, TensorType],
  423. state: List[TensorType],
  424. seq_lens: TensorType,
  425. ) -> (TensorType, List[TensorType]):
  426. assert seq_lens is not None
  427. # Push obs through "unwrapped" net's `forward()` first.
  428. wrapped_out, _ = self._wrapped_forward(input_dict, [], None)
  429. # Concat. prev-action/reward if required.
  430. prev_a_r = []
  431. # Prev actions.
  432. if self.use_n_prev_actions:
  433. prev_n_actions = input_dict[SampleBatch.PREV_ACTIONS]
  434. # If actions are not processed yet (in their original form as
  435. # have been sent to environment):
  436. # Flatten/one-hot into 1D array.
  437. if self.model_config["_disable_action_flattening"]:
  438. # Merge prev n actions into flat tensor.
  439. flat = flatten_inputs_to_1d_tensor(
  440. prev_n_actions,
  441. spaces_struct=self.action_space_struct,
  442. time_axis=True,
  443. )
  444. # Fold time-axis into flattened data.
  445. flat = tf.reshape(flat, [tf.shape(flat)[0], -1])
  446. prev_a_r.append(flat)
  447. # If actions are already flattened (but not one-hot'd yet!),
  448. # one-hot discrete/multi-discrete actions here and concatenate the
  449. # n most recent actions together.
  450. else:
  451. if isinstance(self.action_space, Discrete):
  452. for i in range(self.use_n_prev_actions):
  453. prev_a_r.append(
  454. one_hot(prev_n_actions[:, i], self.action_space)
  455. )
  456. elif isinstance(self.action_space, MultiDiscrete):
  457. for i in range(
  458. 0, self.use_n_prev_actions, self.action_space.shape[0]
  459. ):
  460. prev_a_r.append(
  461. one_hot(
  462. tf.cast(
  463. prev_n_actions[
  464. :, i : i + self.action_space.shape[0]
  465. ],
  466. tf.float32,
  467. ),
  468. space=self.action_space,
  469. )
  470. )
  471. else:
  472. prev_a_r.append(
  473. tf.reshape(
  474. tf.cast(prev_n_actions, tf.float32),
  475. [-1, self.use_n_prev_actions * self.action_dim],
  476. )
  477. )
  478. # Prev rewards.
  479. if self.use_n_prev_rewards:
  480. prev_a_r.append(
  481. tf.reshape(
  482. tf.cast(input_dict[SampleBatch.PREV_REWARDS], tf.float32),
  483. [-1, self.use_n_prev_rewards],
  484. )
  485. )
  486. # Concat prev. actions + rewards to the "main" input.
  487. if prev_a_r:
  488. wrapped_out = tf.concat([wrapped_out] + prev_a_r, axis=1)
  489. # Then through our GTrXL.
  490. input_dict["obs_flat"] = input_dict["obs"] = wrapped_out
  491. self._features, memory_outs = self.gtrxl(input_dict, state, seq_lens)
  492. model_out = self._logits_branch(self._features)
  493. return model_out, memory_outs
  494. @override(ModelV2)
  495. def value_function(self) -> TensorType:
  496. assert self._features is not None, "Must call forward() first!"
  497. return tf.reshape(self._value_branch(self._features), [-1])
  498. @override(ModelV2)
  499. def get_initial_state(self) -> Union[List[np.ndarray], List[TensorType]]:
  500. return [
  501. np.zeros(self.gtrxl.view_requirements["state_in_{}".format(i)].space.shape)
  502. for i in range(self.gtrxl.num_transformer_units)
  503. ]