policy_template.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. from typing import (
  2. Any,
  3. Callable,
  4. Dict,
  5. List,
  6. Optional,
  7. Tuple,
  8. Type,
  9. Union,
  10. )
  11. import gymnasium as gym
  12. from ray.rllib.models.catalog import ModelCatalog
  13. from ray.rllib.models.modelv2 import ModelV2
  14. from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
  15. from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
  16. from ray.rllib.policy.policy import Policy
  17. from ray.rllib.policy.sample_batch import SampleBatch
  18. from ray.rllib.policy.torch_policy import TorchPolicy
  19. from ray.rllib.utils import NullContextManager, add_mixins
  20. from ray.rllib.utils.annotations import OldAPIStack, override
  21. from ray.rllib.utils.framework import try_import_jax, try_import_torch
  22. from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
  23. from ray.rllib.utils.numpy import convert_to_numpy
  24. from ray.rllib.utils.typing import AlgorithmConfigDict, ModelGradients, TensorType
  25. jax, _ = try_import_jax()
  26. torch, _ = try_import_torch()
  27. @OldAPIStack
  28. def build_policy_class(
  29. name: str,
  30. framework: str,
  31. *,
  32. loss_fn: Optional[
  33. Callable[
  34. [Policy, ModelV2, Type[TorchDistributionWrapper], SampleBatch],
  35. Union[TensorType, List[TensorType]],
  36. ]
  37. ],
  38. get_default_config: Optional[Callable[[], AlgorithmConfigDict]] = None,
  39. stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[str, TensorType]]] = None,
  40. postprocess_fn: Optional[
  41. Callable[
  42. [
  43. Policy,
  44. SampleBatch,
  45. Optional[Dict[Any, SampleBatch]],
  46. Optional[Any],
  47. ],
  48. SampleBatch,
  49. ]
  50. ] = None,
  51. extra_action_out_fn: Optional[
  52. Callable[
  53. [
  54. Policy,
  55. Dict[str, TensorType],
  56. List[TensorType],
  57. ModelV2,
  58. TorchDistributionWrapper,
  59. ],
  60. Dict[str, TensorType],
  61. ]
  62. ] = None,
  63. extra_grad_process_fn: Optional[
  64. Callable[[Policy, "torch.optim.Optimizer", TensorType], Dict[str, TensorType]]
  65. ] = None,
  66. # TODO: (sven) Replace "fetches" with "process".
  67. extra_learn_fetches_fn: Optional[Callable[[Policy], Dict[str, TensorType]]] = None,
  68. optimizer_fn: Optional[
  69. Callable[[Policy, AlgorithmConfigDict], "torch.optim.Optimizer"]
  70. ] = None,
  71. validate_spaces: Optional[
  72. Callable[[Policy, gym.Space, gym.Space, AlgorithmConfigDict], None]
  73. ] = None,
  74. before_init: Optional[
  75. Callable[[Policy, gym.Space, gym.Space, AlgorithmConfigDict], None]
  76. ] = None,
  77. before_loss_init: Optional[
  78. Callable[
  79. [Policy, gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict], None
  80. ]
  81. ] = None,
  82. after_init: Optional[
  83. Callable[[Policy, gym.Space, gym.Space, AlgorithmConfigDict], None]
  84. ] = None,
  85. _after_loss_init: Optional[
  86. Callable[
  87. [Policy, gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict], None
  88. ]
  89. ] = None,
  90. action_sampler_fn: Optional[
  91. Callable[[TensorType, List[TensorType]], Tuple[TensorType, TensorType]]
  92. ] = None,
  93. action_distribution_fn: Optional[
  94. Callable[
  95. [Policy, ModelV2, TensorType, TensorType, TensorType],
  96. Tuple[TensorType, type, List[TensorType]],
  97. ]
  98. ] = None,
  99. make_model: Optional[
  100. Callable[
  101. [Policy, gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict], ModelV2
  102. ]
  103. ] = None,
  104. make_model_and_action_dist: Optional[
  105. Callable[
  106. [Policy, gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict],
  107. Tuple[ModelV2, Type[TorchDistributionWrapper]],
  108. ]
  109. ] = None,
  110. compute_gradients_fn: Optional[
  111. Callable[[Policy, SampleBatch], Tuple[ModelGradients, dict]]
  112. ] = None,
  113. apply_gradients_fn: Optional[
  114. Callable[[Policy, "torch.optim.Optimizer"], None]
  115. ] = None,
  116. mixins: Optional[List[type]] = None,
  117. get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None,
  118. ) -> Type[TorchPolicy]:
  119. """Helper function for creating a new Policy class at runtime.
  120. Supports frameworks JAX and PyTorch.
  121. Args:
  122. name: name of the policy (e.g., "PPOTorchPolicy")
  123. framework: Either "jax" or "torch".
  124. loss_fn (Optional[Callable[[Policy, ModelV2,
  125. Type[TorchDistributionWrapper], SampleBatch], Union[TensorType,
  126. List[TensorType]]]]): Callable that returns a loss tensor.
  127. get_default_config (Optional[Callable[[None], AlgorithmConfigDict]]):
  128. Optional callable that returns the default config to merge with any
  129. overrides. If None, uses only(!) the user-provided
  130. PartialAlgorithmConfigDict as dict for this Policy.
  131. postprocess_fn (Optional[Callable[[Policy, SampleBatch,
  132. Optional[Dict[Any, SampleBatch]], Optional[Any]],
  133. SampleBatch]]): Optional callable for post-processing experience
  134. batches (called after the super's `postprocess_trajectory` method).
  135. stats_fn (Optional[Callable[[Policy, SampleBatch],
  136. Dict[str, TensorType]]]): Optional callable that returns a dict of
  137. values given the policy and training batch. If None,
  138. will use `TorchPolicy.extra_grad_info()` instead. The stats dict is
  139. used for logging (e.g. in TensorBoard).
  140. extra_action_out_fn (Optional[Callable[[Policy, Dict[str, TensorType],
  141. List[TensorType], ModelV2, TorchDistributionWrapper]], Dict[str,
  142. TensorType]]]): Optional callable that returns a dict of extra
  143. values to include in experiences. If None, no extra computations
  144. will be performed.
  145. extra_grad_process_fn (Optional[Callable[[Policy,
  146. "torch.optim.Optimizer", TensorType], Dict[str, TensorType]]]):
  147. Optional callable that is called after gradients are computed and
  148. returns a processing info dict. If None, will call the
  149. `TorchPolicy.extra_grad_process()` method instead.
  150. # TODO: (sven) dissolve naming mismatch between "learn" and "compute.."
  151. extra_learn_fetches_fn (Optional[Callable[[Policy],
  152. Dict[str, TensorType]]]): Optional callable that returns a dict of
  153. extra tensors from the policy after loss evaluation. If None,
  154. will call the `TorchPolicy.extra_compute_grad_fetches()` method
  155. instead.
  156. optimizer_fn (Optional[Callable[[Policy, AlgorithmConfigDict],
  157. "torch.optim.Optimizer"]]): Optional callable that returns a
  158. torch optimizer given the policy and config. If None, will call
  159. the `TorchPolicy.optimizer()` method instead (which returns a
  160. torch Adam optimizer).
  161. validate_spaces (Optional[Callable[[Policy, gym.Space, gym.Space,
  162. AlgorithmConfigDict], None]]): Optional callable that takes the
  163. Policy, observation_space, action_space, and config to check for
  164. correctness. If None, no spaces checking will be done.
  165. before_init (Optional[Callable[[Policy, gym.Space, gym.Space,
  166. AlgorithmConfigDict], None]]): Optional callable to run at the
  167. beginning of `Policy.__init__` that takes the same arguments as
  168. the Policy constructor. If None, this step will be skipped.
  169. before_loss_init (Optional[Callable[[Policy, gym.spaces.Space,
  170. gym.spaces.Space, AlgorithmConfigDict], None]]): Optional callable to
  171. run prior to loss init. If None, this step will be skipped.
  172. after_init (Optional[Callable[[Policy, gym.Space, gym.Space,
  173. AlgorithmConfigDict], None]]): DEPRECATED: Use `before_loss_init`
  174. instead.
  175. _after_loss_init (Optional[Callable[[Policy, gym.spaces.Space,
  176. gym.spaces.Space, AlgorithmConfigDict], None]]): Optional callable to
  177. run after the loss init. If None, this step will be skipped.
  178. This will be deprecated at some point and renamed into `after_init`
  179. to match `build_tf_policy()` behavior.
  180. action_sampler_fn (Optional[Callable[[TensorType, List[TensorType]],
  181. Tuple[TensorType, TensorType]]]): Optional callable returning a
  182. sampled action and its log-likelihood given some (obs and state)
  183. inputs. If None, will either use `action_distribution_fn` or
  184. compute actions by calling self.model, then sampling from the
  185. so parameterized action distribution.
  186. action_distribution_fn (Optional[Callable[[Policy, ModelV2, TensorType,
  187. TensorType, TensorType], Tuple[TensorType,
  188. Type[TorchDistributionWrapper], List[TensorType]]]]): A callable
  189. that takes the Policy, Model, the observation batch, an
  190. explore-flag, a timestep, and an is_training flag and returns a
  191. tuple of a) distribution inputs (parameters), b) a dist-class to
  192. generate an action distribution object from, and c) internal-state
  193. outputs (empty list if not applicable). If None, will either use
  194. `action_sampler_fn` or compute actions by calling self.model,
  195. then sampling from the parameterized action distribution.
  196. make_model (Optional[Callable[[Policy, gym.spaces.Space,
  197. gym.spaces.Space, AlgorithmConfigDict], ModelV2]]): Optional callable
  198. that takes the same arguments as Policy.__init__ and returns a
  199. model instance. The distribution class will be determined
  200. automatically. Note: Only one of `make_model` or
  201. `make_model_and_action_dist` should be provided. If both are None,
  202. a default Model will be created.
  203. make_model_and_action_dist (Optional[Callable[[Policy,
  204. gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict],
  205. Tuple[ModelV2, Type[TorchDistributionWrapper]]]]): Optional
  206. callable that takes the same arguments as Policy.__init__ and
  207. returns a tuple of model instance and torch action distribution
  208. class.
  209. Note: Only one of `make_model` or `make_model_and_action_dist`
  210. should be provided. If both are None, a default Model will be
  211. created.
  212. compute_gradients_fn (Optional[Callable[
  213. [Policy, SampleBatch], Tuple[ModelGradients, dict]]]): Optional
  214. callable that the sampled batch an computes the gradients w.r.
  215. to the loss function.
  216. If None, will call the `TorchPolicy.compute_gradients()` method
  217. instead.
  218. apply_gradients_fn (Optional[Callable[[Policy,
  219. "torch.optim.Optimizer"], None]]): Optional callable that
  220. takes a grads list and applies these to the Model's parameters.
  221. If None, will call the `TorchPolicy.apply_gradients()` method
  222. instead.
  223. mixins (Optional[List[type]]): Optional list of any class mixins for
  224. the returned policy class. These mixins will be applied in order
  225. and will have higher precedence than the TorchPolicy class.
  226. get_batch_divisibility_req (Optional[Callable[[Policy], int]]):
  227. Optional callable that returns the divisibility requirement for
  228. sample batches. If None, will assume a value of 1.
  229. Returns:
  230. Type[TorchPolicy]: TorchPolicy child class constructed from the
  231. specified args.
  232. """
  233. original_kwargs = locals().copy()
  234. parent_cls = TorchPolicy
  235. base = add_mixins(parent_cls, mixins)
  236. class policy_cls(base):
  237. def __init__(self, obs_space, action_space, config):
  238. self.config = config
  239. # Set the DL framework for this Policy.
  240. self.framework = self.config["framework"] = framework
  241. # Validate observation- and action-spaces.
  242. if validate_spaces:
  243. validate_spaces(self, obs_space, action_space, self.config)
  244. # Do some pre-initialization steps.
  245. if before_init:
  246. before_init(self, obs_space, action_space, self.config)
  247. # Model is customized (use default action dist class).
  248. if make_model:
  249. assert make_model_and_action_dist is None, (
  250. "Either `make_model` or `make_model_and_action_dist`"
  251. " must be None!"
  252. )
  253. self.model = make_model(self, obs_space, action_space, config)
  254. dist_class, _ = ModelCatalog.get_action_dist(
  255. action_space, self.config["model"], framework=framework
  256. )
  257. # Model and action dist class are customized.
  258. elif make_model_and_action_dist:
  259. self.model, dist_class = make_model_and_action_dist(
  260. self, obs_space, action_space, config
  261. )
  262. # Use default model and default action dist.
  263. else:
  264. dist_class, logit_dim = ModelCatalog.get_action_dist(
  265. action_space, self.config["model"], framework=framework
  266. )
  267. self.model = ModelCatalog.get_model_v2(
  268. obs_space=obs_space,
  269. action_space=action_space,
  270. num_outputs=logit_dim,
  271. model_config=self.config["model"],
  272. framework=framework,
  273. )
  274. # Make sure, we passed in a correct Model factory.
  275. model_cls = TorchModelV2
  276. assert isinstance(
  277. self.model, model_cls
  278. ), "ERROR: Generated Model must be a TorchModelV2 object!"
  279. # Call the framework-specific Policy constructor.
  280. self.parent_cls = parent_cls
  281. self.parent_cls.__init__(
  282. self,
  283. observation_space=obs_space,
  284. action_space=action_space,
  285. config=config,
  286. model=self.model,
  287. loss=None if self.config["in_evaluation"] else loss_fn,
  288. action_distribution_class=dist_class,
  289. action_sampler_fn=action_sampler_fn,
  290. action_distribution_fn=action_distribution_fn,
  291. max_seq_len=config["model"]["max_seq_len"],
  292. get_batch_divisibility_req=get_batch_divisibility_req,
  293. )
  294. # Merge Model's view requirements into Policy's.
  295. self.view_requirements.update(self.model.view_requirements)
  296. _before_loss_init = before_loss_init or after_init
  297. if _before_loss_init:
  298. _before_loss_init(
  299. self, self.observation_space, self.action_space, config
  300. )
  301. # Perform test runs through postprocessing- and loss functions.
  302. self._initialize_loss_from_dummy_batch(
  303. auto_remove_unneeded_view_reqs=True,
  304. stats_fn=None if self.config["in_evaluation"] else stats_fn,
  305. )
  306. if _after_loss_init:
  307. _after_loss_init(self, obs_space, action_space, config)
  308. # Got to reset global_timestep again after this fake run-through.
  309. self.global_timestep = 0
  310. @override(Policy)
  311. def postprocess_trajectory(
  312. self, sample_batch, other_agent_batches=None, episode=None
  313. ):
  314. # Do all post-processing always with no_grad().
  315. # Not using this here will introduce a memory leak
  316. # in torch (issue #6962).
  317. with self._no_grad_context():
  318. # Call super's postprocess_trajectory first.
  319. sample_batch = super().postprocess_trajectory(
  320. sample_batch, other_agent_batches, episode
  321. )
  322. if postprocess_fn:
  323. return postprocess_fn(
  324. self, sample_batch, other_agent_batches, episode
  325. )
  326. return sample_batch
  327. @override(parent_cls)
  328. def extra_grad_process(self, optimizer, loss):
  329. """Called after optimizer.zero_grad() and loss.backward() calls.
  330. Allows for gradient processing before optimizer.step() is called.
  331. E.g. for gradient clipping.
  332. """
  333. if extra_grad_process_fn:
  334. return extra_grad_process_fn(self, optimizer, loss)
  335. else:
  336. return parent_cls.extra_grad_process(self, optimizer, loss)
  337. @override(parent_cls)
  338. def extra_compute_grad_fetches(self):
  339. if extra_learn_fetches_fn:
  340. fetches = convert_to_numpy(extra_learn_fetches_fn(self))
  341. # Auto-add empty learner stats dict if needed.
  342. return dict({LEARNER_STATS_KEY: {}}, **fetches)
  343. else:
  344. return parent_cls.extra_compute_grad_fetches(self)
  345. @override(parent_cls)
  346. def compute_gradients(self, batch):
  347. if compute_gradients_fn:
  348. return compute_gradients_fn(self, batch)
  349. else:
  350. return parent_cls.compute_gradients(self, batch)
  351. @override(parent_cls)
  352. def apply_gradients(self, gradients):
  353. if apply_gradients_fn:
  354. apply_gradients_fn(self, gradients)
  355. else:
  356. parent_cls.apply_gradients(self, gradients)
  357. @override(parent_cls)
  358. def extra_action_out(self, input_dict, state_batches, model, action_dist):
  359. with self._no_grad_context():
  360. if extra_action_out_fn:
  361. stats_dict = extra_action_out_fn(
  362. self, input_dict, state_batches, model, action_dist
  363. )
  364. else:
  365. stats_dict = parent_cls.extra_action_out(
  366. self, input_dict, state_batches, model, action_dist
  367. )
  368. return self._convert_to_numpy(stats_dict)
  369. @override(parent_cls)
  370. def optimizer(self):
  371. if optimizer_fn:
  372. optimizers = optimizer_fn(self, self.config)
  373. else:
  374. optimizers = parent_cls.optimizer(self)
  375. return optimizers
  376. @override(parent_cls)
  377. def extra_grad_info(self, train_batch):
  378. with self._no_grad_context():
  379. if stats_fn:
  380. stats_dict = stats_fn(self, train_batch)
  381. else:
  382. stats_dict = self.parent_cls.extra_grad_info(self, train_batch)
  383. return self._convert_to_numpy(stats_dict)
  384. def _no_grad_context(self):
  385. if self.framework == "torch":
  386. return torch.no_grad()
  387. return NullContextManager()
  388. def _convert_to_numpy(self, data):
  389. if self.framework == "torch":
  390. return convert_to_numpy(data)
  391. return data
  392. def with_updates(**overrides):
  393. """Creates a Torch|JAXPolicy cls based on settings of another one.
  394. Keyword Args:
  395. **overrides: The settings (passed into `build_torch_policy`) that
  396. should be different from the class that this method is called
  397. on.
  398. Returns:
  399. type: A new Torch|JAXPolicy sub-class.
  400. Examples:
  401. >> MySpecialDQNPolicyClass = DQNTorchPolicy.with_updates(
  402. .. name="MySpecialDQNPolicyClass",
  403. .. loss_function=[some_new_loss_function],
  404. .. )
  405. """
  406. return build_policy_class(**dict(original_kwargs, **overrides))
  407. policy_cls.with_updates = staticmethod(with_updates)
  408. policy_cls.__name__ = name
  409. policy_cls.__qualname__ = name
  410. return policy_cls