atari_wrappers.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. from collections import deque
  2. from typing import Optional, Union
  3. import gymnasium as gym
  4. import numpy as np
  5. from gymnasium import spaces
  6. from ray.rllib.utils.annotations import PublicAPI
  7. from ray.rllib.utils.images import resize, rgb2gray
  8. @PublicAPI
  9. def is_atari(env: Union[gym.Env, str]) -> bool:
  10. """Returns, whether a given env object or env descriptor (str) is an Atari env.
  11. Args:
  12. env: The gym.Env object or a string descriptor of the env (for example,
  13. "ale_py:ALE/Pong-v5").
  14. Returns:
  15. Whether `env` is an Atari environment.
  16. """
  17. # If a gym.Env, check proper spaces as well as occurrence of the "Atari<ALE" string
  18. # in the class name.
  19. if not isinstance(env, str):
  20. if (
  21. hasattr(env.observation_space, "shape")
  22. and env.observation_space.shape is not None
  23. and len(env.observation_space.shape) <= 2
  24. ):
  25. return False
  26. return "AtariEnv<ALE" in str(env)
  27. # If string, check for "ale_py:ALE/" prefix.
  28. else:
  29. return env.startswith("ALE/") or env.startswith("ale_py:")
  30. @PublicAPI
  31. def get_wrapper_by_cls(env, cls):
  32. """Returns the gym env wrapper of the given class, or None."""
  33. currentenv = env
  34. while True:
  35. if isinstance(currentenv, cls):
  36. return currentenv
  37. elif isinstance(currentenv, gym.Wrapper):
  38. currentenv = currentenv.env
  39. else:
  40. return None
  41. @PublicAPI
  42. class ClipRewardEnv(gym.RewardWrapper):
  43. def __init__(self, env):
  44. gym.RewardWrapper.__init__(self, env)
  45. def reward(self, reward):
  46. """Bin reward to {+1, 0, -1} by its sign."""
  47. return np.sign(reward)
  48. @PublicAPI
  49. class EpisodicLifeEnv(gym.Wrapper):
  50. def __init__(self, env):
  51. """Make end-of-life == end-of-episode, but only reset on true game over.
  52. Done by DeepMind for the DQN and co. since it helps value estimation.
  53. """
  54. gym.Wrapper.__init__(self, env)
  55. self.lives = 0
  56. self.was_real_terminated = True
  57. def step(self, action):
  58. obs, reward, terminated, truncated, info = self.env.step(action)
  59. self.was_real_terminated = terminated
  60. # check current lives, make loss of life terminal,
  61. # then update lives to handle bonus lives
  62. lives = self.env.unwrapped.ale.lives()
  63. if lives < self.lives and lives > 0:
  64. # for Qbert sometimes we stay in lives == 0 condtion for a few fr
  65. # so its important to keep lives > 0, so that we only reset once
  66. # the environment advertises `terminated`.
  67. terminated = True
  68. self.lives = lives
  69. return obs, reward, terminated, truncated, info
  70. def reset(self, **kwargs):
  71. """Reset only when lives are exhausted.
  72. This way all states are still reachable even though lives are episodic,
  73. and the learner need not know about any of this behind-the-scenes.
  74. """
  75. if self.was_real_terminated:
  76. obs, info = self.env.reset(**kwargs)
  77. else:
  78. # no-op step to advance from terminal/lost life state
  79. obs, _, _, _, info = self.env.step(0)
  80. self.lives = self.env.unwrapped.ale.lives()
  81. return obs, info
  82. @PublicAPI
  83. class FireResetEnv(gym.Wrapper):
  84. def __init__(self, env):
  85. """Take action on reset.
  86. For environments that are fixed until firing."""
  87. gym.Wrapper.__init__(self, env)
  88. assert env.unwrapped.get_action_meanings()[1] == "FIRE"
  89. assert len(env.unwrapped.get_action_meanings()) >= 3
  90. def reset(self, **kwargs):
  91. self.env.reset(**kwargs)
  92. obs, _, terminated, truncated, _ = self.env.step(1)
  93. if terminated or truncated:
  94. self.env.reset(**kwargs)
  95. obs, _, terminated, truncated, info = self.env.step(2)
  96. if terminated or truncated:
  97. self.env.reset(**kwargs)
  98. return obs, info
  99. def step(self, ac):
  100. return self.env.step(ac)
  101. @PublicAPI
  102. class FrameStack(gym.Wrapper):
  103. def __init__(self, env, k):
  104. """Stack k last frames."""
  105. gym.Wrapper.__init__(self, env)
  106. self.k = k
  107. self.frames = deque([], maxlen=k)
  108. shp = env.observation_space.shape
  109. self.observation_space = spaces.Box(
  110. low=np.repeat(env.observation_space.low, repeats=k, axis=-1),
  111. high=np.repeat(env.observation_space.high, repeats=k, axis=-1),
  112. shape=(shp[0], shp[1], shp[2] * k),
  113. dtype=env.observation_space.dtype,
  114. )
  115. def reset(self, *, seed=None, options=None):
  116. ob, infos = self.env.reset(seed=seed, options=options)
  117. for _ in range(self.k):
  118. self.frames.append(ob)
  119. return self._get_ob(), infos
  120. def step(self, action):
  121. ob, reward, terminated, truncated, info = self.env.step(action)
  122. self.frames.append(ob)
  123. return self._get_ob(), reward, terminated, truncated, info
  124. def _get_ob(self):
  125. assert len(self.frames) == self.k
  126. return np.concatenate(self.frames, axis=2)
  127. @PublicAPI
  128. class FrameStackTrajectoryView(gym.ObservationWrapper):
  129. def __init__(self, env):
  130. """No stacking. Trajectory View API takes care of this."""
  131. gym.Wrapper.__init__(self, env)
  132. shp = env.observation_space.shape
  133. assert shp[2] == 1
  134. self.observation_space = spaces.Box(
  135. low=0, high=255, shape=(shp[0], shp[1]), dtype=env.observation_space.dtype
  136. )
  137. def observation(self, observation):
  138. return np.squeeze(observation, axis=-1)
  139. @PublicAPI
  140. class MaxAndSkipEnv(gym.Wrapper):
  141. def __init__(self, env, skip=4):
  142. """Return only every `skip`-th frame"""
  143. gym.Wrapper.__init__(self, env)
  144. # most recent raw observations (for max pooling across time steps)
  145. self._obs_buffer = np.zeros(
  146. (2,) + env.observation_space.shape, dtype=env.observation_space.dtype
  147. )
  148. self._skip = skip
  149. def step(self, action):
  150. """Repeat action, sum reward, and max over last observations."""
  151. total_reward = 0.0
  152. terminated = truncated = info = None
  153. for i in range(self._skip):
  154. obs, reward, terminated, truncated, info = self.env.step(action)
  155. if i == self._skip - 2:
  156. self._obs_buffer[0] = obs
  157. if i == self._skip - 1:
  158. self._obs_buffer[1] = obs
  159. total_reward += reward
  160. if terminated or truncated:
  161. break
  162. # Note that the observation on the terminated|truncated=True frame
  163. # doesn't matter
  164. max_frame = self._obs_buffer.max(axis=0)
  165. return max_frame, total_reward, terminated, truncated, info
  166. def reset(self, **kwargs):
  167. return self.env.reset(**kwargs)
  168. @PublicAPI
  169. class MonitorEnv(gym.Wrapper):
  170. def __init__(self, env=None):
  171. """Record episodes stats prior to EpisodicLifeEnv, etc."""
  172. gym.Wrapper.__init__(self, env)
  173. self._current_reward = None
  174. self._num_steps = None
  175. self._total_steps = None
  176. self._episode_rewards = []
  177. self._episode_lengths = []
  178. self._num_episodes = 0
  179. self._num_returned = 0
  180. def reset(self, **kwargs):
  181. obs, info = self.env.reset(**kwargs)
  182. if self._total_steps is None:
  183. self._total_steps = sum(self._episode_lengths)
  184. if self._current_reward is not None:
  185. self._episode_rewards.append(self._current_reward)
  186. self._episode_lengths.append(self._num_steps)
  187. self._num_episodes += 1
  188. self._current_reward = 0
  189. self._num_steps = 0
  190. return obs, info
  191. def step(self, action):
  192. obs, rew, terminated, truncated, info = self.env.step(action)
  193. self._current_reward += rew
  194. self._num_steps += 1
  195. self._total_steps += 1
  196. return obs, rew, terminated, truncated, info
  197. def get_episode_rewards(self):
  198. return self._episode_rewards
  199. def get_episode_lengths(self):
  200. return self._episode_lengths
  201. def get_total_steps(self):
  202. return self._total_steps
  203. def next_episode_results(self):
  204. for i in range(self._num_returned, len(self._episode_rewards)):
  205. yield (self._episode_rewards[i], self._episode_lengths[i])
  206. self._num_returned = len(self._episode_rewards)
  207. @PublicAPI
  208. class NoopResetEnv(gym.Wrapper):
  209. def __init__(self, env, noop_max=30):
  210. """Sample initial states by taking random number of no-ops on reset.
  211. No-op is assumed to be action 0.
  212. """
  213. gym.Wrapper.__init__(self, env)
  214. self.noop_max = noop_max
  215. self.override_num_noops = None
  216. self.noop_action = 0
  217. assert env.unwrapped.get_action_meanings()[0] == "NOOP"
  218. def reset(self, **kwargs):
  219. """Do no-op action for a number of steps in [1, noop_max]."""
  220. self.env.reset(**kwargs)
  221. if self.override_num_noops is not None:
  222. noops = self.override_num_noops
  223. else:
  224. # This environment now uses the pcg64 random number generator which
  225. # does not have `randint` as an attribute only has `integers`.
  226. try:
  227. noops = self.unwrapped.np_random.integers(1, self.noop_max + 1)
  228. # Also still support older versions.
  229. except AttributeError:
  230. noops = self.unwrapped.np_random.randint(1, self.noop_max + 1)
  231. assert noops > 0
  232. obs = None
  233. for _ in range(noops):
  234. obs, _, terminated, truncated, info = self.env.step(self.noop_action)
  235. if terminated or truncated:
  236. obs, info = self.env.reset(**kwargs)
  237. return obs, info
  238. def step(self, ac):
  239. return self.env.step(ac)
  240. @PublicAPI
  241. class NormalizedImageEnv(gym.ObservationWrapper):
  242. def __init__(self, *args, **kwargs):
  243. super().__init__(*args, **kwargs)
  244. self.observation_space = gym.spaces.Box(
  245. -1.0,
  246. 1.0,
  247. shape=self.observation_space.shape,
  248. dtype=np.float32,
  249. )
  250. # Divide by scale and center around 0.0, such that observations are in the range
  251. # of -1.0 and 1.0.
  252. def observation(self, observation):
  253. return (observation.astype(np.float32) / 128.0) - 1.0
  254. @PublicAPI
  255. class GrayScaleAndResize(gym.ObservationWrapper):
  256. def __init__(self, env, dim, grayscale: bool = True):
  257. """Warp frames to the specified size (dim x dim)."""
  258. gym.ObservationWrapper.__init__(self, env)
  259. self.width = dim
  260. self.height = dim
  261. self.grayscale = grayscale
  262. self.observation_space = spaces.Box(
  263. low=0,
  264. high=255,
  265. shape=(self.height, self.width, 1 if grayscale else 3),
  266. dtype=np.uint8,
  267. )
  268. def observation(self, frame):
  269. if self.grayscale:
  270. frame = rgb2gray(frame)
  271. frame = resize(frame, height=self.height, width=self.width)
  272. return frame[:, :, None]
  273. else:
  274. return resize(frame, height=self.height, width=self.width)
  275. WarpFrame = GrayScaleAndResize
  276. @PublicAPI
  277. def wrap_atari_for_new_api_stack(
  278. env: gym.Env,
  279. dim: int = 64,
  280. frameskip: int = 4,
  281. framestack: Optional[int] = None,
  282. grayscale: bool = True,
  283. # TODO (sven): Add option to NOT grayscale, in which case framestack must be None
  284. # (b/c we are using the 3 color channels already as stacking frames).
  285. ) -> gym.Env:
  286. """Wraps `env` for new-API-stack-friendly RLlib Atari experiments.
  287. Note that we assume reward clipping is done outside the wrapper.
  288. Args:
  289. env: The env object to wrap.
  290. dim: Dimension to resize observations to (dim x dim).
  291. frameskip: Whether to skip n frames and max over them (keep brightest pixels).
  292. framestack: Whether to stack the last n (grayscaled) frames. Note that this
  293. step happens after(!) a possible frameskip step, meaning that if
  294. frameskip=4 and framestack=2, we would perform the following over this
  295. trajectory:
  296. actual env timesteps: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 -> ...
  297. frameskip: ( max ) ( max ) ( max ) ( max )
  298. framestack: ( stack ) (stack )
  299. Returns:
  300. The wrapped gym.Env.
  301. """
  302. # Time limit.
  303. env = gym.wrappers.TimeLimit(env, max_episode_steps=108000)
  304. # Grayscale + resize.
  305. env = WarpFrame(env, dim=dim, grayscale=grayscale)
  306. # Normalize the image.
  307. env = NormalizedImageEnv(env)
  308. # Frameskip: Take max over these n frames.
  309. if frameskip > 1:
  310. assert env.spec is not None
  311. env = MaxAndSkipEnv(env, skip=frameskip)
  312. # Send n noop actions into env after reset to increase variance in the
  313. # "start states" of the trajectories. These dummy steps are NOT included in the
  314. # sampled data used for learning.
  315. env = NoopResetEnv(env, noop_max=30)
  316. # Each life is one episode.
  317. env = EpisodicLifeEnv(env)
  318. # Some envs only start playing after pressing fire. Unblock those.
  319. if "FIRE" in env.unwrapped.get_action_meanings():
  320. env = FireResetEnv(env)
  321. # Framestack.
  322. if framestack:
  323. env = FrameStack(env, k=framestack)
  324. return env
  325. @PublicAPI
  326. def wrap_deepmind(env, dim=84, framestack=True, noframeskip=False):
  327. """Configure environment for DeepMind-style Atari.
  328. Note that we assume reward clipping is done outside the wrapper.
  329. Args:
  330. env: The env object to wrap.
  331. dim: Dimension to resize observations to (dim x dim).
  332. framestack: Whether to framestack observations.
  333. """
  334. env = MonitorEnv(env)
  335. env = NoopResetEnv(env, noop_max=30)
  336. if env.spec is not None and noframeskip is True:
  337. env = MaxAndSkipEnv(env, skip=4)
  338. env = EpisodicLifeEnv(env)
  339. if "FIRE" in env.unwrapped.get_action_meanings():
  340. env = FireResetEnv(env)
  341. env = WarpFrame(env, dim)
  342. # env = ClipRewardEnv(env) # reward clipping is handled by policy eval
  343. # 4x image framestacking.
  344. if framestack is True:
  345. env = FrameStack(env, 4)
  346. return env