view_requirement.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. import dataclasses
  2. from typing import Dict, List, Optional, Union
  3. import gymnasium as gym
  4. import numpy as np
  5. from ray.rllib.utils.annotations import OldAPIStack
  6. from ray.rllib.utils.framework import try_import_torch
  7. from ray.rllib.utils.serialization import (
  8. gym_space_from_dict,
  9. gym_space_to_dict,
  10. )
  11. torch, _ = try_import_torch()
  12. @OldAPIStack
  13. @dataclasses.dataclass
  14. class ViewRequirement:
  15. """Single view requirement (for one column in an SampleBatch/input_dict).
  16. Policies and ModelV2s return a Dict[str, ViewRequirement] upon calling
  17. their `[train|inference]_view_requirements()` methods, where the str key
  18. represents the column name (C) under which the view is available in the
  19. input_dict/SampleBatch and ViewRequirement specifies the actual underlying
  20. column names (in the original data buffer), timestep shifts, and other
  21. options to build the view.
  22. .. testcode::
  23. :skipif: True
  24. from ray.rllib.models.modelv2 import ModelV2
  25. # The default ViewRequirement for a Model is:
  26. req = ModelV2(...).view_requirements
  27. print(req)
  28. .. testoutput::
  29. {"obs": ViewRequirement(shift=0)}
  30. Args:
  31. data_col: The data column name from the SampleBatch
  32. (str key). If None, use the dict key under which this
  33. ViewRequirement resides.
  34. space: The gym Space used in case we need to pad data
  35. in inaccessible areas of the trajectory (t<0 or t>H).
  36. Default: Simple box space, e.g. rewards.
  37. shift: Single shift value or
  38. list of relative positions to use (relative to the underlying
  39. `data_col`).
  40. Example: For a view column "prev_actions", you can set
  41. `data_col="actions"` and `shift=-1`.
  42. Example: For a view column "obs" in an Atari framestacking
  43. fashion, you can set `data_col="obs"` and
  44. `shift=[-3, -2, -1, 0]`.
  45. Example: For the obs input to an attention net, you can specify
  46. a range via a str: `shift="-100:0"`, which will pass in
  47. the past 100 observations plus the current one.
  48. index: An optional absolute position arg,
  49. used e.g. for the location of a requested inference dict within
  50. the trajectory. Negative values refer to counting from the end
  51. of a trajectory. (#TODO: Is this still used?)
  52. batch_repeat_value: determines how many time steps we should skip
  53. before we repeat the view indexing for the next timestep. For RNNs this
  54. number is usually the sequence length that we will rollout over.
  55. Example:
  56. view_col = "state_in_0", data_col = "state_out_0"
  57. batch_repeat_value = 5, shift = -1
  58. buffer["state_out_0"] = [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
  59. output["state_in_0"] = [-1, 4, 9]
  60. Explanation: For t=0, we output buffer["state_out_0"][-1]. We then skip 5
  61. time steps and repeat the view. for t=5, we output buffer["state_out_0"][4]
  62. . Continuing on this pattern, for t=10, we output buffer["state_out_0"][9].
  63. used_for_compute_actions: Whether the data will be used for
  64. creating input_dicts for `Policy.compute_actions()` calls (or
  65. `Policy.compute_actions_from_input_dict()`).
  66. used_for_training: Whether the data will be used for
  67. training. If False, the column will not be copied into the
  68. final train batch.
  69. """
  70. data_col: Optional[str] = None
  71. space: gym.Space = None
  72. shift: Union[int, str, List[int]] = 0
  73. index: Optional[int] = None
  74. batch_repeat_value: int = 1
  75. used_for_compute_actions: bool = True
  76. used_for_training: bool = True
  77. shift_arr: Optional[np.ndarray] = dataclasses.field(init=False)
  78. def __post_init__(self):
  79. """Initializes a ViewRequirement object.
  80. shift_arr is infered from the shift value.
  81. For example:
  82. - if shift is -1, then shift_arr is np.array([-1]).
  83. - if shift is [-1, -2], then shift_arr is np.array([-2, -1]).
  84. - if shift is "-2:2", then shift_arr is np.array([-2, -1, 0, 1, 2]).
  85. """
  86. if self.space is None:
  87. self.space = gym.spaces.Box(float("-inf"), float("inf"), shape=())
  88. # TODO: ideally we won't need shift_from and shift_to, and shift_step.
  89. # all of them should be captured within shift_arr.
  90. # Special case: Providing a (probably larger) range of indices, e.g.
  91. # "-100:0" (past 100 timesteps plus current one).
  92. self.shift_from = self.shift_to = self.shift_step = None
  93. if isinstance(self.shift, str):
  94. split = self.shift.split(":")
  95. assert len(split) in [2, 3], f"Invalid shift str format: {self.shift}"
  96. if len(split) == 2:
  97. f, t = split
  98. self.shift_step = 1
  99. else:
  100. f, t, s = split
  101. self.shift_step = int(s)
  102. self.shift_from = int(f)
  103. self.shift_to = int(t)
  104. shift = self.shift
  105. self.shfit_arr = None
  106. if self.shift_from:
  107. self.shift_arr = np.arange(
  108. self.shift_from, self.shift_to + 1, self.shift_step
  109. )
  110. else:
  111. if isinstance(shift, int):
  112. self.shift_arr = np.array([shift])
  113. elif isinstance(shift, list):
  114. self.shift_arr = np.array(shift)
  115. else:
  116. ValueError(f'unrecognized shift type: "{shift}"')
  117. def to_dict(self) -> Dict:
  118. """Return a dict for this ViewRequirement that can be JSON serialized."""
  119. return {
  120. "data_col": self.data_col,
  121. "space": gym_space_to_dict(self.space),
  122. "shift": self.shift,
  123. "index": self.index,
  124. "batch_repeat_value": self.batch_repeat_value,
  125. "used_for_training": self.used_for_training,
  126. "used_for_compute_actions": self.used_for_compute_actions,
  127. }
  128. @classmethod
  129. def from_dict(cls, d: Dict):
  130. """Construct a ViewRequirement instance from JSON deserialized dict."""
  131. d["space"] = gym_space_from_dict(d["space"])
  132. return cls(**d)