action_dist.py 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. import gymnasium as gym
  2. import numpy as np
  3. from ray.rllib.models.modelv2 import ModelV2
  4. from ray.rllib.utils.annotations import OldAPIStack
  5. from ray.rllib.utils.typing import List, ModelConfigDict, TensorType, Union
  6. @OldAPIStack
  7. class ActionDistribution:
  8. """The policy action distribution of an agent.
  9. Attributes:
  10. inputs: input vector to compute samples from.
  11. model (ModelV2): reference to model producing the inputs.
  12. """
  13. def __init__(self, inputs: List[TensorType], model: ModelV2):
  14. """Initializes an ActionDist object.
  15. Args:
  16. inputs: input vector to compute samples from.
  17. model (ModelV2): reference to model producing the inputs. This
  18. is mainly useful if you want to use model variables to compute
  19. action outputs (i.e., for autoregressive action distributions,
  20. see examples/autoregressive_action_dist.py).
  21. """
  22. self.inputs = inputs
  23. self.model = model
  24. def sample(self) -> TensorType:
  25. """Draw a sample from the action distribution."""
  26. raise NotImplementedError
  27. def deterministic_sample(self) -> TensorType:
  28. """
  29. Get the deterministic "sampling" output from the distribution.
  30. This is usually the max likelihood output, i.e. mean for Normal, argmax
  31. for Categorical, etc..
  32. """
  33. raise NotImplementedError
  34. def sampled_action_logp(self) -> TensorType:
  35. """Returns the log probability of the last sampled action."""
  36. raise NotImplementedError
  37. def logp(self, x: TensorType) -> TensorType:
  38. """The log-likelihood of the action distribution."""
  39. raise NotImplementedError
  40. def kl(self, other: "ActionDistribution") -> TensorType:
  41. """The KL-divergence between two action distributions."""
  42. raise NotImplementedError
  43. def entropy(self) -> TensorType:
  44. """The entropy of the action distribution."""
  45. raise NotImplementedError
  46. def multi_kl(self, other: "ActionDistribution") -> TensorType:
  47. """The KL-divergence between two action distributions.
  48. This differs from kl() in that it can return an array for
  49. MultiDiscrete. TODO(ekl) consider removing this.
  50. """
  51. return self.kl(other)
  52. def multi_entropy(self) -> TensorType:
  53. """The entropy of the action distribution.
  54. This differs from entropy() in that it can return an array for
  55. MultiDiscrete. TODO(ekl) consider removing this.
  56. """
  57. return self.entropy()
  58. @staticmethod
  59. @OldAPIStack
  60. def required_model_output_shape(
  61. action_space: gym.Space, model_config: ModelConfigDict
  62. ) -> Union[int, np.ndarray]:
  63. """Returns the required shape of an input parameter tensor for a
  64. particular action space and an optional dict of distribution-specific
  65. options.
  66. Args:
  67. action_space (gym.Space): The action space this distribution will
  68. be used for, whose shape attributes will be used to determine
  69. the required shape of the input parameter tensor.
  70. model_config: Model's config dict (as defined in catalog.py)
  71. Returns:
  72. model_output_shape (int or np.ndarray of ints): size of the
  73. required input vector (minus leading batch dimension).
  74. """
  75. raise NotImplementedError