fcnet.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. import logging
  2. import gymnasium as gym
  3. import numpy as np
  4. from ray.rllib.models.torch.misc import AppendBiasLayer, SlimFC, normc_initializer
  5. from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
  6. from ray.rllib.utils.annotations import OldAPIStack, override
  7. from ray.rllib.utils.framework import try_import_torch
  8. from ray.rllib.utils.typing import Dict, List, ModelConfigDict, TensorType
  9. torch, nn = try_import_torch()
  10. logger = logging.getLogger(__name__)
  11. @OldAPIStack
  12. class FullyConnectedNetwork(TorchModelV2, nn.Module):
  13. """Generic fully connected network."""
  14. def __init__(
  15. self,
  16. obs_space: gym.spaces.Space,
  17. action_space: gym.spaces.Space,
  18. num_outputs: int,
  19. model_config: ModelConfigDict,
  20. name: str,
  21. ):
  22. TorchModelV2.__init__(
  23. self, obs_space, action_space, num_outputs, model_config, name
  24. )
  25. nn.Module.__init__(self)
  26. hiddens = list(model_config.get("fcnet_hiddens", [])) + list(
  27. model_config.get("post_fcnet_hiddens", [])
  28. )
  29. activation = model_config.get("fcnet_activation")
  30. if not model_config.get("fcnet_hiddens", []):
  31. activation = model_config.get("post_fcnet_activation")
  32. no_final_linear = model_config.get("no_final_linear")
  33. self.vf_share_layers = model_config.get("vf_share_layers")
  34. self.free_log_std = model_config.get("free_log_std")
  35. # Generate free-floating bias variables for the second half of
  36. # the outputs.
  37. if self.free_log_std:
  38. assert num_outputs % 2 == 0, (
  39. "num_outputs must be divisible by two",
  40. num_outputs,
  41. )
  42. num_outputs = num_outputs // 2
  43. layers = []
  44. prev_layer_size = int(np.prod(obs_space.shape))
  45. self._logits = None
  46. # Create layers 0 to second-last.
  47. for size in hiddens[:-1]:
  48. layers.append(
  49. SlimFC(
  50. in_size=prev_layer_size,
  51. out_size=size,
  52. initializer=normc_initializer(1.0),
  53. activation_fn=activation,
  54. )
  55. )
  56. prev_layer_size = size
  57. # The last layer is adjusted to be of size num_outputs, but it's a
  58. # layer with activation.
  59. if no_final_linear and num_outputs:
  60. layers.append(
  61. SlimFC(
  62. in_size=prev_layer_size,
  63. out_size=num_outputs,
  64. initializer=normc_initializer(1.0),
  65. activation_fn=activation,
  66. )
  67. )
  68. prev_layer_size = num_outputs
  69. # Finish the layers with the provided sizes (`hiddens`), plus -
  70. # iff num_outputs > 0 - a last linear layer of size num_outputs.
  71. else:
  72. if len(hiddens) > 0:
  73. layers.append(
  74. SlimFC(
  75. in_size=prev_layer_size,
  76. out_size=hiddens[-1],
  77. initializer=normc_initializer(1.0),
  78. activation_fn=activation,
  79. )
  80. )
  81. prev_layer_size = hiddens[-1]
  82. if num_outputs:
  83. self._logits = SlimFC(
  84. in_size=prev_layer_size,
  85. out_size=num_outputs,
  86. initializer=normc_initializer(0.01),
  87. activation_fn=None,
  88. )
  89. else:
  90. self.num_outputs = ([int(np.prod(obs_space.shape))] + hiddens[-1:])[-1]
  91. # Layer to add the log std vars to the state-dependent means.
  92. if self.free_log_std and self._logits:
  93. self._append_free_log_std = AppendBiasLayer(num_outputs)
  94. self._hidden_layers = nn.Sequential(*layers)
  95. self._value_branch_separate = None
  96. if not self.vf_share_layers:
  97. # Build a parallel set of hidden layers for the value net.
  98. prev_vf_layer_size = int(np.prod(obs_space.shape))
  99. vf_layers = []
  100. for size in hiddens:
  101. vf_layers.append(
  102. SlimFC(
  103. in_size=prev_vf_layer_size,
  104. out_size=size,
  105. activation_fn=activation,
  106. initializer=normc_initializer(1.0),
  107. )
  108. )
  109. prev_vf_layer_size = size
  110. self._value_branch_separate = nn.Sequential(*vf_layers)
  111. self._value_branch = SlimFC(
  112. in_size=prev_layer_size,
  113. out_size=1,
  114. initializer=normc_initializer(0.01),
  115. activation_fn=None,
  116. )
  117. # Holds the current "base" output (before logits layer).
  118. self._features = None
  119. # Holds the last input, in case value branch is separate.
  120. self._last_flat_in = None
  121. @override(TorchModelV2)
  122. def forward(
  123. self,
  124. input_dict: Dict[str, TensorType],
  125. state: List[TensorType],
  126. seq_lens: TensorType,
  127. ) -> (TensorType, List[TensorType]):
  128. obs = input_dict["obs_flat"].float()
  129. self._last_flat_in = obs.reshape(obs.shape[0], -1)
  130. self._features = self._hidden_layers(self._last_flat_in)
  131. logits = self._logits(self._features) if self._logits else self._features
  132. if self.free_log_std:
  133. logits = self._append_free_log_std(logits)
  134. return logits, state
  135. @override(TorchModelV2)
  136. def value_function(self) -> TensorType:
  137. assert self._features is not None, "must call forward() first"
  138. if self._value_branch_separate:
  139. out = self._value_branch(
  140. self._value_branch_separate(self._last_flat_in)
  141. ).squeeze(1)
  142. else:
  143. out = self._value_branch(self._features).squeeze(1)
  144. return out