tf_mixins.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387
  1. import logging
  2. from typing import Dict, List
  3. import numpy as np
  4. from ray.rllib.models.modelv2 import ModelV2
  5. from ray.rllib.policy.eager_tf_policy import EagerTFPolicy
  6. from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2
  7. from ray.rllib.policy.policy import PolicyState
  8. from ray.rllib.policy.sample_batch import SampleBatch
  9. from ray.rllib.policy.tf_policy import TFPolicy
  10. from ray.rllib.utils.annotations import OldAPIStack
  11. from ray.rllib.utils.framework import get_variable, try_import_tf
  12. from ray.rllib.utils.schedules import PiecewiseSchedule
  13. from ray.rllib.utils.tf_utils import make_tf_callable
  14. from ray.rllib.utils.typing import (
  15. AlgorithmConfigDict,
  16. LocalOptimizer,
  17. ModelGradients,
  18. TensorType,
  19. )
  20. logger = logging.getLogger(__name__)
  21. tf1, tf, tfv = try_import_tf()
  22. @OldAPIStack
  23. class LearningRateSchedule:
  24. """Mixin for TFPolicy that adds a learning rate schedule."""
  25. def __init__(self, lr, lr_schedule):
  26. self._lr_schedule = None
  27. if lr_schedule is None:
  28. self.cur_lr = tf1.get_variable("lr", initializer=lr, trainable=False)
  29. else:
  30. self._lr_schedule = PiecewiseSchedule(
  31. lr_schedule, outside_value=lr_schedule[-1][-1], framework=None
  32. )
  33. self.cur_lr = tf1.get_variable(
  34. "lr", initializer=self._lr_schedule.value(0), trainable=False
  35. )
  36. if self.framework == "tf":
  37. self._lr_placeholder = tf1.placeholder(dtype=tf.float32, name="lr")
  38. self._lr_update = self.cur_lr.assign(
  39. self._lr_placeholder, read_value=False
  40. )
  41. def on_global_var_update(self, global_vars):
  42. super().on_global_var_update(global_vars)
  43. if self._lr_schedule is not None:
  44. new_val = self._lr_schedule.value(global_vars["timestep"])
  45. if self.framework == "tf":
  46. self.get_session().run(
  47. self._lr_update, feed_dict={self._lr_placeholder: new_val}
  48. )
  49. else:
  50. self.cur_lr.assign(new_val, read_value=False)
  51. # This property (self._optimizer) is (still) accessible for
  52. # both TFPolicy and any TFPolicy_eager.
  53. self._optimizer.learning_rate.assign(self.cur_lr)
  54. def optimizer(self):
  55. if self.framework == "tf":
  56. return tf1.train.AdamOptimizer(learning_rate=self.cur_lr)
  57. else:
  58. return tf.keras.optimizers.Adam(self.cur_lr)
  59. @OldAPIStack
  60. class EntropyCoeffSchedule:
  61. """Mixin for TFPolicy that adds entropy coeff decay."""
  62. def __init__(self, entropy_coeff, entropy_coeff_schedule):
  63. self._entropy_coeff_schedule = None
  64. if entropy_coeff_schedule is None:
  65. self.entropy_coeff = get_variable(
  66. entropy_coeff, framework="tf", tf_name="entropy_coeff", trainable=False
  67. )
  68. else:
  69. # Allows for custom schedule similar to lr_schedule format
  70. if isinstance(entropy_coeff_schedule, list):
  71. self._entropy_coeff_schedule = PiecewiseSchedule(
  72. entropy_coeff_schedule,
  73. outside_value=entropy_coeff_schedule[-1][-1],
  74. framework=None,
  75. )
  76. else:
  77. # Implements previous version but enforces outside_value
  78. self._entropy_coeff_schedule = PiecewiseSchedule(
  79. [[0, entropy_coeff], [entropy_coeff_schedule, 0.0]],
  80. outside_value=0.0,
  81. framework=None,
  82. )
  83. self.entropy_coeff = get_variable(
  84. self._entropy_coeff_schedule.value(0),
  85. framework="tf",
  86. tf_name="entropy_coeff",
  87. trainable=False,
  88. )
  89. if self.framework == "tf":
  90. self._entropy_coeff_placeholder = tf1.placeholder(
  91. dtype=tf.float32, name="entropy_coeff"
  92. )
  93. self._entropy_coeff_update = self.entropy_coeff.assign(
  94. self._entropy_coeff_placeholder, read_value=False
  95. )
  96. def on_global_var_update(self, global_vars):
  97. super().on_global_var_update(global_vars)
  98. if self._entropy_coeff_schedule is not None:
  99. new_val = self._entropy_coeff_schedule.value(global_vars["timestep"])
  100. if self.framework == "tf":
  101. self.get_session().run(
  102. self._entropy_coeff_update,
  103. feed_dict={self._entropy_coeff_placeholder: new_val},
  104. )
  105. else:
  106. self.entropy_coeff.assign(new_val, read_value=False)
  107. @OldAPIStack
  108. class KLCoeffMixin:
  109. """Assigns the `update_kl()` and other KL-related methods to a TFPolicy.
  110. This is used in Algorithms to update the KL coefficient after each
  111. learning step based on `config.kl_target` and the measured KL value
  112. (from the train_batch).
  113. """
  114. def __init__(self, config: AlgorithmConfigDict):
  115. # The current KL value (as python float).
  116. self.kl_coeff_val = config["kl_coeff"]
  117. # The current KL value (as tf Variable for in-graph operations).
  118. self.kl_coeff = get_variable(
  119. float(self.kl_coeff_val),
  120. tf_name="kl_coeff",
  121. trainable=False,
  122. framework=config["framework"],
  123. )
  124. # Constant target value.
  125. self.kl_target = config["kl_target"]
  126. if self.framework == "tf":
  127. self._kl_coeff_placeholder = tf1.placeholder(
  128. dtype=tf.float32, name="kl_coeff"
  129. )
  130. self._kl_coeff_update = self.kl_coeff.assign(
  131. self._kl_coeff_placeholder, read_value=False
  132. )
  133. def update_kl(self, sampled_kl):
  134. # Update the current KL value based on the recently measured value.
  135. # Increase.
  136. if sampled_kl > 2.0 * self.kl_target:
  137. self.kl_coeff_val *= 1.5
  138. # Decrease.
  139. elif sampled_kl < 0.5 * self.kl_target:
  140. self.kl_coeff_val *= 0.5
  141. # No change.
  142. else:
  143. return self.kl_coeff_val
  144. # Make sure, new value is also stored in graph/tf variable.
  145. self._set_kl_coeff(self.kl_coeff_val)
  146. # Return the current KL value.
  147. return self.kl_coeff_val
  148. def _set_kl_coeff(self, new_kl_coeff):
  149. # Set the (off graph) value.
  150. self.kl_coeff_val = new_kl_coeff
  151. # Update the tf/tf2 Variable (via session call for tf or `assign`).
  152. if self.framework == "tf":
  153. self.get_session().run(
  154. self._kl_coeff_update,
  155. feed_dict={self._kl_coeff_placeholder: self.kl_coeff_val},
  156. )
  157. else:
  158. self.kl_coeff.assign(self.kl_coeff_val, read_value=False)
  159. def get_state(self) -> PolicyState:
  160. state = super().get_state()
  161. # Add current kl-coeff value.
  162. state["current_kl_coeff"] = self.kl_coeff_val
  163. return state
  164. def set_state(self, state: PolicyState) -> None:
  165. # Set current kl-coeff value first.
  166. self._set_kl_coeff(state.pop("current_kl_coeff", self.config["kl_coeff"]))
  167. # Call super's set_state with rest of the state dict.
  168. super().set_state(state)
  169. @OldAPIStack
  170. class TargetNetworkMixin:
  171. """Assign the `update_target` method to the policy.
  172. The function is called every `target_network_update_freq` steps by the
  173. master learner.
  174. """
  175. def __init__(self):
  176. model_vars = self.model.trainable_variables()
  177. target_model_vars = self.target_model.trainable_variables()
  178. @make_tf_callable(self.get_session())
  179. def update_target_fn(tau):
  180. tau = tf.convert_to_tensor(tau, dtype=tf.float32)
  181. update_target_expr = []
  182. assert len(model_vars) == len(target_model_vars), (
  183. model_vars,
  184. target_model_vars,
  185. )
  186. for var, var_target in zip(model_vars, target_model_vars):
  187. update_target_expr.append(
  188. var_target.assign(tau * var + (1.0 - tau) * var_target)
  189. )
  190. logger.debug("Update target op {}".format(var_target))
  191. return tf.group(*update_target_expr)
  192. # Hard initial update.
  193. self._do_update = update_target_fn
  194. # TODO: The previous SAC implementation does an update(1.0) here.
  195. # If this is changed to tau != 1.0 the sac_loss_function test fails. Why?
  196. # Also the test is not very maintainable, we need to change that unittest
  197. # anyway.
  198. self.update_target(tau=1.0) # self.config.get("tau", 1.0))
  199. @property
  200. def q_func_vars(self):
  201. if not hasattr(self, "_q_func_vars"):
  202. self._q_func_vars = self.model.variables()
  203. return self._q_func_vars
  204. @property
  205. def target_q_func_vars(self):
  206. if not hasattr(self, "_target_q_func_vars"):
  207. self._target_q_func_vars = self.target_model.variables()
  208. return self._target_q_func_vars
  209. # Support both hard and soft sync.
  210. def update_target(self, tau: int = None) -> None:
  211. self._do_update(np.float32(tau or self.config.get("tau", 1.0)))
  212. def variables(self) -> List[TensorType]:
  213. return self.model.variables()
  214. def set_weights(self, weights):
  215. if isinstance(self, TFPolicy):
  216. TFPolicy.set_weights(self, weights)
  217. elif isinstance(self, EagerTFPolicyV2): # Handle TF2V2 policies.
  218. EagerTFPolicyV2.set_weights(self, weights)
  219. elif isinstance(self, EagerTFPolicy): # Handle TF2 policies.
  220. EagerTFPolicy.set_weights(self, weights)
  221. self.update_target(self.config.get("tau", 1.0))
  222. @OldAPIStack
  223. class ValueNetworkMixin:
  224. """Assigns the `_value()` method to a TFPolicy.
  225. This way, Policy can call `_value()` to get the current VF estimate on a
  226. single(!) observation (as done in `postprocess_trajectory_fn`).
  227. Note: When doing this, an actual forward pass is being performed.
  228. This is different from only calling `model.value_function()`, where
  229. the result of the most recent forward pass is being used to return an
  230. already calculated tensor.
  231. """
  232. def __init__(self, config):
  233. # When doing GAE or vtrace, we need the value function estimate on the
  234. # observation.
  235. if config.get("use_gae") or config.get("vtrace"):
  236. # Input dict is provided to us automatically via the Model's
  237. # requirements. It's a single-timestep (last one in trajectory)
  238. # input_dict.
  239. @make_tf_callable(self.get_session())
  240. def value(**input_dict):
  241. input_dict = SampleBatch(input_dict)
  242. if isinstance(self.model, tf.keras.Model):
  243. _, _, extra_outs = self.model(input_dict)
  244. return extra_outs[SampleBatch.VF_PREDS][0]
  245. else:
  246. model_out, _ = self.model(input_dict)
  247. # [0] = remove the batch dim.
  248. return self.model.value_function()[0]
  249. # When not doing GAE, we do not require the value function's output.
  250. else:
  251. @make_tf_callable(self.get_session())
  252. def value(*args, **kwargs):
  253. return tf.constant(0.0)
  254. self._value = value
  255. self._should_cache_extra_action = config["framework"] == "tf"
  256. self._cached_extra_action_fetches = None
  257. def _extra_action_out_impl(self) -> Dict[str, TensorType]:
  258. extra_action_out = super().extra_action_out_fn()
  259. # Keras models return values for each call in third return argument
  260. # (dict).
  261. if isinstance(self.model, tf.keras.Model):
  262. return extra_action_out
  263. # Return value function outputs. VF estimates will hence be added to the
  264. # SampleBatches produced by the sampler(s) to generate the train batches
  265. # going into the loss function.
  266. extra_action_out.update(
  267. {
  268. SampleBatch.VF_PREDS: self.model.value_function(),
  269. }
  270. )
  271. return extra_action_out
  272. def extra_action_out_fn(self) -> Dict[str, TensorType]:
  273. if not self._should_cache_extra_action:
  274. return self._extra_action_out_impl()
  275. # Note: there are 2 reasons we are caching the extra_action_fetches for
  276. # TF1 static graph here.
  277. # 1. for better performance, so we don't query base class and model for
  278. # extra fetches every single time.
  279. # 2. for correctness. TF1 is special because the static graph may contain
  280. # two logical graphs. One created by DynamicTFPolicy for action
  281. # computation, and one created by MultiGPUTower for GPU training.
  282. # Depending on which logical graph ran last time,
  283. # self.model.value_function() will point to the output tensor
  284. # of the specific logical graph, causing problem if we try to
  285. # fetch action (run inference) using the training output tensor.
  286. # For that reason, we cache the action output tensor from the
  287. # vanilla DynamicTFPolicy once and call it a day.
  288. if self._cached_extra_action_fetches is not None:
  289. return self._cached_extra_action_fetches
  290. self._cached_extra_action_fetches = self._extra_action_out_impl()
  291. return self._cached_extra_action_fetches
  292. @OldAPIStack
  293. class GradStatsMixin:
  294. def __init__(self):
  295. pass
  296. def grad_stats_fn(
  297. self, train_batch: SampleBatch, grads: ModelGradients
  298. ) -> Dict[str, TensorType]:
  299. # We have support for more than one loss (list of lists of grads).
  300. if self.config.get("_tf_policy_handles_more_than_one_loss"):
  301. grad_gnorm = [tf.linalg.global_norm(g) for g in grads]
  302. # Old case: We have a single list of grads (only one loss term and
  303. # optimizer).
  304. else:
  305. grad_gnorm = tf.linalg.global_norm(grads)
  306. return {
  307. "grad_gnorm": grad_gnorm,
  308. }
  309. def compute_gradients(
  310. policy, optimizer: LocalOptimizer, loss: TensorType
  311. ) -> ModelGradients:
  312. # Compute the gradients.
  313. variables = policy.model.trainable_variables
  314. if isinstance(policy.model, ModelV2):
  315. variables = variables()
  316. grads_and_vars = optimizer.compute_gradients(loss, variables)
  317. # Clip by global norm, if necessary.
  318. if policy.config.get("grad_clip") is not None:
  319. # Defuse inf gradients (due to super large losses).
  320. grads = [g for (g, v) in grads_and_vars]
  321. grads, _ = tf.clip_by_global_norm(grads, policy.config["grad_clip"])
  322. # If the global_norm is inf -> All grads will be NaN. Stabilize this
  323. # here by setting them to 0.0. This will simply ignore destructive loss
  324. # calculations.
  325. policy.grads = []
  326. for g in grads:
  327. if g is not None:
  328. policy.grads.append(tf.where(tf.math.is_nan(g), tf.zeros_like(g), g))
  329. else:
  330. policy.grads.append(None)
  331. clipped_grads_and_vars = list(zip(policy.grads, variables))
  332. return clipped_grads_and_vars
  333. else:
  334. return grads_and_vars