recurrent_net.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. import logging
  2. from typing import Dict, List, Tuple
  3. import gymnasium as gym
  4. import numpy as np
  5. import tree # pip install dm_tree
  6. from gymnasium.spaces import Discrete, MultiDiscrete
  7. from ray._common.deprecation import deprecation_warning
  8. from ray.rllib.models.modelv2 import ModelV2
  9. from ray.rllib.models.tf.tf_modelv2 import TFModelV2
  10. from ray.rllib.policy.rnn_sequencing import add_time_dimension
  11. from ray.rllib.policy.sample_batch import SampleBatch
  12. from ray.rllib.policy.view_requirement import ViewRequirement
  13. from ray.rllib.utils.annotations import OldAPIStack, override
  14. from ray.rllib.utils.framework import try_import_tf
  15. from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
  16. from ray.rllib.utils.tf_utils import flatten_inputs_to_1d_tensor, one_hot
  17. from ray.rllib.utils.typing import ModelConfigDict, TensorType
  18. from ray.util.debug import log_once
  19. tf1, tf, tfv = try_import_tf()
  20. logger = logging.getLogger(__name__)
  21. @OldAPIStack
  22. class RecurrentNetwork(TFModelV2):
  23. """Helper class to simplify implementing RNN models with TFModelV2.
  24. Instead of implementing forward(), you can implement forward_rnn() which
  25. takes batches with the time dimension added already.
  26. Here is an example implementation for a subclass
  27. ``MyRNNClass(RecurrentNetwork)``::
  28. def __init__(self, *args, **kwargs):
  29. super(MyModelClass, self).__init__(*args, **kwargs)
  30. cell_size = 256
  31. # Define input layers
  32. input_layer = tf.keras.layers.Input(
  33. shape=(None, obs_space.shape[0]))
  34. state_in_h = tf.keras.layers.Input(shape=(256, ))
  35. state_in_c = tf.keras.layers.Input(shape=(256, ))
  36. seq_in = tf.keras.layers.Input(shape=(), dtype=tf.int32)
  37. # Send to LSTM cell
  38. lstm_out, state_h, state_c = tf.keras.layers.LSTM(
  39. cell_size, return_sequences=True, return_state=True,
  40. name="lstm")(
  41. inputs=input_layer,
  42. mask=tf.sequence_mask(seq_in),
  43. initial_state=[state_in_h, state_in_c])
  44. output_layer = tf.keras.layers.Dense(...)(lstm_out)
  45. # Create the RNN model
  46. self.rnn_model = tf.keras.Model(
  47. inputs=[input_layer, seq_in, state_in_h, state_in_c],
  48. outputs=[output_layer, state_h, state_c])
  49. self.rnn_model.summary()
  50. """
  51. @override(ModelV2)
  52. def forward(
  53. self,
  54. input_dict: Dict[str, TensorType],
  55. state: List[TensorType],
  56. seq_lens: TensorType,
  57. ) -> Tuple[TensorType, List[TensorType]]:
  58. """Adds time dimension to batch before sending inputs to forward_rnn().
  59. You should implement forward_rnn() in your subclass."""
  60. # Creating a __init__ function that acts as a passthrough and adding the warning
  61. # there led to errors probably due to the multiple inheritance. We encountered
  62. # the same error if we add the Deprecated decorator. We therefore add the
  63. # deprecation warning here.
  64. if log_once("recurrent_network_tf"):
  65. deprecation_warning(
  66. old="ray.rllib.models.tf.recurrent_net.RecurrentNetwork"
  67. )
  68. assert seq_lens is not None
  69. flat_inputs = input_dict["obs_flat"]
  70. inputs = add_time_dimension(
  71. padded_inputs=flat_inputs, seq_lens=seq_lens, framework="tf"
  72. )
  73. output, new_state = self.forward_rnn(
  74. inputs,
  75. state,
  76. seq_lens,
  77. )
  78. return tf.reshape(output, [-1, self.num_outputs]), new_state
  79. def forward_rnn(
  80. self, inputs: TensorType, state: List[TensorType], seq_lens: TensorType
  81. ) -> Tuple[TensorType, List[TensorType]]:
  82. """Call the model with the given input tensors and state.
  83. Args:
  84. inputs: observation tensor with shape [B, T, obs_size].
  85. state: list of state tensors, each with shape [B, T, size].
  86. seq_lens: 1d tensor holding input sequence lengths.
  87. Returns:
  88. (outputs, new_state): The model output tensor of shape
  89. [B, T, num_outputs] and the list of new state tensors each with
  90. shape [B, size].
  91. Sample implementation for the ``MyRNNClass`` example::
  92. def forward_rnn(self, inputs, state, seq_lens):
  93. model_out, h, c = self.rnn_model([inputs, seq_lens] + state)
  94. return model_out, [h, c]
  95. """
  96. raise NotImplementedError("You must implement this for a RNN model")
  97. def get_initial_state(self) -> List[TensorType]:
  98. """Get the initial recurrent state values for the model.
  99. Returns:
  100. list of np.array objects, if any
  101. Sample implementation for the ``MyRNNClass`` example::
  102. def get_initial_state(self):
  103. return [
  104. np.zeros(self.cell_size, np.float32),
  105. np.zeros(self.cell_size, np.float32),
  106. ]
  107. """
  108. raise NotImplementedError("You must implement this for a RNN model")
  109. @OldAPIStack
  110. class LSTMWrapper(RecurrentNetwork):
  111. """An LSTM wrapper serving as an interface for ModelV2s that set use_lstm."""
  112. def __init__(
  113. self,
  114. obs_space: gym.spaces.Space,
  115. action_space: gym.spaces.Space,
  116. num_outputs: int,
  117. model_config: ModelConfigDict,
  118. name: str,
  119. ):
  120. super(LSTMWrapper, self).__init__(
  121. obs_space, action_space, None, model_config, name
  122. )
  123. # At this point, self.num_outputs is the number of nodes coming
  124. # from the wrapped (underlying) model. In other words, self.num_outputs
  125. # is the input size for the LSTM layer.
  126. # If None, set it to the observation space.
  127. if self.num_outputs is None:
  128. self.num_outputs = int(np.prod(self.obs_space.shape))
  129. self.cell_size = model_config["lstm_cell_size"]
  130. self.use_prev_action = model_config["lstm_use_prev_action"]
  131. self.use_prev_reward = model_config["lstm_use_prev_reward"]
  132. self.action_space_struct = get_base_struct_from_space(self.action_space)
  133. self.action_dim = 0
  134. for space in tree.flatten(self.action_space_struct):
  135. if isinstance(space, Discrete):
  136. self.action_dim += space.n
  137. elif isinstance(space, MultiDiscrete):
  138. self.action_dim += np.sum(space.nvec)
  139. elif space.shape is not None:
  140. self.action_dim += int(np.prod(space.shape))
  141. else:
  142. self.action_dim += int(len(space))
  143. # Add prev-action/reward nodes to input to LSTM.
  144. if self.use_prev_action:
  145. self.num_outputs += self.action_dim
  146. if self.use_prev_reward:
  147. self.num_outputs += 1
  148. # Define input layers.
  149. input_layer = tf.keras.layers.Input(
  150. shape=(None, self.num_outputs), name="inputs"
  151. )
  152. # Set self.num_outputs to the number of output nodes desired by the
  153. # caller of this constructor.
  154. self.num_outputs = num_outputs
  155. state_in_h = tf.keras.layers.Input(shape=(self.cell_size,), name="h")
  156. state_in_c = tf.keras.layers.Input(shape=(self.cell_size,), name="c")
  157. seq_in = tf.keras.layers.Input(shape=(), name="seq_in", dtype=tf.int32)
  158. # Preprocess observation with a hidden layer and send to LSTM cell
  159. lstm_out, state_h, state_c = tf.keras.layers.LSTM(
  160. self.cell_size, return_sequences=True, return_state=True, name="lstm"
  161. )(
  162. inputs=input_layer,
  163. mask=tf.sequence_mask(seq_in),
  164. initial_state=[state_in_h, state_in_c],
  165. )
  166. # Postprocess LSTM output with another hidden layer and compute values
  167. logits = tf.keras.layers.Dense(
  168. self.num_outputs, activation=tf.keras.activations.linear, name="logits"
  169. )(lstm_out)
  170. values = tf.keras.layers.Dense(1, activation=None, name="values")(lstm_out)
  171. # Create the RNN model
  172. self._rnn_model = tf.keras.Model(
  173. inputs=[input_layer, seq_in, state_in_h, state_in_c],
  174. outputs=[logits, values, state_h, state_c],
  175. )
  176. # Print out model summary in INFO logging mode.
  177. if logger.isEnabledFor(logging.INFO):
  178. self._rnn_model.summary()
  179. # Add prev-a/r to this model's view, if required.
  180. if model_config["lstm_use_prev_action"]:
  181. self.view_requirements[SampleBatch.PREV_ACTIONS] = ViewRequirement(
  182. SampleBatch.ACTIONS, space=self.action_space, shift=-1
  183. )
  184. if model_config["lstm_use_prev_reward"]:
  185. self.view_requirements[SampleBatch.PREV_REWARDS] = ViewRequirement(
  186. SampleBatch.REWARDS, shift=-1
  187. )
  188. @override(RecurrentNetwork)
  189. def forward(
  190. self,
  191. input_dict: Dict[str, TensorType],
  192. state: List[TensorType],
  193. seq_lens: TensorType,
  194. ) -> Tuple[TensorType, List[TensorType]]:
  195. assert seq_lens is not None
  196. # Push obs through "unwrapped" net's `forward()` first.
  197. wrapped_out, _ = self._wrapped_forward(input_dict, [], None)
  198. # Concat. prev-action/reward if required.
  199. prev_a_r = []
  200. # Prev actions.
  201. if self.model_config["lstm_use_prev_action"]:
  202. prev_a = input_dict[SampleBatch.PREV_ACTIONS]
  203. # If actions are not processed yet (in their original form as
  204. # have been sent to environment):
  205. # Flatten/one-hot into 1D array.
  206. if self.model_config["_disable_action_flattening"]:
  207. prev_a_r.append(
  208. flatten_inputs_to_1d_tensor(
  209. prev_a,
  210. spaces_struct=self.action_space_struct,
  211. time_axis=False,
  212. )
  213. )
  214. # If actions are already flattened (but not one-hot'd yet!),
  215. # one-hot discrete/multi-discrete actions here.
  216. else:
  217. if isinstance(self.action_space, (Discrete, MultiDiscrete)):
  218. prev_a = one_hot(prev_a, self.action_space)
  219. prev_a_r.append(
  220. tf.reshape(tf.cast(prev_a, tf.float32), [-1, self.action_dim])
  221. )
  222. # Prev rewards.
  223. if self.model_config["lstm_use_prev_reward"]:
  224. prev_a_r.append(
  225. tf.reshape(
  226. tf.cast(input_dict[SampleBatch.PREV_REWARDS], tf.float32), [-1, 1]
  227. )
  228. )
  229. # Concat prev. actions + rewards to the "main" input.
  230. if prev_a_r:
  231. wrapped_out = tf.concat([wrapped_out] + prev_a_r, axis=1)
  232. # Push everything through our LSTM.
  233. input_dict["obs_flat"] = wrapped_out
  234. return super().forward(input_dict, state, seq_lens)
  235. @override(RecurrentNetwork)
  236. def forward_rnn(
  237. self, inputs: TensorType, state: List[TensorType], seq_lens: TensorType
  238. ) -> Tuple[TensorType, List[TensorType]]:
  239. model_out, self._value_out, h, c = self._rnn_model([inputs, seq_lens] + state)
  240. return model_out, [h, c]
  241. @override(ModelV2)
  242. def get_initial_state(self) -> List[np.ndarray]:
  243. return [
  244. np.zeros(self.cell_size, np.float32),
  245. np.zeros(self.cell_size, np.float32),
  246. ]
  247. @override(ModelV2)
  248. def value_function(self) -> TensorType:
  249. return tf.reshape(self._value_out, [-1])