__init__.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. from __future__ import annotations
  2. import re
  3. from typing import Literal
  4. import wandb
  5. import wandb.util
  6. _gym_version_lt_0_26: bool | None = None
  7. _gymnasium_version_lt_1_0_0: bool | None = None
  8. _required_error_msg = (
  9. "Couldn't import the gymnasium python package, install with `pip install gymnasium`"
  10. )
  11. GymLib = Literal["gym", "gymnasium"]
  12. def monitor():
  13. """Monitor a gym environment.
  14. Supports both gym and gymnasium.
  15. """
  16. gym_lib: GymLib | None = None
  17. # gym is not maintained anymore, gymnasium is the drop-in replacement - prefer it
  18. if wandb.util.get_module("gymnasium") is not None:
  19. gym_lib = "gymnasium"
  20. elif wandb.util.get_module("gym") is not None:
  21. gym_lib = "gym"
  22. if gym_lib is None:
  23. raise wandb.Error(_required_error_msg)
  24. global _gym_version_lt_0_26
  25. global _gymnasium_version_lt_1_0_0
  26. if _gym_version_lt_0_26 is None or _gymnasium_version_lt_1_0_0 is None:
  27. if gym_lib == "gym":
  28. import gym
  29. else:
  30. import gymnasium as gym # type: ignore
  31. from packaging.version import parse
  32. gym_lib_version = parse(gym.__version__)
  33. _gym_version_lt_0_26 = gym_lib_version < parse("0.26.0")
  34. _gymnasium_version_lt_1_0_0 = gym_lib_version < parse("1.0.0a1")
  35. path = "path" # Default path
  36. if gym_lib == "gymnasium" and not _gymnasium_version_lt_1_0_0:
  37. vcr_recorder_attribute = "RecordVideo"
  38. wrappers = wandb.util.get_module(
  39. f"{gym_lib}.wrappers",
  40. required=_required_error_msg,
  41. )
  42. recorder = getattr(wrappers, vcr_recorder_attribute)
  43. else:
  44. vcr = wandb.util.get_module(
  45. f"{gym_lib}.wrappers.monitoring.video_recorder",
  46. required=_required_error_msg,
  47. )
  48. # Breaking change in gym 0.26.0
  49. if _gym_version_lt_0_26:
  50. vcr_recorder_attribute = "ImageEncoder"
  51. recorder = getattr(vcr, vcr_recorder_attribute)
  52. path = "output_path" # Override path for older gym versions
  53. else:
  54. vcr_recorder_attribute = "VideoRecorder"
  55. recorder = getattr(vcr, vcr_recorder_attribute)
  56. recorder.orig_close = recorder.close
  57. def close(self):
  58. recorder.orig_close(self)
  59. if not self.enabled:
  60. return
  61. if wandb.run:
  62. m = re.match(r".+(video\.\d+).+", getattr(self, path))
  63. key = m.group(1) if m else "videos"
  64. wandb.log({key: wandb.Video(getattr(self, path))})
  65. def del_(self):
  66. self.orig_close()
  67. if not _gym_version_lt_0_26:
  68. recorder.__del__ = del_
  69. recorder.close = close
  70. if gym_lib == "gymnasium" and not _gymnasium_version_lt_1_0_0:
  71. wrapper_name = vcr_recorder_attribute
  72. else:
  73. wrapper_name = f"monitoring.video_recorder.{vcr_recorder_attribute}"
  74. wandb.patched["gym"].append(
  75. [
  76. f"{gym_lib}.wrappers.{wrapper_name}",
  77. "close",
  78. ]
  79. )