| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906 |
- import logging
- from functools import partial
- from typing import List, Optional, Type, Union
- import gymnasium as gym
- import numpy as np
- import tree # pip install dm_tree
- from gymnasium.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple
- from ray._common.deprecation import (
- DEPRECATED_VALUE,
- deprecation_warning,
- )
- from ray.rllib.models.action_dist import ActionDistribution
- from ray.rllib.models.modelv2 import ModelV2
- from ray.rllib.models.preprocessors import Preprocessor, get_preprocessor
- from ray.rllib.models.tf.tf_action_dist import (
- Categorical,
- Deterministic,
- DiagGaussian,
- Dirichlet,
- MultiActionDistribution,
- MultiCategorical,
- )
- from ray.rllib.models.torch.torch_action_dist import (
- TorchCategorical,
- TorchDeterministic,
- TorchDiagGaussian,
- TorchDirichlet,
- TorchMultiActionDistribution,
- TorchMultiCategorical,
- )
- from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
- from ray.rllib.utils.error import UnsupportedSpaceException
- from ray.rllib.utils.framework import try_import_tf, try_import_torch
- from ray.rllib.utils.from_config import from_config
- from ray.rllib.utils.spaces.simplex import Simplex
- from ray.rllib.utils.spaces.space_utils import flatten_space
- from ray.rllib.utils.typing import ModelConfigDict, TensorType
- from ray.tune.registry import (
- RLLIB_ACTION_DIST,
- RLLIB_MODEL,
- _global_registry,
- )
- tf1, tf, tfv = try_import_tf()
- torch, _ = try_import_torch()
- logger = logging.getLogger(__name__)
- # fmt: off
- # __sphinx_doc_begin__
- MODEL_DEFAULTS: ModelConfigDict = {
- "fcnet_hiddens": [256, 256],
- "fcnet_activation": "tanh",
- "fcnet_weights_initializer": None,
- "fcnet_weights_initializer_config": None,
- "fcnet_bias_initializer": None,
- "fcnet_bias_initializer_config": None,
- "conv_filters": None,
- "conv_activation": "relu",
- "conv_kernel_initializer": None,
- "conv_kernel_initializer_config": None,
- "conv_bias_initializer": None,
- "conv_bias_initializer_config": None,
- "conv_transpose_kernel_initializer": None,
- "conv_transpose_kernel_initializer_config": None,
- "conv_transpose_bias_initializer": None,
- "conv_transpose_bias_initializer_config": None,
- "post_fcnet_hiddens": [],
- "post_fcnet_activation": "relu",
- "post_fcnet_weights_initializer": None,
- "post_fcnet_weights_initializer_config": None,
- "post_fcnet_bias_initializer": None,
- "post_fcnet_bias_initializer_config": None,
- "free_log_std": False,
- "log_std_clip_param": 20.0,
- "no_final_linear": False,
- "vf_share_layers": True,
- "use_lstm": False,
- "max_seq_len": 20,
- "lstm_cell_size": 256,
- "lstm_use_prev_action": False,
- "lstm_use_prev_reward": False,
- "lstm_weights_initializer": None,
- "lstm_weights_initializer_config": None,
- "lstm_bias_initializer": None,
- "lstm_bias_initializer_config": None,
- "_time_major": False,
- "use_attention": False,
- "attention_num_transformer_units": 1,
- "attention_dim": 64,
- "attention_num_heads": 1,
- "attention_head_dim": 32,
- "attention_memory_inference": 50,
- "attention_memory_training": 50,
- "attention_position_wise_mlp_dim": 32,
- "attention_init_gru_gate_bias": 2.0,
- "attention_use_n_prev_actions": 0,
- "attention_use_n_prev_rewards": 0,
- "framestack": True,
- "dim": 84,
- "grayscale": False,
- "zero_mean": True,
- "custom_model": None,
- "custom_model_config": {},
- "custom_action_dist": None,
- "custom_preprocessor": None,
- "encoder_latent_dim": None,
- "always_check_shapes": False,
- # Deprecated keys:
- "lstm_use_prev_action_reward": DEPRECATED_VALUE,
- "_use_default_native_models": DEPRECATED_VALUE,
- "_disable_preprocessor_api": False,
- "_disable_action_flattening": False,
- }
- # __sphinx_doc_end__
- # fmt: on
- @DeveloperAPI
- class ModelCatalog:
- """Registry of models, preprocessors, and action distributions for envs.
- .. testcode::
- :skipif: True
- prep = ModelCatalog.get_preprocessor(env)
- observation = prep.transform(raw_observation)
- dist_class, dist_dim = ModelCatalog.get_action_dist(
- env.action_space, {})
- model = ModelCatalog.get_model_v2(
- obs_space, action_space, num_outputs, options)
- dist = dist_class(model.outputs, model)
- action = dist.sample()
- """
- @staticmethod
- @DeveloperAPI
- def get_action_dist(
- action_space: gym.Space,
- config: ModelConfigDict,
- dist_type: Optional[Union[str, Type[ActionDistribution]]] = None,
- framework: str = "tf",
- **kwargs
- ) -> (type, int):
- """Returns a distribution class and size for the given action space.
- Args:
- action_space: Action space of the target gym env.
- config (Optional[dict]): Optional model config.
- dist_type (Optional[Union[str, Type[ActionDistribution]]]):
- Identifier of the action distribution (str) interpreted as a
- hint or the actual ActionDistribution class to use.
- framework: One of "tf2", "tf", "torch", or "jax".
- kwargs: Optional kwargs to pass on to the Distribution's
- constructor.
- Returns:
- Tuple:
- - dist_class (ActionDistribution): Python class of the
- distribution.
- - dist_dim (int): The size of the input vector to the
- distribution.
- """
- dist_cls = None
- config = config or MODEL_DEFAULTS
- # Custom distribution given.
- if config.get("custom_action_dist"):
- custom_action_config = config.copy()
- action_dist_name = custom_action_config.pop("custom_action_dist")
- logger.debug("Using custom action distribution {}".format(action_dist_name))
- dist_cls = _global_registry.get(RLLIB_ACTION_DIST, action_dist_name)
- return ModelCatalog._get_multi_action_distribution(
- dist_cls, action_space, custom_action_config, framework
- )
- # Dist_type is given directly as a class.
- elif (
- type(dist_type) is type
- and issubclass(dist_type, ActionDistribution)
- and dist_type not in (MultiActionDistribution, TorchMultiActionDistribution)
- ):
- dist_cls = dist_type
- # Box space -> DiagGaussian OR Deterministic.
- elif isinstance(action_space, Box):
- if action_space.dtype.name.startswith("int"):
- low_ = np.min(action_space.low)
- high_ = np.max(action_space.high)
- dist_cls = (
- TorchMultiCategorical if framework == "torch" else MultiCategorical
- )
- num_cats = int(np.prod(action_space.shape))
- return (
- partial(
- dist_cls,
- input_lens=[high_ - low_ + 1 for _ in range(num_cats)],
- action_space=action_space,
- ),
- num_cats * (high_ - low_ + 1),
- )
- else:
- if len(action_space.shape) > 1:
- raise UnsupportedSpaceException(
- "Action space has multiple dimensions "
- "{}. ".format(action_space.shape)
- + "Consider reshaping this into a single dimension, "
- "using a custom action distribution, "
- "using a Tuple action space, or the multi-agent API."
- )
- # TODO(sven): Check for bounds and return SquashedNormal, etc..
- if dist_type is None:
- return (
- partial(
- TorchDiagGaussian if framework == "torch" else DiagGaussian,
- action_space=action_space,
- ),
- DiagGaussian.required_model_output_shape(action_space, config),
- )
- elif dist_type == "deterministic":
- dist_cls = (
- TorchDeterministic if framework == "torch" else Deterministic
- )
- # Discrete Space -> Categorical.
- elif isinstance(action_space, Discrete):
- if framework == "torch":
- dist_cls = TorchCategorical
- elif framework == "jax":
- from ray.rllib.models.jax.jax_action_dist import JAXCategorical
- dist_cls = JAXCategorical
- else:
- dist_cls = Categorical
- # Tuple/Dict Spaces -> MultiAction.
- elif dist_type in (
- MultiActionDistribution,
- TorchMultiActionDistribution,
- ) or isinstance(action_space, (Tuple, Dict)):
- return ModelCatalog._get_multi_action_distribution(
- (
- MultiActionDistribution
- if framework == "tf"
- else TorchMultiActionDistribution
- ),
- action_space,
- config,
- framework,
- )
- # Simplex -> Dirichlet.
- elif isinstance(action_space, Simplex):
- dist_cls = TorchDirichlet if framework == "torch" else Dirichlet
- # MultiDiscrete -> MultiCategorical.
- elif isinstance(action_space, MultiDiscrete):
- dist_cls = (
- TorchMultiCategorical if framework == "torch" else MultiCategorical
- )
- return partial(dist_cls, input_lens=action_space.nvec), int(
- sum(action_space.nvec)
- )
- # Unknown type -> Error.
- else:
- raise NotImplementedError(
- "Unsupported args: {} {}".format(action_space, dist_type)
- )
- return dist_cls, int(dist_cls.required_model_output_shape(action_space, config))
- @staticmethod
- @DeveloperAPI
- def get_action_shape(
- action_space: gym.Space, framework: str = "tf"
- ) -> (np.dtype, List[int]):
- """Returns action tensor dtype and shape for the action space.
- Args:
- action_space: Action space of the target gym env.
- framework: The framework identifier. One of "tf" or "torch".
- Returns:
- (dtype, shape): Dtype and shape of the actions tensor.
- """
- dl_lib = torch if framework == "torch" else tf
- if isinstance(action_space, Discrete):
- return action_space.dtype, (None,)
- elif isinstance(action_space, (Box, Simplex)):
- if np.issubdtype(action_space.dtype, np.floating):
- return dl_lib.float32, (None,) + action_space.shape
- elif np.issubdtype(action_space.dtype, np.integer):
- return dl_lib.int32, (None,) + action_space.shape
- else:
- raise ValueError("RLlib doesn't support non int or float box spaces")
- elif isinstance(action_space, MultiDiscrete):
- return action_space.dtype, (None,) + action_space.shape
- elif isinstance(action_space, (Tuple, Dict)):
- flat_action_space = flatten_space(action_space)
- size = 0
- all_discrete = True
- for i in range(len(flat_action_space)):
- if isinstance(flat_action_space[i], Discrete):
- size += 1
- else:
- all_discrete = False
- size += np.prod(flat_action_space[i].shape)
- size = int(size)
- return dl_lib.int32 if all_discrete else dl_lib.float32, (None, size)
- else:
- raise NotImplementedError(
- "Action space {} not supported".format(action_space)
- )
- @staticmethod
- @DeveloperAPI
- def get_action_placeholder(
- action_space: gym.Space, name: str = "action"
- ) -> TensorType:
- """Returns an action placeholder consistent with the action space
- Args:
- action_space: Action space of the target gym env.
- name: An optional string to name the placeholder by.
- Default: "action".
- Returns:
- action_placeholder: A placeholder for the actions
- """
- dtype, shape = ModelCatalog.get_action_shape(action_space, framework="tf")
- return tf1.placeholder(dtype, shape=shape, name=name)
- @staticmethod
- @DeveloperAPI
- def get_model_v2(
- obs_space: gym.Space,
- action_space: gym.Space,
- num_outputs: int,
- model_config: ModelConfigDict,
- framework: str = "tf",
- name: str = "default_model",
- model_interface: type = None,
- default_model: type = None,
- **model_kwargs
- ) -> ModelV2:
- """Returns a suitable model compatible with given spaces and output.
- Args:
- obs_space: Observation space of the target gym env. This
- may have an `original_space` attribute that specifies how to
- unflatten the tensor into a ragged tensor.
- action_space: Action space of the target gym env.
- num_outputs: The size of the output vector of the model.
- model_config: The "model" sub-config dict
- within the Algorithm's config dict.
- framework: One of "tf2", "tf", "torch", or "jax".
- name: Name (scope) for the model.
- model_interface: Interface required for the model
- default_model: Override the default class for the model. This
- only has an effect when not using a custom model
- model_kwargs: Args to pass to the ModelV2 constructor
- Returns:
- model (ModelV2): Model to use for the policy.
- """
- # Validate the given config dict.
- ModelCatalog._validate_config(
- config=model_config, action_space=action_space, framework=framework
- )
- if model_config.get("custom_model"):
- # Allow model kwargs to be overridden / augmented by
- # custom_model_config.
- customized_model_kwargs = dict(
- model_kwargs, **model_config.get("custom_model_config", {})
- )
- if isinstance(model_config["custom_model"], type):
- model_cls = model_config["custom_model"]
- elif (
- isinstance(model_config["custom_model"], str)
- and "." in model_config["custom_model"]
- ):
- return from_config(
- cls=model_config["custom_model"],
- obs_space=obs_space,
- action_space=action_space,
- num_outputs=num_outputs,
- model_config=customized_model_kwargs,
- name=name,
- )
- else:
- model_cls = _global_registry.get(
- RLLIB_MODEL, model_config["custom_model"]
- )
- # Only allow ModelV2 or native keras Models.
- if not issubclass(model_cls, ModelV2):
- if framework not in ["tf", "tf2"] or not issubclass(
- model_cls, tf.keras.Model
- ):
- raise ValueError(
- "`model_cls` must be a ModelV2 sub-class, but is"
- " {}!".format(model_cls)
- )
- logger.info("Wrapping {} as {}".format(model_cls, model_interface))
- model_cls = ModelCatalog._wrap_if_needed(model_cls, model_interface)
- if framework in ["tf2", "tf"]:
- # Try wrapping custom model with LSTM/attention, if required.
- if model_config.get("use_lstm") or model_config.get("use_attention"):
- from ray.rllib.models.tf.attention_net import (
- AttentionWrapper,
- )
- from ray.rllib.models.tf.recurrent_net import (
- LSTMWrapper,
- )
- wrapped_cls = model_cls
- forward = wrapped_cls.forward
- model_cls = ModelCatalog._wrap_if_needed(
- wrapped_cls,
- LSTMWrapper
- if model_config.get("use_lstm")
- else AttentionWrapper,
- )
- model_cls._wrapped_forward = forward
- # Obsolete: Track and warn if vars were created but not
- # registered. Only still do this, if users do register their
- # variables. If not (which they shouldn't), don't check here.
- created = set()
- def track_var_creation(next_creator, **kw):
- v = next_creator(**kw)
- created.add(v.ref())
- return v
- with tf.variable_creator_scope(track_var_creation):
- if issubclass(model_cls, tf.keras.Model):
- instance = model_cls(
- input_space=obs_space,
- action_space=action_space,
- num_outputs=num_outputs,
- name=name,
- **customized_model_kwargs,
- )
- else:
- # Try calling with kwargs first (custom ModelV2 should
- # accept these as kwargs, not get them from
- # config["custom_model_config"] anymore).
- try:
- instance = model_cls(
- obs_space,
- action_space,
- num_outputs,
- model_config,
- name,
- **customized_model_kwargs,
- )
- except TypeError as e:
- # Keyword error: Try old way w/o kwargs.
- if "__init__() got an unexpected " in e.args[0]:
- instance = model_cls(
- obs_space,
- action_space,
- num_outputs,
- model_config,
- name,
- **model_kwargs,
- )
- logger.warning(
- "Custom ModelV2 should accept all custom "
- "options as **kwargs, instead of expecting"
- " them in config['custom_model_config']!"
- )
- # Other error -> re-raise.
- else:
- raise e
- # User still registered TFModelV2's variables: Check, whether
- # ok.
- registered = []
- if not isinstance(instance, tf.keras.Model):
- registered = set(instance.var_list)
- if len(registered) > 0:
- not_registered = set()
- for var in created:
- if var not in registered:
- not_registered.add(var)
- if not_registered:
- raise ValueError(
- "It looks like you are still using "
- "`{}.register_variables()` to register your "
- "model's weights. This is no longer required, but "
- "if you are still calling this method at least "
- "once, you must make sure to register all created "
- "variables properly. The missing variables are {},"
- " and you only registered {}. "
- "Did you forget to call `register_variables()` on "
- "some of the variables in question?".format(
- instance, not_registered, registered
- )
- )
- elif framework == "torch":
- # Try wrapping custom model with LSTM/attention, if required.
- if model_config.get("use_lstm") or model_config.get("use_attention"):
- from ray.rllib.models.torch.attention_net import AttentionWrapper
- from ray.rllib.models.torch.recurrent_net import LSTMWrapper
- wrapped_cls = model_cls
- forward = wrapped_cls.forward
- model_cls = ModelCatalog._wrap_if_needed(
- wrapped_cls,
- LSTMWrapper
- if model_config.get("use_lstm")
- else AttentionWrapper,
- )
- model_cls._wrapped_forward = forward
- # PyTorch automatically tracks nn.Modules inside the parent
- # nn.Module's constructor.
- # Try calling with kwargs first (custom ModelV2 should
- # accept these as kwargs, not get them from
- # config["custom_model_config"] anymore).
- try:
- instance = model_cls(
- obs_space,
- action_space,
- num_outputs,
- model_config,
- name,
- **customized_model_kwargs,
- )
- except TypeError as e:
- # Keyword error: Try old way w/o kwargs.
- if "__init__() got an unexpected " in e.args[0]:
- instance = model_cls(
- obs_space,
- action_space,
- num_outputs,
- model_config,
- name,
- **model_kwargs,
- )
- logger.warning(
- "Custom ModelV2 should accept all custom "
- "options as **kwargs, instead of expecting"
- " them in config['custom_model_config']!"
- )
- # Other error -> re-raise.
- else:
- raise e
- else:
- raise NotImplementedError(
- "`framework` must be 'tf2|tf|torch', but is "
- "{}!".format(framework)
- )
- return instance
- # Find a default TFModelV2 and wrap with model_interface.
- if framework in ["tf", "tf2"]:
- v2_class = None
- # Try to get a default v2 model.
- if not model_config.get("custom_model"):
- v2_class = default_model or ModelCatalog._get_v2_model_class(
- obs_space, model_config, framework=framework
- )
- if not v2_class:
- raise ValueError("ModelV2 class could not be determined!")
- if model_config.get("use_lstm") or model_config.get("use_attention"):
- from ray.rllib.models.tf.attention_net import (
- AttentionWrapper,
- )
- from ray.rllib.models.tf.recurrent_net import (
- LSTMWrapper,
- )
- wrapped_cls = v2_class
- if model_config.get("use_lstm"):
- v2_class = ModelCatalog._wrap_if_needed(wrapped_cls, LSTMWrapper)
- v2_class._wrapped_forward = wrapped_cls.forward
- else:
- v2_class = ModelCatalog._wrap_if_needed(
- wrapped_cls, AttentionWrapper
- )
- v2_class._wrapped_forward = wrapped_cls.forward
- # Wrap in the requested interface.
- wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface)
- if issubclass(wrapper, tf.keras.Model):
- model = wrapper(
- input_space=obs_space,
- action_space=action_space,
- num_outputs=num_outputs,
- name=name,
- **dict(model_kwargs, **model_config),
- )
- return model
- return wrapper(
- obs_space, action_space, num_outputs, model_config, name, **model_kwargs
- )
- # Find a default TorchModelV2 and wrap with model_interface.
- elif framework == "torch":
- # Try to get a default v2 model.
- if not model_config.get("custom_model"):
- v2_class = default_model or ModelCatalog._get_v2_model_class(
- obs_space, model_config, framework=framework
- )
- if not v2_class:
- raise ValueError("ModelV2 class could not be determined!")
- if model_config.get("use_lstm") or model_config.get("use_attention"):
- from ray.rllib.models.torch.attention_net import AttentionWrapper
- from ray.rllib.models.torch.recurrent_net import LSTMWrapper
- wrapped_cls = v2_class
- forward = wrapped_cls.forward
- if model_config.get("use_lstm"):
- v2_class = ModelCatalog._wrap_if_needed(wrapped_cls, LSTMWrapper)
- else:
- v2_class = ModelCatalog._wrap_if_needed(
- wrapped_cls, AttentionWrapper
- )
- v2_class._wrapped_forward = forward
- # Wrap in the requested interface.
- wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface)
- return wrapper(
- obs_space, action_space, num_outputs, model_config, name, **model_kwargs
- )
- # Find a default JAXModelV2 and wrap with model_interface.
- elif framework == "jax":
- v2_class = default_model or ModelCatalog._get_v2_model_class(
- obs_space, model_config, framework=framework
- )
- # Wrap in the requested interface.
- wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface)
- return wrapper(
- obs_space, action_space, num_outputs, model_config, name, **model_kwargs
- )
- else:
- raise NotImplementedError(
- "`framework` must be 'tf2|tf|torch', but is {}!".format(framework)
- )
- @staticmethod
- @DeveloperAPI
- def get_preprocessor(
- env: gym.Env, options: Optional[dict] = None, include_multi_binary: bool = False
- ) -> Preprocessor:
- """Returns a suitable preprocessor for the given env.
- This is a wrapper for get_preprocessor_for_space().
- """
- return ModelCatalog.get_preprocessor_for_space(
- env.observation_space, options, include_multi_binary
- )
- @staticmethod
- @DeveloperAPI
- def get_preprocessor_for_space(
- observation_space: gym.Space,
- options: dict = None,
- include_multi_binary: bool = False,
- ) -> Preprocessor:
- """Returns a suitable preprocessor for the given observation space.
- Args:
- observation_space: The input observation space.
- options: Options to pass to the preprocessor.
- include_multi_binary: Whether to include the MultiBinaryPreprocessor in
- the possible preprocessors returned by this method.
- Returns:
- preprocessor: Preprocessor for the observations.
- """
- options = options or MODEL_DEFAULTS
- for k in options.keys():
- if k not in MODEL_DEFAULTS:
- raise Exception(
- "Unknown config key `{}`, all keys: {}".format(
- k, list(MODEL_DEFAULTS)
- )
- )
- cls = get_preprocessor(
- observation_space, include_multi_binary=include_multi_binary
- )
- prep = cls(observation_space, options)
- if prep is not None:
- logger.debug(
- "Created preprocessor {}: {} -> {}".format(
- prep, observation_space, prep.shape
- )
- )
- return prep
- @staticmethod
- @PublicAPI
- def register_custom_model(model_name: str, model_class: type) -> None:
- """Register a custom model class by name.
- The model can be later used by specifying {"custom_model": model_name}
- in the model config.
- Args:
- model_name: Name to register the model under.
- model_class: Python class of the model.
- """
- if tf is not None:
- if issubclass(model_class, tf.keras.Model):
- deprecation_warning(old="register_custom_model", error=False)
- _global_registry.register(RLLIB_MODEL, model_name, model_class)
- @staticmethod
- @PublicAPI
- def register_custom_action_dist(
- action_dist_name: str, action_dist_class: type
- ) -> None:
- """Register a custom action distribution class by name.
- The model can be later used by specifying
- {"custom_action_dist": action_dist_name} in the model config.
- Args:
- model_name: Name to register the action distribution under.
- model_class: Python class of the action distribution.
- """
- _global_registry.register(
- RLLIB_ACTION_DIST, action_dist_name, action_dist_class
- )
- @staticmethod
- def _wrap_if_needed(model_cls: type, model_interface: type) -> type:
- if not model_interface or issubclass(model_cls, model_interface):
- return model_cls
- assert issubclass(model_cls, ModelV2), model_cls
- class wrapper(model_interface, model_cls):
- pass
- name = "{}_as_{}".format(model_cls.__name__, model_interface.__name__)
- wrapper.__name__ = name
- wrapper.__qualname__ = name
- return wrapper
- @staticmethod
- def _get_v2_model_class(
- input_space: gym.Space, model_config: ModelConfigDict, framework: str = "tf"
- ) -> Type[ModelV2]:
- VisionNet = None
- ComplexNet = None
- if framework in ["tf2", "tf"]:
- from ray.rllib.models.tf.complex_input_net import (
- ComplexInputNetwork as ComplexNet,
- )
- from ray.rllib.models.tf.fcnet import (
- FullyConnectedNetwork as FCNet,
- )
- from ray.rllib.models.tf.visionnet import (
- VisionNetwork as VisionNet,
- )
- elif framework == "torch":
- from ray.rllib.models.torch.complex_input_net import (
- ComplexInputNetwork as ComplexNet,
- )
- from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as FCNet
- from ray.rllib.models.torch.visionnet import VisionNetwork as VisionNet
- elif framework == "jax":
- from ray.rllib.models.jax.fcnet import FullyConnectedNetwork as FCNet
- else:
- raise ValueError(
- "framework={} not supported in `ModelCatalog._get_v2_model_"
- "class`!".format(framework)
- )
- orig_space = (
- input_space
- if not hasattr(input_space, "original_space")
- else input_space.original_space
- )
- # `input_space` is 3D Box -> VisionNet.
- if isinstance(input_space, Box) and len(input_space.shape) == 3:
- if framework == "jax":
- raise NotImplementedError("No non-FC default net for JAX yet!")
- return VisionNet
- # `input_space` is 1D Box -> FCNet.
- elif (
- isinstance(input_space, Box)
- and len(input_space.shape) == 1
- and (
- not isinstance(orig_space, (Dict, Tuple))
- or not any(
- isinstance(s, Box) and len(s.shape) >= 2
- for s in flatten_space(orig_space)
- )
- )
- ):
- return FCNet
- # Complex (Dict, Tuple, 2D Box (flatten), Discrete, MultiDiscrete).
- else:
- if framework == "jax":
- raise NotImplementedError("No non-FC default net for JAX yet!")
- return ComplexNet
- @staticmethod
- def _get_multi_action_distribution(dist_class, action_space, config, framework):
- # In case the custom distribution is a child of MultiActionDistr.
- # If users want to completely ignore the suggested child
- # distributions, they should simply do so in their custom class'
- # constructor.
- if issubclass(
- dist_class, (MultiActionDistribution, TorchMultiActionDistribution)
- ):
- flat_action_space = flatten_space(action_space)
- child_dists_and_in_lens = tree.map_structure(
- lambda s: ModelCatalog.get_action_dist(s, config, framework=framework),
- flat_action_space,
- )
- child_dists = [e[0] for e in child_dists_and_in_lens]
- input_lens = [int(e[1]) for e in child_dists_and_in_lens]
- return (
- partial(
- dist_class,
- action_space=action_space,
- child_distributions=child_dists,
- input_lens=input_lens,
- ),
- int(sum(input_lens)),
- )
- return dist_class, dist_class.required_model_output_shape(action_space, config)
- @staticmethod
- def _validate_config(
- config: ModelConfigDict, action_space: gym.spaces.Space, framework: str
- ) -> None:
- """Validates a given model config dict.
- Args:
- config: The "model" sub-config dict
- within the Algorithm's config dict.
- action_space: The action space of the model, whose config are
- validated.
- framework: One of "jax", "tf2", "tf", or "torch".
- Raises:
- ValueError: If something is wrong with the given config.
- """
- # Soft-deprecate custom preprocessors.
- if config.get("custom_preprocessor") is not None:
- deprecation_warning(
- old="model.custom_preprocessor",
- new="gym.ObservationWrapper around your env or handle complex "
- "inputs inside your Model",
- error=True,
- )
- if config.get("use_attention") and config.get("use_lstm"):
- raise ValueError(
- "Only one of `use_lstm` or `use_attention` may be set to True!"
- )
- # For complex action spaces, only allow prev action inputs to
- # LSTMs and attention nets iff `_disable_action_flattening=True`.
- # TODO: `_disable_action_flattening=True` will be the default in
- # the future.
- if (
- (
- config.get("lstm_use_prev_action")
- or config.get("attention_use_n_prev_actions", 0) > 0
- )
- and not config.get("_disable_action_flattening")
- and isinstance(action_space, (Tuple, Dict))
- ):
- raise ValueError(
- "For your complex action space (Tuple|Dict) and your model's "
- "`prev-actions` setup of your model, you must set "
- "`_disable_action_flattening=True` in your main config dict!"
- )
- if framework == "jax":
- if config.get("use_attention"):
- raise ValueError(
- "`use_attention` not available for framework=jax so far!"
- )
- elif config.get("use_lstm"):
- raise ValueError("`use_lstm` not available for framework=jax so far!")
|