tf_utils.py 37 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054
  1. import logging
  2. from collections import OrderedDict, deque
  3. from typing import TYPE_CHECKING, Any, Callable, List, Optional, Type, Union
  4. import gymnasium as gym
  5. import numpy as np
  6. import tree # pip install dm_tree
  7. from gymnasium.spaces import Discrete, MultiDiscrete
  8. from ray.rllib.utils import force_list
  9. from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
  10. from ray.rllib.utils.framework import try_import_tf
  11. from ray.rllib.utils.numpy import SMALL_NUMBER
  12. from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
  13. from ray.rllib.utils.typing import (
  14. LocalOptimizer,
  15. ModelGradients,
  16. NetworkType,
  17. PartialAlgorithmConfigDict,
  18. SpaceStruct,
  19. TensorStructType,
  20. TensorType,
  21. )
  22. if TYPE_CHECKING:
  23. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
  24. from ray.rllib.core.learner.learner import ParamDict
  25. from ray.rllib.policy.eager_tf_policy import EagerTFPolicy
  26. from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2
  27. from ray.rllib.policy.tf_policy import TFPolicy
  28. logger = logging.getLogger(__name__)
  29. tf1, tf, tfv = try_import_tf()
  30. @PublicAPI
  31. def clip_gradients(
  32. gradients_dict: "ParamDict",
  33. *,
  34. grad_clip: Optional[float] = None,
  35. grad_clip_by: str,
  36. ) -> Optional[float]:
  37. """Performs gradient clipping on a grad-dict based on a clip value and clip mode.
  38. Changes the provided gradient dict in place.
  39. Args:
  40. gradients_dict: The gradients dict, mapping str to gradient tensors.
  41. grad_clip: The value to clip with. The way gradients are clipped is defined
  42. by the `grad_clip_by` arg (see below).
  43. grad_clip_by: One of 'value', 'norm', or 'global_norm'.
  44. Returns:
  45. If `grad_clip_by`="global_norm" and `grad_clip` is not None, returns the global
  46. norm of all tensors, otherwise returns None.
  47. """
  48. # No clipping, return.
  49. if grad_clip is None:
  50. return
  51. # Clip by value (each gradient individually).
  52. if grad_clip_by == "value":
  53. for k, v in gradients_dict.copy().items():
  54. gradients_dict[k] = tf.clip_by_value(v, -grad_clip, grad_clip)
  55. # Clip by L2-norm (per gradient tensor).
  56. elif grad_clip_by == "norm":
  57. for k, v in gradients_dict.copy().items():
  58. gradients_dict[k] = tf.clip_by_norm(v, grad_clip)
  59. # Clip by global L2-norm (across all gradient tensors).
  60. else:
  61. assert grad_clip_by == "global_norm"
  62. clipped_grads, global_norm = tf.clip_by_global_norm(
  63. list(gradients_dict.values()), grad_clip
  64. )
  65. for k, v in zip(gradients_dict.copy().keys(), clipped_grads):
  66. gradients_dict[k] = v
  67. # Return the computed global norm scalar.
  68. return global_norm
  69. @PublicAPI
  70. def explained_variance(y: TensorType, pred: TensorType) -> TensorType:
  71. """Computes the explained variance for a pair of labels and predictions.
  72. The formula used is:
  73. max(-1.0, 1.0 - (std(y - pred)^2 / std(y)^2))
  74. Args:
  75. y: The labels.
  76. pred: The predictions.
  77. Returns:
  78. The explained variance given a pair of labels and predictions.
  79. """
  80. _, y_var = tf.nn.moments(y, axes=[0])
  81. _, diff_var = tf.nn.moments(y - pred, axes=[0])
  82. return tf.maximum(-1.0, 1 - (diff_var / (y_var + SMALL_NUMBER)))
  83. @PublicAPI
  84. def flatten_inputs_to_1d_tensor(
  85. inputs: TensorStructType,
  86. spaces_struct: Optional[SpaceStruct] = None,
  87. time_axis: bool = False,
  88. ) -> TensorType:
  89. """Flattens arbitrary input structs according to the given spaces struct.
  90. Returns a single 1D tensor resulting from the different input
  91. components' values.
  92. Thereby:
  93. - Boxes (any shape) get flattened to (B, [T]?, -1). Note that image boxes
  94. are not treated differently from other types of Boxes and get
  95. flattened as well.
  96. - Discrete (int) values are one-hot'd, e.g. a batch of [1, 0, 3] (B=3 with
  97. Discrete(4) space) results in [[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1]].
  98. - MultiDiscrete values are multi-one-hot'd, e.g. a batch of
  99. [[0, 2], [1, 4]] (B=2 with MultiDiscrete([2, 5]) space) results in
  100. [[1, 0, 0, 0, 1, 0, 0], [0, 1, 0, 0, 0, 0, 1]].
  101. Args:
  102. inputs: The inputs to be flattened.
  103. spaces_struct: The structure of the spaces that behind the input
  104. time_axis: Whether all inputs have a time-axis (after the batch axis).
  105. If True, will keep not only the batch axis (0th), but the time axis
  106. (1st) as-is and flatten everything from the 2nd axis up.
  107. Returns:
  108. A single 1D tensor resulting from concatenating all
  109. flattened/one-hot'd input components. Depending on the time_axis flag,
  110. the shape is (B, n) or (B, T, n).
  111. .. testcode::
  112. :skipif: True
  113. # B=2
  114. from ray.rllib.utils.tf_utils import flatten_inputs_to_1d_tensor
  115. from gymnasium.spaces import Discrete, Box
  116. out = flatten_inputs_to_1d_tensor(
  117. {"a": [1, 0], "b": [[[0.0], [0.1]], [1.0], [1.1]]},
  118. spaces_struct=dict(a=Discrete(2), b=Box(shape=(2, 1)))
  119. )
  120. print(out)
  121. # B=2; T=2
  122. out = flatten_inputs_to_1d_tensor(
  123. ([[1, 0], [0, 1]],
  124. [[[0.0, 0.1], [1.0, 1.1]], [[2.0, 2.1], [3.0, 3.1]]]),
  125. spaces_struct=tuple([Discrete(2), Box(shape=(2, ))]),
  126. time_axis=True
  127. )
  128. print(out)
  129. .. testoutput::
  130. [[0.0, 1.0, 0.0, 0.1], [1.0, 0.0, 1.0, 1.1]] # B=2 n=4
  131. [[[0.0, 1.0, 0.0, 0.1], [1.0, 0.0, 1.0, 1.1]],
  132. [[1.0, 0.0, 2.0, 2.1], [0.0, 1.0, 3.0, 3.1]]] # B=2 T=2 n=4
  133. """
  134. flat_inputs = tree.flatten(inputs)
  135. flat_spaces = (
  136. tree.flatten(spaces_struct)
  137. if spaces_struct is not None
  138. else [None] * len(flat_inputs)
  139. )
  140. B = None
  141. T = None
  142. out = []
  143. for input_, space in zip(flat_inputs, flat_spaces):
  144. input_ = tf.convert_to_tensor(input_)
  145. shape = tf.shape(input_)
  146. # Store batch and (if applicable) time dimension.
  147. if B is None:
  148. B = shape[0]
  149. if time_axis:
  150. T = shape[1]
  151. # One-hot encoding.
  152. if isinstance(space, Discrete):
  153. if time_axis:
  154. input_ = tf.reshape(input_, [B * T])
  155. out.append(tf.cast(one_hot(input_, space), tf.float32))
  156. elif isinstance(space, MultiDiscrete):
  157. if time_axis:
  158. input_ = tf.reshape(input_, [B * T, -1])
  159. out.append(tf.cast(one_hot(input_, space), tf.float32))
  160. # Flatten.
  161. else:
  162. if time_axis:
  163. input_ = tf.reshape(input_, [B * T, -1])
  164. else:
  165. input_ = tf.reshape(input_, [B, -1])
  166. out.append(tf.cast(input_, tf.float32))
  167. merged = tf.concat(out, axis=-1)
  168. # Restore the time-dimension, if applicable.
  169. if time_axis:
  170. merged = tf.reshape(merged, [B, T, -1])
  171. return merged
  172. @PublicAPI
  173. def get_gpu_devices() -> List[str]:
  174. """Returns a list of GPU device names, e.g. ["/gpu:0", "/gpu:1"].
  175. Supports both tf1.x and tf2.x.
  176. Returns:
  177. List of GPU device names (str).
  178. """
  179. if tfv == 1:
  180. from tensorflow.python.client import device_lib
  181. devices = device_lib.list_local_devices()
  182. else:
  183. try:
  184. devices = tf.config.list_physical_devices()
  185. except Exception:
  186. devices = tf.config.experimental.list_physical_devices()
  187. # Expect "GPU", but also stuff like: "XLA_GPU".
  188. return [d.name for d in devices if "GPU" in d.device_type]
  189. @PublicAPI
  190. def get_placeholder(
  191. *,
  192. space: Optional[gym.Space] = None,
  193. value: Optional[Any] = None,
  194. name: Optional[str] = None,
  195. time_axis: bool = False,
  196. flatten: bool = True,
  197. ) -> "tf1.placeholder":
  198. """Returns a tf1.placeholder object given optional hints, such as a space.
  199. Note that the returned placeholder will always have a leading batch
  200. dimension (None).
  201. Args:
  202. space: An optional gym.Space to hint the shape and dtype of the
  203. placeholder.
  204. value: An optional value to hint the shape and dtype of the
  205. placeholder.
  206. name: An optional name for the placeholder.
  207. time_axis: Whether the placeholder should also receive a time
  208. dimension (None).
  209. flatten: Whether to flatten the given space into a plain Box space
  210. and then create the placeholder from the resulting space.
  211. Returns:
  212. The tf1 placeholder.
  213. """
  214. from ray.rllib.models.catalog import ModelCatalog
  215. if space is not None:
  216. if isinstance(space, (gym.spaces.Dict, gym.spaces.Tuple)):
  217. if flatten:
  218. return ModelCatalog.get_action_placeholder(space, None)
  219. else:
  220. return tree.map_structure_with_path(
  221. lambda path, component: get_placeholder(
  222. space=component,
  223. name=name + "." + ".".join([str(p) for p in path]),
  224. ),
  225. get_base_struct_from_space(space),
  226. )
  227. return tf1.placeholder(
  228. shape=(None,) + ((None,) if time_axis else ()) + space.shape,
  229. dtype=tf.float32 if space.dtype == np.float64 else space.dtype,
  230. name=name,
  231. )
  232. else:
  233. assert value is not None
  234. shape = value.shape[1:]
  235. return tf1.placeholder(
  236. shape=(None,)
  237. + ((None,) if time_axis else ())
  238. + (shape if isinstance(shape, tuple) else tuple(shape.as_list())),
  239. dtype=tf.float32 if value.dtype == np.float64 else value.dtype,
  240. name=name,
  241. )
  242. @PublicAPI
  243. def get_tf_eager_cls_if_necessary(
  244. orig_cls: Type["TFPolicy"],
  245. config: Union["AlgorithmConfig", PartialAlgorithmConfigDict],
  246. ) -> Type[Union["TFPolicy", "EagerTFPolicy", "EagerTFPolicyV2"]]:
  247. """Returns the corresponding tf-eager class for a given TFPolicy class.
  248. Args:
  249. orig_cls: The original TFPolicy class to get the corresponding tf-eager
  250. class for.
  251. config: The Algorithm config dict or AlgorithmConfig object.
  252. Returns:
  253. The tf eager policy class corresponding to the given TFPolicy class.
  254. """
  255. cls = orig_cls
  256. framework = config.get("framework", "tf")
  257. if framework in ["tf2", "tf"] and not tf1:
  258. raise ImportError("Could not import tensorflow!")
  259. if framework == "tf2":
  260. if not tf1.executing_eagerly():
  261. tf1.enable_eager_execution()
  262. assert tf1.executing_eagerly()
  263. from ray.rllib.policy.eager_tf_policy import EagerTFPolicy
  264. from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2
  265. from ray.rllib.policy.tf_policy import TFPolicy
  266. # Create eager-class (if not already one).
  267. if hasattr(orig_cls, "as_eager") and not issubclass(orig_cls, EagerTFPolicy):
  268. cls = orig_cls.as_eager()
  269. # Could be some other type of policy or already
  270. # eager-ized.
  271. elif not issubclass(orig_cls, TFPolicy):
  272. pass
  273. else:
  274. raise ValueError(
  275. "This policy does not support eager execution: {}".format(orig_cls)
  276. )
  277. # Now that we know, policy is an eager one, add tracing, if necessary.
  278. if config.get("eager_tracing") and issubclass(
  279. cls, (EagerTFPolicy, EagerTFPolicyV2)
  280. ):
  281. cls = cls.with_tracing()
  282. return cls
  283. @PublicAPI
  284. def huber_loss(x: TensorType, delta: float = 1.0) -> TensorType:
  285. """Computes the huber loss for a given term and delta parameter.
  286. Reference: https://en.wikipedia.org/wiki/Huber_loss
  287. Note that the factor of 0.5 is implicitly included in the calculation.
  288. Formula:
  289. L = 0.5 * x^2 for small abs x (delta threshold)
  290. L = delta * (abs(x) - 0.5*delta) for larger abs x (delta threshold)
  291. Args:
  292. x: The input term, e.g. a TD error.
  293. delta: The delta parmameter in the above formula.
  294. Returns:
  295. The Huber loss resulting from `x` and `delta`.
  296. """
  297. return tf.where(
  298. tf.abs(x) < delta, # for small x -> apply the Huber correction
  299. tf.math.square(x) * 0.5,
  300. delta * (tf.abs(x) - 0.5 * delta),
  301. )
  302. @PublicAPI
  303. def l2_loss(x: TensorType) -> TensorType:
  304. """Computes half the L2 norm over a tensor's values without the sqrt.
  305. output = 0.5 * sum(x ** 2)
  306. Args:
  307. x: The input tensor.
  308. Returns:
  309. 0.5 times the L2 norm over the given tensor's values (w/o sqrt).
  310. """
  311. return 0.5 * tf.reduce_sum(tf.pow(x, 2.0))
  312. @PublicAPI
  313. def make_tf_callable(
  314. session_or_none: Optional["tf1.Session"], dynamic_shape: bool = False
  315. ) -> Callable:
  316. """Returns a function that can be executed in either graph or eager mode.
  317. The function must take only positional args.
  318. If eager is enabled, this will act as just a function. Otherwise, it
  319. will build a function that executes a session run with placeholders
  320. internally.
  321. Args:
  322. session_or_none: tf.Session if in graph mode, else None.
  323. dynamic_shape: True if the placeholders should have a dynamic
  324. batch dimension. Otherwise they will be fixed shape.
  325. Returns:
  326. A function that can be called in either eager or static-graph mode.
  327. """
  328. if tf.executing_eagerly():
  329. assert session_or_none is None
  330. else:
  331. assert session_or_none is not None
  332. def make_wrapper(fn):
  333. # Static-graph mode: Create placeholders and make a session call each
  334. # time the wrapped function is called. Returns the output of this
  335. # session call.
  336. if session_or_none is not None:
  337. args_placeholders = []
  338. kwargs_placeholders = {}
  339. symbolic_out = [None]
  340. def call(*args, **kwargs):
  341. args_flat = []
  342. for a in args:
  343. if type(a) is list:
  344. args_flat.extend(a)
  345. else:
  346. args_flat.append(a)
  347. args = args_flat
  348. # We have not built any placeholders yet: Do this once here,
  349. # then reuse the same placeholders each time we call this
  350. # function again.
  351. if symbolic_out[0] is None:
  352. with session_or_none.graph.as_default():
  353. def _create_placeholders(path, value):
  354. if dynamic_shape:
  355. if len(value.shape) > 0:
  356. shape = (None,) + value.shape[1:]
  357. else:
  358. shape = ()
  359. else:
  360. shape = value.shape
  361. return tf1.placeholder(
  362. dtype=value.dtype,
  363. shape=shape,
  364. name=".".join([str(p) for p in path]),
  365. )
  366. placeholders = tree.map_structure_with_path(
  367. _create_placeholders, args
  368. )
  369. for ph in tree.flatten(placeholders):
  370. args_placeholders.append(ph)
  371. placeholders = tree.map_structure_with_path(
  372. _create_placeholders, kwargs
  373. )
  374. for k, ph in placeholders.items():
  375. kwargs_placeholders[k] = ph
  376. symbolic_out[0] = fn(*args_placeholders, **kwargs_placeholders)
  377. feed_dict = dict(zip(args_placeholders, tree.flatten(args)))
  378. tree.map_structure(
  379. lambda ph, v: feed_dict.__setitem__(ph, v),
  380. kwargs_placeholders,
  381. kwargs,
  382. )
  383. ret = session_or_none.run(symbolic_out[0], feed_dict)
  384. return ret
  385. return call
  386. # Eager mode (call function as is).
  387. else:
  388. return fn
  389. return make_wrapper
  390. # TODO (sven): Deprecate this function once we have moved completely to the Learner API.
  391. # Replaced with `clip_gradients()`.
  392. @PublicAPI
  393. def minimize_and_clip(
  394. optimizer: LocalOptimizer,
  395. objective: TensorType,
  396. var_list: List["tf.Variable"],
  397. clip_val: float = 10.0,
  398. ) -> ModelGradients:
  399. """Computes, then clips gradients using objective, optimizer and var list.
  400. Ensures the norm of the gradients for each variable is clipped to
  401. `clip_val`.
  402. Args:
  403. optimizer: Either a shim optimizer (tf eager) containing a
  404. tf.GradientTape under `self.tape` or a tf1 local optimizer
  405. object.
  406. objective: The loss tensor to calculate gradients on.
  407. var_list: The list of tf.Variables to compute gradients over.
  408. clip_val: The global norm clip value. Will clip around -clip_val and
  409. +clip_val.
  410. Returns:
  411. The resulting model gradients (list or tuples of grads + vars)
  412. corresponding to the input `var_list`.
  413. """
  414. # Accidentally passing values < 0.0 will break all gradients.
  415. assert clip_val is None or clip_val > 0.0, clip_val
  416. if tf.executing_eagerly():
  417. tape = optimizer.tape
  418. grads_and_vars = list(zip(list(tape.gradient(objective, var_list)), var_list))
  419. else:
  420. grads_and_vars = optimizer.compute_gradients(objective, var_list=var_list)
  421. return [
  422. (tf.clip_by_norm(g, clip_val) if clip_val is not None else g, v)
  423. for (g, v) in grads_and_vars
  424. if g is not None
  425. ]
  426. @PublicAPI
  427. def one_hot(x: TensorType, space: gym.Space) -> TensorType:
  428. """Returns a one-hot tensor, given and int tensor and a space.
  429. Handles the MultiDiscrete case as well.
  430. Args:
  431. x: The input tensor.
  432. space: The space to use for generating the one-hot tensor.
  433. Returns:
  434. The resulting one-hot tensor.
  435. Raises:
  436. ValueError: If the given space is not a discrete one.
  437. .. testcode::
  438. :skipif: True
  439. import gymnasium as gym
  440. import tensorflow as tf
  441. from ray.rllib.utils.tf_utils import one_hot
  442. x = tf.Variable([0, 3], dtype=tf.int32) # batch-dim=2
  443. # Discrete space with 4 (one-hot) slots per batch item.
  444. s = gym.spaces.Discrete(4)
  445. one_hot(x, s)
  446. .. testoutput::
  447. <tf.Tensor 'one_hot:0' shape=(2, 4) dtype=float32>
  448. .. testcode::
  449. :skipif: True
  450. x = tf.Variable([[0, 1, 2, 3]], dtype=tf.int32) # batch-dim=1
  451. # MultiDiscrete space with 5 + 4 + 4 + 7 = 20 (one-hot) slots
  452. # per batch item.
  453. s = gym.spaces.MultiDiscrete([5, 4, 4, 7])
  454. one_hot(x, s)
  455. .. testoutput::
  456. <tf.Tensor 'concat:0' shape=(1, 20) dtype=float32>
  457. """
  458. if isinstance(space, Discrete):
  459. return tf.one_hot(x, space.n, dtype=tf.float32)
  460. elif isinstance(space, MultiDiscrete):
  461. if isinstance(space.nvec[0], np.ndarray):
  462. nvec = np.ravel(space.nvec)
  463. x = tf.reshape(x, (x.shape[0], -1))
  464. else:
  465. nvec = space.nvec
  466. return tf.concat(
  467. [tf.one_hot(x[:, i], n, dtype=tf.float32) for i, n in enumerate(nvec)],
  468. axis=-1,
  469. )
  470. else:
  471. raise ValueError("Unsupported space for `one_hot`: {}".format(space))
  472. @PublicAPI
  473. def reduce_mean_ignore_inf(x: TensorType, axis: Optional[int] = None) -> TensorType:
  474. """Same as tf.reduce_mean() but ignores -inf values.
  475. Args:
  476. x: The input tensor to reduce mean over.
  477. axis: The axis over which to reduce. None for all axes.
  478. Returns:
  479. The mean reduced inputs, ignoring inf values.
  480. """
  481. mask = tf.not_equal(x, tf.float32.min)
  482. x_zeroed = tf.where(mask, x, tf.zeros_like(x))
  483. return tf.math.reduce_sum(x_zeroed, axis) / tf.math.reduce_sum(
  484. tf.cast(mask, tf.float32), axis
  485. )
  486. @PublicAPI
  487. def scope_vars(
  488. scope: Union[str, "tf1.VariableScope"], trainable_only: bool = False
  489. ) -> List["tf.Variable"]:
  490. """Get variables inside a given scope.
  491. Args:
  492. scope: Scope in which the variables reside.
  493. trainable_only: Whether or not to return only the variables that were
  494. marked as trainable.
  495. Returns:
  496. The list of variables in the given `scope`.
  497. """
  498. return tf1.get_collection(
  499. tf1.GraphKeys.TRAINABLE_VARIABLES
  500. if trainable_only
  501. else tf1.GraphKeys.VARIABLES,
  502. scope=scope if isinstance(scope, str) else scope.name,
  503. )
  504. @PublicAPI
  505. def symlog(x: "tf.Tensor") -> "tf.Tensor":
  506. """The symlog function as described in [1]:
  507. [1] Mastering Diverse Domains through World Models - 2023
  508. D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap
  509. https://arxiv.org/pdf/2301.04104v1.pdf
  510. """
  511. return tf.math.sign(x) * tf.math.log(tf.math.abs(x) + 1)
  512. @PublicAPI
  513. def inverse_symlog(y: "tf.Tensor") -> "tf.Tensor":
  514. """Inverse of the `symlog` function as desribed in [1]:
  515. [1] Mastering Diverse Domains through World Models - 2023
  516. D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap
  517. https://arxiv.org/pdf/2301.04104v1.pdf
  518. """
  519. # To get to symlog inverse, we solve the symlog equation for x:
  520. # y = sign(x) * log(|x| + 1)
  521. # <=> y / sign(x) = log(|x| + 1)
  522. # <=> y = log( x + 1) V x >= 0
  523. # -y = log(-x + 1) V x < 0
  524. # <=> exp(y) = x + 1 V x >= 0
  525. # exp(-y) = -x + 1 V x < 0
  526. # <=> exp(y) - 1 = x V x >= 0
  527. # exp(-y) - 1 = -x V x < 0
  528. # <=> exp(y) - 1 = x V x >= 0 (if x >= 0, then y must also be >= 0)
  529. # -exp(-y) - 1 = x V x < 0 (if x < 0, then y must also be < 0)
  530. # <=> sign(y) * (exp(|y|) - 1) = x
  531. return tf.math.sign(y) * (tf.math.exp(tf.math.abs(y)) - 1)
  532. @PublicAPI
  533. def two_hot(
  534. value: "tf.Tensor",
  535. num_buckets: int = 255,
  536. lower_bound: float = -20.0,
  537. upper_bound: float = 20.0,
  538. dtype=None,
  539. ):
  540. """Returns a two-hot vector of dim=num_buckets with two entries that are non-zero.
  541. See [1] for more details:
  542. [1] Mastering Diverse Domains through World Models - 2023
  543. D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap
  544. https://arxiv.org/pdf/2301.04104v1.pdf
  545. Entries in the vector represent equally sized buckets within some fixed range
  546. (`lower_bound` to `upper_bound`).
  547. Those entries not 0.0 at positions k and k+1 encode the actual `value` and sum
  548. up to 1.0. They are the weights multiplied by the buckets values at k and k+1 for
  549. retrieving `value`.
  550. Example:
  551. num_buckets=11
  552. lower_bound=-5
  553. upper_bound=5
  554. value=2.5
  555. -> [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.5, 0.0, 0.0]
  556. -> [-5 -4 -3 -2 -1 0 1 2 3 4 5] (0.5*2 + 0.5*3=2.5)
  557. Example:
  558. num_buckets=5
  559. lower_bound=-1
  560. upper_bound=1
  561. value=0.1
  562. -> [0.0, 0.0, 0.8, 0.2, 0.0]
  563. -> [-1 -0.5 0 0.5 1] (0.2*0.5 + 0.8*0=0.1)
  564. Args:
  565. value: The input tensor of shape (B,) to be two-hot encoded.
  566. num_buckets: The number of buckets to two-hot encode into.
  567. lower_bound: The lower bound value used for the encoding. If input values are
  568. lower than this boundary, they will be encoded as `lower_bound`.
  569. upper_bound: The upper bound value used for the encoding. If input values are
  570. higher than this boundary, they will be encoded as `upper_bound`.
  571. Returns:
  572. The two-hot encoded tensor of shape (B, num_buckets).
  573. """
  574. # First make sure, values are clipped.
  575. value = tf.clip_by_value(value, lower_bound, upper_bound)
  576. # Tensor of batch indices: [0, B=batch size).
  577. batch_indices = tf.cast(
  578. tf.range(0, tf.shape(value)[0]),
  579. dtype=dtype or tf.float32,
  580. )
  581. # Calculate the step deltas (how much space between each bucket's central value?).
  582. bucket_delta = (upper_bound - lower_bound) / (num_buckets - 1)
  583. # Compute the float indices (might be non-int numbers: sitting between two buckets).
  584. idx = (-lower_bound + value) / bucket_delta
  585. # k
  586. k = tf.math.floor(idx)
  587. # k+1
  588. kp1 = tf.math.ceil(idx)
  589. # In case k == kp1 (idx is exactly on the bucket boundary), move kp1 up by 1.0.
  590. # Otherwise, this would result in a NaN in the returned two-hot tensor.
  591. kp1 = tf.where(tf.equal(k, kp1), kp1 + 1.0, kp1)
  592. # Iff `kp1` is one beyond our last index (because incoming value is larger than
  593. # `upper_bound`), move it to one before k (kp1's weight is going to be 0.0 anyways,
  594. # so it doesn't matter where it points to; we are just avoiding an index error
  595. # with this).
  596. kp1 = tf.where(tf.equal(kp1, num_buckets), kp1 - 2.0, kp1)
  597. # The actual values found at k and k+1 inside the set of buckets.
  598. values_k = lower_bound + k * bucket_delta
  599. values_kp1 = lower_bound + kp1 * bucket_delta
  600. # Compute the two-hot weights (adding up to 1.0) to use at index k and k+1.
  601. weights_k = (value - values_kp1) / (values_k - values_kp1)
  602. weights_kp1 = 1.0 - weights_k
  603. # Compile a tensor of full paths (indices from batch index to feature index) to
  604. # use for the scatter_nd op.
  605. indices_k = tf.stack([batch_indices, k], -1)
  606. indices_kp1 = tf.stack([batch_indices, kp1], -1)
  607. indices = tf.concat([indices_k, indices_kp1], 0)
  608. # The actual values (weights adding up to 1.0) to place at the computed indices.
  609. updates = tf.concat([weights_k, weights_kp1], 0)
  610. # Call the actual scatter update op, returning a zero-filled tensor, only changed
  611. # at the given indices.
  612. return tf.scatter_nd(
  613. tf.cast(indices, tf.int32),
  614. updates,
  615. shape=(tf.shape(value)[0], num_buckets),
  616. )
  617. @PublicAPI
  618. def update_target_network(
  619. main_net: NetworkType,
  620. target_net: NetworkType,
  621. tau: float,
  622. ) -> None:
  623. """Updates a keras.Model target network using Polyak averaging.
  624. new_target_net_weight = (
  625. tau * main_net_weight + (1.0 - tau) * current_target_net_weight
  626. )
  627. Args:
  628. main_net: The keras.Model to update from.
  629. target_net: The target network to update.
  630. tau: The tau value to use in the Polyak averaging formula.
  631. """
  632. for old_var, current_var in zip(target_net.variables, main_net.variables):
  633. updated_var = tau * current_var + (1.0 - tau) * old_var
  634. old_var.assign(updated_var)
  635. @PublicAPI
  636. def zero_logps_from_actions(actions: TensorStructType) -> TensorType:
  637. """Helper function useful for returning dummy logp's (0) for some actions.
  638. Args:
  639. actions: The input actions. This can be any struct
  640. of complex action components or a simple tensor of different
  641. dimensions, e.g. [B], [B, 2], or {"a": [B, 4, 5], "b": [B]}.
  642. Returns:
  643. A 1D tensor of 0.0 (dummy logp's) matching the batch
  644. dim of `actions` (shape=[B]).
  645. """
  646. # Need to flatten `actions` in case we have a complex action space.
  647. # Take the 0th component to extract the batch dim.
  648. action_component = tree.flatten(actions)[0]
  649. logp_ = tf.zeros_like(action_component, dtype=tf.float32)
  650. # Logp's should be single values (but with the same batch dim as
  651. # `deterministic_actions` or `stochastic_actions`). In case
  652. # actions are just [B], zeros_like works just fine here, but if
  653. # actions are [B, ...], we have to reduce logp back to just [B].
  654. while len(logp_.shape) > 1:
  655. logp_ = logp_[:, 0]
  656. return logp_
  657. @DeveloperAPI
  658. def warn_if_infinite_kl_divergence(
  659. policy: Type["TFPolicy"], mean_kl: TensorType
  660. ) -> None:
  661. def print_warning():
  662. logger.warning(
  663. "KL divergence is non-finite, this will likely destabilize your model and"
  664. " the training process. Action(s) in a specific state have near-zero"
  665. " probability. This can happen naturally in deterministic environments"
  666. " where the optimal policy has zero mass for a specific action. To fix this"
  667. " issue, consider setting the coefficient for the KL loss term to zero or"
  668. " increasing policy entropy."
  669. )
  670. return tf.constant(0.0)
  671. if policy.loss_initialized():
  672. tf.cond(
  673. tf.math.is_inf(mean_kl),
  674. false_fn=lambda: tf.constant(0.0),
  675. true_fn=lambda: print_warning(),
  676. )
  677. def _unflatten(vector, shapes):
  678. i = 0
  679. arrays = []
  680. for shape in shapes:
  681. size = np.prod(shape, dtype=np.int_)
  682. array = vector[i : (i + size)].reshape(shape)
  683. arrays.append(array)
  684. i += size
  685. assert len(vector) == i, "Passed weight does not have the correct shape."
  686. return arrays
  687. @DeveloperAPI
  688. class TensorFlowVariables:
  689. """A class used to set and get weights for Tensorflow networks.
  690. Attributes:
  691. sess (tf.Session): The tensorflow session used to run assignment.
  692. variables (Dict[str, tf.Variable]): Extracted variables from the loss
  693. or additional variables that are passed in.
  694. placeholders (Dict[str, tf.placeholders]): Placeholders for weights.
  695. assignment_nodes (Dict[str, tf.Tensor]): Nodes that assign weights.
  696. """
  697. def __init__(self, output, sess=None, input_variables=None):
  698. """Creates TensorFlowVariables containing extracted variables.
  699. The variables are extracted by performing a BFS search on the
  700. dependency graph with loss as the root node. After the tree is
  701. traversed and those variables are collected, we append input_variables
  702. to the collected variables. For each variable in the list, the
  703. variable has a placeholder and assignment operation created for it.
  704. Args:
  705. output (tf.Operation, List[tf.Operation]): The tensorflow
  706. operation to extract all variables from.
  707. sess (Optional[tf.Session]): Optional tf.Session used for running
  708. the get and set methods in tf graph mode.
  709. Use None for tf eager.
  710. input_variables (List[tf.Variables]): Variables to include in the
  711. list.
  712. """
  713. self.sess = sess
  714. output = force_list(output)
  715. queue = deque(output)
  716. variable_names = []
  717. explored_inputs = set(output)
  718. # We do a BFS on the dependency graph of the input function to find
  719. # the variables.
  720. while len(queue) != 0:
  721. tf_obj = queue.popleft()
  722. if tf_obj is None:
  723. continue
  724. # The object put into the queue is not necessarily an operation,
  725. # so we want the op attribute to get the operation underlying the
  726. # object. Only operations contain the inputs that we can explore.
  727. if hasattr(tf_obj, "op"):
  728. tf_obj = tf_obj.op
  729. for input_op in tf_obj.inputs:
  730. if input_op not in explored_inputs:
  731. queue.append(input_op)
  732. explored_inputs.add(input_op)
  733. # Tensorflow control inputs can be circular, so we keep track of
  734. # explored operations.
  735. for control in tf_obj.control_inputs:
  736. if control not in explored_inputs:
  737. queue.append(control)
  738. explored_inputs.add(control)
  739. if "Variable" in tf_obj.node_def.op or "VarHandle" in tf_obj.node_def.op:
  740. variable_names.append(tf_obj.node_def.name)
  741. self.variables = OrderedDict()
  742. variable_list = [
  743. v for v in tf1.global_variables() if v.op.node_def.name in variable_names
  744. ]
  745. if input_variables is not None:
  746. variable_list += input_variables
  747. def _get_var_name(v):
  748. """Get variable name, supporting both TF1 ResourceVariable and
  749. Keras 3 Variable objects."""
  750. if hasattr(v, "op"):
  751. return v.op.node_def.name
  752. return v.name
  753. if not tf1.executing_eagerly():
  754. for v in variable_list:
  755. self.variables[_get_var_name(v)] = v
  756. self.placeholders = {}
  757. self.assignment_nodes = {}
  758. # Create new placeholders to put in custom weights.
  759. for k, var in self.variables.items():
  760. dtype = var.value().dtype if hasattr(var, "op") else var.dtype
  761. shape = (
  762. var.get_shape().as_list()
  763. if hasattr(var, "get_shape")
  764. else list(var.shape)
  765. )
  766. self.placeholders[k] = tf1.placeholder(
  767. dtype,
  768. shape,
  769. name="Placeholder_" + k,
  770. )
  771. self.assignment_nodes[k] = var.assign(self.placeholders[k])
  772. else:
  773. for v in variable_list:
  774. self.variables[v.name] = v
  775. def get_flat_size(self):
  776. """Returns the total length of all of the flattened variables.
  777. Returns:
  778. The length of all flattened variables concatenated.
  779. """
  780. return sum(np.prod(v.get_shape().as_list()) for v in self.variables.values())
  781. def get_flat(self):
  782. """Gets the weights and returns them as a flat array.
  783. Returns:
  784. 1D Array containing the flattened weights.
  785. """
  786. # Eager mode.
  787. if not self.sess:
  788. return np.concatenate(
  789. [v.numpy().flatten() for v in self.variables.values()]
  790. )
  791. # Graph mode.
  792. return np.concatenate(
  793. [v.eval(session=self.sess).flatten() for v in self.variables.values()]
  794. )
  795. def set_flat(self, new_weights):
  796. """Sets the weights to new_weights, converting from a flat array.
  797. Note:
  798. You can only set all weights in the network using this function,
  799. i.e., the length of the array must match get_flat_size.
  800. Args:
  801. new_weights (np.ndarray): Flat array containing weights.
  802. """
  803. shapes = [v.get_shape().as_list() for v in self.variables.values()]
  804. arrays = _unflatten(new_weights, shapes)
  805. if not self.sess:
  806. for v, a in zip(self.variables.values(), arrays):
  807. v.assign(a)
  808. else:
  809. placeholders = [self.placeholders[k] for k, v in self.variables.items()]
  810. self.sess.run(
  811. list(self.assignment_nodes.values()),
  812. feed_dict=dict(zip(placeholders, arrays)),
  813. )
  814. def get_weights(self):
  815. """Returns a dictionary containing the weights of the network.
  816. Returns:
  817. Dictionary mapping variable names to their weights.
  818. """
  819. # Eager mode.
  820. if not self.sess:
  821. return self.variables
  822. # Graph mode.
  823. return self.sess.run(self.variables)
  824. def set_weights(self, new_weights: dict):
  825. """Sets the weights to new_weights.
  826. Note:
  827. Can set subsets of variables as well, by only passing in the
  828. variables you want to be set.
  829. Args:
  830. new_weights: Dictionary mapping variable names to their
  831. weights.
  832. """
  833. if self.sess is None:
  834. for name, var in self.variables.items():
  835. var.assign(new_weights[name])
  836. else:
  837. assign_list, feed_dict = self._assign_weights(new_weights)
  838. self.sess.run(assign_list, feed_dict=feed_dict)
  839. def _assign_weights(self, weights):
  840. """Sets weigths using exact or closest assignable variable name
  841. Args:
  842. weights: Dictionary mapping variable names to their
  843. weights.
  844. Returns:
  845. Tuple[List, Dict]: assigned variables list, dict of
  846. placeholders and weights
  847. """
  848. assigned = []
  849. feed_dict = {}
  850. assignable = set(self.assignment_nodes.keys())
  851. def nb_common_elem(l1, l2):
  852. return len([e for e in l1 if e in l2])
  853. def assign(name, value):
  854. feed_dict[self.placeholders[name]] = value
  855. assigned.append(name)
  856. assignable.remove(name)
  857. for name, value in weights.items():
  858. if name in assignable:
  859. assign(name, value)
  860. else:
  861. common = {
  862. var: nb_common_elem(name.split("/"), var.split("/"))
  863. for var in assignable
  864. }
  865. select = [
  866. close_var
  867. for close_var, cn in sorted(common.items(), key=lambda i: -i[1])
  868. if cn > 0 and value.shape == self.assignment_nodes[close_var].shape
  869. ]
  870. if select:
  871. assign(select[0], value)
  872. assert assigned, (
  873. "No variables in the input matched those in the network. "
  874. "Possible cause: Two networks were defined in the same "
  875. "TensorFlow graph. To fix this, place each network "
  876. "definition in its own tf.Graph."
  877. )
  878. assert len(assigned) == len(weights), (
  879. "All weights couldn't be assigned because no variable "
  880. "had an exact/close name or had same shape"
  881. )
  882. return [self.assignment_nodes[v] for v in assigned], feed_dict