typing.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. from typing import (
  2. TYPE_CHECKING,
  3. Any,
  4. Callable,
  5. Dict,
  6. Hashable,
  7. List,
  8. Optional,
  9. Sequence,
  10. Tuple,
  11. Type,
  12. TypeVar,
  13. Union,
  14. )
  15. import gymnasium as gym
  16. from ray.rllib.utils.annotations import OldAPIStack
  17. if TYPE_CHECKING:
  18. # Modules might be missing but supply users with type hints if they are installed.
  19. import jax.numpy as jnp
  20. import keras
  21. import tensorflow as tf
  22. import torch
  23. from numpy.typing import NDArray
  24. from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
  25. from ray.rllib.core.rl_module.rl_module import RLModuleSpec
  26. from ray.rllib.env.env_context import EnvContext
  27. from ray.rllib.env.multi_agent_episode import MultiAgentEpisode
  28. from ray.rllib.env.single_agent_episode import SingleAgentEpisode
  29. from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2
  30. from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2
  31. from ray.rllib.policy.policy import PolicySpec
  32. from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
  33. from ray.rllib.policy.view_requirement import ViewRequirement
  34. TensorType = Union["NDArray[Any]", "jnp.ndarray", "tf.Tensor", "torch.Tensor"]
  35. """
  36. Represents a generic tensor type.
  37. This could be an np.ndarray, jnp.ndarray, tf.Tensor, or a torch.Tensor.
  38. """
  39. TensorStructType = Union[TensorType, dict, tuple]
  40. """Either a plain tensor, or a dict or tuple of tensors (or StructTensors)."""
  41. # A shape of a tensor.
  42. TensorShape = Union[Tuple[int, ...], List[int]]
  43. NetworkType = Union["torch.nn.Module", "keras.Model"]
  44. """A neural network."""
  45. DeviceType = Union[str, "torch.device", "int"]
  46. """
  47. A device identifier, which can be a string (e.g. 'cpu', 'cuda:0'),
  48. a torch.device object, or other types supported by torch.
  49. """
  50. RLModuleSpecType = Union["RLModuleSpec", "MultiRLModuleSpec"]
  51. """An RLModule spec (single-agent or multi-agent)."""
  52. StateDict = Dict[str, Any]
  53. """A state dict of an RLlib component (e.g. EnvRunner, Learner, RLModule)."""
  54. AlgorithmConfigDict = dict # @OldAPIStack
  55. """
  56. Represents a fully filled out config of a Algorithm class.
  57. Note:
  58. Policy config dicts are usually the same as AlgorithmConfigDict, but
  59. parts of it may sometimes be altered in e.g. a multi-agent setup,
  60. where we have >1 Policies in the same Algorithm.
  61. """
  62. PartialAlgorithmConfigDict = dict # @OldAPIStack
  63. """
  64. An algorithm config dict that only has overrides. It needs to be combined with
  65. the default algorithm config to be used.
  66. """
  67. ModelConfigDict = dict # @OldAPIStack
  68. """
  69. Represents the model config sub-dict of the algo config that is passed to the
  70. model catalog.
  71. """
  72. ConvFilterSpec = List[
  73. Tuple[int, Union[int, Tuple[int, int]], Union[int, Tuple[int, int]]]
  74. ]
  75. """
  76. Conv2D configuration format. Each entry in the outer list represents one Conv2D
  77. layer. Each inner list has the format: [num_output_filters, kernel, stride], where
  78. kernel and stride may be single ints (width and height are the same) or 2-tuples
  79. (int, int) for width and height (different values).
  80. """
  81. FromConfigSpec = Union[Dict[str, Union[Any, type, str]], type, str]
  82. """
  83. Objects that can be created through the `from_config()` util method
  84. need a config dict with a "type" key, a class path (str), or a type directly.
  85. """
  86. EnvConfigDict = dict
  87. """
  88. Represents the env_config sub-dict of the algo config that is passed to
  89. the env constructor.
  90. """
  91. EnvID = Union[int, str]
  92. """
  93. Represents an environment id. These could be:
  94. - An int index for a sub-env within a vectorized env.
  95. - An external env ID (str), which changes(!) each episode.
  96. """
  97. # TODO (sven): Specify this type more strictly (it should just be gym.Env).
  98. EnvType = Union[Any, gym.Env]
  99. """
  100. Represents a BaseEnv, MultiAgentEnv, ExternalEnv, ExternalMultiAgentEnv,
  101. VectorEnv, gym.Env, or ActorHandle.
  102. """
  103. EnvCreator = Callable[["EnvContext"], Optional[EnvType]]
  104. """
  105. A callable, taking a EnvContext object
  106. (config dict + properties: `worker_index`, `vector_index`, `num_workers`,
  107. and `remote`) and returning an env object (or None if no env is used).
  108. """
  109. AgentID = Hashable
  110. """Represents a generic identifier for an agent (e.g., "agent1")."""
  111. PolicyID = str # @OldAPIStack
  112. """Represents a generic identifier for a policy (e.g., "pol1")."""
  113. ModuleID = str
  114. """Represents a generic identifier for a (single-agent) RLModule."""
  115. MultiAgentPolicyConfigDict = Dict[PolicyID, "PolicySpec"] # @OldAPIStack
  116. """Type of the config.policies dict for multi-agent training."""
  117. EpisodeType = Union["SingleAgentEpisode", "MultiAgentEpisode"]
  118. """A new stack Episode type: Either single-agent or multi-agent."""
  119. # @ OldAPIStack
  120. IsPolicyToTrain = Callable[[PolicyID, Optional["MultiAgentBatch"]], bool]
  121. """Is Policy to train callable."""
  122. AgentToModuleMappingFn = Callable[[AgentID, EpisodeType], ModuleID]
  123. """Function describing an agent to module mapping."""
  124. ShouldModuleBeUpdatedFn = Union[
  125. Sequence[ModuleID],
  126. Callable[[ModuleID, Optional["MultiAgentBatch"]], bool],
  127. ]
  128. """
  129. ModuleIDs that should be updated
  130. or a callable to return whether a module should be updated.
  131. """
  132. PolicyState = Dict[str, TensorStructType] # @OldAPIStack
  133. """
  134. State dict of a Policy, mapping strings (e.g. "weights")
  135. to some state data (TensorStructType).
  136. """
  137. TFPolicyV2Type = Type[Union["DynamicTFPolicyV2", "EagerTFPolicyV2"]] # @OldAPIStack
  138. """Any tf Policy type (static-graph or eager Policy)."""
  139. EpisodeID = Union[int, str]
  140. """Represents an episode id (old and new API stack)."""
  141. UnrollID = int # @OldAPIStack
  142. """Represents an "unroll" (maybe across different sub-envs in a vector env)."""
  143. MultiAgentDict = Dict[AgentID, Any]
  144. """A dict keyed by agent ids, e.g. {"agent-1": value}."""
  145. MultiEnvDict = Dict[EnvID, MultiAgentDict]
  146. """
  147. A dict keyed by env ids that contain further nested dictionaries keyed by agent
  148. ids. e.g., {"env-1": {"agent-1": value}}.
  149. """
  150. EnvObsType = Any
  151. """Represents an observation returned from the env. (Any alias)"""
  152. EnvActionType = Any
  153. """Represents an action passed to the env. (Any alias)"""
  154. EnvInfoDict = dict
  155. """
  156. Info dictionary returned by calling `reset()` or `step()` on `gymnasium.Env`
  157. instances. Might be an empty dict.
  158. """
  159. FileType = Any
  160. """Represents a File object. (Any alias)"""
  161. ViewRequirementsDict = Dict[str, "ViewRequirement"] # @OldAPIStack
  162. """
  163. Represents a ViewRequirements dict mapping column names (str) to ViewRequirement
  164. objects.
  165. """
  166. ResultDict = Dict
  167. """
  168. Represents the result dict returned by Algorithm.train() and algorithm components,
  169. such as EnvRunners, LearnerGroup, etc.. Also, the MetricsLogger used by all these
  170. components returns this upon its `reduce()` method call, so a ResultDict can further
  171. be accumulated (and reduced again) by downstream components.
  172. """
  173. LocalOptimizer = Union["torch.optim.Optimizer", "keras.optimizers.Optimizer"]
  174. """A tf or torch local optimizer object."""
  175. Optimizer = LocalOptimizer
  176. """A tf or torch optimizer object."""
  177. Param = Union["torch.Tensor", "tf.Variable"]
  178. """A parameter, either a torch.Tensor or tf.Variable."""
  179. ParamRef = Hashable
  180. """A reference to a parameter. (Hashable alias)"""
  181. ParamDict = Dict[ParamRef, Param]
  182. """A dictionary mapping parameter references to parameters."""
  183. ParamList = List[Param]
  184. """A list of parameters."""
  185. NamedParamDict = Dict[str, Param]
  186. """A dictionary mapping parameter names to parameters."""
  187. LearningRateOrSchedule = Union[
  188. float,
  189. List[List[Union[int, float]]],
  190. List[Tuple[int, Union[int, float]]],
  191. ]
  192. """
  193. A single learning rate or a learning rate schedule (list of sub-lists, each of
  194. the format: [ts (int), lr_to_reach_by_ts (float)]).
  195. """
  196. GradInfoDict = dict
  197. """
  198. Dict of tensors returned by compute gradients on the policy, e.g.,
  199. {"td_error": [...], "learner_stats": {"vf_loss": ..., ...}},
  200. for multi-agent, {"policy1": {"learner_stats": ..., }, "policy2": ...}.
  201. """
  202. LearnerStatsDict = dict
  203. """
  204. Dict of learner stats returned by compute gradients on the policy, e.g.,
  205. {"vf_loss": ..., ...}. This will always be nested under the "learner_stats" key(s)
  206. of a GradInfoDict. In the multi-agent case, this will be keyed by policy id.
  207. """
  208. ModelGradients = Union[List[Tuple[TensorType, TensorType]], List[TensorType]]
  209. """
  210. List of grads+var tuples (tf) or list of gradient tensors (torch) representing
  211. model gradients and returned by compute_gradients().
  212. """
  213. ModelWeights = dict
  214. """Type of dict returned by get_weights() representing model weights."""
  215. ModelInputDict = Dict[str, TensorType]
  216. """An input dict used for direct ModelV2 calls."""
  217. SampleBatchType = Union["SampleBatch", "MultiAgentBatch", Dict[str, Any]]
  218. """Some kind of sample batch."""
  219. SpaceStruct = Union[
  220. gym.spaces.Space, Dict[str, gym.spaces.Space], Tuple[gym.spaces.Space, ...]
  221. ]
  222. """
  223. A (possibly nested) space struct: Either a gym.spaces.Space or a (possibly
  224. nested) dict|tuple of gym.space.Spaces.
  225. """
  226. StateBatches = List[List[Any]] # @OldAPIStack
  227. """
  228. A list of batches of RNN states.
  229. Each item in this list has dimension [B, S] (S=state vector size)
  230. """
  231. # __sphinx_doc_begin_policy_output_type__
  232. PolicyOutputType = Tuple[TensorStructType, StateBatches, Dict] # @OldAPIStack
  233. """Format of data output from policy forward pass."""
  234. # __sphinx_doc_end_policy_output_type__
  235. # __sphinx_doc_begin_agent_connector_data_type__
  236. @OldAPIStack
  237. class AgentConnectorDataType:
  238. """Data type that is fed into and yielded from agent connectors.
  239. Args:
  240. env_id: ID of the environment.
  241. agent_id: ID to help identify the agent from which the data is received.
  242. data: A payload (``data``). With RLlib's default sampler, the payload
  243. is a dictionary of arbitrary data columns (obs, rewards, terminateds,
  244. truncateds, etc).
  245. """
  246. def __init__(self, env_id: str, agent_id: str, data: Any):
  247. self.env_id = env_id
  248. self.agent_id = agent_id
  249. self.data = data
  250. # __sphinx_doc_end_agent_connector_data_type__
  251. # __sphinx_doc_begin_action_connector_output__
  252. @OldAPIStack
  253. class ActionConnectorDataType:
  254. """Data type that is fed into and yielded from agent connectors.
  255. Args:
  256. env_id: ID of the environment.
  257. agent_id: ID to help identify the agent from which the data is received.
  258. input_dict: Input data that was passed into the policy.
  259. Sometimes output must be adapted based on the input, for example
  260. action masking. So the entire input data structure is provided here.
  261. output: An object of PolicyOutputType. It is is composed of the
  262. action output, the internal state output, and additional data fetches.
  263. """
  264. def __init__(
  265. self,
  266. env_id: str,
  267. agent_id: str,
  268. input_dict: TensorStructType,
  269. output: PolicyOutputType,
  270. ):
  271. self.env_id = env_id
  272. self.agent_id = agent_id
  273. self.input_dict = input_dict
  274. self.output = output
  275. # __sphinx_doc_end_action_connector_output__
  276. # __sphinx_doc_begin_agent_connector_output__
  277. @OldAPIStack
  278. class AgentConnectorsOutput:
  279. """Final output data type of agent connectors.
  280. Args are populated depending on the AgentConnector settings.
  281. The branching happens in ViewRequirementAgentConnector.
  282. Args:
  283. raw_dict: The raw input dictionary that sampler can use to
  284. build episodes and training batches.
  285. This raw dict also gets passed into ActionConnectors in case
  286. it contains data useful for action adaptation (e.g. action masks).
  287. sample_batch: The SampleBatch that can be immediately used for
  288. querying the policy for next action.
  289. """
  290. def __init__(
  291. self, raw_dict: Dict[str, TensorStructType], sample_batch: "SampleBatch"
  292. ):
  293. self.raw_dict = raw_dict
  294. self.sample_batch = sample_batch
  295. # __sphinx_doc_end_agent_connector_output__
  296. # Generic type var.
  297. T = TypeVar("T")