repeated_values.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. from typing import List
  2. from ray.rllib.utils.annotations import OldAPIStack
  3. from ray.rllib.utils.typing import TensorStructType, TensorType
  4. @OldAPIStack
  5. class RepeatedValues:
  6. """Represents a variable-length list of items from spaces.Repeated.
  7. RepeatedValues are created when you use spaces.Repeated, and are
  8. accessible as part of input_dict["obs"] in ModelV2 forward functions.
  9. Example:
  10. Suppose the gym space definition was:
  11. Repeated(Repeated(Box(K), N), M)
  12. Then in the model forward function, input_dict["obs"] is of type:
  13. RepeatedValues(RepeatedValues(<Tensor shape=(B, M, N, K)>))
  14. The tensor is accessible via:
  15. input_dict["obs"].values.values
  16. And the actual data lengths via:
  17. # outer repetition, shape [B], range [0, M]
  18. input_dict["obs"].lengths
  19. -and-
  20. # inner repetition, shape [B, M], range [0, N]
  21. input_dict["obs"].values.lengths
  22. Attributes:
  23. values: The padded data tensor of shape [B, max_len, ..., sz],
  24. where B is the batch dimension, max_len is the max length of this
  25. list, followed by any number of sub list max lens, followed by the
  26. actual data size.
  27. lengths (List[int]): Tensor of shape [B, ...] that represents the
  28. number of valid items in each list. When the list is nested within
  29. other lists, there will be extra dimensions for the parent list
  30. max lens.
  31. max_len: The max number of items allowed in each list.
  32. TODO(ekl): support conversion to tf.RaggedTensor.
  33. """
  34. def __init__(self, values: TensorType, lengths: List[int], max_len: int):
  35. self.values = values
  36. self.lengths = lengths
  37. self.max_len = max_len
  38. self._unbatched_repr = None
  39. def unbatch_all(self) -> List[List[TensorType]]:
  40. """Unbatch both the repeat and batch dimensions into Python lists.
  41. This is only supported in PyTorch / TF eager mode.
  42. This lets you view the data unbatched in its original form, but is
  43. not efficient for processing.
  44. .. testcode::
  45. :skipif: True
  46. batch = RepeatedValues(<Tensor shape=(B, N, K)>)
  47. items = batch.unbatch_all()
  48. print(len(items) == B)
  49. .. testoutput::
  50. True
  51. .. testcode::
  52. :skipif: True
  53. print(max(len(x) for x in items) <= N)
  54. .. testoutput::
  55. True
  56. .. testcode::
  57. :skipif: True
  58. print(items)
  59. .. testoutput::
  60. [[<Tensor_1 shape=(K)>, ..., <Tensor_N, shape=(K)>],
  61. ...
  62. [<Tensor_1 shape=(K)>, <Tensor_2 shape=(K)>],
  63. ...
  64. [<Tensor_1 shape=(K)>],
  65. ...
  66. [<Tensor_1 shape=(K)>, ..., <Tensor_N shape=(K)>]]
  67. """
  68. if self._unbatched_repr is None:
  69. B = _get_batch_dim_helper(self.values)
  70. if B is None:
  71. raise ValueError(
  72. "Cannot call unbatch_all() when batch_dim is unknown. "
  73. "This is probably because you are using TF graph mode."
  74. )
  75. else:
  76. B = int(B)
  77. slices = self.unbatch_repeat_dim()
  78. result = []
  79. for i in range(B):
  80. if hasattr(self.lengths[i], "item"):
  81. dynamic_len = int(self.lengths[i].item())
  82. else:
  83. dynamic_len = int(self.lengths[i].numpy())
  84. dynamic_slice = []
  85. for j in range(dynamic_len):
  86. dynamic_slice.append(_batch_index_helper(slices, i, j))
  87. result.append(dynamic_slice)
  88. self._unbatched_repr = result
  89. return self._unbatched_repr
  90. def unbatch_repeat_dim(self) -> List[TensorType]:
  91. """Unbatches the repeat dimension (the one `max_len` in size).
  92. This removes the repeat dimension. The result will be a Python list of
  93. with length `self.max_len`. Note that the data is still padded.
  94. .. testcode::
  95. :skipif: True
  96. batch = RepeatedValues(<Tensor shape=(B, N, K)>)
  97. items = batch.unbatch()
  98. len(items) == batch.max_len
  99. .. testoutput::
  100. True
  101. .. testcode::
  102. :skipif: True
  103. print(items)
  104. .. testoutput::
  105. [<Tensor_1 shape=(B, K)>, ..., <Tensor_N shape=(B, K)>]
  106. """
  107. return _unbatch_helper(self.values, self.max_len)
  108. def __repr__(self):
  109. return "RepeatedValues(value={}, lengths={}, max_len={})".format(
  110. repr(self.values), repr(self.lengths), self.max_len
  111. )
  112. def __str__(self):
  113. return repr(self)
  114. def _get_batch_dim_helper(v: TensorStructType) -> int:
  115. """Tries to find the batch dimension size of v, or None."""
  116. if isinstance(v, dict):
  117. for u in v.values():
  118. return _get_batch_dim_helper(u)
  119. elif isinstance(v, tuple):
  120. return _get_batch_dim_helper(v[0])
  121. elif isinstance(v, RepeatedValues):
  122. return _get_batch_dim_helper(v.values)
  123. else:
  124. B = v.shape[0]
  125. if hasattr(B, "value"):
  126. B = B.value # TensorFlow
  127. return B
  128. def _unbatch_helper(v: TensorStructType, max_len: int) -> TensorStructType:
  129. """Recursively unpacks the repeat dimension (max_len)."""
  130. if isinstance(v, dict):
  131. return {k: _unbatch_helper(u, max_len) for (k, u) in v.items()}
  132. elif isinstance(v, tuple):
  133. return tuple(_unbatch_helper(u, max_len) for u in v)
  134. elif isinstance(v, RepeatedValues):
  135. unbatched = _unbatch_helper(v.values, max_len)
  136. return [
  137. RepeatedValues(u, v.lengths[:, i, ...], v.max_len)
  138. for i, u in enumerate(unbatched)
  139. ]
  140. else:
  141. return [v[:, i, ...] for i in range(max_len)]
  142. def _batch_index_helper(v: TensorStructType, i: int, j: int) -> TensorStructType:
  143. """Selects the item at the ith batch index and jth repetition."""
  144. if isinstance(v, dict):
  145. return {k: _batch_index_helper(u, i, j) for (k, u) in v.items()}
  146. elif isinstance(v, tuple):
  147. return tuple(_batch_index_helper(u, i, j) for u in v)
  148. elif isinstance(v, list):
  149. # This is the output of unbatch_repeat_dim(). Unfortunately we have to
  150. # process it here instead of in unbatch_all(), since it may be buried
  151. # under a dict / tuple.
  152. return _batch_index_helper(v[j], i, j)
  153. elif isinstance(v, RepeatedValues):
  154. unbatched = v.unbatch_all()
  155. # Don't need to select j here; that's already done in unbatch_all.
  156. return unbatched[i]
  157. else:
  158. return v[i, ...]