tf_policy_template.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. from typing import Callable, Dict, List, Optional, Tuple, Type, Union
  2. import gymnasium as gym
  3. from ray._common.deprecation import (
  4. DEPRECATED_VALUE,
  5. deprecation_warning,
  6. )
  7. from ray.rllib.models.modelv2 import ModelV2
  8. from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
  9. from ray.rllib.policy import eager_tf_policy
  10. from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy
  11. from ray.rllib.policy.policy import Policy
  12. from ray.rllib.policy.sample_batch import SampleBatch
  13. from ray.rllib.policy.tf_policy import TFPolicy
  14. from ray.rllib.utils import add_mixins, force_list
  15. from ray.rllib.utils.annotations import OldAPIStack, override
  16. from ray.rllib.utils.framework import try_import_tf
  17. from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
  18. from ray.rllib.utils.typing import (
  19. AlgorithmConfigDict,
  20. ModelGradients,
  21. TensorType,
  22. )
  23. tf1, tf, tfv = try_import_tf()
  24. @OldAPIStack
  25. def build_tf_policy(
  26. name: str,
  27. *,
  28. loss_fn: Callable[
  29. [Policy, ModelV2, Type[TFActionDistribution], SampleBatch],
  30. Union[TensorType, List[TensorType]],
  31. ],
  32. get_default_config: Optional[Callable[[None], AlgorithmConfigDict]] = None,
  33. postprocess_fn=None,
  34. stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[str, TensorType]]] = None,
  35. optimizer_fn: Optional[
  36. Callable[[Policy, AlgorithmConfigDict], "tf.keras.optimizers.Optimizer"]
  37. ] = None,
  38. compute_gradients_fn: Optional[
  39. Callable[[Policy, "tf.keras.optimizers.Optimizer", TensorType], ModelGradients]
  40. ] = None,
  41. apply_gradients_fn: Optional[
  42. Callable[
  43. [Policy, "tf.keras.optimizers.Optimizer", ModelGradients], "tf.Operation"
  44. ]
  45. ] = None,
  46. grad_stats_fn: Optional[
  47. Callable[[Policy, SampleBatch, ModelGradients], Dict[str, TensorType]]
  48. ] = None,
  49. extra_action_out_fn: Optional[Callable[[Policy], Dict[str, TensorType]]] = None,
  50. extra_learn_fetches_fn: Optional[Callable[[Policy], Dict[str, TensorType]]] = None,
  51. validate_spaces: Optional[
  52. Callable[[Policy, gym.Space, gym.Space, AlgorithmConfigDict], None]
  53. ] = None,
  54. before_init: Optional[
  55. Callable[[Policy, gym.Space, gym.Space, AlgorithmConfigDict], None]
  56. ] = None,
  57. before_loss_init: Optional[
  58. Callable[
  59. [Policy, gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict], None
  60. ]
  61. ] = None,
  62. after_init: Optional[
  63. Callable[[Policy, gym.Space, gym.Space, AlgorithmConfigDict], None]
  64. ] = None,
  65. make_model: Optional[
  66. Callable[
  67. [Policy, gym.spaces.Space, gym.spaces.Space, AlgorithmConfigDict], ModelV2
  68. ]
  69. ] = None,
  70. action_sampler_fn: Optional[
  71. Callable[[TensorType, List[TensorType]], Tuple[TensorType, TensorType]]
  72. ] = None,
  73. action_distribution_fn: Optional[
  74. Callable[
  75. [Policy, ModelV2, TensorType, TensorType, TensorType],
  76. Tuple[TensorType, type, List[TensorType]],
  77. ]
  78. ] = None,
  79. mixins: Optional[List[type]] = None,
  80. get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None,
  81. # Deprecated args.
  82. obs_include_prev_action_reward=DEPRECATED_VALUE,
  83. extra_action_fetches_fn=None, # Use `extra_action_out_fn`.
  84. gradients_fn=None, # Use `compute_gradients_fn`.
  85. ) -> Type[DynamicTFPolicy]:
  86. """Helper function for creating a dynamic tf policy at runtime.
  87. Functions will be run in this order to initialize the policy:
  88. 1. Placeholder setup: postprocess_fn
  89. 2. Loss init: loss_fn, stats_fn
  90. 3. Optimizer init: optimizer_fn, gradients_fn, apply_gradients_fn,
  91. grad_stats_fn
  92. This means that you can e.g., depend on any policy attributes created in
  93. the running of `loss_fn` in later functions such as `stats_fn`.
  94. In eager mode, the following functions will be run repeatedly on each
  95. eager execution: loss_fn, stats_fn, gradients_fn, apply_gradients_fn,
  96. and grad_stats_fn.
  97. This means that these functions should not define any variables internally,
  98. otherwise they will fail in eager mode execution. Variable should only
  99. be created in make_model (if defined).
  100. Args:
  101. name: Name of the policy (e.g., "PPOTFPolicy").
  102. loss_fn (Callable[[
  103. Policy, ModelV2, Type[TFActionDistribution], SampleBatch],
  104. Union[TensorType, List[TensorType]]]): Callable for calculating a
  105. loss tensor.
  106. get_default_config (Optional[Callable[[None], AlgorithmConfigDict]]):
  107. Optional callable that returns the default config to merge with any
  108. overrides. If None, uses only(!) the user-provided
  109. PartialAlgorithmConfigDict as dict for this Policy.
  110. postprocess_fn (Optional[Callable[[Policy, SampleBatch,
  111. Optional[Dict[AgentID, SampleBatch]], Episode], None]]):
  112. Optional callable for post-processing experience batches (called
  113. after the parent class' `postprocess_trajectory` method).
  114. stats_fn (Optional[Callable[[Policy, SampleBatch],
  115. Dict[str, TensorType]]]): Optional callable that returns a dict of
  116. TF tensors to fetch given the policy and batch input tensors. If
  117. None, will not compute any stats.
  118. optimizer_fn (Optional[Callable[[Policy, AlgorithmConfigDict],
  119. "tf.keras.optimizers.Optimizer"]]): Optional callable that returns
  120. a tf.Optimizer given the policy and config. If None, will call
  121. the base class' `optimizer()` method instead (which returns a
  122. tf1.train.AdamOptimizer).
  123. compute_gradients_fn (Optional[Callable[[Policy,
  124. "tf.keras.optimizers.Optimizer", TensorType], ModelGradients]]):
  125. Optional callable that returns a list of gradients. If None,
  126. this defaults to optimizer.compute_gradients([loss]).
  127. apply_gradients_fn (Optional[Callable[[Policy,
  128. "tf.keras.optimizers.Optimizer", ModelGradients],
  129. "tf.Operation"]]): Optional callable that returns an apply
  130. gradients op given policy, tf-optimizer, and grads_and_vars. If
  131. None, will call the base class' `build_apply_op()` method instead.
  132. grad_stats_fn (Optional[Callable[[Policy, SampleBatch, ModelGradients],
  133. Dict[str, TensorType]]]): Optional callable that returns a dict of
  134. TF fetches given the policy, batch input, and gradient tensors. If
  135. None, will not collect any gradient stats.
  136. extra_action_out_fn (Optional[Callable[[Policy],
  137. Dict[str, TensorType]]]): Optional callable that returns
  138. a dict of TF fetches given the policy object. If None, will not
  139. perform any extra fetches.
  140. extra_learn_fetches_fn (Optional[Callable[[Policy],
  141. Dict[str, TensorType]]]): Optional callable that returns a dict of
  142. extra values to fetch and return when learning on a batch. If None,
  143. will call the base class' `extra_compute_grad_fetches()` method
  144. instead.
  145. validate_spaces (Optional[Callable[[Policy, gym.Space, gym.Space,
  146. AlgorithmConfigDict], None]]): Optional callable that takes the
  147. Policy, observation_space, action_space, and config to check
  148. the spaces for correctness. If None, no spaces checking will be
  149. done.
  150. before_init (Optional[Callable[[Policy, gym.Space, gym.Space,
  151. AlgorithmConfigDict], None]]): Optional callable to run at the
  152. beginning of policy init that takes the same arguments as the
  153. policy constructor. If None, this step will be skipped.
  154. before_loss_init (Optional[Callable[[Policy, gym.spaces.Space,
  155. gym.spaces.Space, AlgorithmConfigDict], None]]): Optional callable to
  156. run prior to loss init. If None, this step will be skipped.
  157. after_init (Optional[Callable[[Policy, gym.Space, gym.Space,
  158. AlgorithmConfigDict], None]]): Optional callable to run at the end of
  159. policy init. If None, this step will be skipped.
  160. make_model (Optional[Callable[[Policy, gym.spaces.Space,
  161. gym.spaces.Space, AlgorithmConfigDict], ModelV2]]): Optional callable
  162. that returns a ModelV2 object.
  163. All policy variables should be created in this function. If None,
  164. a default ModelV2 object will be created.
  165. action_sampler_fn (Optional[Callable[[TensorType, List[TensorType]],
  166. Tuple[TensorType, TensorType]]]): A callable returning a sampled
  167. action and its log-likelihood given observation and state inputs.
  168. If None, will either use `action_distribution_fn` or
  169. compute actions by calling self.model, then sampling from the
  170. so parameterized action distribution.
  171. action_distribution_fn (Optional[Callable[[Policy, ModelV2, TensorType,
  172. TensorType, TensorType],
  173. Tuple[TensorType, type, List[TensorType]]]]): Optional callable
  174. returning distribution inputs (parameters), a dist-class to
  175. generate an action distribution object from, and internal-state
  176. outputs (or an empty list if not applicable). If None, will either
  177. use `action_sampler_fn` or compute actions by calling self.model,
  178. then sampling from the so parameterized action distribution.
  179. mixins (Optional[List[type]]): Optional list of any class mixins for
  180. the returned policy class. These mixins will be applied in order
  181. and will have higher precedence than the DynamicTFPolicy class.
  182. get_batch_divisibility_req (Optional[Callable[[Policy], int]]):
  183. Optional callable that returns the divisibility requirement for
  184. sample batches. If None, will assume a value of 1.
  185. Returns:
  186. Type[DynamicTFPolicy]: A child class of DynamicTFPolicy based on the
  187. specified args.
  188. """
  189. original_kwargs = locals().copy()
  190. base = add_mixins(DynamicTFPolicy, mixins)
  191. if obs_include_prev_action_reward != DEPRECATED_VALUE:
  192. deprecation_warning(old="obs_include_prev_action_reward", error=True)
  193. if extra_action_fetches_fn is not None:
  194. deprecation_warning(
  195. old="extra_action_fetches_fn", new="extra_action_out_fn", error=True
  196. )
  197. if gradients_fn is not None:
  198. deprecation_warning(old="gradients_fn", new="compute_gradients_fn", error=True)
  199. class policy_cls(base):
  200. def __init__(
  201. self,
  202. obs_space,
  203. action_space,
  204. config,
  205. existing_model=None,
  206. existing_inputs=None,
  207. ):
  208. if validate_spaces:
  209. validate_spaces(self, obs_space, action_space, config)
  210. if before_init:
  211. before_init(self, obs_space, action_space, config)
  212. def before_loss_init_wrapper(policy, obs_space, action_space, config):
  213. if before_loss_init:
  214. before_loss_init(policy, obs_space, action_space, config)
  215. if extra_action_out_fn is None or policy._is_tower:
  216. extra_action_fetches = {}
  217. else:
  218. extra_action_fetches = extra_action_out_fn(policy)
  219. if hasattr(policy, "_extra_action_fetches"):
  220. policy._extra_action_fetches.update(extra_action_fetches)
  221. else:
  222. policy._extra_action_fetches = extra_action_fetches
  223. DynamicTFPolicy.__init__(
  224. self,
  225. obs_space=obs_space,
  226. action_space=action_space,
  227. config=config,
  228. loss_fn=loss_fn,
  229. stats_fn=stats_fn,
  230. grad_stats_fn=grad_stats_fn,
  231. before_loss_init=before_loss_init_wrapper,
  232. make_model=make_model,
  233. action_sampler_fn=action_sampler_fn,
  234. action_distribution_fn=action_distribution_fn,
  235. existing_inputs=existing_inputs,
  236. existing_model=existing_model,
  237. get_batch_divisibility_req=get_batch_divisibility_req,
  238. )
  239. if after_init:
  240. after_init(self, obs_space, action_space, config)
  241. # Got to reset global_timestep again after this fake run-through.
  242. self.global_timestep = 0
  243. @override(Policy)
  244. def postprocess_trajectory(
  245. self, sample_batch, other_agent_batches=None, episode=None
  246. ):
  247. # Call super's postprocess_trajectory first.
  248. sample_batch = Policy.postprocess_trajectory(self, sample_batch)
  249. if postprocess_fn:
  250. return postprocess_fn(self, sample_batch, other_agent_batches, episode)
  251. return sample_batch
  252. @override(TFPolicy)
  253. def optimizer(self):
  254. if optimizer_fn:
  255. optimizers = optimizer_fn(self, self.config)
  256. else:
  257. optimizers = base.optimizer(self)
  258. optimizers = force_list(optimizers)
  259. if self.exploration:
  260. optimizers = self.exploration.get_exploration_optimizer(optimizers)
  261. # No optimizers produced -> Return None.
  262. if not optimizers:
  263. return None
  264. # New API: Allow more than one optimizer to be returned.
  265. # -> Return list.
  266. elif self.config["_tf_policy_handles_more_than_one_loss"]:
  267. return optimizers
  268. # Old API: Return a single LocalOptimizer.
  269. else:
  270. return optimizers[0]
  271. @override(TFPolicy)
  272. def gradients(self, optimizer, loss):
  273. optimizers = force_list(optimizer)
  274. losses = force_list(loss)
  275. if compute_gradients_fn:
  276. # New API: Allow more than one optimizer -> Return a list of
  277. # lists of gradients.
  278. if self.config["_tf_policy_handles_more_than_one_loss"]:
  279. return compute_gradients_fn(self, optimizers, losses)
  280. # Old API: Return a single List of gradients.
  281. else:
  282. return compute_gradients_fn(self, optimizers[0], losses[0])
  283. else:
  284. return base.gradients(self, optimizers, losses)
  285. @override(TFPolicy)
  286. def build_apply_op(self, optimizer, grads_and_vars):
  287. if apply_gradients_fn:
  288. return apply_gradients_fn(self, optimizer, grads_and_vars)
  289. else:
  290. return base.build_apply_op(self, optimizer, grads_and_vars)
  291. @override(TFPolicy)
  292. def extra_compute_action_fetches(self):
  293. return dict(
  294. base.extra_compute_action_fetches(self), **self._extra_action_fetches
  295. )
  296. @override(TFPolicy)
  297. def extra_compute_grad_fetches(self):
  298. if extra_learn_fetches_fn:
  299. # TODO: (sven) in torch, extra_learn_fetches do not exist.
  300. # Hence, things like td_error are returned by the stats_fn
  301. # and end up under the LEARNER_STATS_KEY. We should
  302. # change tf to do this as well. However, this will confilct
  303. # the handling of LEARNER_STATS_KEY inside the multi-GPU
  304. # train op.
  305. # Auto-add empty learner stats dict if needed.
  306. return dict({LEARNER_STATS_KEY: {}}, **extra_learn_fetches_fn(self))
  307. else:
  308. return base.extra_compute_grad_fetches(self)
  309. def with_updates(**overrides):
  310. """Allows creating a TFPolicy cls based on settings of another one.
  311. Keyword Args:
  312. **overrides: The settings (passed into `build_tf_policy`) that
  313. should be different from the class that this method is called
  314. on.
  315. Returns:
  316. type: A new TFPolicy sub-class.
  317. Examples:
  318. >> MySpecialDQNPolicyClass = DQNTFPolicy.with_updates(
  319. .. name="MySpecialDQNPolicyClass",
  320. .. loss_function=[some_new_loss_function],
  321. .. )
  322. """
  323. return build_tf_policy(**dict(original_kwargs, **overrides))
  324. def as_eager():
  325. return eager_tf_policy._build_eager_tf_policy(**original_kwargs)
  326. policy_cls.with_updates = staticmethod(with_updates)
  327. policy_cls.as_eager = staticmethod(as_eager)
  328. policy_cls.__name__ = name
  329. policy_cls.__qualname__ = name
  330. return policy_cls