fcnet.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. from typing import Dict
  2. import gymnasium as gym
  3. import numpy as np
  4. from ray.rllib.models.tf.misc import normc_initializer
  5. from ray.rllib.models.tf.tf_modelv2 import TFModelV2
  6. from ray.rllib.models.utils import get_activation_fn
  7. from ray.rllib.utils.annotations import OldAPIStack
  8. from ray.rllib.utils.framework import try_import_tf
  9. from ray.rllib.utils.typing import List, ModelConfigDict, TensorType
  10. tf1, tf, tfv = try_import_tf()
  11. @OldAPIStack
  12. class FullyConnectedNetwork(TFModelV2):
  13. """Generic fully connected network implemented in ModelV2 API."""
  14. def __init__(
  15. self,
  16. obs_space: gym.spaces.Space,
  17. action_space: gym.spaces.Space,
  18. num_outputs: int,
  19. model_config: ModelConfigDict,
  20. name: str,
  21. ):
  22. super(FullyConnectedNetwork, self).__init__(
  23. obs_space, action_space, num_outputs, model_config, name
  24. )
  25. hiddens = list(model_config.get("fcnet_hiddens", [])) + list(
  26. model_config.get("post_fcnet_hiddens", [])
  27. )
  28. activation = model_config.get("fcnet_activation")
  29. if not model_config.get("fcnet_hiddens", []):
  30. activation = model_config.get("post_fcnet_activation")
  31. activation = get_activation_fn(activation)
  32. no_final_linear = model_config.get("no_final_linear")
  33. vf_share_layers = model_config.get("vf_share_layers")
  34. free_log_std = model_config.get("free_log_std")
  35. # Generate free-floating bias variables for the second half of
  36. # the outputs.
  37. if free_log_std:
  38. assert num_outputs % 2 == 0, (
  39. "num_outputs must be divisible by two",
  40. num_outputs,
  41. )
  42. num_outputs = num_outputs // 2
  43. self.log_std_var = tf.Variable(
  44. [0.0] * num_outputs, dtype=tf.float32, name="log_std"
  45. )
  46. # We are using obs_flat, so take the flattened shape as input.
  47. inputs = tf.keras.layers.Input(
  48. shape=(int(np.prod(obs_space.shape)),), name="observations"
  49. )
  50. # Last hidden layer output (before logits outputs).
  51. last_layer = inputs
  52. # The action distribution outputs.
  53. logits_out = None
  54. i = 1
  55. # Create layers 0 to second-last.
  56. for size in hiddens[:-1]:
  57. last_layer = tf.keras.layers.Dense(
  58. size,
  59. name="fc_{}".format(i),
  60. activation=activation,
  61. kernel_initializer=normc_initializer(1.0),
  62. )(last_layer)
  63. i += 1
  64. # The last layer is adjusted to be of size num_outputs, but it's a
  65. # layer with activation.
  66. if no_final_linear and num_outputs:
  67. logits_out = tf.keras.layers.Dense(
  68. num_outputs,
  69. name="fc_out",
  70. activation=activation,
  71. kernel_initializer=normc_initializer(1.0),
  72. )(last_layer)
  73. # Finish the layers with the provided sizes (`hiddens`), plus -
  74. # iff num_outputs > 0 - a last linear layer of size num_outputs.
  75. else:
  76. if len(hiddens) > 0:
  77. last_layer = tf.keras.layers.Dense(
  78. hiddens[-1],
  79. name="fc_{}".format(i),
  80. activation=activation,
  81. kernel_initializer=normc_initializer(1.0),
  82. )(last_layer)
  83. if num_outputs:
  84. logits_out = tf.keras.layers.Dense(
  85. num_outputs,
  86. name="fc_out",
  87. activation=None,
  88. kernel_initializer=normc_initializer(0.01),
  89. )(last_layer)
  90. # Adjust num_outputs to be the number of nodes in the last layer.
  91. else:
  92. self.num_outputs = ([int(np.prod(obs_space.shape))] + hiddens[-1:])[-1]
  93. # Concat the log std vars to the end of the state-dependent means.
  94. if free_log_std and logits_out is not None:
  95. def tiled_log_std(x):
  96. return tf.tile(tf.expand_dims(self.log_std_var, 0), [tf.shape(x)[0], 1])
  97. log_std_out = tf.keras.layers.Lambda(tiled_log_std)(inputs)
  98. logits_out = tf.keras.layers.Concatenate(axis=1)([logits_out, log_std_out])
  99. last_vf_layer = None
  100. if not vf_share_layers:
  101. # Build a parallel set of hidden layers for the value net.
  102. last_vf_layer = inputs
  103. i = 1
  104. for size in hiddens:
  105. last_vf_layer = tf.keras.layers.Dense(
  106. size,
  107. name="fc_value_{}".format(i),
  108. activation=activation,
  109. kernel_initializer=normc_initializer(1.0),
  110. )(last_vf_layer)
  111. i += 1
  112. value_out = tf.keras.layers.Dense(
  113. 1,
  114. name="value_out",
  115. activation=None,
  116. kernel_initializer=normc_initializer(0.01),
  117. )(last_vf_layer if last_vf_layer is not None else last_layer)
  118. self.base_model = tf.keras.Model(
  119. inputs, [(logits_out if logits_out is not None else last_layer), value_out]
  120. )
  121. def forward(
  122. self,
  123. input_dict: Dict[str, TensorType],
  124. state: List[TensorType],
  125. seq_lens: TensorType,
  126. ) -> (TensorType, List[TensorType]):
  127. model_out, self._value_out = self.base_model(input_dict["obs_flat"])
  128. return model_out, state
  129. def value_function(self) -> TensorType:
  130. return tf.reshape(self._value_out, [-1])