catalog.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906
  1. import logging
  2. from functools import partial
  3. from typing import List, Optional, Type, Union
  4. import gymnasium as gym
  5. import numpy as np
  6. import tree # pip install dm_tree
  7. from gymnasium.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple
  8. from ray._common.deprecation import (
  9. DEPRECATED_VALUE,
  10. deprecation_warning,
  11. )
  12. from ray.rllib.models.action_dist import ActionDistribution
  13. from ray.rllib.models.modelv2 import ModelV2
  14. from ray.rllib.models.preprocessors import Preprocessor, get_preprocessor
  15. from ray.rllib.models.tf.tf_action_dist import (
  16. Categorical,
  17. Deterministic,
  18. DiagGaussian,
  19. Dirichlet,
  20. MultiActionDistribution,
  21. MultiCategorical,
  22. )
  23. from ray.rllib.models.torch.torch_action_dist import (
  24. TorchCategorical,
  25. TorchDeterministic,
  26. TorchDiagGaussian,
  27. TorchDirichlet,
  28. TorchMultiActionDistribution,
  29. TorchMultiCategorical,
  30. )
  31. from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
  32. from ray.rllib.utils.error import UnsupportedSpaceException
  33. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  34. from ray.rllib.utils.from_config import from_config
  35. from ray.rllib.utils.spaces.simplex import Simplex
  36. from ray.rllib.utils.spaces.space_utils import flatten_space
  37. from ray.rllib.utils.typing import ModelConfigDict, TensorType
  38. from ray.tune.registry import (
  39. RLLIB_ACTION_DIST,
  40. RLLIB_MODEL,
  41. _global_registry,
  42. )
  43. tf1, tf, tfv = try_import_tf()
  44. torch, _ = try_import_torch()
  45. logger = logging.getLogger(__name__)
  46. # fmt: off
  47. # __sphinx_doc_begin__
  48. MODEL_DEFAULTS: ModelConfigDict = {
  49. "fcnet_hiddens": [256, 256],
  50. "fcnet_activation": "tanh",
  51. "fcnet_weights_initializer": None,
  52. "fcnet_weights_initializer_config": None,
  53. "fcnet_bias_initializer": None,
  54. "fcnet_bias_initializer_config": None,
  55. "conv_filters": None,
  56. "conv_activation": "relu",
  57. "conv_kernel_initializer": None,
  58. "conv_kernel_initializer_config": None,
  59. "conv_bias_initializer": None,
  60. "conv_bias_initializer_config": None,
  61. "conv_transpose_kernel_initializer": None,
  62. "conv_transpose_kernel_initializer_config": None,
  63. "conv_transpose_bias_initializer": None,
  64. "conv_transpose_bias_initializer_config": None,
  65. "post_fcnet_hiddens": [],
  66. "post_fcnet_activation": "relu",
  67. "post_fcnet_weights_initializer": None,
  68. "post_fcnet_weights_initializer_config": None,
  69. "post_fcnet_bias_initializer": None,
  70. "post_fcnet_bias_initializer_config": None,
  71. "free_log_std": False,
  72. "log_std_clip_param": 20.0,
  73. "no_final_linear": False,
  74. "vf_share_layers": True,
  75. "use_lstm": False,
  76. "max_seq_len": 20,
  77. "lstm_cell_size": 256,
  78. "lstm_use_prev_action": False,
  79. "lstm_use_prev_reward": False,
  80. "lstm_weights_initializer": None,
  81. "lstm_weights_initializer_config": None,
  82. "lstm_bias_initializer": None,
  83. "lstm_bias_initializer_config": None,
  84. "_time_major": False,
  85. "use_attention": False,
  86. "attention_num_transformer_units": 1,
  87. "attention_dim": 64,
  88. "attention_num_heads": 1,
  89. "attention_head_dim": 32,
  90. "attention_memory_inference": 50,
  91. "attention_memory_training": 50,
  92. "attention_position_wise_mlp_dim": 32,
  93. "attention_init_gru_gate_bias": 2.0,
  94. "attention_use_n_prev_actions": 0,
  95. "attention_use_n_prev_rewards": 0,
  96. "framestack": True,
  97. "dim": 84,
  98. "grayscale": False,
  99. "zero_mean": True,
  100. "custom_model": None,
  101. "custom_model_config": {},
  102. "custom_action_dist": None,
  103. "custom_preprocessor": None,
  104. "encoder_latent_dim": None,
  105. "always_check_shapes": False,
  106. # Deprecated keys:
  107. "lstm_use_prev_action_reward": DEPRECATED_VALUE,
  108. "_use_default_native_models": DEPRECATED_VALUE,
  109. "_disable_preprocessor_api": False,
  110. "_disable_action_flattening": False,
  111. }
  112. # __sphinx_doc_end__
  113. # fmt: on
  114. @DeveloperAPI
  115. class ModelCatalog:
  116. """Registry of models, preprocessors, and action distributions for envs.
  117. .. testcode::
  118. :skipif: True
  119. prep = ModelCatalog.get_preprocessor(env)
  120. observation = prep.transform(raw_observation)
  121. dist_class, dist_dim = ModelCatalog.get_action_dist(
  122. env.action_space, {})
  123. model = ModelCatalog.get_model_v2(
  124. obs_space, action_space, num_outputs, options)
  125. dist = dist_class(model.outputs, model)
  126. action = dist.sample()
  127. """
  128. @staticmethod
  129. @DeveloperAPI
  130. def get_action_dist(
  131. action_space: gym.Space,
  132. config: ModelConfigDict,
  133. dist_type: Optional[Union[str, Type[ActionDistribution]]] = None,
  134. framework: str = "tf",
  135. **kwargs
  136. ) -> (type, int):
  137. """Returns a distribution class and size for the given action space.
  138. Args:
  139. action_space: Action space of the target gym env.
  140. config (Optional[dict]): Optional model config.
  141. dist_type (Optional[Union[str, Type[ActionDistribution]]]):
  142. Identifier of the action distribution (str) interpreted as a
  143. hint or the actual ActionDistribution class to use.
  144. framework: One of "tf2", "tf", "torch", or "jax".
  145. kwargs: Optional kwargs to pass on to the Distribution's
  146. constructor.
  147. Returns:
  148. Tuple:
  149. - dist_class (ActionDistribution): Python class of the
  150. distribution.
  151. - dist_dim (int): The size of the input vector to the
  152. distribution.
  153. """
  154. dist_cls = None
  155. config = config or MODEL_DEFAULTS
  156. # Custom distribution given.
  157. if config.get("custom_action_dist"):
  158. custom_action_config = config.copy()
  159. action_dist_name = custom_action_config.pop("custom_action_dist")
  160. logger.debug("Using custom action distribution {}".format(action_dist_name))
  161. dist_cls = _global_registry.get(RLLIB_ACTION_DIST, action_dist_name)
  162. return ModelCatalog._get_multi_action_distribution(
  163. dist_cls, action_space, custom_action_config, framework
  164. )
  165. # Dist_type is given directly as a class.
  166. elif (
  167. type(dist_type) is type
  168. and issubclass(dist_type, ActionDistribution)
  169. and dist_type not in (MultiActionDistribution, TorchMultiActionDistribution)
  170. ):
  171. dist_cls = dist_type
  172. # Box space -> DiagGaussian OR Deterministic.
  173. elif isinstance(action_space, Box):
  174. if action_space.dtype.name.startswith("int"):
  175. low_ = np.min(action_space.low)
  176. high_ = np.max(action_space.high)
  177. dist_cls = (
  178. TorchMultiCategorical if framework == "torch" else MultiCategorical
  179. )
  180. num_cats = int(np.prod(action_space.shape))
  181. return (
  182. partial(
  183. dist_cls,
  184. input_lens=[high_ - low_ + 1 for _ in range(num_cats)],
  185. action_space=action_space,
  186. ),
  187. num_cats * (high_ - low_ + 1),
  188. )
  189. else:
  190. if len(action_space.shape) > 1:
  191. raise UnsupportedSpaceException(
  192. "Action space has multiple dimensions "
  193. "{}. ".format(action_space.shape)
  194. + "Consider reshaping this into a single dimension, "
  195. "using a custom action distribution, "
  196. "using a Tuple action space, or the multi-agent API."
  197. )
  198. # TODO(sven): Check for bounds and return SquashedNormal, etc..
  199. if dist_type is None:
  200. return (
  201. partial(
  202. TorchDiagGaussian if framework == "torch" else DiagGaussian,
  203. action_space=action_space,
  204. ),
  205. DiagGaussian.required_model_output_shape(action_space, config),
  206. )
  207. elif dist_type == "deterministic":
  208. dist_cls = (
  209. TorchDeterministic if framework == "torch" else Deterministic
  210. )
  211. # Discrete Space -> Categorical.
  212. elif isinstance(action_space, Discrete):
  213. if framework == "torch":
  214. dist_cls = TorchCategorical
  215. elif framework == "jax":
  216. from ray.rllib.models.jax.jax_action_dist import JAXCategorical
  217. dist_cls = JAXCategorical
  218. else:
  219. dist_cls = Categorical
  220. # Tuple/Dict Spaces -> MultiAction.
  221. elif dist_type in (
  222. MultiActionDistribution,
  223. TorchMultiActionDistribution,
  224. ) or isinstance(action_space, (Tuple, Dict)):
  225. return ModelCatalog._get_multi_action_distribution(
  226. (
  227. MultiActionDistribution
  228. if framework == "tf"
  229. else TorchMultiActionDistribution
  230. ),
  231. action_space,
  232. config,
  233. framework,
  234. )
  235. # Simplex -> Dirichlet.
  236. elif isinstance(action_space, Simplex):
  237. dist_cls = TorchDirichlet if framework == "torch" else Dirichlet
  238. # MultiDiscrete -> MultiCategorical.
  239. elif isinstance(action_space, MultiDiscrete):
  240. dist_cls = (
  241. TorchMultiCategorical if framework == "torch" else MultiCategorical
  242. )
  243. return partial(dist_cls, input_lens=action_space.nvec), int(
  244. sum(action_space.nvec)
  245. )
  246. # Unknown type -> Error.
  247. else:
  248. raise NotImplementedError(
  249. "Unsupported args: {} {}".format(action_space, dist_type)
  250. )
  251. return dist_cls, int(dist_cls.required_model_output_shape(action_space, config))
  252. @staticmethod
  253. @DeveloperAPI
  254. def get_action_shape(
  255. action_space: gym.Space, framework: str = "tf"
  256. ) -> (np.dtype, List[int]):
  257. """Returns action tensor dtype and shape for the action space.
  258. Args:
  259. action_space: Action space of the target gym env.
  260. framework: The framework identifier. One of "tf" or "torch".
  261. Returns:
  262. (dtype, shape): Dtype and shape of the actions tensor.
  263. """
  264. dl_lib = torch if framework == "torch" else tf
  265. if isinstance(action_space, Discrete):
  266. return action_space.dtype, (None,)
  267. elif isinstance(action_space, (Box, Simplex)):
  268. if np.issubdtype(action_space.dtype, np.floating):
  269. return dl_lib.float32, (None,) + action_space.shape
  270. elif np.issubdtype(action_space.dtype, np.integer):
  271. return dl_lib.int32, (None,) + action_space.shape
  272. else:
  273. raise ValueError("RLlib doesn't support non int or float box spaces")
  274. elif isinstance(action_space, MultiDiscrete):
  275. return action_space.dtype, (None,) + action_space.shape
  276. elif isinstance(action_space, (Tuple, Dict)):
  277. flat_action_space = flatten_space(action_space)
  278. size = 0
  279. all_discrete = True
  280. for i in range(len(flat_action_space)):
  281. if isinstance(flat_action_space[i], Discrete):
  282. size += 1
  283. else:
  284. all_discrete = False
  285. size += np.prod(flat_action_space[i].shape)
  286. size = int(size)
  287. return dl_lib.int32 if all_discrete else dl_lib.float32, (None, size)
  288. else:
  289. raise NotImplementedError(
  290. "Action space {} not supported".format(action_space)
  291. )
  292. @staticmethod
  293. @DeveloperAPI
  294. def get_action_placeholder(
  295. action_space: gym.Space, name: str = "action"
  296. ) -> TensorType:
  297. """Returns an action placeholder consistent with the action space
  298. Args:
  299. action_space: Action space of the target gym env.
  300. name: An optional string to name the placeholder by.
  301. Default: "action".
  302. Returns:
  303. action_placeholder: A placeholder for the actions
  304. """
  305. dtype, shape = ModelCatalog.get_action_shape(action_space, framework="tf")
  306. return tf1.placeholder(dtype, shape=shape, name=name)
  307. @staticmethod
  308. @DeveloperAPI
  309. def get_model_v2(
  310. obs_space: gym.Space,
  311. action_space: gym.Space,
  312. num_outputs: int,
  313. model_config: ModelConfigDict,
  314. framework: str = "tf",
  315. name: str = "default_model",
  316. model_interface: type = None,
  317. default_model: type = None,
  318. **model_kwargs
  319. ) -> ModelV2:
  320. """Returns a suitable model compatible with given spaces and output.
  321. Args:
  322. obs_space: Observation space of the target gym env. This
  323. may have an `original_space` attribute that specifies how to
  324. unflatten the tensor into a ragged tensor.
  325. action_space: Action space of the target gym env.
  326. num_outputs: The size of the output vector of the model.
  327. model_config: The "model" sub-config dict
  328. within the Algorithm's config dict.
  329. framework: One of "tf2", "tf", "torch", or "jax".
  330. name: Name (scope) for the model.
  331. model_interface: Interface required for the model
  332. default_model: Override the default class for the model. This
  333. only has an effect when not using a custom model
  334. model_kwargs: Args to pass to the ModelV2 constructor
  335. Returns:
  336. model (ModelV2): Model to use for the policy.
  337. """
  338. # Validate the given config dict.
  339. ModelCatalog._validate_config(
  340. config=model_config, action_space=action_space, framework=framework
  341. )
  342. if model_config.get("custom_model"):
  343. # Allow model kwargs to be overridden / augmented by
  344. # custom_model_config.
  345. customized_model_kwargs = dict(
  346. model_kwargs, **model_config.get("custom_model_config", {})
  347. )
  348. if isinstance(model_config["custom_model"], type):
  349. model_cls = model_config["custom_model"]
  350. elif (
  351. isinstance(model_config["custom_model"], str)
  352. and "." in model_config["custom_model"]
  353. ):
  354. return from_config(
  355. cls=model_config["custom_model"],
  356. obs_space=obs_space,
  357. action_space=action_space,
  358. num_outputs=num_outputs,
  359. model_config=customized_model_kwargs,
  360. name=name,
  361. )
  362. else:
  363. model_cls = _global_registry.get(
  364. RLLIB_MODEL, model_config["custom_model"]
  365. )
  366. # Only allow ModelV2 or native keras Models.
  367. if not issubclass(model_cls, ModelV2):
  368. if framework not in ["tf", "tf2"] or not issubclass(
  369. model_cls, tf.keras.Model
  370. ):
  371. raise ValueError(
  372. "`model_cls` must be a ModelV2 sub-class, but is"
  373. " {}!".format(model_cls)
  374. )
  375. logger.info("Wrapping {} as {}".format(model_cls, model_interface))
  376. model_cls = ModelCatalog._wrap_if_needed(model_cls, model_interface)
  377. if framework in ["tf2", "tf"]:
  378. # Try wrapping custom model with LSTM/attention, if required.
  379. if model_config.get("use_lstm") or model_config.get("use_attention"):
  380. from ray.rllib.models.tf.attention_net import (
  381. AttentionWrapper,
  382. )
  383. from ray.rllib.models.tf.recurrent_net import (
  384. LSTMWrapper,
  385. )
  386. wrapped_cls = model_cls
  387. forward = wrapped_cls.forward
  388. model_cls = ModelCatalog._wrap_if_needed(
  389. wrapped_cls,
  390. LSTMWrapper
  391. if model_config.get("use_lstm")
  392. else AttentionWrapper,
  393. )
  394. model_cls._wrapped_forward = forward
  395. # Obsolete: Track and warn if vars were created but not
  396. # registered. Only still do this, if users do register their
  397. # variables. If not (which they shouldn't), don't check here.
  398. created = set()
  399. def track_var_creation(next_creator, **kw):
  400. v = next_creator(**kw)
  401. created.add(v.ref())
  402. return v
  403. with tf.variable_creator_scope(track_var_creation):
  404. if issubclass(model_cls, tf.keras.Model):
  405. instance = model_cls(
  406. input_space=obs_space,
  407. action_space=action_space,
  408. num_outputs=num_outputs,
  409. name=name,
  410. **customized_model_kwargs,
  411. )
  412. else:
  413. # Try calling with kwargs first (custom ModelV2 should
  414. # accept these as kwargs, not get them from
  415. # config["custom_model_config"] anymore).
  416. try:
  417. instance = model_cls(
  418. obs_space,
  419. action_space,
  420. num_outputs,
  421. model_config,
  422. name,
  423. **customized_model_kwargs,
  424. )
  425. except TypeError as e:
  426. # Keyword error: Try old way w/o kwargs.
  427. if "__init__() got an unexpected " in e.args[0]:
  428. instance = model_cls(
  429. obs_space,
  430. action_space,
  431. num_outputs,
  432. model_config,
  433. name,
  434. **model_kwargs,
  435. )
  436. logger.warning(
  437. "Custom ModelV2 should accept all custom "
  438. "options as **kwargs, instead of expecting"
  439. " them in config['custom_model_config']!"
  440. )
  441. # Other error -> re-raise.
  442. else:
  443. raise e
  444. # User still registered TFModelV2's variables: Check, whether
  445. # ok.
  446. registered = []
  447. if not isinstance(instance, tf.keras.Model):
  448. registered = set(instance.var_list)
  449. if len(registered) > 0:
  450. not_registered = set()
  451. for var in created:
  452. if var not in registered:
  453. not_registered.add(var)
  454. if not_registered:
  455. raise ValueError(
  456. "It looks like you are still using "
  457. "`{}.register_variables()` to register your "
  458. "model's weights. This is no longer required, but "
  459. "if you are still calling this method at least "
  460. "once, you must make sure to register all created "
  461. "variables properly. The missing variables are {},"
  462. " and you only registered {}. "
  463. "Did you forget to call `register_variables()` on "
  464. "some of the variables in question?".format(
  465. instance, not_registered, registered
  466. )
  467. )
  468. elif framework == "torch":
  469. # Try wrapping custom model with LSTM/attention, if required.
  470. if model_config.get("use_lstm") or model_config.get("use_attention"):
  471. from ray.rllib.models.torch.attention_net import AttentionWrapper
  472. from ray.rllib.models.torch.recurrent_net import LSTMWrapper
  473. wrapped_cls = model_cls
  474. forward = wrapped_cls.forward
  475. model_cls = ModelCatalog._wrap_if_needed(
  476. wrapped_cls,
  477. LSTMWrapper
  478. if model_config.get("use_lstm")
  479. else AttentionWrapper,
  480. )
  481. model_cls._wrapped_forward = forward
  482. # PyTorch automatically tracks nn.Modules inside the parent
  483. # nn.Module's constructor.
  484. # Try calling with kwargs first (custom ModelV2 should
  485. # accept these as kwargs, not get them from
  486. # config["custom_model_config"] anymore).
  487. try:
  488. instance = model_cls(
  489. obs_space,
  490. action_space,
  491. num_outputs,
  492. model_config,
  493. name,
  494. **customized_model_kwargs,
  495. )
  496. except TypeError as e:
  497. # Keyword error: Try old way w/o kwargs.
  498. if "__init__() got an unexpected " in e.args[0]:
  499. instance = model_cls(
  500. obs_space,
  501. action_space,
  502. num_outputs,
  503. model_config,
  504. name,
  505. **model_kwargs,
  506. )
  507. logger.warning(
  508. "Custom ModelV2 should accept all custom "
  509. "options as **kwargs, instead of expecting"
  510. " them in config['custom_model_config']!"
  511. )
  512. # Other error -> re-raise.
  513. else:
  514. raise e
  515. else:
  516. raise NotImplementedError(
  517. "`framework` must be 'tf2|tf|torch', but is "
  518. "{}!".format(framework)
  519. )
  520. return instance
  521. # Find a default TFModelV2 and wrap with model_interface.
  522. if framework in ["tf", "tf2"]:
  523. v2_class = None
  524. # Try to get a default v2 model.
  525. if not model_config.get("custom_model"):
  526. v2_class = default_model or ModelCatalog._get_v2_model_class(
  527. obs_space, model_config, framework=framework
  528. )
  529. if not v2_class:
  530. raise ValueError("ModelV2 class could not be determined!")
  531. if model_config.get("use_lstm") or model_config.get("use_attention"):
  532. from ray.rllib.models.tf.attention_net import (
  533. AttentionWrapper,
  534. )
  535. from ray.rllib.models.tf.recurrent_net import (
  536. LSTMWrapper,
  537. )
  538. wrapped_cls = v2_class
  539. if model_config.get("use_lstm"):
  540. v2_class = ModelCatalog._wrap_if_needed(wrapped_cls, LSTMWrapper)
  541. v2_class._wrapped_forward = wrapped_cls.forward
  542. else:
  543. v2_class = ModelCatalog._wrap_if_needed(
  544. wrapped_cls, AttentionWrapper
  545. )
  546. v2_class._wrapped_forward = wrapped_cls.forward
  547. # Wrap in the requested interface.
  548. wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface)
  549. if issubclass(wrapper, tf.keras.Model):
  550. model = wrapper(
  551. input_space=obs_space,
  552. action_space=action_space,
  553. num_outputs=num_outputs,
  554. name=name,
  555. **dict(model_kwargs, **model_config),
  556. )
  557. return model
  558. return wrapper(
  559. obs_space, action_space, num_outputs, model_config, name, **model_kwargs
  560. )
  561. # Find a default TorchModelV2 and wrap with model_interface.
  562. elif framework == "torch":
  563. # Try to get a default v2 model.
  564. if not model_config.get("custom_model"):
  565. v2_class = default_model or ModelCatalog._get_v2_model_class(
  566. obs_space, model_config, framework=framework
  567. )
  568. if not v2_class:
  569. raise ValueError("ModelV2 class could not be determined!")
  570. if model_config.get("use_lstm") or model_config.get("use_attention"):
  571. from ray.rllib.models.torch.attention_net import AttentionWrapper
  572. from ray.rllib.models.torch.recurrent_net import LSTMWrapper
  573. wrapped_cls = v2_class
  574. forward = wrapped_cls.forward
  575. if model_config.get("use_lstm"):
  576. v2_class = ModelCatalog._wrap_if_needed(wrapped_cls, LSTMWrapper)
  577. else:
  578. v2_class = ModelCatalog._wrap_if_needed(
  579. wrapped_cls, AttentionWrapper
  580. )
  581. v2_class._wrapped_forward = forward
  582. # Wrap in the requested interface.
  583. wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface)
  584. return wrapper(
  585. obs_space, action_space, num_outputs, model_config, name, **model_kwargs
  586. )
  587. # Find a default JAXModelV2 and wrap with model_interface.
  588. elif framework == "jax":
  589. v2_class = default_model or ModelCatalog._get_v2_model_class(
  590. obs_space, model_config, framework=framework
  591. )
  592. # Wrap in the requested interface.
  593. wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface)
  594. return wrapper(
  595. obs_space, action_space, num_outputs, model_config, name, **model_kwargs
  596. )
  597. else:
  598. raise NotImplementedError(
  599. "`framework` must be 'tf2|tf|torch', but is {}!".format(framework)
  600. )
  601. @staticmethod
  602. @DeveloperAPI
  603. def get_preprocessor(
  604. env: gym.Env, options: Optional[dict] = None, include_multi_binary: bool = False
  605. ) -> Preprocessor:
  606. """Returns a suitable preprocessor for the given env.
  607. This is a wrapper for get_preprocessor_for_space().
  608. """
  609. return ModelCatalog.get_preprocessor_for_space(
  610. env.observation_space, options, include_multi_binary
  611. )
  612. @staticmethod
  613. @DeveloperAPI
  614. def get_preprocessor_for_space(
  615. observation_space: gym.Space,
  616. options: dict = None,
  617. include_multi_binary: bool = False,
  618. ) -> Preprocessor:
  619. """Returns a suitable preprocessor for the given observation space.
  620. Args:
  621. observation_space: The input observation space.
  622. options: Options to pass to the preprocessor.
  623. include_multi_binary: Whether to include the MultiBinaryPreprocessor in
  624. the possible preprocessors returned by this method.
  625. Returns:
  626. preprocessor: Preprocessor for the observations.
  627. """
  628. options = options or MODEL_DEFAULTS
  629. for k in options.keys():
  630. if k not in MODEL_DEFAULTS:
  631. raise Exception(
  632. "Unknown config key `{}`, all keys: {}".format(
  633. k, list(MODEL_DEFAULTS)
  634. )
  635. )
  636. cls = get_preprocessor(
  637. observation_space, include_multi_binary=include_multi_binary
  638. )
  639. prep = cls(observation_space, options)
  640. if prep is not None:
  641. logger.debug(
  642. "Created preprocessor {}: {} -> {}".format(
  643. prep, observation_space, prep.shape
  644. )
  645. )
  646. return prep
  647. @staticmethod
  648. @PublicAPI
  649. def register_custom_model(model_name: str, model_class: type) -> None:
  650. """Register a custom model class by name.
  651. The model can be later used by specifying {"custom_model": model_name}
  652. in the model config.
  653. Args:
  654. model_name: Name to register the model under.
  655. model_class: Python class of the model.
  656. """
  657. if tf is not None:
  658. if issubclass(model_class, tf.keras.Model):
  659. deprecation_warning(old="register_custom_model", error=False)
  660. _global_registry.register(RLLIB_MODEL, model_name, model_class)
  661. @staticmethod
  662. @PublicAPI
  663. def register_custom_action_dist(
  664. action_dist_name: str, action_dist_class: type
  665. ) -> None:
  666. """Register a custom action distribution class by name.
  667. The model can be later used by specifying
  668. {"custom_action_dist": action_dist_name} in the model config.
  669. Args:
  670. model_name: Name to register the action distribution under.
  671. model_class: Python class of the action distribution.
  672. """
  673. _global_registry.register(
  674. RLLIB_ACTION_DIST, action_dist_name, action_dist_class
  675. )
  676. @staticmethod
  677. def _wrap_if_needed(model_cls: type, model_interface: type) -> type:
  678. if not model_interface or issubclass(model_cls, model_interface):
  679. return model_cls
  680. assert issubclass(model_cls, ModelV2), model_cls
  681. class wrapper(model_interface, model_cls):
  682. pass
  683. name = "{}_as_{}".format(model_cls.__name__, model_interface.__name__)
  684. wrapper.__name__ = name
  685. wrapper.__qualname__ = name
  686. return wrapper
  687. @staticmethod
  688. def _get_v2_model_class(
  689. input_space: gym.Space, model_config: ModelConfigDict, framework: str = "tf"
  690. ) -> Type[ModelV2]:
  691. VisionNet = None
  692. ComplexNet = None
  693. if framework in ["tf2", "tf"]:
  694. from ray.rllib.models.tf.complex_input_net import (
  695. ComplexInputNetwork as ComplexNet,
  696. )
  697. from ray.rllib.models.tf.fcnet import (
  698. FullyConnectedNetwork as FCNet,
  699. )
  700. from ray.rllib.models.tf.visionnet import (
  701. VisionNetwork as VisionNet,
  702. )
  703. elif framework == "torch":
  704. from ray.rllib.models.torch.complex_input_net import (
  705. ComplexInputNetwork as ComplexNet,
  706. )
  707. from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as FCNet
  708. from ray.rllib.models.torch.visionnet import VisionNetwork as VisionNet
  709. elif framework == "jax":
  710. from ray.rllib.models.jax.fcnet import FullyConnectedNetwork as FCNet
  711. else:
  712. raise ValueError(
  713. "framework={} not supported in `ModelCatalog._get_v2_model_"
  714. "class`!".format(framework)
  715. )
  716. orig_space = (
  717. input_space
  718. if not hasattr(input_space, "original_space")
  719. else input_space.original_space
  720. )
  721. # `input_space` is 3D Box -> VisionNet.
  722. if isinstance(input_space, Box) and len(input_space.shape) == 3:
  723. if framework == "jax":
  724. raise NotImplementedError("No non-FC default net for JAX yet!")
  725. return VisionNet
  726. # `input_space` is 1D Box -> FCNet.
  727. elif (
  728. isinstance(input_space, Box)
  729. and len(input_space.shape) == 1
  730. and (
  731. not isinstance(orig_space, (Dict, Tuple))
  732. or not any(
  733. isinstance(s, Box) and len(s.shape) >= 2
  734. for s in flatten_space(orig_space)
  735. )
  736. )
  737. ):
  738. return FCNet
  739. # Complex (Dict, Tuple, 2D Box (flatten), Discrete, MultiDiscrete).
  740. else:
  741. if framework == "jax":
  742. raise NotImplementedError("No non-FC default net for JAX yet!")
  743. return ComplexNet
  744. @staticmethod
  745. def _get_multi_action_distribution(dist_class, action_space, config, framework):
  746. # In case the custom distribution is a child of MultiActionDistr.
  747. # If users want to completely ignore the suggested child
  748. # distributions, they should simply do so in their custom class'
  749. # constructor.
  750. if issubclass(
  751. dist_class, (MultiActionDistribution, TorchMultiActionDistribution)
  752. ):
  753. flat_action_space = flatten_space(action_space)
  754. child_dists_and_in_lens = tree.map_structure(
  755. lambda s: ModelCatalog.get_action_dist(s, config, framework=framework),
  756. flat_action_space,
  757. )
  758. child_dists = [e[0] for e in child_dists_and_in_lens]
  759. input_lens = [int(e[1]) for e in child_dists_and_in_lens]
  760. return (
  761. partial(
  762. dist_class,
  763. action_space=action_space,
  764. child_distributions=child_dists,
  765. input_lens=input_lens,
  766. ),
  767. int(sum(input_lens)),
  768. )
  769. return dist_class, dist_class.required_model_output_shape(action_space, config)
  770. @staticmethod
  771. def _validate_config(
  772. config: ModelConfigDict, action_space: gym.spaces.Space, framework: str
  773. ) -> None:
  774. """Validates a given model config dict.
  775. Args:
  776. config: The "model" sub-config dict
  777. within the Algorithm's config dict.
  778. action_space: The action space of the model, whose config are
  779. validated.
  780. framework: One of "jax", "tf2", "tf", or "torch".
  781. Raises:
  782. ValueError: If something is wrong with the given config.
  783. """
  784. # Soft-deprecate custom preprocessors.
  785. if config.get("custom_preprocessor") is not None:
  786. deprecation_warning(
  787. old="model.custom_preprocessor",
  788. new="gym.ObservationWrapper around your env or handle complex "
  789. "inputs inside your Model",
  790. error=True,
  791. )
  792. if config.get("use_attention") and config.get("use_lstm"):
  793. raise ValueError(
  794. "Only one of `use_lstm` or `use_attention` may be set to True!"
  795. )
  796. # For complex action spaces, only allow prev action inputs to
  797. # LSTMs and attention nets iff `_disable_action_flattening=True`.
  798. # TODO: `_disable_action_flattening=True` will be the default in
  799. # the future.
  800. if (
  801. (
  802. config.get("lstm_use_prev_action")
  803. or config.get("attention_use_n_prev_actions", 0) > 0
  804. )
  805. and not config.get("_disable_action_flattening")
  806. and isinstance(action_space, (Tuple, Dict))
  807. ):
  808. raise ValueError(
  809. "For your complex action space (Tuple|Dict) and your model's "
  810. "`prev-actions` setup of your model, you must set "
  811. "`_disable_action_flattening=True` in your main config dict!"
  812. )
  813. if framework == "jax":
  814. if config.get("use_attention"):
  815. raise ValueError(
  816. "`use_attention` not available for framework=jax so far!"
  817. )
  818. elif config.get("use_lstm"):
  819. raise ValueError("`use_lstm` not available for framework=jax so far!")