| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161 |
- import logging
- import gymnasium as gym
- import numpy as np
- from ray.rllib.models.torch.misc import AppendBiasLayer, SlimFC, normc_initializer
- from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
- from ray.rllib.utils.annotations import OldAPIStack, override
- from ray.rllib.utils.framework import try_import_torch
- from ray.rllib.utils.typing import Dict, List, ModelConfigDict, TensorType
- torch, nn = try_import_torch()
- logger = logging.getLogger(__name__)
- @OldAPIStack
- class FullyConnectedNetwork(TorchModelV2, nn.Module):
- """Generic fully connected network."""
- def __init__(
- self,
- obs_space: gym.spaces.Space,
- action_space: gym.spaces.Space,
- num_outputs: int,
- model_config: ModelConfigDict,
- name: str,
- ):
- TorchModelV2.__init__(
- self, obs_space, action_space, num_outputs, model_config, name
- )
- nn.Module.__init__(self)
- hiddens = list(model_config.get("fcnet_hiddens", [])) + list(
- model_config.get("post_fcnet_hiddens", [])
- )
- activation = model_config.get("fcnet_activation")
- if not model_config.get("fcnet_hiddens", []):
- activation = model_config.get("post_fcnet_activation")
- no_final_linear = model_config.get("no_final_linear")
- self.vf_share_layers = model_config.get("vf_share_layers")
- self.free_log_std = model_config.get("free_log_std")
- # Generate free-floating bias variables for the second half of
- # the outputs.
- if self.free_log_std:
- assert num_outputs % 2 == 0, (
- "num_outputs must be divisible by two",
- num_outputs,
- )
- num_outputs = num_outputs // 2
- layers = []
- prev_layer_size = int(np.prod(obs_space.shape))
- self._logits = None
- # Create layers 0 to second-last.
- for size in hiddens[:-1]:
- layers.append(
- SlimFC(
- in_size=prev_layer_size,
- out_size=size,
- initializer=normc_initializer(1.0),
- activation_fn=activation,
- )
- )
- prev_layer_size = size
- # The last layer is adjusted to be of size num_outputs, but it's a
- # layer with activation.
- if no_final_linear and num_outputs:
- layers.append(
- SlimFC(
- in_size=prev_layer_size,
- out_size=num_outputs,
- initializer=normc_initializer(1.0),
- activation_fn=activation,
- )
- )
- prev_layer_size = num_outputs
- # Finish the layers with the provided sizes (`hiddens`), plus -
- # iff num_outputs > 0 - a last linear layer of size num_outputs.
- else:
- if len(hiddens) > 0:
- layers.append(
- SlimFC(
- in_size=prev_layer_size,
- out_size=hiddens[-1],
- initializer=normc_initializer(1.0),
- activation_fn=activation,
- )
- )
- prev_layer_size = hiddens[-1]
- if num_outputs:
- self._logits = SlimFC(
- in_size=prev_layer_size,
- out_size=num_outputs,
- initializer=normc_initializer(0.01),
- activation_fn=None,
- )
- else:
- self.num_outputs = ([int(np.prod(obs_space.shape))] + hiddens[-1:])[-1]
- # Layer to add the log std vars to the state-dependent means.
- if self.free_log_std and self._logits:
- self._append_free_log_std = AppendBiasLayer(num_outputs)
- self._hidden_layers = nn.Sequential(*layers)
- self._value_branch_separate = None
- if not self.vf_share_layers:
- # Build a parallel set of hidden layers for the value net.
- prev_vf_layer_size = int(np.prod(obs_space.shape))
- vf_layers = []
- for size in hiddens:
- vf_layers.append(
- SlimFC(
- in_size=prev_vf_layer_size,
- out_size=size,
- activation_fn=activation,
- initializer=normc_initializer(1.0),
- )
- )
- prev_vf_layer_size = size
- self._value_branch_separate = nn.Sequential(*vf_layers)
- self._value_branch = SlimFC(
- in_size=prev_layer_size,
- out_size=1,
- initializer=normc_initializer(0.01),
- activation_fn=None,
- )
- # Holds the current "base" output (before logits layer).
- self._features = None
- # Holds the last input, in case value branch is separate.
- self._last_flat_in = None
- @override(TorchModelV2)
- def forward(
- self,
- input_dict: Dict[str, TensorType],
- state: List[TensorType],
- seq_lens: TensorType,
- ) -> (TensorType, List[TensorType]):
- obs = input_dict["obs_flat"].float()
- self._last_flat_in = obs.reshape(obs.shape[0], -1)
- self._features = self._hidden_layers(self._last_flat_in)
- logits = self._logits(self._features) if self._logits else self._features
- if self.free_log_std:
- logits = self._append_free_log_std(logits)
- return logits, state
- @override(TorchModelV2)
- def value_function(self) -> TensorType:
- assert self._features is not None, "must call forward() first"
- if self._value_branch_separate:
- out = self._value_branch(
- self._value_branch_separate(self._last_flat_in)
- ).squeeze(1)
- else:
- out = self._value_branch(self._features).squeeze(1)
- return out
|