modelv2.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472
  1. import contextlib
  2. from collections import OrderedDict
  3. from typing import Any, Dict, List, Union
  4. import gymnasium as gym
  5. import numpy as np
  6. from gymnasium.spaces import Space
  7. from ray._common.deprecation import Deprecated
  8. from ray.rllib.models.preprocessors import RepeatedValuesPreprocessor, get_preprocessor
  9. from ray.rllib.models.repeated_values import RepeatedValues
  10. from ray.rllib.policy.sample_batch import SampleBatch
  11. from ray.rllib.policy.view_requirement import ViewRequirement
  12. from ray.rllib.utils import NullContextManager
  13. from ray.rllib.utils.annotations import OldAPIStack
  14. from ray.rllib.utils.framework import TensorType, try_import_tf, try_import_torch
  15. from ray.rllib.utils.spaces.repeated import Repeated
  16. from ray.rllib.utils.typing import ModelConfigDict, ModelInputDict, TensorStructType
  17. tf1, tf, tfv = try_import_tf()
  18. torch, _ = try_import_torch()
  19. @OldAPIStack
  20. class ModelV2:
  21. r"""Defines an abstract neural network model for use with RLlib.
  22. Custom models should extend either TFModelV2 or TorchModelV2 instead of
  23. this class directly.
  24. Data flow:
  25. obs -> forward() -> model_out
  26. \-> value_function() -> V(s)
  27. """
  28. def __init__(
  29. self,
  30. obs_space: Space,
  31. action_space: Space,
  32. num_outputs: int,
  33. model_config: ModelConfigDict,
  34. name: str,
  35. framework: str,
  36. ):
  37. """Initializes a ModelV2 instance.
  38. This method should create any variables used by the model.
  39. Args:
  40. obs_space: Observation space of the target gym
  41. env. This may have an `original_space` attribute that
  42. specifies how to unflatten the tensor into a ragged tensor.
  43. action_space: Action space of the target gym
  44. env.
  45. num_outputs: Number of output units of the model.
  46. model_config: Config for the model, documented
  47. in ModelCatalog.
  48. name: Name (scope) for the model.
  49. framework: Either "tf" or "torch".
  50. """
  51. self.obs_space: Space = obs_space
  52. self.action_space: Space = action_space
  53. self.num_outputs: int = num_outputs
  54. self.model_config: ModelConfigDict = model_config
  55. self.name: str = name or "default_model"
  56. self.framework: str = framework
  57. self._last_output = None
  58. self.time_major = self.model_config.get("_time_major")
  59. # Basic view requirement for all models: Use the observation as input.
  60. self.view_requirements = {
  61. SampleBatch.OBS: ViewRequirement(shift=0, space=self.obs_space),
  62. }
  63. def get_initial_state(self) -> List[TensorType]:
  64. """Get the initial recurrent state values for the model.
  65. Returns:
  66. List of np.array (for tf) or Tensor (for torch) objects containing the
  67. initial hidden state of an RNN, if applicable.
  68. .. testcode::
  69. :skipif: True
  70. import numpy as np
  71. from ray.rllib.models.modelv2 import ModelV2
  72. class MyModel(ModelV2):
  73. # ...
  74. def get_initial_state(self):
  75. return [
  76. np.zeros(self.cell_size, np.float32),
  77. np.zeros(self.cell_size, np.float32),
  78. ]
  79. """
  80. return []
  81. def forward(
  82. self,
  83. input_dict: Dict[str, TensorType],
  84. state: List[TensorType],
  85. seq_lens: TensorType,
  86. ) -> (TensorType, List[TensorType]):
  87. """Call the model with the given input tensors and state.
  88. Any complex observations (dicts, tuples, etc.) will be unpacked by
  89. __call__ before being passed to forward(). To access the flattened
  90. observation tensor, refer to input_dict["obs_flat"].
  91. This method can be called any number of times. In eager execution,
  92. each call to forward() will eagerly evaluate the model. In symbolic
  93. execution, each call to forward creates a computation graph that
  94. operates over the variables of this model (i.e., shares weights).
  95. Custom models should override this instead of __call__.
  96. Args:
  97. input_dict: dictionary of input tensors, including "obs",
  98. "obs_flat", "prev_action", "prev_reward", "is_training",
  99. "eps_id", "agent_id", "infos", and "t".
  100. state: list of state tensors with sizes matching those
  101. returned by get_initial_state + the batch dimension
  102. seq_lens: 1d tensor holding input sequence lengths
  103. Returns:
  104. A tuple consisting of the model output tensor of size
  105. [BATCH, num_outputs] and the list of new RNN state(s) if any.
  106. .. testcode::
  107. :skipif: True
  108. import numpy as np
  109. from ray.rllib.models.modelv2 import ModelV2
  110. class MyModel(ModelV2):
  111. # ...
  112. def forward(self, input_dict, state, seq_lens):
  113. model_out, self._value_out = self.base_model(
  114. input_dict["obs"])
  115. return model_out, state
  116. """
  117. raise NotImplementedError
  118. def value_function(self) -> TensorType:
  119. """Returns the value function output for the most recent forward pass.
  120. Note that a `forward` call has to be performed first, before this
  121. methods can return anything and thus that calling this method does not
  122. cause an extra forward pass through the network.
  123. Returns:
  124. Value estimate tensor of shape [BATCH].
  125. """
  126. raise NotImplementedError
  127. def custom_loss(
  128. self, policy_loss: TensorType, loss_inputs: Dict[str, TensorType]
  129. ) -> Union[List[TensorType], TensorType]:
  130. """Override to customize the loss function used to optimize this model.
  131. This can be used to incorporate self-supervised losses (by defining
  132. a loss over existing input and output tensors of this model), and
  133. supervised losses (by defining losses over a variable-sharing copy of
  134. this model's layers).
  135. You can find an runnable example in examples/custom_loss.py.
  136. Args:
  137. policy_loss: List of or single policy loss(es) from the policy.
  138. loss_inputs: map of input placeholders for rollout data.
  139. Returns:
  140. List of or scalar tensor for the customized loss(es) for this
  141. model.
  142. """
  143. return policy_loss
  144. def metrics(self) -> Dict[str, TensorType]:
  145. """Override to return custom metrics from your model.
  146. The stats will be reported as part of the learner stats, i.e.,
  147. info.learner.[policy_id, e.g. "default_policy"].model.key1=metric1
  148. Returns:
  149. The custom metrics for this model.
  150. """
  151. return {}
  152. def __call__(
  153. self,
  154. input_dict: Union[SampleBatch, ModelInputDict],
  155. state: List[Any] = None,
  156. seq_lens: TensorType = None,
  157. ) -> (TensorType, List[TensorType]):
  158. """Call the model with the given input tensors and state.
  159. This is the method used by RLlib to execute the forward pass. It calls
  160. forward() internally after unpacking nested observation tensors.
  161. Custom models should override forward() instead of __call__.
  162. Args:
  163. input_dict: Dictionary of input tensors.
  164. state: list of state tensors with sizes matching those
  165. returned by get_initial_state + the batch dimension
  166. seq_lens: 1D tensor holding input sequence lengths.
  167. Returns:
  168. A tuple consisting of the model output tensor of size
  169. [BATCH, output_spec.size] or a list of tensors corresponding to
  170. output_spec.shape_list, and a list of state tensors of
  171. [BATCH, state_size_i] if any.
  172. """
  173. # Original observations will be stored in "obs".
  174. # Flattened (preprocessed) obs will be stored in "obs_flat".
  175. # SampleBatch case: Models can now be called directly with a
  176. # SampleBatch (which also includes tracking-dict case (deprecated now),
  177. # where tensors get automatically converted).
  178. if isinstance(input_dict, SampleBatch):
  179. restored = input_dict.copy(shallow=True)
  180. else:
  181. restored = input_dict.copy()
  182. # Backward compatibility.
  183. if not state:
  184. state = []
  185. i = 0
  186. while "state_in_{}".format(i) in input_dict:
  187. state.append(input_dict["state_in_{}".format(i)])
  188. i += 1
  189. if seq_lens is None:
  190. seq_lens = input_dict.get(SampleBatch.SEQ_LENS)
  191. # No Preprocessor used: `config._disable_preprocessor_api`=True.
  192. # TODO: This is unnecessary for when no preprocessor is used.
  193. # Obs are not flat then anymore. However, we'll keep this
  194. # here for backward-compatibility until Preprocessors have
  195. # been fully deprecated.
  196. if self.model_config.get("_disable_preprocessor_api"):
  197. restored["obs_flat"] = input_dict["obs"]
  198. # Input to this Model went through a Preprocessor.
  199. # Generate extra keys: "obs_flat" (vs "obs", which will hold the
  200. # original obs).
  201. else:
  202. restored["obs"] = restore_original_dimensions(
  203. input_dict["obs"], self.obs_space, self.framework
  204. )
  205. try:
  206. if len(input_dict["obs"].shape) > 2:
  207. restored["obs_flat"] = flatten(input_dict["obs"], self.framework)
  208. else:
  209. restored["obs_flat"] = input_dict["obs"]
  210. except AttributeError:
  211. restored["obs_flat"] = input_dict["obs"]
  212. with self.context():
  213. res = self.forward(restored, state or [], seq_lens)
  214. if isinstance(input_dict, SampleBatch):
  215. input_dict.accessed_keys = restored.accessed_keys - {"obs_flat"}
  216. input_dict.deleted_keys = restored.deleted_keys
  217. input_dict.added_keys = restored.added_keys - {"obs_flat"}
  218. if (not isinstance(res, list) and not isinstance(res, tuple)) or len(res) != 2:
  219. raise ValueError(
  220. "forward() must return a tuple of (output, state) tensors, "
  221. "got {}".format(res)
  222. )
  223. outputs, state_out = res
  224. if not isinstance(state_out, list):
  225. raise ValueError("State output is not a list: {}".format(state_out))
  226. self._last_output = outputs
  227. return outputs, state_out if len(state_out) > 0 else (state or [])
  228. def last_output(self) -> TensorType:
  229. """Returns the last output returned from calling the model."""
  230. return self._last_output
  231. def context(self) -> contextlib.AbstractContextManager:
  232. """Returns a contextmanager for the current forward pass."""
  233. return NullContextManager()
  234. def variables(
  235. self, as_dict: bool = False
  236. ) -> Union[List[TensorType], Dict[str, TensorType]]:
  237. """Returns the list (or a dict) of variables for this model.
  238. Args:
  239. as_dict: Whether variables should be returned as dict-values
  240. (using descriptive str keys).
  241. Returns:
  242. The list (or dict if `as_dict` is True) of all variables of this
  243. ModelV2.
  244. """
  245. raise NotImplementedError
  246. def trainable_variables(
  247. self, as_dict: bool = False
  248. ) -> Union[List[TensorType], Dict[str, TensorType]]:
  249. """Returns the list of trainable variables for this model.
  250. Args:
  251. as_dict: Whether variables should be returned as dict-values
  252. (using descriptive keys).
  253. Returns:
  254. The list (or dict if `as_dict` is True) of all trainable
  255. (tf)/requires_grad (torch) variables of this ModelV2.
  256. """
  257. raise NotImplementedError
  258. def is_time_major(self) -> bool:
  259. """If True, data for calling this ModelV2 must be in time-major format.
  260. Returns
  261. Whether this ModelV2 requires a time-major (TxBx...) data
  262. format.
  263. """
  264. return self.time_major is True
  265. @Deprecated(error=True)
  266. def import_from_h5(self, *args, **kwargs):
  267. pass
  268. @OldAPIStack
  269. def flatten(obs: TensorType, framework: str) -> TensorType:
  270. """Flatten the given tensor."""
  271. if framework in ["tf2", "tf"]:
  272. return tf1.keras.layers.Flatten()(obs)
  273. elif framework == "torch":
  274. assert torch is not None
  275. return torch.flatten(obs, start_dim=1)
  276. else:
  277. raise NotImplementedError("flatten", framework)
  278. @OldAPIStack
  279. def restore_original_dimensions(
  280. obs: TensorType, obs_space: Space, tensorlib: Any = tf
  281. ) -> TensorStructType:
  282. """Unpacks Dict and Tuple space observations into their original form.
  283. This is needed since we flatten Dict and Tuple observations in transit
  284. within a SampleBatch. Before sending them to the model though, we should
  285. unflatten them into Dicts or Tuples of tensors.
  286. Args:
  287. obs: The flattened observation tensor.
  288. obs_space: The flattened obs space. If this has the
  289. `original_space` attribute, we will unflatten the tensor to that
  290. shape.
  291. tensorlib: The library used to unflatten (reshape) the array/tensor.
  292. Returns:
  293. single tensor or dict / tuple of tensors matching the original
  294. observation space.
  295. """
  296. if tensorlib in ["tf", "tf2"]:
  297. assert tf is not None
  298. tensorlib = tf
  299. elif tensorlib == "torch":
  300. assert torch is not None
  301. tensorlib = torch
  302. elif tensorlib == "numpy":
  303. assert np is not None
  304. tensorlib = np
  305. original_space = getattr(obs_space, "original_space", obs_space)
  306. return _unpack_obs(obs, original_space, tensorlib=tensorlib)
  307. # Cache of preprocessors, for if the user is calling unpack obs often.
  308. _cache = {}
  309. @OldAPIStack
  310. def _unpack_obs(obs: TensorType, space: Space, tensorlib: Any = tf) -> TensorStructType:
  311. """Unpack a flattened Dict or Tuple observation array/tensor.
  312. Args:
  313. obs: The flattened observation tensor, with last dimension equal to
  314. the flat size and any number of batch dimensions. For example, for
  315. Box(4,), the obs may have shape [B, 4], or [B, N, M, 4] in case
  316. the Box was nested under two Repeated spaces.
  317. space: The original space prior to flattening
  318. tensorlib: The library used to unflatten (reshape) the array/tensor
  319. """
  320. if isinstance(space, (gym.spaces.Dict, gym.spaces.Tuple, Repeated)):
  321. # Already unpacked?
  322. if (isinstance(space, gym.spaces.Tuple) and isinstance(obs, (list, tuple))) or (
  323. isinstance(space, gym.spaces.Dict) and isinstance(obs, dict)
  324. ):
  325. return obs
  326. # Unpack using preprocessor
  327. if id(space) in _cache:
  328. prep = _cache[id(space)]
  329. else:
  330. prep = get_preprocessor(space)(space)
  331. # Make an attempt to cache the result, if enough space left.
  332. if len(_cache) < 999:
  333. _cache[id(space)] = prep
  334. if len(obs.shape) < 2 or obs.shape[-1] != prep.shape[0]:
  335. raise ValueError(
  336. "Expected flattened obs shape of [..., {}], got {}".format(
  337. prep.shape[0], obs.shape
  338. )
  339. )
  340. offset = 0
  341. if tensorlib == tf:
  342. def get_value(v):
  343. if v is None:
  344. return -1
  345. elif isinstance(v, int):
  346. return v
  347. elif v.value is None:
  348. return -1
  349. else:
  350. return v.value
  351. batch_dims = [get_value(v) for v in obs.shape[:-1]]
  352. else:
  353. batch_dims = list(obs.shape[:-1])
  354. if isinstance(space, gym.spaces.Tuple):
  355. assert len(prep.preprocessors) == len(space.spaces), len(
  356. prep.preprocessors
  357. ) == len(space.spaces)
  358. u = []
  359. for p, v in zip(prep.preprocessors, space.spaces):
  360. obs_slice = obs[..., offset : offset + p.size]
  361. offset += p.size
  362. u.append(
  363. _unpack_obs(
  364. tensorlib.reshape(obs_slice, batch_dims + list(p.shape)),
  365. v,
  366. tensorlib=tensorlib,
  367. )
  368. )
  369. elif isinstance(space, gym.spaces.Dict):
  370. assert len(prep.preprocessors) == len(space.spaces), len(
  371. prep.preprocessors
  372. ) == len(space.spaces)
  373. u = OrderedDict()
  374. for p, (k, v) in zip(prep.preprocessors, space.spaces.items()):
  375. obs_slice = obs[..., offset : offset + p.size]
  376. offset += p.size
  377. u[k] = _unpack_obs(
  378. tensorlib.reshape(obs_slice, batch_dims + list(p.shape)),
  379. v,
  380. tensorlib=tensorlib,
  381. )
  382. # Repeated space.
  383. else:
  384. assert isinstance(prep, RepeatedValuesPreprocessor), prep
  385. child_size = prep.child_preprocessor.size
  386. # The list lengths are stored in the first slot of the flat obs.
  387. lengths = obs[..., 0]
  388. # [B, ..., 1 + max_len * child_sz] -> [B, ..., max_len, child_sz]
  389. with_repeat_dim = tensorlib.reshape(
  390. obs[..., 1:], batch_dims + [space.max_len, child_size]
  391. )
  392. # Retry the unpack, dropping the List container space.
  393. u = _unpack_obs(with_repeat_dim, space.child_space, tensorlib=tensorlib)
  394. return RepeatedValues(u, lengths=lengths, max_len=prep._obs_space.max_len)
  395. return u
  396. else:
  397. return obs