observation_function.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. from typing import Dict
  2. from ray.rllib.env import BaseEnv
  3. from ray.rllib.evaluation import RolloutWorker
  4. from ray.rllib.policy import Policy
  5. from ray.rllib.utils.annotations import OldAPIStack
  6. from ray.rllib.utils.framework import TensorType
  7. from ray.rllib.utils.typing import AgentID, PolicyID
  8. @OldAPIStack
  9. class ObservationFunction:
  10. """Interceptor function for rewriting observations from the environment.
  11. These callbacks can be used for preprocessing of observations, especially
  12. in multi-agent scenarios.
  13. Observation functions can be specified in the multi-agent config by
  14. specifying ``{"observation_fn": your_obs_func}``. Note that
  15. ``your_obs_func`` can be a plain Python function.
  16. This API is **experimental**.
  17. """
  18. def __call__(
  19. self,
  20. agent_obs: Dict[AgentID, TensorType],
  21. worker: RolloutWorker,
  22. base_env: BaseEnv,
  23. policies: Dict[PolicyID, Policy],
  24. episode,
  25. **kw
  26. ) -> Dict[AgentID, TensorType]:
  27. """Callback run on each environment step to observe the environment.
  28. This method takes in the original agent observation dict returned by
  29. a MultiAgentEnv, and returns a possibly modified one. It can be
  30. thought of as a "wrapper" around the environment.
  31. TODO(ekl): allow end-to-end differentiation through the observation
  32. function and policy losses.
  33. TODO(ekl): enable batch processing.
  34. Args:
  35. agent_obs: Dictionary of default observations from the
  36. environment. The default implementation of observe() simply
  37. returns this dict.
  38. worker: Reference to the current rollout worker.
  39. base_env: BaseEnv running the episode. The underlying
  40. sub environment objects (BaseEnvs are vectorized) can be
  41. retrieved by calling `base_env.get_sub_environments()`.
  42. policies: Mapping of policy id to policy objects. In single
  43. agent mode there will only be a single "default" policy.
  44. episode: Episode state object.
  45. kwargs: Forward compatibility placeholder.
  46. Returns:
  47. new_agent_obs: copy of agent obs with updates. You can
  48. rewrite or drop data from the dict if needed (e.g., the env
  49. can have a dummy "global" observation, and the observer can
  50. merge the global state into individual observations.
  51. .. testcode::
  52. :skipif: True
  53. # Observer that merges global state into individual obs. It is
  54. # rewriting the discrete obs into a tuple with global state.
  55. example_obs_fn1({"a": 1, "b": 2, "global_state": 101}, ...)
  56. .. testoutput::
  57. {"a": [1, 101], "b": [2, 101]}
  58. .. testcode::
  59. :skipif: True
  60. # Observer for e.g., custom centralized critic model. It is
  61. # rewriting the discrete obs into a dict with more data.
  62. example_obs_fn2({"a": 1, "b": 2}, ...)
  63. .. testoutput::
  64. {"a": {"self": 1, "other": 2}, "b": {"self": 2, "other": 1}}
  65. """
  66. return agent_obs