preprocessors.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. import logging
  2. from collections import OrderedDict
  3. from typing import Any, List
  4. import gymnasium as gym
  5. import numpy as np
  6. from ray.rllib.utils.annotations import OldAPIStack, override
  7. from ray.rllib.utils.images import resize
  8. from ray.rllib.utils.spaces.repeated import Repeated
  9. from ray.rllib.utils.spaces.space_utils import convert_element_to_space_type
  10. from ray.rllib.utils.typing import TensorType
  11. ATARI_OBS_SHAPE = (210, 160, 3)
  12. ATARI_RAM_OBS_SHAPE = (128,)
  13. # Only validate env observations vs the observation space every n times in a
  14. # Preprocessor.
  15. OBS_VALIDATION_INTERVAL = 100
  16. logger = logging.getLogger(__name__)
  17. @OldAPIStack
  18. class Preprocessor:
  19. """Defines an abstract observation preprocessor function.
  20. Attributes:
  21. shape (List[int]): Shape of the preprocessed output.
  22. """
  23. def __init__(self, obs_space: gym.Space, options: dict = None):
  24. _legacy_patch_shapes(obs_space)
  25. self._obs_space = obs_space
  26. if not options:
  27. from ray.rllib.models.catalog import MODEL_DEFAULTS
  28. self._options = MODEL_DEFAULTS.copy()
  29. else:
  30. self._options = options
  31. self.shape = self._init_shape(obs_space, self._options)
  32. self._size = int(np.prod(self.shape))
  33. self._i = 0
  34. self._obs_for_type_matching = self._obs_space.sample()
  35. def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]:
  36. """Returns the shape after preprocessing."""
  37. raise NotImplementedError
  38. def transform(self, observation: TensorType) -> np.ndarray:
  39. """Returns the preprocessed observation."""
  40. raise NotImplementedError
  41. def write(self, observation: TensorType, array: np.ndarray, offset: int) -> None:
  42. """Alternative to transform for more efficient flattening."""
  43. array[offset : offset + self._size] = self.transform(observation)
  44. def check_shape(self, observation: Any) -> None:
  45. """Checks the shape of the given observation."""
  46. if self._i % OBS_VALIDATION_INTERVAL == 0:
  47. # Convert lists to np.ndarrays.
  48. if type(observation) is list and isinstance(
  49. self._obs_space, gym.spaces.Box
  50. ):
  51. observation = np.array(observation).astype(np.float32)
  52. if not self._obs_space.contains(observation):
  53. observation = convert_element_to_space_type(
  54. observation, self._obs_for_type_matching
  55. )
  56. try:
  57. if not self._obs_space.contains(observation):
  58. raise ValueError(
  59. "Observation ({} dtype={}) outside given space ({})!".format(
  60. observation,
  61. observation.dtype
  62. if isinstance(self._obs_space, gym.spaces.Box)
  63. else None,
  64. self._obs_space,
  65. )
  66. )
  67. except AttributeError as e:
  68. raise ValueError(
  69. "Observation for a Box/MultiBinary/MultiDiscrete space "
  70. "should be an np.array, not a Python list.",
  71. observation,
  72. ) from e
  73. self._i += 1
  74. @property
  75. def size(self) -> int:
  76. return self._size
  77. @property
  78. def observation_space(self) -> gym.Space:
  79. obs_space = gym.spaces.Box(-1.0, 1.0, self.shape, dtype=np.float32)
  80. # Stash the unwrapped space so that we can unwrap dict and tuple spaces
  81. # automatically in modelv2.py
  82. classes = (
  83. DictFlatteningPreprocessor,
  84. OneHotPreprocessor,
  85. RepeatedValuesPreprocessor,
  86. TupleFlatteningPreprocessor,
  87. AtariRamPreprocessor,
  88. GenericPixelPreprocessor,
  89. )
  90. if isinstance(self, classes):
  91. obs_space.original_space = self._obs_space
  92. return obs_space
  93. @OldAPIStack
  94. class GenericPixelPreprocessor(Preprocessor):
  95. """Generic image preprocessor.
  96. Note: for Atari games, use config {"preprocessor_pref": "deepmind"}
  97. instead for deepmind-style Atari preprocessing.
  98. """
  99. @override(Preprocessor)
  100. def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]:
  101. self._grayscale = options.get("grayscale")
  102. self._zero_mean = options.get("zero_mean")
  103. self._dim = options.get("dim")
  104. if self._grayscale:
  105. shape = (self._dim, self._dim, 1)
  106. else:
  107. shape = (self._dim, self._dim, 3)
  108. return shape
  109. @override(Preprocessor)
  110. def transform(self, observation: TensorType) -> np.ndarray:
  111. """Downsamples images from (210, 160, 3) by the configured factor."""
  112. self.check_shape(observation)
  113. scaled = observation[25:-25, :, :]
  114. if self._dim < 84:
  115. scaled = resize(scaled, height=84, width=84)
  116. # OpenAI: Resize by half, then down to 42x42 (essentially mipmapping).
  117. # If we resize directly we lose pixels that, when mapped to 42x42,
  118. # aren't close enough to the pixel boundary.
  119. scaled = resize(scaled, height=self._dim, width=self._dim)
  120. if self._grayscale:
  121. scaled = scaled.mean(2)
  122. scaled = scaled.astype(np.float32)
  123. # Rescale needed for maintaining 1 channel
  124. scaled = np.reshape(scaled, [self._dim, self._dim, 1])
  125. if self._zero_mean:
  126. scaled = (scaled - 128) / 128
  127. else:
  128. scaled *= 1.0 / 255.0
  129. return scaled
  130. @OldAPIStack
  131. class AtariRamPreprocessor(Preprocessor):
  132. @override(Preprocessor)
  133. def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]:
  134. return (128,)
  135. @override(Preprocessor)
  136. def transform(self, observation: TensorType) -> np.ndarray:
  137. self.check_shape(observation)
  138. return (observation.astype("float32") - 128) / 128
  139. @OldAPIStack
  140. class OneHotPreprocessor(Preprocessor):
  141. """One-hot preprocessor for Discrete and MultiDiscrete spaces.
  142. .. testcode::
  143. :skipif: True
  144. self.transform(Discrete(3).sample())
  145. .. testoutput::
  146. np.array([0.0, 1.0, 0.0])
  147. .. testcode::
  148. :skipif: True
  149. self.transform(MultiDiscrete([2, 3]).sample())
  150. .. testoutput::
  151. np.array([0.0, 1.0, 0.0, 0.0, 1.0])
  152. """
  153. @override(Preprocessor)
  154. def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]:
  155. if isinstance(obs_space, gym.spaces.Discrete):
  156. return (self._obs_space.n,)
  157. else:
  158. return (np.sum(self._obs_space.nvec),)
  159. @override(Preprocessor)
  160. def transform(self, observation: TensorType) -> np.ndarray:
  161. self.check_shape(observation)
  162. return gym.spaces.utils.flatten(self._obs_space, observation).astype(np.float32)
  163. @override(Preprocessor)
  164. def write(self, observation: TensorType, array: np.ndarray, offset: int) -> None:
  165. array[offset : offset + self.size] = self.transform(observation)
  166. @OldAPIStack
  167. class NoPreprocessor(Preprocessor):
  168. @override(Preprocessor)
  169. def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]:
  170. return self._obs_space.shape
  171. @override(Preprocessor)
  172. def transform(self, observation: TensorType) -> np.ndarray:
  173. self.check_shape(observation)
  174. return observation
  175. @override(Preprocessor)
  176. def write(self, observation: TensorType, array: np.ndarray, offset: int) -> None:
  177. array[offset : offset + self._size] = np.array(observation, copy=False).ravel()
  178. @property
  179. @override(Preprocessor)
  180. def observation_space(self) -> gym.Space:
  181. return self._obs_space
  182. @OldAPIStack
  183. class MultiBinaryPreprocessor(Preprocessor):
  184. """Preprocessor that turns a MultiBinary space into a Box.
  185. Note: Before RLModules were introduced, RLlib's ModelCatalogV2 would produce
  186. ComplexInputNetworks that treat MultiBinary spaces as Boxes. This preprocessor is
  187. needed to get rid of the ComplexInputNetworks and use RLModules instead because
  188. RLModules lack the logic to handle MultiBinary or other non-Box spaces.
  189. """
  190. @override(Preprocessor)
  191. def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]:
  192. return self._obs_space.shape
  193. @override(Preprocessor)
  194. def transform(self, observation: TensorType) -> np.ndarray:
  195. # The shape stays the same, but the dtype changes.
  196. self.check_shape(observation)
  197. return observation.astype(np.float32)
  198. @override(Preprocessor)
  199. def write(self, observation: TensorType, array: np.ndarray, offset: int) -> None:
  200. array[offset : offset + self._size] = np.array(observation, copy=False).ravel()
  201. @property
  202. @override(Preprocessor)
  203. def observation_space(self) -> gym.Space:
  204. obs_space = gym.spaces.Box(0.0, 1.0, self.shape, dtype=np.float32)
  205. obs_space.original_space = self._obs_space
  206. return obs_space
  207. @OldAPIStack
  208. class TupleFlatteningPreprocessor(Preprocessor):
  209. """Preprocesses each tuple element, then flattens it all into a vector.
  210. RLlib models will unpack the flattened output before _build_layers_v2().
  211. """
  212. @override(Preprocessor)
  213. def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]:
  214. assert isinstance(self._obs_space, gym.spaces.Tuple)
  215. size = 0
  216. self.preprocessors = []
  217. for i in range(len(self._obs_space.spaces)):
  218. space = self._obs_space.spaces[i]
  219. logger.debug("Creating sub-preprocessor for {}".format(space))
  220. preprocessor_class = get_preprocessor(space)
  221. if preprocessor_class is not None:
  222. preprocessor = preprocessor_class(space, self._options)
  223. size += preprocessor.size
  224. else:
  225. preprocessor = None
  226. size += int(np.prod(space.shape))
  227. self.preprocessors.append(preprocessor)
  228. return (size,)
  229. @override(Preprocessor)
  230. def transform(self, observation: TensorType) -> np.ndarray:
  231. self.check_shape(observation)
  232. array = np.zeros(self.shape, dtype=np.float32)
  233. self.write(observation, array, 0)
  234. return array
  235. @override(Preprocessor)
  236. def write(self, observation: TensorType, array: np.ndarray, offset: int) -> None:
  237. assert len(observation) == len(self.preprocessors), observation
  238. for o, p in zip(observation, self.preprocessors):
  239. p.write(o, array, offset)
  240. offset += p.size
  241. @OldAPIStack
  242. class DictFlatteningPreprocessor(Preprocessor):
  243. """Preprocesses each dict value, then flattens it all into a vector.
  244. RLlib models will unpack the flattened output before _build_layers_v2().
  245. """
  246. @override(Preprocessor)
  247. def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]:
  248. assert isinstance(self._obs_space, gym.spaces.Dict)
  249. size = 0
  250. self.preprocessors = []
  251. for space in self._obs_space.spaces.values():
  252. logger.debug("Creating sub-preprocessor for {}".format(space))
  253. preprocessor_class = get_preprocessor(space)
  254. if preprocessor_class is not None:
  255. preprocessor = preprocessor_class(space, self._options)
  256. size += preprocessor.size
  257. else:
  258. preprocessor = None
  259. size += int(np.prod(space.shape))
  260. self.preprocessors.append(preprocessor)
  261. return (size,)
  262. @override(Preprocessor)
  263. def transform(self, observation: TensorType) -> np.ndarray:
  264. self.check_shape(observation)
  265. array = np.zeros(self.shape, dtype=np.float32)
  266. self.write(observation, array, 0)
  267. return array
  268. @override(Preprocessor)
  269. def write(self, observation: TensorType, array: np.ndarray, offset: int) -> None:
  270. if not isinstance(observation, OrderedDict):
  271. observation = OrderedDict(sorted(observation.items()))
  272. assert len(observation) == len(self.preprocessors), (
  273. len(observation),
  274. len(self.preprocessors),
  275. )
  276. for o, p in zip(observation.values(), self.preprocessors):
  277. p.write(o, array, offset)
  278. offset += p.size
  279. @OldAPIStack
  280. class RepeatedValuesPreprocessor(Preprocessor):
  281. """Pads and batches the variable-length list value."""
  282. @override(Preprocessor)
  283. def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]:
  284. assert isinstance(self._obs_space, Repeated)
  285. child_space = obs_space.child_space
  286. self.child_preprocessor = get_preprocessor(child_space)(
  287. child_space, self._options
  288. )
  289. # The first slot encodes the list length.
  290. size = 1 + self.child_preprocessor.size * obs_space.max_len
  291. return (size,)
  292. @override(Preprocessor)
  293. def transform(self, observation: TensorType) -> np.ndarray:
  294. array = np.zeros(self.shape)
  295. if isinstance(observation, list):
  296. for elem in observation:
  297. self.child_preprocessor.check_shape(elem)
  298. else:
  299. pass # ValueError will be raised in write() below.
  300. self.write(observation, array, 0)
  301. return array
  302. @override(Preprocessor)
  303. def write(self, observation: TensorType, array: np.ndarray, offset: int) -> None:
  304. if not isinstance(observation, (list, np.ndarray)):
  305. raise ValueError(
  306. "Input for {} must be list type, got {}".format(self, observation)
  307. )
  308. elif len(observation) > self._obs_space.max_len:
  309. raise ValueError(
  310. "Input {} exceeds max len of space {}".format(
  311. observation, self._obs_space.max_len
  312. )
  313. )
  314. # The first slot encodes the list length.
  315. array[offset] = len(observation)
  316. for i, elem in enumerate(observation):
  317. offset_i = offset + 1 + i * self.child_preprocessor.size
  318. self.child_preprocessor.write(elem, array, offset_i)
  319. @OldAPIStack
  320. def get_preprocessor(space: gym.Space, include_multi_binary=False) -> type:
  321. """Returns an appropriate preprocessor class for the given space."""
  322. _legacy_patch_shapes(space)
  323. obs_shape = space.shape
  324. if isinstance(space, (gym.spaces.Discrete, gym.spaces.MultiDiscrete)):
  325. preprocessor = OneHotPreprocessor
  326. elif obs_shape == ATARI_OBS_SHAPE:
  327. logger.debug(
  328. "Defaulting to RLlib's GenericPixelPreprocessor because input "
  329. "space has the atari-typical shape {}. Turn this behaviour off by setting "
  330. "`preprocessor_pref=None` or "
  331. "`preprocessor_pref='deepmind'` or disabling the preprocessing API "
  332. "altogether with `_disable_preprocessor_api=True`.".format(ATARI_OBS_SHAPE)
  333. )
  334. preprocessor = GenericPixelPreprocessor
  335. elif obs_shape == ATARI_RAM_OBS_SHAPE:
  336. logger.debug(
  337. "Defaulting to RLlib's AtariRamPreprocessor because input "
  338. "space has the atari-typical shape {}. Turn this behaviour off by setting "
  339. "`preprocessor_pref=None` or "
  340. "`preprocessor_pref='deepmind' or disabling the preprocessing API "
  341. "altogether with `_disable_preprocessor_api=True`."
  342. "`.".format(ATARI_OBS_SHAPE)
  343. )
  344. preprocessor = AtariRamPreprocessor
  345. elif isinstance(space, gym.spaces.Tuple):
  346. preprocessor = TupleFlatteningPreprocessor
  347. elif isinstance(space, gym.spaces.Dict):
  348. preprocessor = DictFlatteningPreprocessor
  349. elif isinstance(space, Repeated):
  350. preprocessor = RepeatedValuesPreprocessor
  351. # We usually only want to include this when using RLModules
  352. elif isinstance(space, gym.spaces.MultiBinary) and include_multi_binary:
  353. preprocessor = MultiBinaryPreprocessor
  354. else:
  355. preprocessor = NoPreprocessor
  356. return preprocessor
  357. def _legacy_patch_shapes(space: gym.Space) -> List[int]:
  358. """Assigns shapes to spaces that don't have shapes.
  359. This is only needed for older gym versions that don't set shapes properly
  360. for Tuple and Discrete spaces.
  361. """
  362. if not hasattr(space, "shape"):
  363. if isinstance(space, gym.spaces.Discrete):
  364. space.shape = ()
  365. elif isinstance(space, gym.spaces.Tuple):
  366. shapes = []
  367. for s in space.spaces:
  368. shape = _legacy_patch_shapes(s)
  369. shapes.append(shape)
  370. space.shape = tuple(shapes)
  371. return space.shape