torch_policy.py 48 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201
  1. import copy
  2. import functools
  3. import logging
  4. import math
  5. import os
  6. import threading
  7. import time
  8. from typing import (
  9. Any,
  10. Callable,
  11. Dict,
  12. List,
  13. Optional,
  14. Set,
  15. Tuple,
  16. Type,
  17. Union,
  18. )
  19. import gymnasium as gym
  20. import numpy as np
  21. import tree # pip install dm_tree
  22. import ray
  23. from ray.rllib.models.catalog import ModelCatalog
  24. from ray.rllib.models.modelv2 import ModelV2
  25. from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
  26. from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
  27. from ray.rllib.policy.policy import Policy, PolicyState
  28. from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
  29. from ray.rllib.policy.sample_batch import SampleBatch
  30. from ray.rllib.utils import NullContextManager, force_list
  31. from ray.rllib.utils.annotations import OldAPIStack, override
  32. from ray.rllib.utils.error import ERR_MSG_TORCH_POLICY_CANNOT_SAVE_MODEL
  33. from ray.rllib.utils.framework import try_import_torch
  34. from ray.rllib.utils.metrics import (
  35. DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY,
  36. NUM_AGENT_STEPS_TRAINED,
  37. NUM_GRAD_UPDATES_LIFETIME,
  38. )
  39. from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
  40. from ray.rllib.utils.numpy import convert_to_numpy
  41. from ray.rllib.utils.spaces.space_utils import normalize_action
  42. from ray.rllib.utils.threading import with_lock
  43. from ray.rllib.utils.torch_utils import convert_to_torch_tensor
  44. from ray.rllib.utils.typing import (
  45. AlgorithmConfigDict,
  46. GradInfoDict,
  47. ModelGradients,
  48. ModelWeights,
  49. TensorStructType,
  50. TensorType,
  51. )
  52. torch, nn = try_import_torch()
  53. logger = logging.getLogger(__name__)
  54. @OldAPIStack
  55. class TorchPolicy(Policy):
  56. """PyTorch specific Policy class to use with RLlib."""
  57. def __init__(
  58. self,
  59. observation_space: gym.spaces.Space,
  60. action_space: gym.spaces.Space,
  61. config: AlgorithmConfigDict,
  62. *,
  63. model: Optional[TorchModelV2] = None,
  64. loss: Optional[
  65. Callable[
  66. [Policy, ModelV2, Type[TorchDistributionWrapper], SampleBatch],
  67. Union[TensorType, List[TensorType]],
  68. ]
  69. ] = None,
  70. action_distribution_class: Optional[Type[TorchDistributionWrapper]] = None,
  71. action_sampler_fn: Optional[
  72. Callable[
  73. [TensorType, List[TensorType]],
  74. Union[
  75. Tuple[TensorType, TensorType, List[TensorType]],
  76. Tuple[TensorType, TensorType, TensorType, List[TensorType]],
  77. ],
  78. ]
  79. ] = None,
  80. action_distribution_fn: Optional[
  81. Callable[
  82. [Policy, ModelV2, TensorType, TensorType, TensorType],
  83. Tuple[TensorType, Type[TorchDistributionWrapper], List[TensorType]],
  84. ]
  85. ] = None,
  86. max_seq_len: int = 20,
  87. get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None,
  88. ):
  89. """Initializes a TorchPolicy instance.
  90. Args:
  91. observation_space: Observation space of the policy.
  92. action_space: Action space of the policy.
  93. config: The Policy's config dict.
  94. model: PyTorch policy module. Given observations as
  95. input, this module must return a list of outputs where the
  96. first item is action logits, and the rest can be any value.
  97. loss: Callable that returns one or more (a list of) scalar loss
  98. terms.
  99. action_distribution_class: Class for a torch action distribution.
  100. action_sampler_fn: A callable returning either a sampled action,
  101. its log-likelihood and updated state or a sampled action, its
  102. log-likelihood, updated state and action distribution inputs
  103. given Policy, ModelV2, input_dict, state batches (optional),
  104. explore, and timestep. Provide `action_sampler_fn` if you would
  105. like to have full control over the action computation step,
  106. including the model forward pass, possible sampling from a
  107. distribution, and exploration logic.
  108. Note: If `action_sampler_fn` is given, `action_distribution_fn`
  109. must be None. If both `action_sampler_fn` and
  110. `action_distribution_fn` are None, RLlib will simply pass
  111. inputs through `self.model` to get distribution inputs, create
  112. the distribution object, sample from it, and apply some
  113. exploration logic to the results.
  114. The callable takes as inputs: Policy, ModelV2, input_dict
  115. (SampleBatch), state_batches (optional), explore, and timestep.
  116. action_distribution_fn: A callable returning distribution inputs
  117. (parameters), a dist-class to generate an action distribution
  118. object from, and internal-state outputs (or an empty list if
  119. not applicable).
  120. Provide `action_distribution_fn` if you would like to only
  121. customize the model forward pass call. The resulting
  122. distribution parameters are then used by RLlib to create a
  123. distribution object, sample from it, and execute any
  124. exploration logic.
  125. Note: If `action_distribution_fn` is given, `action_sampler_fn`
  126. must be None. If both `action_sampler_fn` and
  127. `action_distribution_fn` are None, RLlib will simply pass
  128. inputs through `self.model` to get distribution inputs, create
  129. the distribution object, sample from it, and apply some
  130. exploration logic to the results.
  131. The callable takes as inputs: Policy, ModelV2, ModelInputDict,
  132. explore, timestep, is_training.
  133. max_seq_len: Max sequence length for LSTM training.
  134. get_batch_divisibility_req: Optional callable that returns the
  135. divisibility requirement for sample batches given the Policy.
  136. """
  137. self.framework = config["framework"] = "torch"
  138. self._loss_initialized = False
  139. super().__init__(observation_space, action_space, config)
  140. # Create multi-GPU model towers, if necessary.
  141. # - The central main model will be stored under self.model, residing
  142. # on self.device (normally, a CPU).
  143. # - Each GPU will have a copy of that model under
  144. # self.model_gpu_towers, matching the devices in self.devices.
  145. # - Parallelization is done by splitting the train batch and passing
  146. # it through the model copies in parallel, then averaging over the
  147. # resulting gradients, applying these averages on the main model and
  148. # updating all towers' weights from the main model.
  149. # - In case of just one device (1 (fake or real) GPU or 1 CPU), no
  150. # parallelization will be done.
  151. # If no Model is provided, build a default one here.
  152. if model is None:
  153. dist_class, logit_dim = ModelCatalog.get_action_dist(
  154. action_space, self.config["model"], framework=self.framework
  155. )
  156. model = ModelCatalog.get_model_v2(
  157. obs_space=self.observation_space,
  158. action_space=self.action_space,
  159. num_outputs=logit_dim,
  160. model_config=self.config["model"],
  161. framework=self.framework,
  162. )
  163. if action_distribution_class is None:
  164. action_distribution_class = dist_class
  165. # Get devices to build the graph on.
  166. num_gpus = self._get_num_gpus_for_policy()
  167. gpu_ids = list(range(torch.cuda.device_count()))
  168. logger.info(f"Found {len(gpu_ids)} visible cuda devices.")
  169. # Place on one or more CPU(s) when either:
  170. # - Fake GPU mode.
  171. # - num_gpus=0 (either set by user or we are in local_mode=True).
  172. # - No GPUs available.
  173. if config["_fake_gpus"] or num_gpus == 0 or not gpu_ids:
  174. self.device = torch.device("cpu")
  175. self.devices = [self.device for _ in range(int(math.ceil(num_gpus)) or 1)]
  176. self.model_gpu_towers = [
  177. model if i == 0 else copy.deepcopy(model)
  178. for i in range(int(math.ceil(num_gpus)) or 1)
  179. ]
  180. if hasattr(self, "target_model"):
  181. self.target_models = {
  182. m: self.target_model for m in self.model_gpu_towers
  183. }
  184. self.model = model
  185. # Place on one or more actual GPU(s), when:
  186. # - num_gpus > 0 (set by user) AND
  187. # - local_mode=False AND
  188. # - actual GPUs available AND
  189. # - non-fake GPU mode.
  190. else:
  191. # We are a remote worker (WORKER_MODE=1):
  192. # GPUs should be assigned to us by ray.
  193. if ray._private.worker._mode() == ray._private.worker.WORKER_MODE:
  194. gpu_ids = ray.get_gpu_ids()
  195. if len(gpu_ids) < num_gpus:
  196. raise ValueError(
  197. "TorchPolicy was not able to find enough GPU IDs! Found "
  198. f"{gpu_ids}, but num_gpus={num_gpus}."
  199. )
  200. self.devices = [
  201. torch.device("cuda:{}".format(i))
  202. for i, id_ in enumerate(gpu_ids)
  203. if i < num_gpus
  204. ]
  205. self.device = self.devices[0]
  206. ids = [id_ for i, id_ in enumerate(gpu_ids) if i < num_gpus]
  207. self.model_gpu_towers = []
  208. for i, _ in enumerate(ids):
  209. model_copy = copy.deepcopy(model)
  210. self.model_gpu_towers.append(model_copy.to(self.devices[i]))
  211. if hasattr(self, "target_model"):
  212. self.target_models = {
  213. m: copy.deepcopy(self.target_model).to(self.devices[i])
  214. for i, m in enumerate(self.model_gpu_towers)
  215. }
  216. self.model = self.model_gpu_towers[0]
  217. # Lock used for locking some methods on the object-level.
  218. # This prevents possible race conditions when calling the model
  219. # first, then its value function (e.g. in a loss function), in
  220. # between of which another model call is made (e.g. to compute an
  221. # action).
  222. self._lock = threading.RLock()
  223. self._state_inputs = self.model.get_initial_state()
  224. self._is_recurrent = len(self._state_inputs) > 0
  225. # Auto-update model's inference view requirements, if recurrent.
  226. self._update_model_view_requirements_from_init_state()
  227. # Combine view_requirements for Model and Policy.
  228. self.view_requirements.update(self.model.view_requirements)
  229. self.exploration = self._create_exploration()
  230. self.unwrapped_model = model # used to support DistributedDataParallel
  231. # To ensure backward compatibility:
  232. # Old way: If `loss` provided here, use as-is (as a function).
  233. if loss is not None:
  234. self._loss = loss
  235. # New way: Convert the overridden `self.loss` into a plain function,
  236. # so it can be called the same way as `loss` would be, ensuring
  237. # backward compatibility.
  238. elif self.loss.__func__.__qualname__ != "Policy.loss":
  239. self._loss = self.loss.__func__
  240. # `loss` not provided nor overridden from Policy -> Set to None.
  241. else:
  242. self._loss = None
  243. self._optimizers = force_list(self.optimizer())
  244. # Store, which params (by index within the model's list of
  245. # parameters) should be updated per optimizer.
  246. # Maps optimizer idx to set or param indices.
  247. self.multi_gpu_param_groups: List[Set[int]] = []
  248. main_params = {p: i for i, p in enumerate(self.model.parameters())}
  249. for o in self._optimizers:
  250. param_indices = []
  251. for pg_idx, pg in enumerate(o.param_groups):
  252. for p in pg["params"]:
  253. param_indices.append(main_params[p])
  254. self.multi_gpu_param_groups.append(set(param_indices))
  255. # Create n sample-batch buffers (num_multi_gpu_tower_stacks), each
  256. # one with m towers (num_gpus).
  257. num_buffers = self.config.get("num_multi_gpu_tower_stacks", 1)
  258. self._loaded_batches = [[] for _ in range(num_buffers)]
  259. self.dist_class = action_distribution_class
  260. self.action_sampler_fn = action_sampler_fn
  261. self.action_distribution_fn = action_distribution_fn
  262. # If set, means we are using distributed allreduce during learning.
  263. self.distributed_world_size = None
  264. self.max_seq_len = max_seq_len
  265. self.batch_divisibility_req = (
  266. get_batch_divisibility_req(self)
  267. if callable(get_batch_divisibility_req)
  268. else (get_batch_divisibility_req or 1)
  269. )
  270. @override(Policy)
  271. def compute_actions_from_input_dict(
  272. self,
  273. input_dict: Dict[str, TensorType],
  274. explore: bool = None,
  275. timestep: Optional[int] = None,
  276. **kwargs,
  277. ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
  278. with torch.no_grad():
  279. # Pass lazy (torch) tensor dict to Model as `input_dict`.
  280. input_dict = self._lazy_tensor_dict(input_dict)
  281. input_dict.set_training(True)
  282. # Pack internal state inputs into (separate) list.
  283. state_batches = [
  284. input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
  285. ]
  286. # Calculate RNN sequence lengths.
  287. seq_lens = (
  288. torch.tensor(
  289. [1] * len(state_batches[0]),
  290. dtype=torch.long,
  291. device=state_batches[0].device,
  292. )
  293. if state_batches
  294. else None
  295. )
  296. return self._compute_action_helper(
  297. input_dict, state_batches, seq_lens, explore, timestep
  298. )
  299. @override(Policy)
  300. def compute_actions(
  301. self,
  302. obs_batch: Union[List[TensorStructType], TensorStructType],
  303. state_batches: Optional[List[TensorType]] = None,
  304. prev_action_batch: Union[List[TensorStructType], TensorStructType] = None,
  305. prev_reward_batch: Union[List[TensorStructType], TensorStructType] = None,
  306. info_batch: Optional[Dict[str, list]] = None,
  307. episodes=None,
  308. explore: Optional[bool] = None,
  309. timestep: Optional[int] = None,
  310. **kwargs,
  311. ) -> Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]]:
  312. with torch.no_grad():
  313. seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
  314. input_dict = self._lazy_tensor_dict(
  315. {
  316. SampleBatch.CUR_OBS: obs_batch,
  317. "is_training": False,
  318. }
  319. )
  320. if prev_action_batch is not None:
  321. input_dict[SampleBatch.PREV_ACTIONS] = np.asarray(prev_action_batch)
  322. if prev_reward_batch is not None:
  323. input_dict[SampleBatch.PREV_REWARDS] = np.asarray(prev_reward_batch)
  324. state_batches = [
  325. convert_to_torch_tensor(s, self.device) for s in (state_batches or [])
  326. ]
  327. return self._compute_action_helper(
  328. input_dict, state_batches, seq_lens, explore, timestep
  329. )
  330. @with_lock
  331. @override(Policy)
  332. def compute_log_likelihoods(
  333. self,
  334. actions: Union[List[TensorStructType], TensorStructType],
  335. obs_batch: Union[List[TensorStructType], TensorStructType],
  336. state_batches: Optional[List[TensorType]] = None,
  337. prev_action_batch: Optional[
  338. Union[List[TensorStructType], TensorStructType]
  339. ] = None,
  340. prev_reward_batch: Optional[
  341. Union[List[TensorStructType], TensorStructType]
  342. ] = None,
  343. actions_normalized: bool = True,
  344. **kwargs,
  345. ) -> TensorType:
  346. if self.action_sampler_fn and self.action_distribution_fn is None:
  347. raise ValueError(
  348. "Cannot compute log-prob/likelihood w/o an "
  349. "`action_distribution_fn` and a provided "
  350. "`action_sampler_fn`!"
  351. )
  352. with torch.no_grad():
  353. input_dict = self._lazy_tensor_dict(
  354. {SampleBatch.CUR_OBS: obs_batch, SampleBatch.ACTIONS: actions}
  355. )
  356. if prev_action_batch is not None:
  357. input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
  358. if prev_reward_batch is not None:
  359. input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
  360. seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
  361. state_batches = [
  362. convert_to_torch_tensor(s, self.device) for s in (state_batches or [])
  363. ]
  364. # Exploration hook before each forward pass.
  365. self.exploration.before_compute_actions(explore=False)
  366. # Action dist class and inputs are generated via custom function.
  367. if self.action_distribution_fn:
  368. # Try new action_distribution_fn signature, supporting
  369. # state_batches and seq_lens.
  370. try:
  371. dist_inputs, dist_class, state_out = self.action_distribution_fn(
  372. self,
  373. self.model,
  374. input_dict=input_dict,
  375. state_batches=state_batches,
  376. seq_lens=seq_lens,
  377. explore=False,
  378. is_training=False,
  379. )
  380. # Trying the old way (to stay backward compatible).
  381. # TODO: Remove in future.
  382. except TypeError as e:
  383. if (
  384. "positional argument" in e.args[0]
  385. or "unexpected keyword argument" in e.args[0]
  386. ):
  387. dist_inputs, dist_class, _ = self.action_distribution_fn(
  388. policy=self,
  389. model=self.model,
  390. obs_batch=input_dict[SampleBatch.CUR_OBS],
  391. explore=False,
  392. is_training=False,
  393. )
  394. else:
  395. raise e
  396. # Default action-dist inputs calculation.
  397. else:
  398. dist_class = self.dist_class
  399. dist_inputs, _ = self.model(input_dict, state_batches, seq_lens)
  400. action_dist = dist_class(dist_inputs, self.model)
  401. # Normalize actions if necessary.
  402. actions = input_dict[SampleBatch.ACTIONS]
  403. if not actions_normalized and self.config["normalize_actions"]:
  404. actions = normalize_action(actions, self.action_space_struct)
  405. log_likelihoods = action_dist.logp(actions)
  406. return log_likelihoods
  407. @with_lock
  408. @override(Policy)
  409. def learn_on_batch(self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]:
  410. # Set Model to train mode.
  411. if self.model:
  412. self.model.train()
  413. # Callback handling.
  414. learn_stats = {}
  415. self.callbacks.on_learn_on_batch(
  416. policy=self, train_batch=postprocessed_batch, result=learn_stats
  417. )
  418. # Compute gradients (will calculate all losses and `backward()`
  419. # them to get the grads).
  420. grads, fetches = self.compute_gradients(postprocessed_batch)
  421. # Step the optimizers.
  422. self.apply_gradients(_directStepOptimizerSingleton)
  423. self.num_grad_updates += 1
  424. if self.model:
  425. fetches["model"] = self.model.metrics()
  426. fetches.update(
  427. {
  428. "custom_metrics": learn_stats,
  429. NUM_AGENT_STEPS_TRAINED: postprocessed_batch.count,
  430. NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates,
  431. # -1, b/c we have to measure this diff before we do the update above.
  432. DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: (
  433. self.num_grad_updates
  434. - 1
  435. - (postprocessed_batch.num_grad_updates or 0)
  436. ),
  437. }
  438. )
  439. return fetches
  440. @override(Policy)
  441. def load_batch_into_buffer(
  442. self,
  443. batch: SampleBatch,
  444. buffer_index: int = 0,
  445. ) -> int:
  446. # Set the is_training flag of the batch.
  447. batch.set_training(True)
  448. # Shortcut for 1 CPU only: Store batch in `self._loaded_batches`.
  449. if len(self.devices) == 1 and self.devices[0].type == "cpu":
  450. assert buffer_index == 0
  451. pad_batch_to_sequences_of_same_size(
  452. batch=batch,
  453. max_seq_len=self.max_seq_len,
  454. shuffle=False,
  455. batch_divisibility_req=self.batch_divisibility_req,
  456. view_requirements=self.view_requirements,
  457. )
  458. self._lazy_tensor_dict(batch)
  459. self._loaded_batches[0] = [batch]
  460. return len(batch)
  461. # Batch (len=28, seq-lens=[4, 7, 4, 10, 3]):
  462. # 0123 0123456 0123 0123456789ABC
  463. # 1) split into n per-GPU sub batches (n=2).
  464. # [0123 0123456] [012] [3 0123456789 ABC]
  465. # (len=14, 14 seq-lens=[4, 7, 3] [1, 10, 3])
  466. slices = batch.timeslices(num_slices=len(self.devices))
  467. # 2) zero-padding (max-seq-len=10).
  468. # - [0123000000 0123456000 0120000000]
  469. # - [3000000000 0123456789 ABC0000000]
  470. for slice in slices:
  471. pad_batch_to_sequences_of_same_size(
  472. batch=slice,
  473. max_seq_len=self.max_seq_len,
  474. shuffle=False,
  475. batch_divisibility_req=self.batch_divisibility_req,
  476. view_requirements=self.view_requirements,
  477. )
  478. # 3) Load splits into the given buffer (consisting of n GPUs).
  479. slices = [slice.to_device(self.devices[i]) for i, slice in enumerate(slices)]
  480. self._loaded_batches[buffer_index] = slices
  481. # Return loaded samples per-device.
  482. return len(slices[0])
  483. @override(Policy)
  484. def get_num_samples_loaded_into_buffer(self, buffer_index: int = 0) -> int:
  485. if len(self.devices) == 1 and self.devices[0] == "/cpu:0":
  486. assert buffer_index == 0
  487. return sum(len(b) for b in self._loaded_batches[buffer_index])
  488. @override(Policy)
  489. def learn_on_loaded_batch(self, offset: int = 0, buffer_index: int = 0):
  490. if not self._loaded_batches[buffer_index]:
  491. raise ValueError(
  492. "Must call Policy.load_batch_into_buffer() before "
  493. "Policy.learn_on_loaded_batch()!"
  494. )
  495. # Get the correct slice of the already loaded batch to use,
  496. # based on offset and batch size.
  497. device_batch_size = self.config.get("minibatch_size")
  498. if device_batch_size is None:
  499. device_batch_size = self.config.get(
  500. "sgd_minibatch_size",
  501. self.config["train_batch_size"],
  502. )
  503. device_batch_size //= len(self.devices)
  504. # Set Model to train mode.
  505. if self.model_gpu_towers:
  506. for t in self.model_gpu_towers:
  507. t.train()
  508. # Shortcut for 1 CPU only: Batch should already be stored in
  509. # `self._loaded_batches`.
  510. if len(self.devices) == 1 and self.devices[0].type == "cpu":
  511. assert buffer_index == 0
  512. if device_batch_size >= len(self._loaded_batches[0][0]):
  513. batch = self._loaded_batches[0][0]
  514. else:
  515. batch = self._loaded_batches[0][0][offset : offset + device_batch_size]
  516. return self.learn_on_batch(batch)
  517. if len(self.devices) > 1:
  518. # Copy weights of main model (tower-0) to all other towers.
  519. state_dict = self.model.state_dict()
  520. # Just making sure tower-0 is really the same as self.model.
  521. assert self.model_gpu_towers[0] is self.model
  522. for tower in self.model_gpu_towers[1:]:
  523. tower.load_state_dict(state_dict)
  524. if device_batch_size >= sum(len(s) for s in self._loaded_batches[buffer_index]):
  525. device_batches = self._loaded_batches[buffer_index]
  526. else:
  527. device_batches = [
  528. b[offset : offset + device_batch_size]
  529. for b in self._loaded_batches[buffer_index]
  530. ]
  531. # Callback handling.
  532. batch_fetches = {}
  533. for i, batch in enumerate(device_batches):
  534. custom_metrics = {}
  535. self.callbacks.on_learn_on_batch(
  536. policy=self, train_batch=batch, result=custom_metrics
  537. )
  538. batch_fetches[f"tower_{i}"] = {"custom_metrics": custom_metrics}
  539. # Do the (maybe parallelized) gradient calculation step.
  540. tower_outputs = self._multi_gpu_parallel_grad_calc(device_batches)
  541. # Mean-reduce gradients over GPU-towers (do this on CPU: self.device).
  542. all_grads = []
  543. for i in range(len(tower_outputs[0][0])):
  544. if tower_outputs[0][0][i] is not None:
  545. all_grads.append(
  546. torch.mean(
  547. torch.stack([t[0][i].to(self.device) for t in tower_outputs]),
  548. dim=0,
  549. )
  550. )
  551. else:
  552. all_grads.append(None)
  553. # Set main model's grads to mean-reduced values.
  554. for i, p in enumerate(self.model.parameters()):
  555. p.grad = all_grads[i]
  556. self.apply_gradients(_directStepOptimizerSingleton)
  557. self.num_grad_updates += 1
  558. for i, (model, batch) in enumerate(zip(self.model_gpu_towers, device_batches)):
  559. batch_fetches[f"tower_{i}"].update(
  560. {
  561. LEARNER_STATS_KEY: self.extra_grad_info(batch),
  562. "model": model.metrics(),
  563. NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates,
  564. # -1, b/c we have to measure this diff before we do the update
  565. # above.
  566. DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: (
  567. self.num_grad_updates - 1 - (batch.num_grad_updates or 0)
  568. ),
  569. }
  570. )
  571. batch_fetches.update(self.extra_compute_grad_fetches())
  572. return batch_fetches
  573. @with_lock
  574. @override(Policy)
  575. def compute_gradients(self, postprocessed_batch: SampleBatch) -> ModelGradients:
  576. assert len(self.devices) == 1
  577. # If not done yet, see whether we have to zero-pad this batch.
  578. if not postprocessed_batch.zero_padded:
  579. pad_batch_to_sequences_of_same_size(
  580. batch=postprocessed_batch,
  581. max_seq_len=self.max_seq_len,
  582. shuffle=False,
  583. batch_divisibility_req=self.batch_divisibility_req,
  584. view_requirements=self.view_requirements,
  585. )
  586. postprocessed_batch.set_training(True)
  587. self._lazy_tensor_dict(postprocessed_batch, device=self.devices[0])
  588. # Do the (maybe parallelized) gradient calculation step.
  589. tower_outputs = self._multi_gpu_parallel_grad_calc([postprocessed_batch])
  590. all_grads, grad_info = tower_outputs[0]
  591. grad_info["allreduce_latency"] /= len(self._optimizers)
  592. grad_info.update(self.extra_grad_info(postprocessed_batch))
  593. fetches = self.extra_compute_grad_fetches()
  594. return all_grads, dict(fetches, **{LEARNER_STATS_KEY: grad_info})
  595. @override(Policy)
  596. def apply_gradients(self, gradients: ModelGradients) -> None:
  597. if gradients == _directStepOptimizerSingleton:
  598. for i, opt in enumerate(self._optimizers):
  599. opt.step()
  600. else:
  601. # TODO(sven): Not supported for multiple optimizers yet.
  602. assert len(self._optimizers) == 1
  603. for g, p in zip(gradients, self.model.parameters()):
  604. if g is not None:
  605. if torch.is_tensor(g):
  606. p.grad = g.to(self.device)
  607. else:
  608. p.grad = torch.from_numpy(g).to(self.device)
  609. self._optimizers[0].step()
  610. def get_tower_stats(self, stats_name: str) -> List[TensorStructType]:
  611. """Returns list of per-tower stats, copied to this Policy's device.
  612. Args:
  613. stats_name: The name of the stats to average over (this str
  614. must exist as a key inside each tower's `tower_stats` dict).
  615. Returns:
  616. The list of stats tensor (structs) of all towers, copied to this
  617. Policy's device.
  618. Raises:
  619. AssertionError: If the `stats_name` cannot be found in any one
  620. of the tower's `tower_stats` dicts.
  621. """
  622. data = []
  623. for tower in self.model_gpu_towers:
  624. if stats_name in tower.tower_stats:
  625. data.append(
  626. tree.map_structure(
  627. lambda s: s.to(self.device), tower.tower_stats[stats_name]
  628. )
  629. )
  630. assert len(data) > 0, (
  631. f"Stats `{stats_name}` not found in any of the towers (you have "
  632. f"{len(self.model_gpu_towers)} towers in total)! Make "
  633. "sure you call the loss function on at least one of the towers."
  634. )
  635. return data
  636. @override(Policy)
  637. def get_weights(self) -> ModelWeights:
  638. return {k: v.cpu().detach().numpy() for k, v in self.model.state_dict().items()}
  639. @override(Policy)
  640. def set_weights(self, weights: ModelWeights) -> None:
  641. weights = convert_to_torch_tensor(weights, device=self.device)
  642. self.model.load_state_dict(weights)
  643. @override(Policy)
  644. def is_recurrent(self) -> bool:
  645. return self._is_recurrent
  646. @override(Policy)
  647. def num_state_tensors(self) -> int:
  648. return len(self.model.get_initial_state())
  649. @override(Policy)
  650. def get_initial_state(self) -> List[TensorType]:
  651. return [s.detach().cpu().numpy() for s in self.model.get_initial_state()]
  652. @override(Policy)
  653. def get_state(self) -> PolicyState:
  654. state = super().get_state()
  655. state["_optimizer_variables"] = []
  656. for i, o in enumerate(self._optimizers):
  657. optim_state_dict = convert_to_numpy(o.state_dict())
  658. state["_optimizer_variables"].append(optim_state_dict)
  659. # Add exploration state.
  660. if self.exploration:
  661. # This is not compatible with RLModules, which have a method
  662. # `forward_exploration` to specify custom exploration behavior.
  663. state["_exploration_state"] = self.exploration.get_state()
  664. return state
  665. @override(Policy)
  666. def set_state(self, state: PolicyState) -> None:
  667. # Set optimizer vars first.
  668. optimizer_vars = state.get("_optimizer_variables", None)
  669. if optimizer_vars:
  670. assert len(optimizer_vars) == len(self._optimizers)
  671. for o, s in zip(self._optimizers, optimizer_vars):
  672. # Torch optimizer param_groups include things like beta, etc. These
  673. # parameters should be left as scalar and not converted to tensors.
  674. # otherwise, torch.optim.step() will start to complain.
  675. optim_state_dict = {"param_groups": s["param_groups"]}
  676. optim_state_dict["state"] = convert_to_torch_tensor(
  677. s["state"], device=self.device
  678. )
  679. o.load_state_dict(optim_state_dict)
  680. # Set exploration's state.
  681. if hasattr(self, "exploration") and "_exploration_state" in state:
  682. self.exploration.set_state(state=state["_exploration_state"])
  683. # Restore global timestep.
  684. self.global_timestep = state["global_timestep"]
  685. # Then the Policy's (NN) weights and connectors.
  686. super().set_state(state)
  687. def extra_grad_process(
  688. self, optimizer: "torch.optim.Optimizer", loss: TensorType
  689. ) -> Dict[str, TensorType]:
  690. """Called after each optimizer.zero_grad() + loss.backward() call.
  691. Called for each self._optimizers/loss-value pair.
  692. Allows for gradient processing before optimizer.step() is called.
  693. E.g. for gradient clipping.
  694. Args:
  695. optimizer: A torch optimizer object.
  696. loss: The loss tensor associated with the optimizer.
  697. Returns:
  698. An dict with information on the gradient processing step.
  699. """
  700. return {}
  701. def extra_compute_grad_fetches(self) -> Dict[str, Any]:
  702. """Extra values to fetch and return from compute_gradients().
  703. Returns:
  704. Extra fetch dict to be added to the fetch dict of the
  705. `compute_gradients` call.
  706. """
  707. return {LEARNER_STATS_KEY: {}} # e.g, stats, td error, etc.
  708. def extra_action_out(
  709. self,
  710. input_dict: Dict[str, TensorType],
  711. state_batches: List[TensorType],
  712. model: TorchModelV2,
  713. action_dist: TorchDistributionWrapper,
  714. ) -> Dict[str, TensorType]:
  715. """Returns dict of extra info to include in experience batch.
  716. Args:
  717. input_dict: Dict of model input tensors.
  718. state_batches: List of state tensors.
  719. model: Reference to the model object.
  720. action_dist: Torch action dist object
  721. to get log-probs (e.g. for already sampled actions).
  722. Returns:
  723. Extra outputs to return in a `compute_actions_from_input_dict()`
  724. call (3rd return value).
  725. """
  726. return {}
  727. def extra_grad_info(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
  728. """Return dict of extra grad info.
  729. Args:
  730. train_batch: The training batch for which to produce
  731. extra grad info for.
  732. Returns:
  733. The info dict carrying grad info per str key.
  734. """
  735. return {}
  736. def optimizer(
  737. self,
  738. ) -> Union[List["torch.optim.Optimizer"], "torch.optim.Optimizer"]:
  739. """Custom the local PyTorch optimizer(s) to use.
  740. Returns:
  741. The local PyTorch optimizer(s) to use for this Policy.
  742. """
  743. if hasattr(self, "config"):
  744. optimizers = [
  745. torch.optim.Adam(self.model.parameters(), lr=self.config["lr"])
  746. ]
  747. else:
  748. optimizers = [torch.optim.Adam(self.model.parameters())]
  749. if self.exploration:
  750. optimizers = self.exploration.get_exploration_optimizer(optimizers)
  751. return optimizers
  752. @override(Policy)
  753. def export_model(self, export_dir: str, onnx: Optional[int] = None) -> None:
  754. """Exports the Policy's Model to local directory for serving.
  755. Creates a TorchScript model and saves it.
  756. Args:
  757. export_dir: Local writable directory or filename.
  758. onnx: If given, will export model in ONNX format. The
  759. value of this parameter set the ONNX OpSet version to use.
  760. """
  761. os.makedirs(export_dir, exist_ok=True)
  762. if onnx:
  763. self._lazy_tensor_dict(self._dummy_batch)
  764. # Provide dummy state inputs if not an RNN (torch cannot jit with
  765. # returned empty internal states list).
  766. if "state_in_0" not in self._dummy_batch:
  767. self._dummy_batch["state_in_0"] = self._dummy_batch[
  768. SampleBatch.SEQ_LENS
  769. ] = np.array([1.0])
  770. seq_lens = self._dummy_batch[SampleBatch.SEQ_LENS]
  771. state_ins = []
  772. i = 0
  773. while "state_in_{}".format(i) in self._dummy_batch:
  774. state_ins.append(self._dummy_batch["state_in_{}".format(i)])
  775. i += 1
  776. dummy_inputs = {
  777. k: self._dummy_batch[k]
  778. for k in self._dummy_batch.keys()
  779. if k != "is_training"
  780. }
  781. file_name = os.path.join(export_dir, "model.onnx")
  782. torch.onnx.export(
  783. self.model,
  784. (dummy_inputs, state_ins, seq_lens),
  785. file_name,
  786. export_params=True,
  787. opset_version=onnx,
  788. do_constant_folding=True,
  789. input_names=list(dummy_inputs.keys())
  790. + ["state_ins", SampleBatch.SEQ_LENS],
  791. output_names=["output", "state_outs"],
  792. dynamic_axes={
  793. k: {0: "batch_size"}
  794. for k in list(dummy_inputs.keys())
  795. + ["state_ins", SampleBatch.SEQ_LENS]
  796. },
  797. )
  798. # Save the torch.Model (architecture and weights, so it can be retrieved
  799. # w/o access to the original (custom) Model or Policy code).
  800. else:
  801. filename = os.path.join(export_dir, "model.pt")
  802. try:
  803. torch.save(self.model, f=filename)
  804. except Exception:
  805. if os.path.exists(filename):
  806. os.remove(filename)
  807. logger.warning(ERR_MSG_TORCH_POLICY_CANNOT_SAVE_MODEL)
  808. @override(Policy)
  809. def import_model_from_h5(self, import_file: str) -> None:
  810. """Imports weights into torch model."""
  811. return self.model.import_from_h5(import_file)
  812. @with_lock
  813. def _compute_action_helper(
  814. self, input_dict, state_batches, seq_lens, explore, timestep
  815. ):
  816. """Shared forward pass logic (w/ and w/o trajectory view API).
  817. Returns:
  818. A tuple consisting of a) actions, b) state_out, c) extra_fetches.
  819. """
  820. explore = explore if explore is not None else self.config["explore"]
  821. timestep = timestep if timestep is not None else self.global_timestep
  822. self._is_recurrent = state_batches is not None and state_batches != []
  823. # Switch to eval mode.
  824. if self.model:
  825. self.model.eval()
  826. if self.action_sampler_fn:
  827. action_dist = dist_inputs = None
  828. action_sampler_outputs = self.action_sampler_fn(
  829. self,
  830. self.model,
  831. input_dict,
  832. state_batches,
  833. explore=explore,
  834. timestep=timestep,
  835. )
  836. if len(action_sampler_outputs) == 4:
  837. actions, logp, dist_inputs, state_out = action_sampler_outputs
  838. else:
  839. actions, logp, state_out = action_sampler_outputs
  840. else:
  841. # Call the exploration before_compute_actions hook.
  842. self.exploration.before_compute_actions(explore=explore, timestep=timestep)
  843. if self.action_distribution_fn:
  844. # Try new action_distribution_fn signature, supporting
  845. # state_batches and seq_lens.
  846. try:
  847. dist_inputs, dist_class, state_out = self.action_distribution_fn(
  848. self,
  849. self.model,
  850. input_dict=input_dict,
  851. state_batches=state_batches,
  852. seq_lens=seq_lens,
  853. explore=explore,
  854. timestep=timestep,
  855. is_training=False,
  856. )
  857. # Trying the old way (to stay backward compatible).
  858. # TODO: Remove in future.
  859. except TypeError as e:
  860. if (
  861. "positional argument" in e.args[0]
  862. or "unexpected keyword argument" in e.args[0]
  863. ):
  864. (
  865. dist_inputs,
  866. dist_class,
  867. state_out,
  868. ) = self.action_distribution_fn(
  869. self,
  870. self.model,
  871. input_dict[SampleBatch.CUR_OBS],
  872. explore=explore,
  873. timestep=timestep,
  874. is_training=False,
  875. )
  876. else:
  877. raise e
  878. else:
  879. dist_class = self.dist_class
  880. dist_inputs, state_out = self.model(input_dict, state_batches, seq_lens)
  881. if not (
  882. isinstance(dist_class, functools.partial)
  883. or issubclass(dist_class, TorchDistributionWrapper)
  884. ):
  885. raise ValueError(
  886. "`dist_class` ({}) not a TorchDistributionWrapper "
  887. "subclass! Make sure your `action_distribution_fn` or "
  888. "`make_model_and_action_dist` return a correct "
  889. "distribution class.".format(dist_class.__name__)
  890. )
  891. action_dist = dist_class(dist_inputs, self.model)
  892. # Get the exploration action from the forward results.
  893. actions, logp = self.exploration.get_exploration_action(
  894. action_distribution=action_dist, timestep=timestep, explore=explore
  895. )
  896. input_dict[SampleBatch.ACTIONS] = actions
  897. # Add default and custom fetches.
  898. extra_fetches = self.extra_action_out(
  899. input_dict, state_batches, self.model, action_dist
  900. )
  901. # Action-dist inputs.
  902. if dist_inputs is not None:
  903. extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs
  904. # Action-logp and action-prob.
  905. if logp is not None:
  906. extra_fetches[SampleBatch.ACTION_PROB] = torch.exp(logp.float())
  907. extra_fetches[SampleBatch.ACTION_LOGP] = logp
  908. # Update our global timestep by the batch size.
  909. self.global_timestep += len(input_dict[SampleBatch.CUR_OBS])
  910. return convert_to_numpy((actions, state_out, extra_fetches))
  911. def _lazy_tensor_dict(self, postprocessed_batch: SampleBatch, device=None):
  912. # TODO: (sven): Keep for a while to ensure backward compatibility.
  913. if not isinstance(postprocessed_batch, SampleBatch):
  914. postprocessed_batch = SampleBatch(postprocessed_batch)
  915. postprocessed_batch.set_get_interceptor(
  916. functools.partial(convert_to_torch_tensor, device=device or self.device)
  917. )
  918. return postprocessed_batch
  919. def _multi_gpu_parallel_grad_calc(
  920. self, sample_batches: List[SampleBatch]
  921. ) -> List[Tuple[List[TensorType], GradInfoDict]]:
  922. """Performs a parallelized loss and gradient calculation over the batch.
  923. Splits up the given train batch into n shards (n=number of this
  924. Policy's devices) and passes each data shard (in parallel) through
  925. the loss function using the individual devices' models
  926. (self.model_gpu_towers). Then returns each tower's outputs.
  927. Args:
  928. sample_batches: A list of SampleBatch shards to
  929. calculate loss and gradients for.
  930. Returns:
  931. A list (one item per device) of 2-tuples, each with 1) gradient
  932. list and 2) grad info dict.
  933. """
  934. assert len(self.model_gpu_towers) == len(sample_batches)
  935. lock = threading.Lock()
  936. results = {}
  937. grad_enabled = torch.is_grad_enabled()
  938. def _worker(shard_idx, model, sample_batch, device):
  939. torch.set_grad_enabled(grad_enabled)
  940. try:
  941. with NullContextManager() if device.type == "cpu" else torch.cuda.device( # noqa: E501
  942. device
  943. ):
  944. loss_out = force_list(
  945. self._loss(self, model, self.dist_class, sample_batch)
  946. )
  947. # Call Model's custom-loss with Policy loss outputs and
  948. # train_batch.
  949. loss_out = model.custom_loss(loss_out, sample_batch)
  950. assert len(loss_out) == len(self._optimizers)
  951. # Loop through all optimizers.
  952. grad_info = {"allreduce_latency": 0.0}
  953. parameters = list(model.parameters())
  954. all_grads = [None for _ in range(len(parameters))]
  955. for opt_idx, opt in enumerate(self._optimizers):
  956. # Erase gradients in all vars of the tower that this
  957. # optimizer would affect.
  958. param_indices = self.multi_gpu_param_groups[opt_idx]
  959. for param_idx, param in enumerate(parameters):
  960. if param_idx in param_indices and param.grad is not None:
  961. param.grad.data.zero_()
  962. # Recompute gradients of loss over all variables.
  963. loss_out[opt_idx].backward(retain_graph=True)
  964. grad_info.update(
  965. self.extra_grad_process(opt, loss_out[opt_idx])
  966. )
  967. grads = []
  968. # Note that return values are just references;
  969. # Calling zero_grad would modify the values.
  970. for param_idx, param in enumerate(parameters):
  971. if param_idx in param_indices:
  972. if param.grad is not None:
  973. grads.append(param.grad)
  974. all_grads[param_idx] = param.grad
  975. if self.distributed_world_size:
  976. start = time.time()
  977. if torch.cuda.is_available():
  978. # Sadly, allreduce_coalesced does not work with
  979. # CUDA yet.
  980. for g in grads:
  981. torch.distributed.all_reduce(
  982. g, op=torch.distributed.ReduceOp.SUM
  983. )
  984. else:
  985. torch.distributed.all_reduce_coalesced(
  986. grads, op=torch.distributed.ReduceOp.SUM
  987. )
  988. for param_group in opt.param_groups:
  989. for p in param_group["params"]:
  990. if p.grad is not None:
  991. p.grad /= self.distributed_world_size
  992. grad_info["allreduce_latency"] += time.time() - start
  993. with lock:
  994. results[shard_idx] = (all_grads, grad_info)
  995. except Exception as e:
  996. import traceback
  997. with lock:
  998. results[shard_idx] = (
  999. ValueError(
  1000. f"Error In tower {shard_idx} on device "
  1001. f"{device} during multi GPU parallel gradient "
  1002. f"calculation:"
  1003. f": {e}\n"
  1004. f"Traceback: \n"
  1005. f"{traceback.format_exc()}\n"
  1006. ),
  1007. e,
  1008. )
  1009. # Single device (GPU) or fake-GPU case (serialize for better
  1010. # debugging).
  1011. if len(self.devices) == 1 or self.config["_fake_gpus"]:
  1012. for shard_idx, (model, sample_batch, device) in enumerate(
  1013. zip(self.model_gpu_towers, sample_batches, self.devices)
  1014. ):
  1015. _worker(shard_idx, model, sample_batch, device)
  1016. # Raise errors right away for better debugging.
  1017. last_result = results[len(results) - 1]
  1018. if isinstance(last_result[0], ValueError):
  1019. raise last_result[0] from last_result[1]
  1020. # Multi device (GPU) case: Parallelize via threads.
  1021. else:
  1022. threads = [
  1023. threading.Thread(
  1024. target=_worker, args=(shard_idx, model, sample_batch, device)
  1025. )
  1026. for shard_idx, (model, sample_batch, device) in enumerate(
  1027. zip(self.model_gpu_towers, sample_batches, self.devices)
  1028. )
  1029. ]
  1030. for thread in threads:
  1031. thread.start()
  1032. for thread in threads:
  1033. thread.join()
  1034. # Gather all threads' outputs and return.
  1035. outputs = []
  1036. for shard_idx in range(len(sample_batches)):
  1037. output = results[shard_idx]
  1038. if isinstance(output[0], Exception):
  1039. raise output[0] from output[1]
  1040. outputs.append(results[shard_idx])
  1041. return outputs
  1042. @OldAPIStack
  1043. class DirectStepOptimizer:
  1044. """Typesafe method for indicating `apply_gradients` can directly step the
  1045. optimizers with in-place gradients.
  1046. """
  1047. _instance = None
  1048. def __new__(cls):
  1049. if DirectStepOptimizer._instance is None:
  1050. DirectStepOptimizer._instance = super().__new__(cls)
  1051. return DirectStepOptimizer._instance
  1052. def __eq__(self, other):
  1053. return type(self) is type(other)
  1054. def __repr__(self):
  1055. return "DirectStepOptimizer"
  1056. _directStepOptimizerSingleton = DirectStepOptimizer()