visionnet.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. from typing import Dict, List
  2. import gymnasium as gym
  3. import numpy as np
  4. from ray.rllib.models.torch.misc import (
  5. SlimConv2d,
  6. SlimFC,
  7. normc_initializer,
  8. same_padding,
  9. )
  10. from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
  11. from ray.rllib.models.utils import get_activation_fn, get_filter_config
  12. from ray.rllib.utils.annotations import OldAPIStack, override
  13. from ray.rllib.utils.framework import try_import_torch
  14. from ray.rllib.utils.typing import ModelConfigDict, TensorType
  15. torch, nn = try_import_torch()
  16. @OldAPIStack
  17. class VisionNetwork(TorchModelV2, nn.Module):
  18. """Generic vision network."""
  19. def __init__(
  20. self,
  21. obs_space: gym.spaces.Space,
  22. action_space: gym.spaces.Space,
  23. num_outputs: int,
  24. model_config: ModelConfigDict,
  25. name: str,
  26. ):
  27. if not model_config.get("conv_filters"):
  28. model_config["conv_filters"] = get_filter_config(obs_space.shape)
  29. TorchModelV2.__init__(
  30. self, obs_space, action_space, num_outputs, model_config, name
  31. )
  32. nn.Module.__init__(self)
  33. activation = self.model_config.get("conv_activation")
  34. filters = self.model_config["conv_filters"]
  35. assert len(filters) > 0, "Must provide at least 1 entry in `conv_filters`!"
  36. # Post FC net config.
  37. post_fcnet_hiddens = model_config.get("post_fcnet_hiddens", [])
  38. post_fcnet_activation = get_activation_fn(
  39. model_config.get("post_fcnet_activation"), framework="torch"
  40. )
  41. no_final_linear = self.model_config.get("no_final_linear")
  42. vf_share_layers = self.model_config.get("vf_share_layers")
  43. # Whether the last layer is the output of a Flattened (rather than
  44. # a n x (1,1) Conv2D).
  45. self.last_layer_is_flattened = False
  46. self._logits = None
  47. layers = []
  48. (w, h, in_channels) = obs_space.shape
  49. in_size = [w, h]
  50. for out_channels, kernel, stride in filters[:-1]:
  51. padding, out_size = same_padding(in_size, kernel, stride)
  52. layers.append(
  53. SlimConv2d(
  54. in_channels,
  55. out_channels,
  56. kernel,
  57. stride,
  58. padding,
  59. activation_fn=activation,
  60. )
  61. )
  62. in_channels = out_channels
  63. in_size = out_size
  64. out_channels, kernel, stride = filters[-1]
  65. # No final linear: Last layer has activation function and exits with
  66. # num_outputs nodes (this could be a 1x1 conv or a FC layer, depending
  67. # on `post_fcnet_...` settings).
  68. if no_final_linear and num_outputs:
  69. out_channels = out_channels if post_fcnet_hiddens else num_outputs
  70. layers.append(
  71. SlimConv2d(
  72. in_channels,
  73. out_channels,
  74. kernel,
  75. stride,
  76. None, # padding=valid
  77. activation_fn=activation,
  78. )
  79. )
  80. # Add (optional) post-fc-stack after last Conv2D layer.
  81. layer_sizes = post_fcnet_hiddens[:-1] + (
  82. [num_outputs] if post_fcnet_hiddens else []
  83. )
  84. for i, out_size in enumerate(layer_sizes):
  85. layers.append(
  86. SlimFC(
  87. in_size=out_channels,
  88. out_size=out_size,
  89. activation_fn=post_fcnet_activation,
  90. initializer=normc_initializer(1.0),
  91. )
  92. )
  93. out_channels = out_size
  94. # Finish network normally (w/o overriding last layer size with
  95. # `num_outputs`), then add another linear one of size `num_outputs`.
  96. else:
  97. layers.append(
  98. SlimConv2d(
  99. in_channels,
  100. out_channels,
  101. kernel,
  102. stride,
  103. None, # padding=valid
  104. activation_fn=activation,
  105. )
  106. )
  107. # num_outputs defined. Use that to create an exact
  108. # `num_output`-sized (1,1)-Conv2D.
  109. if num_outputs:
  110. in_size = [
  111. np.ceil((in_size[0] - kernel[0]) / stride),
  112. np.ceil((in_size[1] - kernel[1]) / stride),
  113. ]
  114. padding, _ = same_padding(in_size, [1, 1], [1, 1])
  115. if post_fcnet_hiddens:
  116. layers.append(nn.Flatten())
  117. in_size = out_channels
  118. # Add (optional) post-fc-stack after last Conv2D layer.
  119. for i, out_size in enumerate(post_fcnet_hiddens + [num_outputs]):
  120. layers.append(
  121. SlimFC(
  122. in_size=in_size,
  123. out_size=out_size,
  124. activation_fn=post_fcnet_activation
  125. if i < len(post_fcnet_hiddens) - 1
  126. else None,
  127. initializer=normc_initializer(1.0),
  128. )
  129. )
  130. in_size = out_size
  131. # Last layer is logits layer.
  132. self._logits = layers.pop()
  133. else:
  134. self._logits = SlimConv2d(
  135. out_channels,
  136. num_outputs,
  137. [1, 1],
  138. 1,
  139. padding,
  140. activation_fn=None,
  141. )
  142. # num_outputs not known -> Flatten, then set self.num_outputs
  143. # to the resulting number of nodes.
  144. else:
  145. self.last_layer_is_flattened = True
  146. layers.append(nn.Flatten())
  147. self._convs = nn.Sequential(*layers)
  148. # If our num_outputs still unknown, we need to do a test pass to
  149. # figure out the output dimensions. This could be the case, if we have
  150. # the Flatten layer at the end.
  151. if self.num_outputs is None:
  152. # Create a B=1 dummy sample and push it through out conv-net.
  153. dummy_in = (
  154. torch.from_numpy(self.obs_space.sample())
  155. .permute(2, 0, 1)
  156. .unsqueeze(0)
  157. .float()
  158. )
  159. dummy_out = self._convs(dummy_in)
  160. self.num_outputs = dummy_out.shape[1]
  161. # Build the value layers
  162. self._value_branch_separate = self._value_branch = None
  163. if vf_share_layers:
  164. self._value_branch = SlimFC(
  165. out_channels, 1, initializer=normc_initializer(0.01), activation_fn=None
  166. )
  167. else:
  168. vf_layers = []
  169. (w, h, in_channels) = obs_space.shape
  170. in_size = [w, h]
  171. for out_channels, kernel, stride in filters[:-1]:
  172. padding, out_size = same_padding(in_size, kernel, stride)
  173. vf_layers.append(
  174. SlimConv2d(
  175. in_channels,
  176. out_channels,
  177. kernel,
  178. stride,
  179. padding,
  180. activation_fn=activation,
  181. )
  182. )
  183. in_channels = out_channels
  184. in_size = out_size
  185. out_channels, kernel, stride = filters[-1]
  186. vf_layers.append(
  187. SlimConv2d(
  188. in_channels,
  189. out_channels,
  190. kernel,
  191. stride,
  192. None,
  193. activation_fn=activation,
  194. )
  195. )
  196. vf_layers.append(
  197. SlimConv2d(
  198. in_channels=out_channels,
  199. out_channels=1,
  200. kernel=1,
  201. stride=1,
  202. padding=None,
  203. activation_fn=None,
  204. )
  205. )
  206. self._value_branch_separate = nn.Sequential(*vf_layers)
  207. # Holds the current "base" output (before logits layer).
  208. self._features = None
  209. @override(TorchModelV2)
  210. def forward(
  211. self,
  212. input_dict: Dict[str, TensorType],
  213. state: List[TensorType],
  214. seq_lens: TensorType,
  215. ) -> (TensorType, List[TensorType]):
  216. self._features = input_dict["obs"].float()
  217. # Permuate b/c data comes in as [B, dim, dim, channels]:
  218. self._features = self._features.permute(0, 3, 1, 2)
  219. conv_out = self._convs(self._features)
  220. # Store features to save forward pass when getting value_function out.
  221. if not self._value_branch_separate:
  222. self._features = conv_out
  223. if not self.last_layer_is_flattened:
  224. if self._logits:
  225. conv_out = self._logits(conv_out)
  226. if len(conv_out.shape) == 4:
  227. if conv_out.shape[2] != 1 or conv_out.shape[3] != 1:
  228. raise ValueError(
  229. "Given `conv_filters` ({}) do not result in a [B, {} "
  230. "(`num_outputs`), 1, 1] shape (but in {})! Please "
  231. "adjust your Conv2D stack such that the last 2 dims "
  232. "are both 1.".format(
  233. self.model_config["conv_filters"],
  234. self.num_outputs,
  235. list(conv_out.shape),
  236. )
  237. )
  238. logits = conv_out.squeeze(3)
  239. logits = logits.squeeze(2)
  240. else:
  241. logits = conv_out
  242. return logits, state
  243. else:
  244. return conv_out, state
  245. @override(TorchModelV2)
  246. def value_function(self) -> TensorType:
  247. assert self._features is not None, "must call forward() first"
  248. if self._value_branch_separate:
  249. value = self._value_branch_separate(self._features)
  250. value = value.squeeze(3)
  251. value = value.squeeze(2)
  252. return value.squeeze(1)
  253. else:
  254. if not self.last_layer_is_flattened:
  255. features = self._features.squeeze(3)
  256. features = features.squeeze(2)
  257. else:
  258. features = self._features
  259. return self._value_branch(features).squeeze(1)
  260. def _hidden_layers(self, obs: TensorType) -> TensorType:
  261. res = self._convs(obs.permute(0, 3, 1, 2)) # switch to channel-major
  262. res = res.squeeze(3)
  263. res = res.squeeze(2)
  264. return res