tf_policy.py 48 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197
  1. import logging
  2. import math
  3. from typing import Dict, List, Optional, Tuple, Union
  4. import gymnasium as gym
  5. import numpy as np
  6. import tree # pip install dm_tree
  7. import ray
  8. from ray._common.deprecation import Deprecated
  9. from ray.rllib.models.modelv2 import ModelV2
  10. from ray.rllib.policy.policy import Policy, PolicySpec, PolicyState
  11. from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
  12. from ray.rllib.policy.sample_batch import SampleBatch
  13. from ray.rllib.utils import force_list
  14. from ray.rllib.utils.annotations import OldAPIStack, override
  15. from ray.rllib.utils.debug import summarize
  16. from ray.rllib.utils.error import ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL
  17. from ray.rllib.utils.framework import try_import_tf
  18. from ray.rllib.utils.metrics import (
  19. DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY,
  20. NUM_AGENT_STEPS_TRAINED,
  21. NUM_GRAD_UPDATES_LIFETIME,
  22. )
  23. from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
  24. from ray.rllib.utils.spaces.space_utils import normalize_action
  25. from ray.rllib.utils.tf_run_builder import _TFRunBuilder
  26. from ray.rllib.utils.tf_utils import TensorFlowVariables, get_gpu_devices
  27. from ray.rllib.utils.typing import (
  28. AlgorithmConfigDict,
  29. LocalOptimizer,
  30. ModelGradients,
  31. TensorType,
  32. )
  33. from ray.util.debug import log_once
  34. tf1, tf, tfv = try_import_tf()
  35. logger = logging.getLogger(__name__)
  36. @OldAPIStack
  37. class TFPolicy(Policy):
  38. """An agent policy and loss implemented in TensorFlow.
  39. Do not sub-class this class directly (neither should you sub-class
  40. DynamicTFPolicy), but rather use
  41. rllib.policy.tf_policy_template.build_tf_policy
  42. to generate your custom tf (graph-mode or eager) Policy classes.
  43. Extending this class enables RLlib to perform TensorFlow specific
  44. optimizations on the policy, e.g., parallelization across gpus or
  45. fusing multiple graphs together in the multi-agent setting.
  46. Input tensors are typically shaped like [BATCH_SIZE, ...].
  47. .. testcode::
  48. :skipif: True
  49. from ray.rllib.policy import TFPolicy
  50. class TFPolicySubclass(TFPolicy):
  51. ...
  52. sess, obs_input, sampled_action, loss, loss_inputs = ...
  53. policy = TFPolicySubclass(
  54. sess, obs_input, sampled_action, loss, loss_inputs)
  55. print(policy.compute_actions([1, 0, 2]))
  56. print(policy.postprocess_trajectory(SampleBatch({...})))
  57. .. testoutput::
  58. (array([0, 1, 1]), [], {})
  59. SampleBatch({"action": ..., "advantages": ..., ...})
  60. """
  61. # In order to create tf_policies from checkpoints, this class needs to separate
  62. # variables into their own scopes. Normally, we would do this in the model
  63. # catalog, but since Policy.from_state() can be called anywhere, we need to
  64. # keep track of it here to not break the from_state API.
  65. tf_var_creation_scope_counter = 0
  66. @staticmethod
  67. def next_tf_var_scope_name():
  68. # Tracks multiple instances that are spawned from this policy via .from_state()
  69. TFPolicy.tf_var_creation_scope_counter += 1
  70. return f"var_scope_{TFPolicy.tf_var_creation_scope_counter}"
  71. def __init__(
  72. self,
  73. observation_space: gym.spaces.Space,
  74. action_space: gym.spaces.Space,
  75. config: AlgorithmConfigDict,
  76. sess: "tf1.Session",
  77. obs_input: TensorType,
  78. sampled_action: TensorType,
  79. loss: Union[TensorType, List[TensorType]],
  80. loss_inputs: List[Tuple[str, TensorType]],
  81. model: Optional[ModelV2] = None,
  82. sampled_action_logp: Optional[TensorType] = None,
  83. action_input: Optional[TensorType] = None,
  84. log_likelihood: Optional[TensorType] = None,
  85. dist_inputs: Optional[TensorType] = None,
  86. dist_class: Optional[type] = None,
  87. state_inputs: Optional[List[TensorType]] = None,
  88. state_outputs: Optional[List[TensorType]] = None,
  89. prev_action_input: Optional[TensorType] = None,
  90. prev_reward_input: Optional[TensorType] = None,
  91. seq_lens: Optional[TensorType] = None,
  92. max_seq_len: int = 20,
  93. batch_divisibility_req: int = 1,
  94. update_ops: List[TensorType] = None,
  95. explore: Optional[TensorType] = None,
  96. timestep: Optional[TensorType] = None,
  97. ):
  98. """Initializes a Policy object.
  99. Args:
  100. observation_space: Observation space of the policy.
  101. action_space: Action space of the policy.
  102. config: Policy-specific configuration data.
  103. sess: The TensorFlow session to use.
  104. obs_input: Input placeholder for observations, of shape
  105. [BATCH_SIZE, obs...].
  106. sampled_action: Tensor for sampling an action, of shape
  107. [BATCH_SIZE, action...]
  108. loss: Scalar policy loss output tensor or a list thereof
  109. (in case there is more than one loss).
  110. loss_inputs: A (name, placeholder) tuple for each loss input
  111. argument. Each placeholder name must
  112. correspond to a SampleBatch column key returned by
  113. postprocess_trajectory(), and has shape [BATCH_SIZE, data...].
  114. These keys will be read from postprocessed sample batches and
  115. fed into the specified placeholders during loss computation.
  116. model: The optional ModelV2 to use for calculating actions and
  117. losses. If not None, TFPolicy will provide functionality for
  118. getting variables, calling the model's custom loss (if
  119. provided), and importing weights into the model.
  120. sampled_action_logp: log probability of the sampled action.
  121. action_input: Input placeholder for actions for
  122. logp/log-likelihood calculations.
  123. log_likelihood: Tensor to calculate the log_likelihood (given
  124. action_input and obs_input).
  125. dist_class: An optional ActionDistribution class to use for
  126. generating a dist object from distribution inputs.
  127. dist_inputs: Tensor to calculate the distribution
  128. inputs/parameters.
  129. state_inputs: List of RNN state input Tensors.
  130. state_outputs: List of RNN state output Tensors.
  131. prev_action_input: placeholder for previous actions.
  132. prev_reward_input: placeholder for previous rewards.
  133. seq_lens: Placeholder for RNN sequence lengths, of shape
  134. [NUM_SEQUENCES].
  135. Note that NUM_SEQUENCES << BATCH_SIZE. See
  136. policy/rnn_sequencing.py for more information.
  137. max_seq_len: Max sequence length for LSTM training.
  138. batch_divisibility_req: pad all agent experiences batches to
  139. multiples of this value. This only has an effect if not using
  140. a LSTM model.
  141. update_ops: override the batchnorm update ops
  142. to run when applying gradients. Otherwise we run all update
  143. ops found in the current variable scope.
  144. explore: Placeholder for `explore` parameter into call to
  145. Exploration.get_exploration_action. Explicitly set this to
  146. False for not creating any Exploration component.
  147. timestep: Placeholder for the global sampling timestep.
  148. """
  149. self.framework = "tf"
  150. super().__init__(observation_space, action_space, config)
  151. # Get devices to build the graph on.
  152. num_gpus = self._get_num_gpus_for_policy()
  153. gpu_ids = get_gpu_devices()
  154. logger.info(f"Found {len(gpu_ids)} visible cuda devices.")
  155. # Place on one or more CPU(s) when either:
  156. # - Fake GPU mode.
  157. # - num_gpus=0 (either set by user or we are in local_mode=True).
  158. # - no GPUs available.
  159. if config["_fake_gpus"] or num_gpus == 0 or not gpu_ids:
  160. self.devices = ["/cpu:0" for _ in range(int(math.ceil(num_gpus)) or 1)]
  161. # Place on one or more actual GPU(s), when:
  162. # - num_gpus > 0 (set by user) AND
  163. # - local_mode=False AND
  164. # - actual GPUs available AND
  165. # - non-fake GPU mode.
  166. else:
  167. # We are a remote worker (WORKER_MODE=1):
  168. # GPUs should be assigned to us by ray.
  169. if ray._private.worker._mode() == ray._private.worker.WORKER_MODE:
  170. gpu_ids = ray.get_gpu_ids()
  171. if len(gpu_ids) < num_gpus:
  172. raise ValueError(
  173. "TFPolicy was not able to find enough GPU IDs! Found "
  174. f"{gpu_ids}, but num_gpus={num_gpus}."
  175. )
  176. self.devices = [f"/gpu:{i}" for i, _ in enumerate(gpu_ids) if i < num_gpus]
  177. # Disable env-info placeholder.
  178. if SampleBatch.INFOS in self.view_requirements:
  179. self.view_requirements[SampleBatch.INFOS].used_for_compute_actions = False
  180. self.view_requirements[SampleBatch.INFOS].used_for_training = False
  181. # Optionally add `infos` to the output dataset
  182. if self.config["output_config"].get("store_infos", False):
  183. self.view_requirements[SampleBatch.INFOS].used_for_training = True
  184. assert model is None or isinstance(model, (ModelV2, tf.keras.Model)), (
  185. "Model classes for TFPolicy other than `ModelV2|tf.keras.Model` "
  186. "not allowed! You passed in {}.".format(model)
  187. )
  188. self.model = model
  189. # Auto-update model's inference view requirements, if recurrent.
  190. if self.model is not None:
  191. self._update_model_view_requirements_from_init_state()
  192. # If `explore` is explicitly set to False, don't create an exploration
  193. # component.
  194. self.exploration = self._create_exploration() if explore is not False else None
  195. self._sess = sess
  196. self._obs_input = obs_input
  197. self._prev_action_input = prev_action_input
  198. self._prev_reward_input = prev_reward_input
  199. self._sampled_action = sampled_action
  200. self._is_training = self._get_is_training_placeholder()
  201. self._is_exploring = (
  202. explore
  203. if explore is not None
  204. else tf1.placeholder_with_default(True, (), name="is_exploring")
  205. )
  206. self._sampled_action_logp = sampled_action_logp
  207. self._sampled_action_prob = (
  208. tf.math.exp(self._sampled_action_logp)
  209. if self._sampled_action_logp is not None
  210. else None
  211. )
  212. self._action_input = action_input # For logp calculations.
  213. self._dist_inputs = dist_inputs
  214. self.dist_class = dist_class
  215. self._cached_extra_action_out = None
  216. self._state_inputs = state_inputs or []
  217. self._state_outputs = state_outputs or []
  218. self._seq_lens = seq_lens
  219. self._max_seq_len = max_seq_len
  220. if self._state_inputs and self._seq_lens is None:
  221. raise ValueError(
  222. "seq_lens tensor must be given if state inputs are defined"
  223. )
  224. self._batch_divisibility_req = batch_divisibility_req
  225. self._update_ops = update_ops
  226. self._apply_op = None
  227. self._stats_fetches = {}
  228. self._timestep = (
  229. timestep
  230. if timestep is not None
  231. else tf1.placeholder_with_default(
  232. tf.zeros((), dtype=tf.int64), (), name="timestep"
  233. )
  234. )
  235. self._optimizers: List[LocalOptimizer] = []
  236. # Backward compatibility and for some code shared with tf-eager Policy.
  237. self._optimizer = None
  238. self._grads_and_vars: Union[ModelGradients, List[ModelGradients]] = []
  239. self._grads: Union[ModelGradients, List[ModelGradients]] = []
  240. # Policy tf-variables (weights), whose values to get/set via
  241. # get_weights/set_weights.
  242. self._variables = None
  243. # Local optimizer(s)' tf-variables (e.g. state vars for Adam).
  244. # Will be stored alongside `self._variables` when checkpointing.
  245. self._optimizer_variables: Optional[TensorFlowVariables] = None
  246. # The loss tf-op(s). Number of losses must match number of optimizers.
  247. self._losses = []
  248. # Backward compatibility (in case custom child TFPolicies access this
  249. # property).
  250. self._loss = None
  251. # A batch dict passed into loss function as input.
  252. self._loss_input_dict = {}
  253. losses = force_list(loss)
  254. if len(losses) > 0:
  255. self._initialize_loss(losses, loss_inputs)
  256. # The log-likelihood calculator op.
  257. self._log_likelihood = log_likelihood
  258. if (
  259. self._log_likelihood is None
  260. and self._dist_inputs is not None
  261. and self.dist_class is not None
  262. ):
  263. self._log_likelihood = self.dist_class(self._dist_inputs, self.model).logp(
  264. self._action_input
  265. )
  266. @override(Policy)
  267. def compute_actions_from_input_dict(
  268. self,
  269. input_dict: Union[SampleBatch, Dict[str, TensorType]],
  270. explore: bool = None,
  271. timestep: Optional[int] = None,
  272. episode=None,
  273. **kwargs,
  274. ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
  275. explore = explore if explore is not None else self.config["explore"]
  276. timestep = timestep if timestep is not None else self.global_timestep
  277. # Switch off is_training flag in our batch.
  278. if isinstance(input_dict, SampleBatch):
  279. input_dict.set_training(False)
  280. else:
  281. # Deprecated dict input.
  282. input_dict["is_training"] = False
  283. builder = _TFRunBuilder(self.get_session(), "compute_actions_from_input_dict")
  284. obs_batch = input_dict[SampleBatch.OBS]
  285. to_fetch = self._build_compute_actions(
  286. builder, input_dict=input_dict, explore=explore, timestep=timestep
  287. )
  288. # Execute session run to get action (and other fetches).
  289. fetched = builder.get(to_fetch)
  290. # Update our global timestep by the batch size.
  291. self.global_timestep += (
  292. len(obs_batch)
  293. if isinstance(obs_batch, list)
  294. else len(input_dict)
  295. if isinstance(input_dict, SampleBatch)
  296. else obs_batch.shape[0]
  297. )
  298. return fetched
  299. @override(Policy)
  300. def compute_actions(
  301. self,
  302. obs_batch: Union[List[TensorType], TensorType],
  303. state_batches: Optional[List[TensorType]] = None,
  304. prev_action_batch: Union[List[TensorType], TensorType] = None,
  305. prev_reward_batch: Union[List[TensorType], TensorType] = None,
  306. info_batch: Optional[Dict[str, list]] = None,
  307. episodes=None,
  308. explore: Optional[bool] = None,
  309. timestep: Optional[int] = None,
  310. **kwargs,
  311. ):
  312. explore = explore if explore is not None else self.config["explore"]
  313. timestep = timestep if timestep is not None else self.global_timestep
  314. builder = _TFRunBuilder(self.get_session(), "compute_actions")
  315. input_dict = {SampleBatch.OBS: obs_batch, "is_training": False}
  316. if state_batches:
  317. for i, s in enumerate(state_batches):
  318. input_dict[f"state_in_{i}"] = s
  319. if prev_action_batch is not None:
  320. input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
  321. if prev_reward_batch is not None:
  322. input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
  323. to_fetch = self._build_compute_actions(
  324. builder, input_dict=input_dict, explore=explore, timestep=timestep
  325. )
  326. # Execute session run to get action (and other fetches).
  327. fetched = builder.get(to_fetch)
  328. # Update our global timestep by the batch size.
  329. self.global_timestep += (
  330. len(obs_batch)
  331. if isinstance(obs_batch, list)
  332. else tree.flatten(obs_batch)[0].shape[0]
  333. )
  334. return fetched
  335. @override(Policy)
  336. def compute_log_likelihoods(
  337. self,
  338. actions: Union[List[TensorType], TensorType],
  339. obs_batch: Union[List[TensorType], TensorType],
  340. state_batches: Optional[List[TensorType]] = None,
  341. prev_action_batch: Optional[Union[List[TensorType], TensorType]] = None,
  342. prev_reward_batch: Optional[Union[List[TensorType], TensorType]] = None,
  343. actions_normalized: bool = True,
  344. **kwargs,
  345. ) -> TensorType:
  346. if self._log_likelihood is None:
  347. raise ValueError(
  348. "Cannot compute log-prob/likelihood w/o a self._log_likelihood op!"
  349. )
  350. # Exploration hook before each forward pass.
  351. self.exploration.before_compute_actions(
  352. explore=False, tf_sess=self.get_session()
  353. )
  354. builder = _TFRunBuilder(self.get_session(), "compute_log_likelihoods")
  355. # Normalize actions if necessary.
  356. if actions_normalized is False and self.config["normalize_actions"]:
  357. actions = normalize_action(actions, self.action_space_struct)
  358. # Feed actions (for which we want logp values) into graph.
  359. builder.add_feed_dict({self._action_input: actions})
  360. # Feed observations.
  361. builder.add_feed_dict({self._obs_input: obs_batch})
  362. # Internal states.
  363. state_batches = state_batches or []
  364. if len(self._state_inputs) != len(state_batches):
  365. raise ValueError(
  366. "Must pass in RNN state batches for placeholders {}, got {}".format(
  367. self._state_inputs, state_batches
  368. )
  369. )
  370. builder.add_feed_dict(dict(zip(self._state_inputs, state_batches)))
  371. if state_batches:
  372. builder.add_feed_dict({self._seq_lens: np.ones(len(obs_batch))})
  373. # Prev-a and r.
  374. if self._prev_action_input is not None and prev_action_batch is not None:
  375. builder.add_feed_dict({self._prev_action_input: prev_action_batch})
  376. if self._prev_reward_input is not None and prev_reward_batch is not None:
  377. builder.add_feed_dict({self._prev_reward_input: prev_reward_batch})
  378. # Fetch the log_likelihoods output and return.
  379. fetches = builder.add_fetches([self._log_likelihood])
  380. return builder.get(fetches)[0]
  381. @override(Policy)
  382. def learn_on_batch(self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]:
  383. assert self.loss_initialized()
  384. # Switch on is_training flag in our batch.
  385. postprocessed_batch.set_training(True)
  386. builder = _TFRunBuilder(self.get_session(), "learn_on_batch")
  387. # Callback handling.
  388. learn_stats = {}
  389. self.callbacks.on_learn_on_batch(
  390. policy=self, train_batch=postprocessed_batch, result=learn_stats
  391. )
  392. fetches = self._build_learn_on_batch(builder, postprocessed_batch)
  393. stats = builder.get(fetches)
  394. self.num_grad_updates += 1
  395. stats.update(
  396. {
  397. "custom_metrics": learn_stats,
  398. NUM_AGENT_STEPS_TRAINED: postprocessed_batch.count,
  399. NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates,
  400. # -1, b/c we have to measure this diff before we do the update above.
  401. DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: (
  402. self.num_grad_updates
  403. - 1
  404. - (postprocessed_batch.num_grad_updates or 0)
  405. ),
  406. }
  407. )
  408. return stats
  409. @override(Policy)
  410. def compute_gradients(
  411. self, postprocessed_batch: SampleBatch
  412. ) -> Tuple[ModelGradients, Dict[str, TensorType]]:
  413. assert self.loss_initialized()
  414. # Switch on is_training flag in our batch.
  415. postprocessed_batch.set_training(True)
  416. builder = _TFRunBuilder(self.get_session(), "compute_gradients")
  417. fetches = self._build_compute_gradients(builder, postprocessed_batch)
  418. return builder.get(fetches)
  419. @staticmethod
  420. def _tf1_from_state_helper(state: PolicyState) -> "Policy":
  421. """Recovers a TFPolicy from a state object.
  422. The `state` of an instantiated TFPolicy can be retrieved by calling its
  423. `get_state` method. Is meant to be used by the Policy.from_state() method to
  424. aid with tracking variable creation.
  425. Args:
  426. state: The state to recover a new TFPolicy instance from.
  427. Returns:
  428. A new TFPolicy instance.
  429. """
  430. serialized_pol_spec: Optional[dict] = state.get("policy_spec")
  431. if serialized_pol_spec is None:
  432. raise ValueError(
  433. "No `policy_spec` key was found in given `state`! "
  434. "Cannot create new Policy."
  435. )
  436. pol_spec = PolicySpec.deserialize(serialized_pol_spec)
  437. with tf1.variable_scope(TFPolicy.next_tf_var_scope_name()):
  438. # Create the new policy.
  439. new_policy = pol_spec.policy_class(
  440. # Note(jungong) : we are intentionally not using keyward arguments here
  441. # because some policies name the observation space parameter obs_space,
  442. # and some others name it observation_space.
  443. pol_spec.observation_space,
  444. pol_spec.action_space,
  445. pol_spec.config,
  446. )
  447. # Set the new policy's state (weights, optimizer vars, exploration state,
  448. # etc..).
  449. new_policy.set_state(state)
  450. # Return the new policy.
  451. return new_policy
  452. @override(Policy)
  453. def apply_gradients(self, gradients: ModelGradients) -> None:
  454. assert self.loss_initialized()
  455. builder = _TFRunBuilder(self.get_session(), "apply_gradients")
  456. fetches = self._build_apply_gradients(builder, gradients)
  457. builder.get(fetches)
  458. @override(Policy)
  459. def get_weights(self) -> Union[Dict[str, TensorType], List[TensorType]]:
  460. return self._variables.get_weights()
  461. @override(Policy)
  462. def set_weights(self, weights) -> None:
  463. return self._variables.set_weights(weights)
  464. @override(Policy)
  465. def get_exploration_state(self) -> Dict[str, TensorType]:
  466. return self.exploration.get_state(sess=self.get_session())
  467. @Deprecated(new="get_exploration_state", error=True)
  468. def get_exploration_info(self) -> Dict[str, TensorType]:
  469. return self.get_exploration_state()
  470. @override(Policy)
  471. def is_recurrent(self) -> bool:
  472. return len(self._state_inputs) > 0
  473. @override(Policy)
  474. def num_state_tensors(self) -> int:
  475. return len(self._state_inputs)
  476. @override(Policy)
  477. def get_state(self) -> PolicyState:
  478. # For tf Policies, return Policy weights and optimizer var values.
  479. state = super().get_state()
  480. if len(self._optimizer_variables.variables) > 0:
  481. state["_optimizer_variables"] = self.get_session().run(
  482. self._optimizer_variables.variables
  483. )
  484. # Add exploration state.
  485. state["_exploration_state"] = self.exploration.get_state(self.get_session())
  486. return state
  487. @override(Policy)
  488. def set_state(self, state: PolicyState) -> None:
  489. # Set optimizer vars first.
  490. optimizer_vars = state.get("_optimizer_variables", None)
  491. if optimizer_vars is not None:
  492. self._optimizer_variables.set_weights(optimizer_vars)
  493. # Set exploration's state.
  494. if hasattr(self, "exploration") and "_exploration_state" in state:
  495. self.exploration.set_state(
  496. state=state["_exploration_state"], sess=self.get_session()
  497. )
  498. # Restore global timestep.
  499. self.global_timestep = state["global_timestep"]
  500. # Then the Policy's (NN) weights and connectors.
  501. super().set_state(state)
  502. @override(Policy)
  503. def export_model(self, export_dir: str, onnx: Optional[int] = None) -> None:
  504. """Export tensorflow graph to export_dir for serving."""
  505. if onnx:
  506. try:
  507. import tf2onnx
  508. except ImportError as e:
  509. raise RuntimeError(
  510. "Converting a TensorFlow model to ONNX requires "
  511. "`tf2onnx` to be installed. Install with "
  512. "`pip install tf2onnx`."
  513. ) from e
  514. with self.get_session().graph.as_default():
  515. signature_def_map = self._build_signature_def()
  516. sd = signature_def_map[
  517. tf1.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY # noqa: E501
  518. ]
  519. inputs = [v.name for k, v in sd.inputs.items()]
  520. outputs = [v.name for k, v in sd.outputs.items()]
  521. from tf2onnx import tf_loader
  522. frozen_graph_def = tf_loader.freeze_session(
  523. self.get_session(), input_names=inputs, output_names=outputs
  524. )
  525. with tf1.Session(graph=tf.Graph()) as session:
  526. tf.import_graph_def(frozen_graph_def, name="")
  527. g = tf2onnx.tfonnx.process_tf_graph(
  528. session.graph,
  529. input_names=inputs,
  530. output_names=outputs,
  531. inputs_as_nchw=inputs,
  532. )
  533. model_proto = g.make_model("onnx_model")
  534. tf2onnx.utils.save_onnx_model(
  535. export_dir, "model", feed_dict={}, model_proto=model_proto
  536. )
  537. # Save the tf.keras.Model (architecture and weights, so it can be retrieved
  538. # w/o access to the original (custom) Model or Policy code).
  539. elif (
  540. hasattr(self, "model")
  541. and hasattr(self.model, "base_model")
  542. and isinstance(self.model.base_model, tf.keras.Model)
  543. ):
  544. with self.get_session().graph.as_default():
  545. try:
  546. self.model.base_model.save(filepath=export_dir, save_format="tf")
  547. except Exception:
  548. logger.warning(ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL)
  549. else:
  550. logger.warning(ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL)
  551. @override(Policy)
  552. def import_model_from_h5(self, import_file: str) -> None:
  553. """Imports weights into tf model."""
  554. if self.model is None:
  555. raise NotImplementedError("No `self.model` to import into!")
  556. # Make sure the session is the right one (see issue #7046).
  557. with self.get_session().graph.as_default():
  558. with self.get_session().as_default():
  559. return self.model.import_from_h5(import_file)
  560. @override(Policy)
  561. def get_session(self) -> Optional["tf1.Session"]:
  562. """Returns a reference to the TF session for this policy."""
  563. return self._sess
  564. def variables(self):
  565. """Return the list of all savable variables for this policy."""
  566. if self.model is None:
  567. raise NotImplementedError("No `self.model` to get variables for!")
  568. elif isinstance(self.model, tf.keras.Model):
  569. return self.model.variables
  570. else:
  571. return self.model.variables()
  572. def get_placeholder(self, name) -> "tf1.placeholder":
  573. """Returns the given action or loss input placeholder by name.
  574. If the loss has not been initialized and a loss input placeholder is
  575. requested, an error is raised.
  576. Args:
  577. name: The name of the placeholder to return. One of
  578. SampleBatch.CUR_OBS|PREV_ACTION/REWARD or a valid key from
  579. `self._loss_input_dict`.
  580. Returns:
  581. tf1.placeholder: The placeholder under the given str key.
  582. """
  583. if name == SampleBatch.CUR_OBS:
  584. return self._obs_input
  585. elif name == SampleBatch.PREV_ACTIONS:
  586. return self._prev_action_input
  587. elif name == SampleBatch.PREV_REWARDS:
  588. return self._prev_reward_input
  589. assert self._loss_input_dict, (
  590. "You need to populate `self._loss_input_dict` before "
  591. "`get_placeholder()` can be called"
  592. )
  593. return self._loss_input_dict[name]
  594. def loss_initialized(self) -> bool:
  595. """Returns whether the loss term(s) have been initialized."""
  596. return len(self._losses) > 0
  597. def _initialize_loss(
  598. self, losses: List[TensorType], loss_inputs: List[Tuple[str, TensorType]]
  599. ) -> None:
  600. """Initializes the loss op from given loss tensor and placeholders.
  601. Args:
  602. loss (List[TensorType]): The list of loss ops returned by some
  603. loss function.
  604. loss_inputs (List[Tuple[str, TensorType]]): The list of Tuples:
  605. (name, tf1.placeholders) needed for calculating the loss.
  606. """
  607. self._loss_input_dict = dict(loss_inputs)
  608. self._loss_input_dict_no_rnn = {
  609. k: v
  610. for k, v in self._loss_input_dict.items()
  611. if (v not in self._state_inputs and v != self._seq_lens)
  612. }
  613. for i, ph in enumerate(self._state_inputs):
  614. self._loss_input_dict["state_in_{}".format(i)] = ph
  615. if self.model and not isinstance(self.model, tf.keras.Model):
  616. self._losses = force_list(
  617. self.model.custom_loss(losses, self._loss_input_dict)
  618. )
  619. self._stats_fetches.update({"model": self.model.metrics()})
  620. else:
  621. self._losses = losses
  622. # Backward compatibility.
  623. self._loss = self._losses[0] if self._losses is not None else None
  624. if not self._optimizers:
  625. self._optimizers = force_list(self.optimizer())
  626. # Backward compatibility.
  627. self._optimizer = self._optimizers[0] if self._optimizers else None
  628. # Supporting more than one loss/optimizer.
  629. if self.config["_tf_policy_handles_more_than_one_loss"]:
  630. self._grads_and_vars = []
  631. self._grads = []
  632. for group in self.gradients(self._optimizers, self._losses):
  633. g_and_v = [(g, v) for (g, v) in group if g is not None]
  634. self._grads_and_vars.append(g_and_v)
  635. self._grads.append([g for (g, _) in g_and_v])
  636. # Only one optimizer and and loss term.
  637. else:
  638. self._grads_and_vars = [
  639. (g, v)
  640. for (g, v) in self.gradients(self._optimizer, self._loss)
  641. if g is not None
  642. ]
  643. self._grads = [g for (g, _) in self._grads_and_vars]
  644. if self.model:
  645. self._variables = TensorFlowVariables(
  646. [], self.get_session(), self.variables()
  647. )
  648. # Gather update ops for any batch norm layers.
  649. if len(self.devices) <= 1:
  650. if not self._update_ops:
  651. self._update_ops = tf1.get_collection(
  652. tf1.GraphKeys.UPDATE_OPS, scope=tf1.get_variable_scope().name
  653. )
  654. if self._update_ops:
  655. logger.info(
  656. "Update ops to run on apply gradient: {}".format(self._update_ops)
  657. )
  658. with tf1.control_dependencies(self._update_ops):
  659. self._apply_op = self.build_apply_op(
  660. optimizer=self._optimizers
  661. if self.config["_tf_policy_handles_more_than_one_loss"]
  662. else self._optimizer,
  663. grads_and_vars=self._grads_and_vars,
  664. )
  665. if log_once("loss_used"):
  666. logger.debug(
  667. "These tensors were used in the loss functions:"
  668. f"\n{summarize(self._loss_input_dict)}\n"
  669. )
  670. self.get_session().run(tf1.global_variables_initializer())
  671. # TensorFlowVariables holding a flat list of all our optimizers'
  672. # variables.
  673. self._optimizer_variables = TensorFlowVariables(
  674. [v for o in self._optimizers for v in o.variables()], self.get_session()
  675. )
  676. def copy(self, existing_inputs: List[Tuple[str, "tf1.placeholder"]]) -> "TFPolicy":
  677. """Creates a copy of self using existing input placeholders.
  678. Optional: Only required to work with the multi-GPU optimizer.
  679. Args:
  680. existing_inputs (List[Tuple[str, tf1.placeholder]]): Dict mapping
  681. names (str) to tf1.placeholders to re-use (share) with the
  682. returned copy of self.
  683. Returns:
  684. TFPolicy: A copy of self.
  685. """
  686. raise NotImplementedError
  687. def extra_compute_action_feed_dict(self) -> Dict[TensorType, TensorType]:
  688. """Extra dict to pass to the compute actions session run.
  689. Returns:
  690. Dict[TensorType, TensorType]: A feed dict to be added to the
  691. feed_dict passed to the compute_actions session.run() call.
  692. """
  693. return {}
  694. def extra_compute_action_fetches(self) -> Dict[str, TensorType]:
  695. # Cache graph fetches for action computation for better
  696. # performance.
  697. # This function is called every time the static graph is run
  698. # to compute actions.
  699. if not self._cached_extra_action_out:
  700. self._cached_extra_action_out = self.extra_action_out_fn()
  701. return self._cached_extra_action_out
  702. def extra_action_out_fn(self) -> Dict[str, TensorType]:
  703. """Extra values to fetch and return from compute_actions().
  704. By default we return action probability/log-likelihood info
  705. and action distribution inputs (if present).
  706. Returns:
  707. Dict[str, TensorType]: An extra fetch-dict to be passed to and
  708. returned from the compute_actions() call.
  709. """
  710. extra_fetches = {}
  711. # Action-logp and action-prob.
  712. if self._sampled_action_logp is not None:
  713. extra_fetches[SampleBatch.ACTION_PROB] = self._sampled_action_prob
  714. extra_fetches[SampleBatch.ACTION_LOGP] = self._sampled_action_logp
  715. # Action-dist inputs.
  716. if self._dist_inputs is not None:
  717. extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = self._dist_inputs
  718. return extra_fetches
  719. def extra_compute_grad_feed_dict(self) -> Dict[TensorType, TensorType]:
  720. """Extra dict to pass to the compute gradients session run.
  721. Returns:
  722. Dict[TensorType, TensorType]: Extra feed_dict to be passed to the
  723. compute_gradients Session.run() call.
  724. """
  725. return {} # e.g, kl_coeff
  726. def extra_compute_grad_fetches(self) -> Dict[str, any]:
  727. """Extra values to fetch and return from compute_gradients().
  728. Returns:
  729. Dict[str, any]: Extra fetch dict to be added to the fetch dict
  730. of the compute_gradients Session.run() call.
  731. """
  732. return {LEARNER_STATS_KEY: {}} # e.g, stats, td error, etc.
  733. def optimizer(self) -> "tf.keras.optimizers.Optimizer":
  734. """TF optimizer to use for policy optimization.
  735. Returns:
  736. tf.keras.optimizers.Optimizer: The local optimizer to use for this
  737. Policy's Model.
  738. """
  739. if hasattr(self, "config") and "lr" in self.config:
  740. return tf1.train.AdamOptimizer(learning_rate=self.config["lr"])
  741. else:
  742. return tf1.train.AdamOptimizer()
  743. def gradients(
  744. self,
  745. optimizer: Union[LocalOptimizer, List[LocalOptimizer]],
  746. loss: Union[TensorType, List[TensorType]],
  747. ) -> Union[List[ModelGradients], List[List[ModelGradients]]]:
  748. """Override this for a custom gradient computation behavior.
  749. Args:
  750. optimizer (Union[LocalOptimizer, List[LocalOptimizer]]): A single
  751. LocalOptimizer of a list thereof to use for gradient
  752. calculations. If more than one optimizer given, the number of
  753. optimizers must match the number of losses provided.
  754. loss (Union[TensorType, List[TensorType]]): A single loss term
  755. or a list thereof to use for gradient calculations.
  756. If more than one loss given, the number of loss terms must
  757. match the number of optimizers provided.
  758. Returns:
  759. Union[List[ModelGradients], List[List[ModelGradients]]]: List of
  760. ModelGradients (grads and vars OR just grads) OR List of List
  761. of ModelGradients in case we have more than one
  762. optimizer/loss.
  763. """
  764. optimizers = force_list(optimizer)
  765. losses = force_list(loss)
  766. # We have more than one optimizers and loss terms.
  767. if self.config["_tf_policy_handles_more_than_one_loss"]:
  768. grads = []
  769. for optim, loss_ in zip(optimizers, losses):
  770. grads.append(optim.compute_gradients(loss_))
  771. # We have only one optimizer and one loss term.
  772. else:
  773. return optimizers[0].compute_gradients(losses[0])
  774. def build_apply_op(
  775. self,
  776. optimizer: Union[LocalOptimizer, List[LocalOptimizer]],
  777. grads_and_vars: Union[ModelGradients, List[ModelGradients]],
  778. ) -> "tf.Operation":
  779. """Override this for a custom gradient apply computation behavior.
  780. Args:
  781. optimizer (Union[LocalOptimizer, List[LocalOptimizer]]): The local
  782. tf optimizer to use for applying the grads and vars.
  783. grads_and_vars (Union[ModelGradients, List[ModelGradients]]): List
  784. of tuples with grad values and the grad-value's corresponding
  785. tf.variable in it.
  786. Returns:
  787. tf.Operation: The tf op that applies all computed gradients
  788. (`grads_and_vars`) to the model(s) via the given optimizer(s).
  789. """
  790. optimizers = force_list(optimizer)
  791. # We have more than one optimizers and loss terms.
  792. if self.config["_tf_policy_handles_more_than_one_loss"]:
  793. ops = []
  794. for i, optim in enumerate(optimizers):
  795. # Specify global_step (e.g. for TD3 which needs to count the
  796. # num updates that have happened).
  797. ops.append(
  798. optim.apply_gradients(
  799. grads_and_vars[i],
  800. global_step=tf1.train.get_or_create_global_step(),
  801. )
  802. )
  803. return tf.group(ops)
  804. # We have only one optimizer and one loss term.
  805. else:
  806. return optimizers[0].apply_gradients(
  807. grads_and_vars, global_step=tf1.train.get_or_create_global_step()
  808. )
  809. def _get_is_training_placeholder(self):
  810. """Get the placeholder for _is_training, i.e., for batch norm layers.
  811. This can be called safely before __init__ has run.
  812. """
  813. if not hasattr(self, "_is_training"):
  814. self._is_training = tf1.placeholder_with_default(
  815. False, (), name="is_training"
  816. )
  817. return self._is_training
  818. def _debug_vars(self):
  819. if log_once("grad_vars"):
  820. if self.config["_tf_policy_handles_more_than_one_loss"]:
  821. for group in self._grads_and_vars:
  822. for _, v in group:
  823. logger.info("Optimizing variable {}".format(v))
  824. else:
  825. for _, v in self._grads_and_vars:
  826. logger.info("Optimizing variable {}".format(v))
  827. def _extra_input_signature_def(self):
  828. """Extra input signatures to add when exporting tf model.
  829. Inferred from extra_compute_action_feed_dict()
  830. """
  831. feed_dict = self.extra_compute_action_feed_dict()
  832. return {
  833. k.name: tf1.saved_model.utils.build_tensor_info(k) for k in feed_dict.keys()
  834. }
  835. def _extra_output_signature_def(self):
  836. """Extra output signatures to add when exporting tf model.
  837. Inferred from extra_compute_action_fetches()
  838. """
  839. fetches = self.extra_compute_action_fetches()
  840. return {
  841. k: tf1.saved_model.utils.build_tensor_info(fetches[k])
  842. for k in fetches.keys()
  843. }
  844. def _build_signature_def(self):
  845. """Build signature def map for tensorflow SavedModelBuilder."""
  846. # build input signatures
  847. input_signature = self._extra_input_signature_def()
  848. input_signature["observations"] = tf1.saved_model.utils.build_tensor_info(
  849. self._obs_input
  850. )
  851. if self._seq_lens is not None:
  852. input_signature[
  853. SampleBatch.SEQ_LENS
  854. ] = tf1.saved_model.utils.build_tensor_info(self._seq_lens)
  855. if self._prev_action_input is not None:
  856. input_signature["prev_action"] = tf1.saved_model.utils.build_tensor_info(
  857. self._prev_action_input
  858. )
  859. if self._prev_reward_input is not None:
  860. input_signature["prev_reward"] = tf1.saved_model.utils.build_tensor_info(
  861. self._prev_reward_input
  862. )
  863. input_signature["is_training"] = tf1.saved_model.utils.build_tensor_info(
  864. self._is_training
  865. )
  866. if self._timestep is not None:
  867. input_signature["timestep"] = tf1.saved_model.utils.build_tensor_info(
  868. self._timestep
  869. )
  870. for state_input in self._state_inputs:
  871. input_signature[state_input.name] = tf1.saved_model.utils.build_tensor_info(
  872. state_input
  873. )
  874. # build output signatures
  875. output_signature = self._extra_output_signature_def()
  876. for i, a in enumerate(tf.nest.flatten(self._sampled_action)):
  877. output_signature[
  878. "actions_{}".format(i)
  879. ] = tf1.saved_model.utils.build_tensor_info(a)
  880. for state_output in self._state_outputs:
  881. output_signature[
  882. state_output.name
  883. ] = tf1.saved_model.utils.build_tensor_info(state_output)
  884. signature_def = tf1.saved_model.signature_def_utils.build_signature_def(
  885. input_signature,
  886. output_signature,
  887. tf1.saved_model.signature_constants.PREDICT_METHOD_NAME,
  888. )
  889. signature_def_key = (
  890. tf1.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
  891. )
  892. signature_def_map = {signature_def_key: signature_def}
  893. return signature_def_map
  894. def _build_compute_actions(
  895. self,
  896. builder,
  897. *,
  898. input_dict=None,
  899. obs_batch=None,
  900. state_batches=None,
  901. prev_action_batch=None,
  902. prev_reward_batch=None,
  903. episodes=None,
  904. explore=None,
  905. timestep=None,
  906. ):
  907. explore = explore if explore is not None else self.config["explore"]
  908. timestep = timestep if timestep is not None else self.global_timestep
  909. # Call the exploration before_compute_actions hook.
  910. self.exploration.before_compute_actions(
  911. timestep=timestep, explore=explore, tf_sess=self.get_session()
  912. )
  913. builder.add_feed_dict(self.extra_compute_action_feed_dict())
  914. # `input_dict` given: Simply build what's in that dict.
  915. if hasattr(self, "_input_dict"):
  916. for key, value in input_dict.items():
  917. if key in self._input_dict:
  918. # Handle complex/nested spaces as well.
  919. tree.map_structure(
  920. lambda k, v: builder.add_feed_dict({k: v}),
  921. self._input_dict[key],
  922. value,
  923. )
  924. # For policies that inherit directly from TFPolicy.
  925. else:
  926. builder.add_feed_dict({self._obs_input: input_dict[SampleBatch.OBS]})
  927. if SampleBatch.PREV_ACTIONS in input_dict:
  928. builder.add_feed_dict(
  929. {self._prev_action_input: input_dict[SampleBatch.PREV_ACTIONS]}
  930. )
  931. if SampleBatch.PREV_REWARDS in input_dict:
  932. builder.add_feed_dict(
  933. {self._prev_reward_input: input_dict[SampleBatch.PREV_REWARDS]}
  934. )
  935. state_batches = []
  936. i = 0
  937. while "state_in_{}".format(i) in input_dict:
  938. state_batches.append(input_dict["state_in_{}".format(i)])
  939. i += 1
  940. builder.add_feed_dict(dict(zip(self._state_inputs, state_batches)))
  941. if "state_in_0" in input_dict and SampleBatch.SEQ_LENS not in input_dict:
  942. builder.add_feed_dict(
  943. {self._seq_lens: np.ones(len(input_dict["state_in_0"]))}
  944. )
  945. builder.add_feed_dict({self._is_exploring: explore})
  946. if timestep is not None:
  947. builder.add_feed_dict({self._timestep: timestep})
  948. # Determine, what exactly to fetch from the graph.
  949. to_fetch = (
  950. [self._sampled_action]
  951. + self._state_outputs
  952. + [self.extra_compute_action_fetches()]
  953. )
  954. # Add the ops to fetch for the upcoming session call.
  955. fetches = builder.add_fetches(to_fetch)
  956. return fetches[0], fetches[1:-1], fetches[-1]
  957. def _build_compute_gradients(self, builder, postprocessed_batch):
  958. self._debug_vars()
  959. builder.add_feed_dict(self.extra_compute_grad_feed_dict())
  960. builder.add_feed_dict(
  961. self._get_loss_inputs_dict(postprocessed_batch, shuffle=False)
  962. )
  963. fetches = builder.add_fetches([self._grads, self._get_grad_and_stats_fetches()])
  964. return fetches[0], fetches[1]
  965. def _build_apply_gradients(self, builder, gradients):
  966. if len(gradients) != len(self._grads):
  967. raise ValueError(
  968. "Unexpected number of gradients to apply, got {} for {}".format(
  969. gradients, self._grads
  970. )
  971. )
  972. builder.add_feed_dict({self._is_training: True})
  973. builder.add_feed_dict(dict(zip(self._grads, gradients)))
  974. fetches = builder.add_fetches([self._apply_op])
  975. return fetches[0]
  976. def _build_learn_on_batch(self, builder, postprocessed_batch):
  977. self._debug_vars()
  978. builder.add_feed_dict(self.extra_compute_grad_feed_dict())
  979. builder.add_feed_dict(
  980. self._get_loss_inputs_dict(postprocessed_batch, shuffle=False)
  981. )
  982. fetches = builder.add_fetches(
  983. [
  984. self._apply_op,
  985. self._get_grad_and_stats_fetches(),
  986. ]
  987. )
  988. return fetches[1]
  989. def _get_grad_and_stats_fetches(self):
  990. fetches = self.extra_compute_grad_fetches()
  991. if LEARNER_STATS_KEY not in fetches:
  992. raise ValueError("Grad fetches should contain 'stats': {...} entry")
  993. if self._stats_fetches:
  994. fetches[LEARNER_STATS_KEY] = dict(
  995. self._stats_fetches, **fetches[LEARNER_STATS_KEY]
  996. )
  997. return fetches
  998. def _get_loss_inputs_dict(self, train_batch: SampleBatch, shuffle: bool):
  999. """Return a feed dict from a batch.
  1000. Args:
  1001. train_batch: batch of data to derive inputs from.
  1002. shuffle: whether to shuffle batch sequences. Shuffle may
  1003. be done in-place. This only makes sense if you're further
  1004. applying minibatch SGD after getting the outputs.
  1005. Returns:
  1006. Feed dict of data.
  1007. """
  1008. # Get batch ready for RNNs, if applicable.
  1009. if not isinstance(train_batch, SampleBatch) or not train_batch.zero_padded:
  1010. pad_batch_to_sequences_of_same_size(
  1011. train_batch,
  1012. max_seq_len=self._max_seq_len,
  1013. shuffle=shuffle,
  1014. batch_divisibility_req=self._batch_divisibility_req,
  1015. feature_keys=list(self._loss_input_dict_no_rnn.keys()),
  1016. view_requirements=self.view_requirements,
  1017. )
  1018. # Mark the batch as "is_training" so the Model can use this
  1019. # information.
  1020. train_batch.set_training(True)
  1021. # Build the feed dict from the batch.
  1022. feed_dict = {}
  1023. for key, placeholders in self._loss_input_dict.items():
  1024. a = tree.map_structure(
  1025. lambda ph, v: feed_dict.__setitem__(ph, v),
  1026. placeholders,
  1027. train_batch[key],
  1028. )
  1029. del a
  1030. state_keys = ["state_in_{}".format(i) for i in range(len(self._state_inputs))]
  1031. for key in state_keys:
  1032. feed_dict[self._loss_input_dict[key]] = train_batch[key]
  1033. if state_keys:
  1034. feed_dict[self._seq_lens] = train_batch[SampleBatch.SEQ_LENS]
  1035. return feed_dict