torch_utils.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886
  1. import logging
  2. import os
  3. import warnings
  4. from typing import TYPE_CHECKING, Dict, List, Optional, Union
  5. import gymnasium as gym
  6. import numpy as np
  7. import tree # pip install dm_tree
  8. from gymnasium.spaces import Discrete, MultiDiscrete
  9. from packaging import version
  10. from ray.rllib.models.repeated_values import RepeatedValues
  11. from ray.rllib.utils.annotations import DeveloperAPI, OldAPIStack, PublicAPI
  12. from ray.rllib.utils.framework import try_import_torch
  13. from ray.rllib.utils.numpy import SMALL_NUMBER
  14. from ray.rllib.utils.typing import (
  15. LocalOptimizer,
  16. NetworkType,
  17. SpaceStruct,
  18. TensorStructType,
  19. TensorType,
  20. )
  21. if TYPE_CHECKING:
  22. from ray.rllib.core.learner.learner import ParamDict, ParamList
  23. from ray.rllib.policy.torch_policy import TorchPolicy
  24. from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
  25. logger = logging.getLogger(__name__)
  26. torch, nn = try_import_torch()
  27. # Limit values suitable for use as close to a -inf logit. These are useful
  28. # since -inf / inf cause NaNs during backprop.
  29. FLOAT_MIN = -3.4e38
  30. FLOAT_MAX = 3.4e38
  31. if torch:
  32. TORCH_COMPILE_REQUIRED_VERSION = version.parse("2.0.0")
  33. else:
  34. TORCH_COMPILE_REQUIRED_VERSION = ValueError(
  35. "torch is not installed. TORCH_COMPILE_REQUIRED_VERSION is not defined."
  36. )
  37. @OldAPIStack
  38. def apply_grad_clipping(
  39. policy: "TorchPolicy", optimizer: LocalOptimizer, loss: TensorType
  40. ) -> Dict[str, TensorType]:
  41. """Applies gradient clipping to already computed grads inside `optimizer`.
  42. Note: This function does NOT perform an analogous operation as
  43. tf.clip_by_global_norm. It merely clips by norm (per gradient tensor) and
  44. then computes the global norm across all given tensors (but without clipping
  45. by that global norm).
  46. Args:
  47. policy: The TorchPolicy, which calculated `loss`.
  48. optimizer: A local torch optimizer object.
  49. loss: The torch loss tensor.
  50. Returns:
  51. An info dict containing the "grad_norm" key and the resulting clipped
  52. gradients.
  53. """
  54. grad_gnorm = 0
  55. if policy.config["grad_clip"] is not None:
  56. clip_value = policy.config["grad_clip"]
  57. else:
  58. clip_value = np.inf
  59. num_none_grads = 0
  60. for param_group in optimizer.param_groups:
  61. # Make sure we only pass params with grad != None into torch
  62. # clip_grad_norm_. Would fail otherwise.
  63. params = list(filter(lambda p: p.grad is not None, param_group["params"]))
  64. if params:
  65. # PyTorch clips gradients inplace and returns the norm before clipping
  66. # We therefore need to compute grad_gnorm further down (fixes #4965)
  67. global_norm = nn.utils.clip_grad_norm_(params, clip_value)
  68. if isinstance(global_norm, torch.Tensor):
  69. global_norm = global_norm.cpu().numpy()
  70. grad_gnorm += min(global_norm, clip_value)
  71. else:
  72. num_none_grads += 1
  73. # Note (Kourosh): grads could indeed be zero. This method should still return
  74. # grad_gnorm in that case.
  75. if num_none_grads == len(optimizer.param_groups):
  76. # No grads available
  77. return {}
  78. return {"grad_gnorm": grad_gnorm}
  79. @PublicAPI
  80. def clip_gradients(
  81. gradients_dict: "ParamDict",
  82. *,
  83. grad_clip: Optional[float] = None,
  84. grad_clip_by: str = "value",
  85. ) -> TensorType:
  86. """Performs gradient clipping on a grad-dict based on a clip value and clip mode.
  87. Changes the provided gradient dict in place.
  88. Args:
  89. gradients_dict: The gradients dict, mapping str to gradient tensors.
  90. grad_clip: The value to clip with. The way gradients are clipped is defined
  91. by the `grad_clip_by` arg (see below).
  92. grad_clip_by: One of 'value', 'norm', or 'global_norm'.
  93. Returns:
  94. If `grad_clip_by`="global_norm" and `grad_clip` is not None, returns the global
  95. norm of all tensors, otherwise returns None.
  96. """
  97. # No clipping, return.
  98. if grad_clip is None:
  99. return
  100. if grad_clip_by not in ["value", "norm", "global_norm"]:
  101. raise ValueError(
  102. f"`grad_clip_by` ({grad_clip_by}) must be one of [value|norm|global_norm]!"
  103. )
  104. # Clip by value (each gradient individually).
  105. if grad_clip_by == "value":
  106. for k, v in gradients_dict.items():
  107. gradients_dict[k] = (
  108. None if v is None else torch.clip(v, -grad_clip, grad_clip)
  109. )
  110. # Clip by L2-norm (per gradient tensor).
  111. elif grad_clip_by == "norm":
  112. for k, v in gradients_dict.items():
  113. if v is not None:
  114. # Compute the L2-norm of the gradient tensor.
  115. norm = v.norm(2).nan_to_num(neginf=-10e8, posinf=10e8)
  116. # Clip all the gradients.
  117. if norm > grad_clip:
  118. v.mul_(grad_clip / norm)
  119. # Clip by global L2-norm (across all gradient tensors).
  120. else:
  121. gradients_list = list(gradients_dict.values())
  122. total_norm = compute_global_norm(gradients_list)
  123. if len(gradients_list) == 0:
  124. return total_norm
  125. # We do want the coefficient to be in between 0.0 and 1.0, therefore
  126. # if the global_norm is smaller than the clip value, we use the clip value
  127. # as normalization constant.
  128. clip_coeff = grad_clip / torch.clamp(total_norm + 1e-6, min=grad_clip)
  129. # Note: multiplying by the clamped coefficient is redundant when the coefficient
  130. # is clamped to 1, but doing so avoids a `if clip_coeff < 1:` conditional which
  131. # can require a CPU <=> device synchronization when the gradients reside in GPU
  132. # memory.
  133. clip_coeff_clamped = torch.clamp(clip_coeff, max=1.0)
  134. for g in gradients_list:
  135. if g is not None:
  136. g.detach().mul_(clip_coeff_clamped.to(g.device))
  137. return total_norm
  138. @PublicAPI
  139. def compute_global_norm(gradients_list: "ParamList") -> TensorType:
  140. """Computes the global norm for a gradients dict.
  141. Args:
  142. gradients_list: The gradients list containing parameters.
  143. Returns:
  144. Returns the global norm of all tensors in `gradients_list`.
  145. """
  146. # Define the norm type to be L2.
  147. norm_type = 2.0
  148. # If we have no grads, return zero.
  149. if len(gradients_list) == 0:
  150. return torch.tensor(0.0)
  151. # Compute the global norm.
  152. total_norm = torch.norm(
  153. torch.stack(
  154. [
  155. torch.norm(g.detach(), norm_type)
  156. # Note, we want to avoid overflow in the norm computation, this does
  157. # not affect the gradients themselves as we clamp by multiplying and
  158. # not by overriding tensor values.
  159. .nan_to_num(neginf=-10e8, posinf=10e8)
  160. for g in gradients_list
  161. if g is not None
  162. ]
  163. ),
  164. norm_type,
  165. ).nan_to_num(neginf=-10e8, posinf=10e8)
  166. # Return the global norm.
  167. return total_norm
  168. @OldAPIStack
  169. def concat_multi_gpu_td_errors(
  170. policy: Union["TorchPolicy", "TorchPolicyV2"]
  171. ) -> Dict[str, TensorType]:
  172. """Concatenates multi-GPU (per-tower) TD error tensors given TorchPolicy.
  173. TD-errors are extracted from the TorchPolicy via its tower_stats property.
  174. Args:
  175. policy: The TorchPolicy to extract the TD-error values from.
  176. Returns:
  177. A dict mapping strings "td_error" and "mean_td_error" to the
  178. corresponding concatenated and mean-reduced values.
  179. """
  180. td_error = torch.cat(
  181. [
  182. t.tower_stats.get("td_error", torch.tensor([0.0])).to(policy.device)
  183. for t in policy.model_gpu_towers
  184. ],
  185. dim=0,
  186. )
  187. policy.td_error = td_error
  188. return {
  189. "td_error": td_error,
  190. "mean_td_error": torch.mean(td_error),
  191. }
  192. @PublicAPI
  193. def convert_to_torch_tensor(
  194. x,
  195. device: Optional[str] = None,
  196. pin_memory: bool = False,
  197. use_stream: bool = False,
  198. stream: Optional[Union["torch.cuda.Stream", "torch.cuda.classes.Stream"]] = None,
  199. ):
  200. """
  201. Converts any (possibly nested) structure to torch.Tensors.
  202. Args:
  203. x: The input structure whose leaves will be converted.
  204. device: The device to create the tensor on (e.g. "cuda:0" or "cpu").
  205. pin_memory: If True, calls `pin_memory()` on the created tensors.
  206. use_stream: If True, uses a separate CUDA stream for `Tensor.to()`.
  207. stream: An optional CUDA stream for the host-to-device copy in `Tensor.to()`.
  208. Returns:
  209. A new structure with the same layout as `x` but with all leaves converted
  210. to torch.Tensors. Leaves that are None are left unchanged.
  211. """
  212. # Convert the provided device (if any) to a torch.device; default to CPU.
  213. device = torch.device(device) if device is not None else torch.device("cpu")
  214. is_cuda = (device.type == "cuda") and torch.cuda.is_available()
  215. # Determine the appropriate stream.
  216. if is_cuda:
  217. if use_stream:
  218. if stream is not None:
  219. # Ensure the provided stream is of an acceptable type.
  220. assert isinstance(
  221. stream, (torch.cuda.Stream, torch.cuda.classes.Stream)
  222. ), f"`stream` must be a torch.cuda.Stream but got {type(stream)}."
  223. else:
  224. stream = torch.cuda.Stream()
  225. else:
  226. stream = torch.cuda.default_stream(device=device)
  227. else:
  228. stream = None
  229. def mapping(item):
  230. # Pass through None values.
  231. if item is None:
  232. return item
  233. # Special handling for "RepeatedValues" types.
  234. if isinstance(item, RepeatedValues):
  235. return RepeatedValues(
  236. tree.map_structure(mapping, item.values),
  237. item.lengths,
  238. item.max_len,
  239. )
  240. # Convert to a tensor if not already one.
  241. if torch.is_tensor(item):
  242. tensor = item
  243. elif isinstance(item, np.ndarray):
  244. # Leave object or string arrays as is.
  245. if item.dtype == object or item.dtype.type is np.str_:
  246. return item
  247. # If the numpy array is not writable, suppress warnings.
  248. if not item.flags.writeable:
  249. with warnings.catch_warnings():
  250. warnings.simplefilter("ignore")
  251. tensor = torch.from_numpy(item)
  252. else:
  253. tensor = torch.from_numpy(item)
  254. else:
  255. tensor = torch.from_numpy(np.asarray(item))
  256. # Convert floating-point tensors from float64 to float32 (unless they are float16).
  257. if tensor.is_floating_point() and tensor.dtype != torch.float16:
  258. tensor = tensor.float()
  259. # Optionally pin memory for faster host-to-GPU copies.
  260. if pin_memory and is_cuda:
  261. tensor = tensor.pin_memory()
  262. # Move the tensor to the desired device.
  263. # For CUDA devices, use the provided stream context if available.
  264. if is_cuda:
  265. if stream is not None:
  266. with torch.cuda.stream(stream):
  267. tensor = tensor.to(device, non_blocking=True)
  268. else:
  269. tensor = tensor.to(device, non_blocking=True)
  270. else:
  271. # For CPU (or non-CUDA), this is a no-op if already on the target device.
  272. tensor = tensor.to(device)
  273. return tensor
  274. return tree.map_structure(mapping, x)
  275. @PublicAPI
  276. def copy_torch_tensors(x: TensorStructType, device: Optional[str] = None):
  277. """Creates a copy of `x` and makes deep copies torch.Tensors in x.
  278. Also moves the copied tensors to the specified device (if not None).
  279. Note if an object in x is not a torch.Tensor, it will be shallow-copied.
  280. Args:
  281. x : Any (possibly nested) struct possibly containing torch.Tensors.
  282. device : The device to move the tensors to.
  283. Returns:
  284. Any: A new struct with the same structure as `x`, but with all
  285. torch.Tensors deep-copied and moved to the specified device.
  286. """
  287. def mapping(item):
  288. if isinstance(item, torch.Tensor):
  289. return (
  290. torch.clone(item.detach())
  291. if device is None
  292. else item.detach().to(device)
  293. )
  294. else:
  295. return item
  296. return tree.map_structure(mapping, x)
  297. @PublicAPI
  298. def explained_variance(y: TensorType, pred: TensorType) -> TensorType:
  299. """Computes the explained variance for a pair of labels and predictions.
  300. The formula used is:
  301. max(-1.0, 1.0 - (std(y - pred)^2 / std(y)^2))
  302. Args:
  303. y: The labels.
  304. pred: The predictions.
  305. Returns:
  306. The explained variance given a pair of labels and predictions.
  307. """
  308. squeezed_y = y.squeeze()
  309. y_var = torch.var(squeezed_y, dim=0)
  310. diff_var = torch.var(squeezed_y - pred.squeeze(), dim=0)
  311. min_ = torch.tensor([-1.0]).to(pred.device)
  312. return torch.max(min_, 1 - (diff_var / (y_var + SMALL_NUMBER)))[0]
  313. @PublicAPI
  314. def flatten_inputs_to_1d_tensor(
  315. inputs: TensorStructType,
  316. spaces_struct: Optional[SpaceStruct] = None,
  317. time_axis: bool = False,
  318. ) -> TensorType:
  319. """Flattens arbitrary input structs according to the given spaces struct.
  320. Returns a single 1D tensor resulting from the different input
  321. components' values.
  322. Thereby:
  323. - Boxes (any shape) get flattened to (B, [T]?, -1). Note that image boxes
  324. are not treated differently from other types of Boxes and get
  325. flattened as well.
  326. - Discrete (int) values are one-hot'd, e.g. a batch of [1, 0, 3] (B=3 with
  327. Discrete(4) space) results in [[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1]].
  328. - MultiDiscrete values are multi-one-hot'd, e.g. a batch of
  329. [[0, 2], [1, 4]] (B=2 with MultiDiscrete([2, 5]) space) results in
  330. [[1, 0, 0, 0, 1, 0, 0], [0, 1, 0, 0, 0, 0, 1]].
  331. Args:
  332. inputs: The inputs to be flattened.
  333. spaces_struct: The structure of the spaces that behind the input
  334. time_axis: Whether all inputs have a time-axis (after the batch axis).
  335. If True, will keep not only the batch axis (0th), but the time axis
  336. (1st) as-is and flatten everything from the 2nd axis up.
  337. Returns:
  338. A single 1D tensor resulting from concatenating all
  339. flattened/one-hot'd input components. Depending on the time_axis flag,
  340. the shape is (B, n) or (B, T, n).
  341. .. testcode::
  342. from gymnasium.spaces import Discrete, Box
  343. from ray.rllib.utils.torch_utils import flatten_inputs_to_1d_tensor
  344. import torch
  345. struct = {
  346. "a": np.array([1, 3]),
  347. "b": (
  348. np.array([[1.0, 2.0], [4.0, 5.0]]),
  349. np.array(
  350. [[[8.0], [7.0]], [[5.0], [4.0]]]
  351. ),
  352. ),
  353. "c": {
  354. "cb": np.array([1.0, 2.0]),
  355. },
  356. }
  357. struct_torch = tree.map_structure(lambda s: torch.from_numpy(s), struct)
  358. spaces = dict(
  359. {
  360. "a": gym.spaces.Discrete(4),
  361. "b": (gym.spaces.Box(-1.0, 10.0, (2,)), gym.spaces.Box(-1.0, 1.0, (2,
  362. 1))),
  363. "c": dict(
  364. {
  365. "cb": gym.spaces.Box(-1.0, 1.0, ()),
  366. }
  367. ),
  368. }
  369. )
  370. print(flatten_inputs_to_1d_tensor(struct_torch, spaces_struct=spaces))
  371. .. testoutput::
  372. tensor([[0., 1., 0., 0., 1., 2., 8., 7., 1.],
  373. [0., 0., 0., 1., 4., 5., 5., 4., 2.]])
  374. """
  375. flat_inputs = tree.flatten(inputs)
  376. flat_spaces = (
  377. tree.flatten(spaces_struct)
  378. if spaces_struct is not None
  379. else [None] * len(flat_inputs)
  380. )
  381. B = None
  382. T = None
  383. out = []
  384. for input_, space in zip(flat_inputs, flat_spaces):
  385. # Store batch and (if applicable) time dimension.
  386. if B is None:
  387. B = input_.shape[0]
  388. if time_axis:
  389. T = input_.shape[1]
  390. # One-hot encoding.
  391. if isinstance(space, Discrete):
  392. if time_axis:
  393. input_ = torch.reshape(input_, [B * T])
  394. out.append(one_hot(input_, space).float())
  395. # Multi one-hot encoding.
  396. elif isinstance(space, MultiDiscrete):
  397. if time_axis:
  398. input_ = torch.reshape(input_, [B * T, -1])
  399. out.append(one_hot(input_, space).float())
  400. # Box: Flatten.
  401. else:
  402. if time_axis:
  403. input_ = torch.reshape(input_, [B * T, -1])
  404. else:
  405. input_ = torch.reshape(input_, [B, -1])
  406. out.append(input_.float())
  407. merged = torch.cat(out, dim=-1)
  408. # Restore the time-dimension, if applicable.
  409. if time_axis:
  410. merged = torch.reshape(merged, [B, T, -1])
  411. return merged
  412. @PublicAPI
  413. def global_norm(tensors: List[TensorType]) -> TensorType:
  414. """Returns the global L2 norm over a list of tensors.
  415. output = sqrt(SUM(t ** 2 for t in tensors)),
  416. where SUM reduces over all tensors and over all elements in tensors.
  417. Args:
  418. tensors: The list of tensors to calculate the global norm over.
  419. Returns:
  420. The global L2 norm over the given tensor list.
  421. """
  422. # List of single tensors' L2 norms: SQRT(SUM(xi^2)) over all xi in tensor.
  423. single_l2s = [torch.pow(torch.sum(torch.pow(t, 2.0)), 0.5) for t in tensors]
  424. # Compute global norm from all single tensors' L2 norms.
  425. return torch.pow(sum(torch.pow(l2, 2.0) for l2 in single_l2s), 0.5)
  426. @OldAPIStack
  427. def huber_loss(x: TensorType, delta: float = 1.0) -> TensorType:
  428. """Computes the huber loss for a given term and delta parameter.
  429. Reference: https://en.wikipedia.org/wiki/Huber_loss
  430. Note that the factor of 0.5 is implicitly included in the calculation.
  431. Formula:
  432. L = 0.5 * x^2 for small abs x (delta threshold)
  433. L = delta * (abs(x) - 0.5*delta) for larger abs x (delta threshold)
  434. Args:
  435. x: The input term, e.g. a TD error.
  436. delta: The delta parmameter in the above formula.
  437. Returns:
  438. The Huber loss resulting from `x` and `delta`.
  439. """
  440. return torch.where(
  441. torch.abs(x) < delta,
  442. torch.pow(x, 2.0) * 0.5,
  443. delta * (torch.abs(x) - 0.5 * delta),
  444. )
  445. @OldAPIStack
  446. def l2_loss(x: TensorType) -> TensorType:
  447. """Computes half the L2 norm over a tensor's values without the sqrt.
  448. output = 0.5 * sum(x ** 2)
  449. Args:
  450. x: The input tensor.
  451. Returns:
  452. 0.5 times the L2 norm over the given tensor's values (w/o sqrt).
  453. """
  454. return 0.5 * torch.sum(torch.pow(x, 2.0))
  455. @PublicAPI
  456. def one_hot(x: TensorType, space: gym.Space) -> TensorType:
  457. """Returns a one-hot tensor, given and int tensor and a space.
  458. Handles the MultiDiscrete case as well.
  459. Args:
  460. x: The input tensor.
  461. space: The space to use for generating the one-hot tensor.
  462. Returns:
  463. The resulting one-hot tensor.
  464. Raises:
  465. ValueError: If the given space is not a discrete one.
  466. .. testcode::
  467. import torch
  468. import gymnasium as gym
  469. from ray.rllib.utils.torch_utils import one_hot
  470. x = torch.IntTensor([0, 3]) # batch-dim=2
  471. # Discrete space with 4 (one-hot) slots per batch item.
  472. s = gym.spaces.Discrete(4)
  473. print(one_hot(x, s))
  474. x = torch.IntTensor([[0, 1, 2, 3]]) # batch-dim=1
  475. # MultiDiscrete space with 5 + 4 + 4 + 7 = 20 (one-hot) slots
  476. # per batch item.
  477. s = gym.spaces.MultiDiscrete([5, 4, 4, 7])
  478. print(one_hot(x, s))
  479. .. testoutput::
  480. tensor([[1, 0, 0, 0],
  481. [0, 0, 0, 1]])
  482. tensor([[1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0]])
  483. """
  484. if isinstance(space, Discrete):
  485. return nn.functional.one_hot(x.long(), space.n)
  486. elif isinstance(space, MultiDiscrete):
  487. if isinstance(space.nvec[0], np.ndarray):
  488. nvec = np.ravel(space.nvec)
  489. x = x.reshape(x.shape[0], -1)
  490. else:
  491. nvec = space.nvec
  492. return torch.cat(
  493. [nn.functional.one_hot(x[:, i].long(), n) for i, n in enumerate(nvec)],
  494. dim=-1,
  495. )
  496. else:
  497. raise ValueError("Unsupported space for `one_hot`: {}".format(space))
  498. @PublicAPI
  499. def reduce_mean_ignore_inf(x: TensorType, axis: Optional[int] = None) -> TensorType:
  500. """Same as torch.mean() but ignores -inf values.
  501. Args:
  502. x: The input tensor to reduce mean over.
  503. axis: The axis over which to reduce. None for all axes.
  504. Returns:
  505. The mean reduced inputs, ignoring inf values.
  506. """
  507. mask = torch.ne(x, float("-inf"))
  508. x_zeroed = torch.where(mask, x, torch.zeros_like(x))
  509. return torch.sum(x_zeroed, axis) / torch.sum(mask.float(), axis)
  510. @PublicAPI
  511. def sequence_mask(
  512. lengths: TensorType,
  513. maxlen: Optional[int] = None,
  514. dtype=None,
  515. time_major: bool = False,
  516. ) -> TensorType:
  517. """Offers same behavior as tf.sequence_mask for torch.
  518. Thanks to Dimitris Papatheodorou
  519. (https://discuss.pytorch.org/t/pytorch-equivalent-for-tf-sequence-mask/
  520. 39036).
  521. Args:
  522. lengths: The tensor of individual lengths to mask by.
  523. maxlen: The maximum length to use for the time axis. If None, use
  524. the max of `lengths`.
  525. dtype: The torch dtype to use for the resulting mask.
  526. time_major: Whether to return the mask as [B, T] (False; default) or
  527. as [T, B] (True).
  528. Returns:
  529. The sequence mask resulting from the given input and parameters.
  530. """
  531. # If maxlen not given, use the longest lengths in the `lengths` tensor.
  532. if maxlen is None:
  533. maxlen = lengths.max()
  534. mask = torch.ones(tuple(lengths.shape) + (maxlen,))
  535. mask = ~(mask.to(lengths.device).cumsum(dim=1).t() > lengths)
  536. # Time major transformation.
  537. if not time_major:
  538. mask = mask.t()
  539. # By default, set the mask to be boolean.
  540. mask.type(dtype or torch.bool)
  541. return mask
  542. @PublicAPI
  543. def update_target_network(
  544. main_net: NetworkType,
  545. target_net: NetworkType,
  546. tau: float,
  547. ) -> None:
  548. """Updates a torch.nn.Module target network using Polyak averaging.
  549. .. code-block:: text
  550. new_target_net_weight = (
  551. tau * main_net_weight + (1.0 - tau) * current_target_net_weight
  552. )
  553. Args:
  554. main_net: The nn.Module to update from.
  555. target_net: The target network to update.
  556. tau: The tau value to use in the Polyak averaging formula.
  557. """
  558. # Get the current parameters from the Q network.
  559. state_dict = main_net.state_dict()
  560. # Use here Polyak averaging.
  561. new_state_dict = {
  562. k: tau * state_dict[k] + (1 - tau) * v
  563. for k, v in target_net.state_dict().items()
  564. }
  565. # Apply the new parameters to the target Q network.
  566. target_net.load_state_dict(new_state_dict)
  567. @DeveloperAPI
  568. def warn_if_infinite_kl_divergence(
  569. policy: "TorchPolicy",
  570. kl_divergence: TensorType,
  571. ) -> None:
  572. if policy.loss_initialized() and kl_divergence.isinf():
  573. logger.warning(
  574. "KL divergence is non-finite, this will likely destabilize your model and"
  575. " the training process. Action(s) in a specific state have near-zero"
  576. " probability. This can happen naturally in deterministic environments"
  577. " where the optimal policy has zero mass for a specific action. To fix this"
  578. " issue, consider setting the coefficient for the KL loss term to zero or"
  579. " increasing policy entropy."
  580. )
  581. @PublicAPI
  582. def set_torch_seed(seed: Optional[int] = None) -> None:
  583. """Sets the torch random seed to the given value.
  584. Args:
  585. seed: The seed to use or None for no seeding.
  586. """
  587. if seed is not None and torch:
  588. torch.manual_seed(seed)
  589. # See https://github.com/pytorch/pytorch/issues/47672.
  590. cuda_version = torch.version.cuda
  591. if cuda_version is not None and float(torch.version.cuda) >= 10.2:
  592. # See https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility.
  593. os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
  594. torch.cuda.manual_seed(seed)
  595. torch.cuda.manual_seed_all(seed) # if using multi-GPU
  596. else:
  597. if version.Version(torch.__version__) >= version.Version("1.8.0"):
  598. # Not all Operations support this.
  599. torch.use_deterministic_algorithms(True)
  600. else:
  601. torch.set_deterministic(True)
  602. # This is only for Convolution no problem.
  603. torch.backends.cudnn.deterministic = True
  604. # For benchmark=True, CuDNN may choose different algorithms depending on runtime
  605. # conditions or slight differences in input sizes, even if the seed is fixed,
  606. # which breaks determinism.
  607. torch.backends.cudnn.benchmark = False
  608. @PublicAPI
  609. def softmax_cross_entropy_with_logits(
  610. logits: TensorType,
  611. labels: TensorType,
  612. ) -> TensorType:
  613. """Same behavior as tf.nn.softmax_cross_entropy_with_logits.
  614. Args:
  615. x: The input predictions.
  616. labels: The labels corresponding to `x`.
  617. Returns:
  618. The resulting softmax cross-entropy given predictions and labels.
  619. """
  620. return torch.sum(-labels * nn.functional.log_softmax(logits, -1), -1)
  621. @PublicAPI
  622. def symlog(x: "torch.Tensor") -> "torch.Tensor":
  623. """The symlog function as described in [1]:
  624. [1] Mastering Diverse Domains through World Models - 2023
  625. D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap
  626. https://arxiv.org/pdf/2301.04104v1.pdf
  627. """
  628. return torch.sign(x) * torch.log(torch.abs(x) + 1)
  629. @PublicAPI
  630. def inverse_symlog(y: "torch.Tensor") -> "torch.Tensor":
  631. """Inverse of the `symlog` function as desribed in [1]:
  632. [1] Mastering Diverse Domains through World Models - 2023
  633. D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap
  634. https://arxiv.org/pdf/2301.04104v1.pdf
  635. """
  636. # To get to symlog inverse, we solve the symlog equation for x:
  637. # y = sign(x) * log(|x| + 1)
  638. # <=> y / sign(x) = log(|x| + 1)
  639. # <=> y = log( x + 1) V x >= 0
  640. # -y = log(-x + 1) V x < 0
  641. # <=> exp(y) = x + 1 V x >= 0
  642. # exp(-y) = -x + 1 V x < 0
  643. # <=> exp(y) - 1 = x V x >= 0
  644. # exp(-y) - 1 = -x V x < 0
  645. # <=> exp(y) - 1 = x V x >= 0 (if x >= 0, then y must also be >= 0)
  646. # -exp(-y) - 1 = x V x < 0 (if x < 0, then y must also be < 0)
  647. # <=> sign(y) * (exp(|y|) - 1) = x
  648. return torch.sign(y) * (torch.exp(torch.abs(y)) - 1)
  649. @PublicAPI
  650. def two_hot(
  651. value: "torch.Tensor",
  652. num_buckets: int = 255,
  653. lower_bound: float = -20.0,
  654. upper_bound: float = 20.0,
  655. device: Optional[str] = None,
  656. ):
  657. """Returns a two-hot vector of dim=num_buckets with two entries that are non-zero.
  658. See [1] for more details:
  659. [1] Mastering Diverse Domains through World Models - 2023
  660. D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap
  661. https://arxiv.org/pdf/2301.04104v1.pdf
  662. Entries in the vector represent equally sized buckets within some fixed range
  663. (`lower_bound` to `upper_bound`).
  664. Those entries not 0.0 at positions k and k+1 encode the actual `value` and sum
  665. up to 1.0. They are the weights multiplied by the buckets values at k and k+1 for
  666. retrieving `value`.
  667. Example:
  668. num_buckets=11
  669. lower_bound=-5
  670. upper_bound=5
  671. value=2.5
  672. -> [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.5, 0.0, 0.0]
  673. -> [-5 -4 -3 -2 -1 0 1 2 3 4 5] (0.5*2 + 0.5*3=2.5)
  674. Example:
  675. num_buckets=5
  676. lower_bound=-1
  677. upper_bound=1
  678. value=0.1
  679. -> [0.0, 0.0, 0.8, 0.2, 0.0]
  680. -> [-1 -0.5 0 0.5 1] (0.2*0.5 + 0.8*0=0.1)
  681. Args:
  682. value: The input tensor of shape (B,) to be two-hot encoded.
  683. num_buckets: The number of buckets to two-hot encode into.
  684. lower_bound: The lower bound value used for the encoding. If input values are
  685. lower than this boundary, they will be encoded as `lower_bound`.
  686. upper_bound: The upper bound value used for the encoding. If input values are
  687. higher than this boundary, they will be encoded as `upper_bound`.
  688. Returns:
  689. The two-hot encoded tensor of shape (B, num_buckets).
  690. """
  691. # First make sure, values are clipped.
  692. value = torch.clamp(value, lower_bound, upper_bound)
  693. # Tensor of batch indices: [0, B=batch size).
  694. batch_indices = torch.arange(0, value.shape[0], device=device).float()
  695. # Calculate the step deltas (how much space between each bucket's central value?).
  696. bucket_delta = (upper_bound - lower_bound) / (num_buckets - 1)
  697. # Compute the float indices (might be non-int numbers: sitting between two buckets).
  698. idx = (-lower_bound + value) / bucket_delta
  699. # k
  700. k = torch.floor(idx)
  701. # k+1
  702. kp1 = torch.ceil(idx)
  703. # In case k == kp1 (idx is exactly on the bucket boundary), move kp1 up by 1.0.
  704. # Otherwise, this would result in a NaN in the returned two-hot tensor.
  705. kp1 = torch.where(k.eq(kp1), kp1 + 1.0, kp1)
  706. # Iff `kp1` is one beyond our last index (because incoming value is larger than
  707. # `upper_bound`), move it to one before k (kp1's weight is going to be 0.0 anyways,
  708. # so it doesn't matter where it points to; we are just avoiding an index error
  709. # with this).
  710. kp1 = torch.where(kp1.eq(num_buckets), kp1 - 2.0, kp1)
  711. # The actual values found at k and k+1 inside the set of buckets.
  712. values_k = lower_bound + k * bucket_delta
  713. values_kp1 = lower_bound + kp1 * bucket_delta
  714. # Compute the two-hot weights (adding up to 1.0) to use at index k and k+1.
  715. weights_k = (value - values_kp1) / (values_k - values_kp1)
  716. weights_kp1 = 1.0 - weights_k
  717. # Compile a tensor of full paths (indices from batch index to feature index) to
  718. # use for the scatter_nd op.
  719. indices_k = torch.stack([batch_indices, k], dim=-1)
  720. indices_kp1 = torch.stack([batch_indices, kp1], dim=-1)
  721. indices = torch.cat([indices_k, indices_kp1], dim=0).long()
  722. # The actual values (weights adding up to 1.0) to place at the computed indices.
  723. updates = torch.cat([weights_k, weights_kp1], dim=0)
  724. # Call the actual scatter update op, returning a zero-filled tensor, only changed
  725. # at the given indices.
  726. output = torch.zeros(value.shape[0], num_buckets, device=device)
  727. # Set our two-hot values at computed indices.
  728. output[indices[:, 0], indices[:, 1]] = updates
  729. return output
  730. def _dynamo_is_available():
  731. # This only works if torch._dynamo is available
  732. try:
  733. # TODO(Artur): Remove this once torch._dynamo is available on CI
  734. import torch._dynamo as dynamo # noqa: F401
  735. return True
  736. except ImportError:
  737. return False