| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051 |
- """Eager mode TF policy built using build_tf_policy().
- It supports both traced and non-traced eager execution modes."""
- import functools
- import logging
- import os
- import threading
- from typing import Dict, List, Optional, Tuple, Union
- import tree # pip install dm_tree
- from ray._common.deprecation import (
- DEPRECATED_VALUE,
- deprecation_warning,
- )
- from ray.rllib.models.catalog import ModelCatalog
- from ray.rllib.models.repeated_values import RepeatedValues
- from ray.rllib.policy.policy import Policy, PolicyState
- from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
- from ray.rllib.policy.sample_batch import SampleBatch
- from ray.rllib.utils import add_mixins, force_list
- from ray.rllib.utils.annotations import OldAPIStack, override
- from ray.rllib.utils.error import ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL
- from ray.rllib.utils.framework import try_import_tf
- from ray.rllib.utils.metrics import (
- DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY,
- NUM_AGENT_STEPS_TRAINED,
- NUM_GRAD_UPDATES_LIFETIME,
- )
- from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
- from ray.rllib.utils.numpy import convert_to_numpy
- from ray.rllib.utils.spaces.space_utils import normalize_action
- from ray.rllib.utils.tf_utils import get_gpu_devices
- from ray.rllib.utils.threading import with_lock
- from ray.rllib.utils.typing import (
- LocalOptimizer,
- ModelGradients,
- TensorStructType,
- TensorType,
- )
- from ray.util.debug import log_once
- tf1, tf, tfv = try_import_tf()
- logger = logging.getLogger(__name__)
- def _convert_to_tf(x, dtype=None):
- if isinstance(x, SampleBatch):
- dict_ = {k: v for k, v in x.items() if k != SampleBatch.INFOS}
- return tree.map_structure(_convert_to_tf, dict_)
- elif isinstance(x, Policy):
- return x
- # Special handling of "Repeated" values.
- elif isinstance(x, RepeatedValues):
- return RepeatedValues(
- tree.map_structure(_convert_to_tf, x.values), x.lengths, x.max_len
- )
- if x is not None:
- d = dtype
- return tree.map_structure(
- lambda f: _convert_to_tf(f, d)
- if isinstance(f, RepeatedValues)
- else tf.convert_to_tensor(f, d)
- if f is not None and not tf.is_tensor(f)
- else f,
- x,
- )
- return x
- def _convert_to_numpy(x):
- def _map(x):
- if isinstance(x, tf.Tensor):
- return x.numpy()
- return x
- try:
- return tf.nest.map_structure(_map, x)
- except AttributeError:
- raise TypeError(
- ("Object of type {} has no method to convert to numpy.").format(type(x))
- )
- def _convert_eager_inputs(func):
- @functools.wraps(func)
- def _func(*args, **kwargs):
- if tf.executing_eagerly():
- eager_args = [_convert_to_tf(x) for x in args]
- # TODO: (sven) find a way to remove key-specific hacks.
- eager_kwargs = {
- k: _convert_to_tf(v, dtype=tf.int64 if k == "timestep" else None)
- for k, v in kwargs.items()
- if k not in {"info_batch", "episodes"}
- }
- return func(*eager_args, **eager_kwargs)
- else:
- return func(*args, **kwargs)
- return _func
- def _convert_eager_outputs(func):
- @functools.wraps(func)
- def _func(*args, **kwargs):
- out = func(*args, **kwargs)
- if tf.executing_eagerly():
- out = tf.nest.map_structure(_convert_to_numpy, out)
- return out
- return _func
- def _disallow_var_creation(next_creator, **kw):
- v = next_creator(**kw)
- raise ValueError(
- "Detected a variable being created during an eager "
- "forward pass. Variables should only be created during "
- "model initialization: {}".format(v.name)
- )
- def _check_too_many_retraces(obj):
- """Asserts that a given number of re-traces is not breached."""
- def _func(self_, *args, **kwargs):
- if (
- self_.config.get("eager_max_retraces") is not None
- and self_._re_trace_counter > self_.config["eager_max_retraces"]
- ):
- raise RuntimeError(
- "Too many tf-eager re-traces detected! This could lead to"
- " significant slow-downs (even slower than running in "
- "tf-eager mode w/ `eager_tracing=False`). To switch off "
- "these re-trace counting checks, set `eager_max_retraces`"
- " in your config to None."
- )
- return obj(self_, *args, **kwargs)
- return _func
- @OldAPIStack
- class EagerTFPolicy(Policy):
- """Dummy class to recognize any eagerized TFPolicy by its inheritance."""
- pass
- def _traced_eager_policy(eager_policy_cls):
- """Wrapper class that enables tracing for all eager policy methods.
- This is enabled by the `--trace`/`eager_tracing=True` config when
- framework=tf2.
- """
- class TracedEagerPolicy(eager_policy_cls):
- def __init__(self, *args, **kwargs):
- self._traced_learn_on_batch_helper = False
- self._traced_compute_actions_helper = False
- self._traced_compute_gradients_helper = False
- self._traced_apply_gradients_helper = False
- super(TracedEagerPolicy, self).__init__(*args, **kwargs)
- @_check_too_many_retraces
- @override(Policy)
- def compute_actions_from_input_dict(
- self,
- input_dict: Dict[str, TensorType],
- explore: bool = None,
- timestep: Optional[int] = None,
- episodes=None,
- **kwargs,
- ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
- """Traced version of Policy.compute_actions_from_input_dict."""
- # Create a traced version of `self._compute_actions_helper`.
- if self._traced_compute_actions_helper is False and not self._no_tracing:
- self._compute_actions_helper = _convert_eager_inputs(
- tf.function(
- super(TracedEagerPolicy, self)._compute_actions_helper,
- autograph=False,
- reduce_retracing=True,
- )
- )
- self._traced_compute_actions_helper = True
- # Now that the helper method is traced, call super's
- # `compute_actions_from_input_dict()` (which will call the traced helper).
- return super(TracedEagerPolicy, self).compute_actions_from_input_dict(
- input_dict=input_dict,
- explore=explore,
- timestep=timestep,
- episodes=episodes,
- **kwargs,
- )
- @_check_too_many_retraces
- @override(eager_policy_cls)
- def learn_on_batch(self, samples):
- """Traced version of Policy.learn_on_batch."""
- # Create a traced version of `self._learn_on_batch_helper`.
- if self._traced_learn_on_batch_helper is False and not self._no_tracing:
- self._learn_on_batch_helper = _convert_eager_inputs(
- tf.function(
- super(TracedEagerPolicy, self)._learn_on_batch_helper,
- autograph=False,
- reduce_retracing=True,
- )
- )
- self._traced_learn_on_batch_helper = True
- # Now that the helper method is traced, call super's
- # apply_gradients (which will call the traced helper).
- return super(TracedEagerPolicy, self).learn_on_batch(samples)
- @_check_too_many_retraces
- @override(eager_policy_cls)
- def compute_gradients(self, samples: SampleBatch) -> ModelGradients:
- """Traced version of Policy.compute_gradients."""
- # Create a traced version of `self._compute_gradients_helper`.
- if self._traced_compute_gradients_helper is False and not self._no_tracing:
- self._compute_gradients_helper = _convert_eager_inputs(
- tf.function(
- super(TracedEagerPolicy, self)._compute_gradients_helper,
- autograph=False,
- reduce_retracing=True,
- )
- )
- self._traced_compute_gradients_helper = True
- # Now that the helper method is traced, call super's
- # `compute_gradients()` (which will call the traced helper).
- return super(TracedEagerPolicy, self).compute_gradients(samples)
- @_check_too_many_retraces
- @override(Policy)
- def apply_gradients(self, grads: ModelGradients) -> None:
- """Traced version of Policy.apply_gradients."""
- # Create a traced version of `self._apply_gradients_helper`.
- if self._traced_apply_gradients_helper is False and not self._no_tracing:
- self._apply_gradients_helper = _convert_eager_inputs(
- tf.function(
- super(TracedEagerPolicy, self)._apply_gradients_helper,
- autograph=False,
- reduce_retracing=True,
- )
- )
- self._traced_apply_gradients_helper = True
- # Now that the helper method is traced, call super's
- # `apply_gradients()` (which will call the traced helper).
- return super(TracedEagerPolicy, self).apply_gradients(grads)
- @classmethod
- def with_tracing(cls):
- # Already traced -> Return same class.
- return cls
- TracedEagerPolicy.__name__ = eager_policy_cls.__name__ + "_traced"
- TracedEagerPolicy.__qualname__ = eager_policy_cls.__qualname__ + "_traced"
- return TracedEagerPolicy
- class _OptimizerWrapper:
- def __init__(self, tape):
- self.tape = tape
- def compute_gradients(self, loss, var_list):
- return list(zip(self.tape.gradient(loss, var_list), var_list))
- @OldAPIStack
- def _build_eager_tf_policy(
- name,
- loss_fn,
- get_default_config=None,
- postprocess_fn=None,
- stats_fn=None,
- optimizer_fn=None,
- compute_gradients_fn=None,
- apply_gradients_fn=None,
- grad_stats_fn=None,
- extra_learn_fetches_fn=None,
- extra_action_out_fn=None,
- validate_spaces=None,
- before_init=None,
- before_loss_init=None,
- after_init=None,
- make_model=None,
- action_sampler_fn=None,
- action_distribution_fn=None,
- mixins=None,
- get_batch_divisibility_req=None,
- # Deprecated args.
- obs_include_prev_action_reward=DEPRECATED_VALUE,
- extra_action_fetches_fn=None,
- gradients_fn=None,
- ):
- """Build an eager TF policy.
- An eager policy runs all operations in eager mode, which makes debugging
- much simpler, but has lower performance.
- You shouldn't need to call this directly. Rather, prefer to build a TF
- graph policy and use set `.framework("tf2", eager_tracing=False) in your
- AlgorithmConfig to have it automatically be converted to an eager policy.
- This has the same signature as build_tf_policy()."""
- base = add_mixins(EagerTFPolicy, mixins)
- if obs_include_prev_action_reward != DEPRECATED_VALUE:
- deprecation_warning(old="obs_include_prev_action_reward", error=True)
- if extra_action_fetches_fn is not None:
- deprecation_warning(
- old="extra_action_fetches_fn", new="extra_action_out_fn", error=True
- )
- if gradients_fn is not None:
- deprecation_warning(old="gradients_fn", new="compute_gradients_fn", error=True)
- class eager_policy_cls(base):
- def __init__(self, observation_space, action_space, config):
- # If this class runs as a @ray.remote actor, eager mode may not
- # have been activated yet.
- if not tf1.executing_eagerly():
- tf1.enable_eager_execution()
- self.framework = config.get("framework", "tf2")
- EagerTFPolicy.__init__(self, observation_space, action_space, config)
- # Global timestep should be a tensor.
- self.global_timestep = tf.Variable(0, trainable=False, dtype=tf.int64)
- self.explore = tf.Variable(
- self.config["explore"], trainable=False, dtype=tf.bool
- )
- # Log device and worker index.
- num_gpus = self._get_num_gpus_for_policy()
- if num_gpus > 0:
- gpu_ids = get_gpu_devices()
- logger.info(f"Found {len(gpu_ids)} visible cuda devices.")
- self._is_training = False
- # Only for `config.eager_tracing=True`: A counter to keep track of
- # how many times an eager-traced method (e.g.
- # `self._compute_actions_helper`) has been re-traced by tensorflow.
- # We will raise an error if more than n re-tracings have been
- # detected, since this would considerably slow down execution.
- # The variable below should only get incremented during the
- # tf.function trace operations, never when calling the already
- # traced function after that.
- self._re_trace_counter = 0
- self._loss_initialized = False
- # To ensure backward compatibility:
- # Old way: If `loss` provided here, use as-is (as a function).
- if loss_fn is not None:
- self._loss = loss_fn
- # New way: Convert the overridden `self.loss` into a plain
- # function, so it can be called the same way as `loss` would
- # be, ensuring backward compatibility.
- elif self.loss.__func__.__qualname__ != "Policy.loss":
- self._loss = self.loss.__func__
- # `loss` not provided nor overridden from Policy -> Set to None.
- else:
- self._loss = None
- self.batch_divisibility_req = (
- get_batch_divisibility_req(self)
- if callable(get_batch_divisibility_req)
- else (get_batch_divisibility_req or 1)
- )
- self._max_seq_len = config["model"]["max_seq_len"]
- if validate_spaces:
- validate_spaces(self, observation_space, action_space, config)
- if before_init:
- before_init(self, observation_space, action_space, config)
- self.config = config
- self.dist_class = None
- if action_sampler_fn or action_distribution_fn:
- if not make_model:
- raise ValueError(
- "`make_model` is required if `action_sampler_fn` OR "
- "`action_distribution_fn` is given"
- )
- else:
- self.dist_class, logit_dim = ModelCatalog.get_action_dist(
- action_space, self.config["model"]
- )
- if make_model:
- self.model = make_model(self, observation_space, action_space, config)
- else:
- self.model = ModelCatalog.get_model_v2(
- observation_space,
- action_space,
- logit_dim,
- config["model"],
- framework=self.framework,
- )
- # Lock used for locking some methods on the object-level.
- # This prevents possible race conditions when calling the model
- # first, then its value function (e.g. in a loss function), in
- # between of which another model call is made (e.g. to compute an
- # action).
- self._lock = threading.RLock()
- # Auto-update model's inference view requirements, if recurrent.
- self._update_model_view_requirements_from_init_state()
- # Combine view_requirements for Model and Policy.
- self.view_requirements.update(self.model.view_requirements)
- self.exploration = self._create_exploration()
- self._state_inputs = self.model.get_initial_state()
- self._is_recurrent = len(self._state_inputs) > 0
- if before_loss_init:
- before_loss_init(self, observation_space, action_space, config)
- if optimizer_fn:
- optimizers = optimizer_fn(self, config)
- else:
- optimizers = tf.keras.optimizers.Adam(config["lr"])
- optimizers = force_list(optimizers)
- if self.exploration:
- optimizers = self.exploration.get_exploration_optimizer(optimizers)
- # The list of local (tf) optimizers (one per loss term).
- self._optimizers: List[LocalOptimizer] = optimizers
- # Backward compatibility: A user's policy may only support a single
- # loss term and optimizer (no lists).
- self._optimizer: LocalOptimizer = optimizers[0] if optimizers else None
- self._initialize_loss_from_dummy_batch(
- auto_remove_unneeded_view_reqs=True,
- stats_fn=stats_fn,
- )
- self._loss_initialized = True
- if after_init:
- after_init(self, observation_space, action_space, config)
- # Got to reset global_timestep again after fake run-throughs.
- self.global_timestep.assign(0)
- @override(Policy)
- def compute_actions_from_input_dict(
- self,
- input_dict: Dict[str, TensorType],
- explore: bool = None,
- timestep: Optional[int] = None,
- episodes=None,
- **kwargs,
- ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
- if not self.config.get("eager_tracing") and not tf1.executing_eagerly():
- tf1.enable_eager_execution()
- self._is_training = False
- explore = explore if explore is not None else self.explore
- timestep = timestep if timestep is not None else self.global_timestep
- if isinstance(timestep, tf.Tensor):
- timestep = int(timestep.numpy())
- # Pass lazy (eager) tensor dict to Model as `input_dict`.
- input_dict = self._lazy_tensor_dict(input_dict)
- input_dict.set_training(False)
- # Pack internal state inputs into (separate) list.
- state_batches = [
- input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
- ]
- self._state_in = state_batches
- self._is_recurrent = state_batches != []
- # Call the exploration before_compute_actions hook.
- self.exploration.before_compute_actions(
- timestep=timestep, explore=explore, tf_sess=self.get_session()
- )
- ret = self._compute_actions_helper(
- input_dict,
- state_batches,
- # TODO: Passing episodes into a traced method does not work.
- None if self.config["eager_tracing"] else episodes,
- explore,
- timestep,
- )
- # Update our global timestep by the batch size.
- self.global_timestep.assign_add(tree.flatten(ret[0])[0].shape.as_list()[0])
- return convert_to_numpy(ret)
- @override(Policy)
- def compute_actions(
- self,
- obs_batch: Union[List[TensorStructType], TensorStructType],
- state_batches: Optional[List[TensorType]] = None,
- prev_action_batch: Union[List[TensorStructType], TensorStructType] = None,
- prev_reward_batch: Union[List[TensorStructType], TensorStructType] = None,
- info_batch: Optional[Dict[str, list]] = None,
- episodes: Optional[List] = None,
- explore: Optional[bool] = None,
- timestep: Optional[int] = None,
- **kwargs,
- ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
- # Create input dict to simply pass the entire call to
- # self.compute_actions_from_input_dict().
- input_dict = SampleBatch(
- {
- SampleBatch.CUR_OBS: obs_batch,
- },
- _is_training=tf.constant(False),
- )
- if state_batches is not None:
- for i, s in enumerate(state_batches):
- input_dict[f"state_in_{i}"] = s
- if prev_action_batch is not None:
- input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
- if prev_reward_batch is not None:
- input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
- if info_batch is not None:
- input_dict[SampleBatch.INFOS] = info_batch
- return self.compute_actions_from_input_dict(
- input_dict=input_dict,
- explore=explore,
- timestep=timestep,
- episodes=episodes,
- **kwargs,
- )
- @with_lock
- @override(Policy)
- def compute_log_likelihoods(
- self,
- actions,
- obs_batch,
- state_batches=None,
- prev_action_batch=None,
- prev_reward_batch=None,
- actions_normalized=True,
- **kwargs,
- ):
- if action_sampler_fn and action_distribution_fn is None:
- raise ValueError(
- "Cannot compute log-prob/likelihood w/o an "
- "`action_distribution_fn` and a provided "
- "`action_sampler_fn`!"
- )
- seq_lens = tf.ones(len(obs_batch), dtype=tf.int32)
- input_batch = SampleBatch(
- {SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch)},
- _is_training=False,
- )
- if prev_action_batch is not None:
- input_batch[SampleBatch.PREV_ACTIONS] = tf.convert_to_tensor(
- prev_action_batch
- )
- if prev_reward_batch is not None:
- input_batch[SampleBatch.PREV_REWARDS] = tf.convert_to_tensor(
- prev_reward_batch
- )
- if self.exploration:
- # Exploration hook before each forward pass.
- self.exploration.before_compute_actions(explore=False)
- # Action dist class and inputs are generated via custom function.
- if action_distribution_fn:
- dist_inputs, dist_class, _ = action_distribution_fn(
- self, self.model, input_batch, explore=False, is_training=False
- )
- # Default log-likelihood calculation.
- else:
- dist_inputs, _ = self.model(input_batch, state_batches, seq_lens)
- dist_class = self.dist_class
- action_dist = dist_class(dist_inputs, self.model)
- # Normalize actions if necessary.
- if not actions_normalized and self.config["normalize_actions"]:
- actions = normalize_action(actions, self.action_space_struct)
- log_likelihoods = action_dist.logp(actions)
- return log_likelihoods
- @override(Policy)
- def postprocess_trajectory(
- self, sample_batch, other_agent_batches=None, episode=None
- ):
- assert tf.executing_eagerly()
- # Call super's postprocess_trajectory first.
- sample_batch = EagerTFPolicy.postprocess_trajectory(self, sample_batch)
- if postprocess_fn:
- return postprocess_fn(self, sample_batch, other_agent_batches, episode)
- return sample_batch
- @with_lock
- @override(Policy)
- def learn_on_batch(self, postprocessed_batch):
- # Callback handling.
- learn_stats = {}
- self.callbacks.on_learn_on_batch(
- policy=self, train_batch=postprocessed_batch, result=learn_stats
- )
- pad_batch_to_sequences_of_same_size(
- postprocessed_batch,
- max_seq_len=self._max_seq_len,
- shuffle=False,
- batch_divisibility_req=self.batch_divisibility_req,
- view_requirements=self.view_requirements,
- )
- self._is_training = True
- postprocessed_batch = self._lazy_tensor_dict(postprocessed_batch)
- postprocessed_batch.set_training(True)
- stats = self._learn_on_batch_helper(postprocessed_batch)
- self.num_grad_updates += 1
- stats.update(
- {
- "custom_metrics": learn_stats,
- NUM_AGENT_STEPS_TRAINED: postprocessed_batch.count,
- NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates,
- # -1, b/c we have to measure this diff before we do the update
- # above.
- DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: (
- self.num_grad_updates
- - 1
- - (postprocessed_batch.num_grad_updates or 0)
- ),
- }
- )
- return convert_to_numpy(stats)
- @override(Policy)
- def compute_gradients(
- self, postprocessed_batch: SampleBatch
- ) -> Tuple[ModelGradients, Dict[str, TensorType]]:
- pad_batch_to_sequences_of_same_size(
- postprocessed_batch,
- shuffle=False,
- max_seq_len=self._max_seq_len,
- batch_divisibility_req=self.batch_divisibility_req,
- view_requirements=self.view_requirements,
- )
- self._is_training = True
- self._lazy_tensor_dict(postprocessed_batch)
- postprocessed_batch.set_training(True)
- grads_and_vars, grads, stats = self._compute_gradients_helper(
- postprocessed_batch
- )
- return convert_to_numpy((grads, stats))
- @override(Policy)
- def apply_gradients(self, gradients: ModelGradients) -> None:
- self._apply_gradients_helper(
- list(
- zip(
- [
- (tf.convert_to_tensor(g) if g is not None else None)
- for g in gradients
- ],
- self.model.trainable_variables(),
- )
- )
- )
- @override(Policy)
- def get_weights(self, as_dict=False):
- variables = self.variables()
- if as_dict:
- return {v.name: v.numpy() for v in variables}
- return [v.numpy() for v in variables]
- @override(Policy)
- def set_weights(self, weights):
- variables = self.variables()
- assert len(weights) == len(variables), (len(weights), len(variables))
- for v, w in zip(variables, weights):
- v.assign(w)
- @override(Policy)
- def get_exploration_state(self):
- return convert_to_numpy(self.exploration.get_state())
- @override(Policy)
- def is_recurrent(self):
- return self._is_recurrent
- @override(Policy)
- def num_state_tensors(self):
- return len(self._state_inputs)
- @override(Policy)
- def get_initial_state(self):
- if hasattr(self, "model"):
- return self.model.get_initial_state()
- return []
- @override(Policy)
- def get_state(self) -> PolicyState:
- # Legacy Policy state (w/o keras model and w/o PolicySpec).
- state = super().get_state()
- state["global_timestep"] = state["global_timestep"].numpy()
- if self._optimizer and len(self._optimizer.variables()) > 0:
- state["_optimizer_variables"] = self._optimizer.variables()
- # Add exploration state.
- if self.exploration:
- # This is not compatible with RLModules, which have a method
- # `forward_exploration` to specify custom exploration behavior.
- state["_exploration_state"] = self.exploration.get_state()
- return state
- @override(Policy)
- def set_state(self, state: PolicyState) -> None:
- # Set optimizer vars first.
- optimizer_vars = state.get("_optimizer_variables", None)
- if optimizer_vars and self._optimizer.variables():
- if not type(self).__name__.endswith("_traced") and log_once(
- "set_state_optimizer_vars_tf_eager_policy_v2"
- ):
- logger.warning(
- "Cannot restore an optimizer's state for tf eager! Keras "
- "is not able to save the v1.x optimizers (from "
- "tf.compat.v1.train) since they aren't compatible with "
- "checkpoints."
- )
- for opt_var, value in zip(self._optimizer.variables(), optimizer_vars):
- opt_var.assign(value)
- # Set exploration's state.
- if hasattr(self, "exploration") and "_exploration_state" in state:
- self.exploration.set_state(state=state["_exploration_state"])
- # Restore glbal timestep (tf vars).
- self.global_timestep.assign(state["global_timestep"])
- # Then the Policy's (NN) weights and connectors.
- super().set_state(state)
- @override(Policy)
- def export_model(self, export_dir, onnx: Optional[int] = None) -> None:
- """Exports the Policy's Model to local directory for serving.
- Note: Since the TfModelV2 class that EagerTfPolicy uses is-NOT-a
- tf.keras.Model, we need to assume that there is a `base_model` property
- within this TfModelV2 class that is-a tf.keras.Model. This base model
- will be used here for the export.
- TODO (kourosh): This restriction will be resolved once we move Policy and
- ModelV2 to the new Learner/RLModule APIs.
- Args:
- export_dir: Local writable directory.
- onnx: If given, will export model in ONNX format. The
- value of this parameter set the ONNX OpSet version to use.
- """
- if (
- hasattr(self, "model")
- and hasattr(self.model, "base_model")
- and isinstance(self.model.base_model, tf.keras.Model)
- ):
- # Store model in ONNX format.
- if onnx:
- try:
- import tf2onnx
- except ImportError as e:
- raise RuntimeError(
- "Converting a TensorFlow model to ONNX requires "
- "`tf2onnx` to be installed. Install with "
- "`pip install tf2onnx`."
- ) from e
- model_proto, external_tensor_storage = tf2onnx.convert.from_keras(
- self.model.base_model,
- output_path=os.path.join(export_dir, "model.onnx"),
- )
- # Save the tf.keras.Model (architecture and weights, so it can be
- # retrieved w/o access to the original (custom) Model or Policy code).
- else:
- try:
- self.model.base_model.save(export_dir, save_format="tf")
- except Exception:
- logger.warning(ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL)
- else:
- logger.warning(ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL)
- def variables(self):
- """Return the list of all savable variables for this policy."""
- if isinstance(self.model, tf.keras.Model):
- return self.model.variables
- else:
- return self.model.variables()
- def loss_initialized(self):
- return self._loss_initialized
- @with_lock
- def _compute_actions_helper(
- self, input_dict, state_batches, episodes, explore, timestep
- ):
- # Increase the tracing counter to make sure we don't re-trace too
- # often. If eager_tracing=True, this counter should only get
- # incremented during the @tf.function trace operations, never when
- # calling the already traced function after that.
- self._re_trace_counter += 1
- # Calculate RNN sequence lengths.
- batch_size = tree.flatten(input_dict[SampleBatch.OBS])[0].shape[0]
- seq_lens = tf.ones(batch_size, dtype=tf.int32) if state_batches else None
- # Add default and custom fetches.
- extra_fetches = {}
- # Use Exploration object.
- with tf.variable_creator_scope(_disallow_var_creation):
- if action_sampler_fn:
- action_sampler_outputs = action_sampler_fn(
- self,
- self.model,
- input_dict[SampleBatch.CUR_OBS],
- explore=explore,
- timestep=timestep,
- episodes=episodes,
- )
- if len(action_sampler_outputs) == 4:
- actions, logp, dist_inputs, state_out = action_sampler_outputs
- else:
- dist_inputs = None
- state_out = []
- actions, logp = action_sampler_outputs
- else:
- if action_distribution_fn:
- # Try new action_distribution_fn signature, supporting
- # state_batches and seq_lens.
- try:
- (
- dist_inputs,
- self.dist_class,
- state_out,
- ) = action_distribution_fn(
- self,
- self.model,
- input_dict=input_dict,
- state_batches=state_batches,
- seq_lens=seq_lens,
- explore=explore,
- timestep=timestep,
- is_training=False,
- )
- # Trying the old way (to stay backward compatible).
- # TODO: Remove in future.
- except TypeError as e:
- if (
- "positional argument" in e.args[0]
- or "unexpected keyword argument" in e.args[0]
- ):
- (
- dist_inputs,
- self.dist_class,
- state_out,
- ) = action_distribution_fn(
- self,
- self.model,
- input_dict[SampleBatch.OBS],
- explore=explore,
- timestep=timestep,
- is_training=False,
- )
- else:
- raise e
- elif isinstance(self.model, tf.keras.Model):
- input_dict = SampleBatch(input_dict, seq_lens=seq_lens)
- if state_batches and "state_in_0" not in input_dict:
- for i, s in enumerate(state_batches):
- input_dict[f"state_in_{i}"] = s
- self._lazy_tensor_dict(input_dict)
- dist_inputs, state_out, extra_fetches = self.model(input_dict)
- else:
- dist_inputs, state_out = self.model(
- input_dict, state_batches, seq_lens
- )
- action_dist = self.dist_class(dist_inputs, self.model)
- # Get the exploration action from the forward results.
- actions, logp = self.exploration.get_exploration_action(
- action_distribution=action_dist,
- timestep=timestep,
- explore=explore,
- )
- # Action-logp and action-prob.
- if logp is not None:
- extra_fetches[SampleBatch.ACTION_PROB] = tf.exp(logp)
- extra_fetches[SampleBatch.ACTION_LOGP] = logp
- # Action-dist inputs.
- if dist_inputs is not None:
- extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs
- # Custom extra fetches.
- if extra_action_out_fn:
- extra_fetches.update(extra_action_out_fn(self))
- return actions, state_out, extra_fetches
- # TODO: Figure out, why _ray_trace_ctx=None helps to prevent a crash in
- # AlphaStar w/ framework=tf2; eager_tracing=True on the policy learner actors.
- # It seems there may be a clash between the traced-by-tf function and the
- # traced-by-ray functions (for making the policy class a ray actor).
- def _learn_on_batch_helper(self, samples, _ray_trace_ctx=None):
- # Increase the tracing counter to make sure we don't re-trace too
- # often. If eager_tracing=True, this counter should only get
- # incremented during the @tf.function trace operations, never when
- # calling the already traced function after that.
- self._re_trace_counter += 1
- with tf.variable_creator_scope(_disallow_var_creation):
- grads_and_vars, _, stats = self._compute_gradients_helper(samples)
- self._apply_gradients_helper(grads_and_vars)
- return stats
- def _get_is_training_placeholder(self):
- return tf.convert_to_tensor(self._is_training)
- @with_lock
- def _compute_gradients_helper(self, samples):
- """Computes and returns grads as eager tensors."""
- # Increase the tracing counter to make sure we don't re-trace too
- # often. If eager_tracing=True, this counter should only get
- # incremented during the @tf.function trace operations, never when
- # calling the already traced function after that.
- self._re_trace_counter += 1
- # Gather all variables for which to calculate losses.
- if isinstance(self.model, tf.keras.Model):
- variables = self.model.trainable_variables
- else:
- variables = self.model.trainable_variables()
- # Calculate the loss(es) inside a tf GradientTape.
- with tf.GradientTape(persistent=compute_gradients_fn is not None) as tape:
- losses = self._loss(self, self.model, self.dist_class, samples)
- losses = force_list(losses)
- # User provided a compute_gradients_fn.
- if compute_gradients_fn:
- # Wrap our tape inside a wrapper, such that the resulting
- # object looks like a "classic" tf.optimizer. This way, custom
- # compute_gradients_fn will work on both tf static graph
- # and tf-eager.
- optimizer = _OptimizerWrapper(tape)
- # More than one loss terms/optimizers.
- if self.config["_tf_policy_handles_more_than_one_loss"]:
- grads_and_vars = compute_gradients_fn(
- self, [optimizer] * len(losses), losses
- )
- # Only one loss and one optimizer.
- else:
- grads_and_vars = [compute_gradients_fn(self, optimizer, losses[0])]
- # Default: Compute gradients using the above tape.
- else:
- grads_and_vars = [
- list(zip(tape.gradient(loss, variables), variables))
- for loss in losses
- ]
- if log_once("grad_vars"):
- for g_and_v in grads_and_vars:
- for g, v in g_and_v:
- if g is not None:
- logger.info(f"Optimizing variable {v.name}")
- # `grads_and_vars` is returned a list (len=num optimizers/losses)
- # of lists of (grad, var) tuples.
- if self.config["_tf_policy_handles_more_than_one_loss"]:
- grads = [[g for g, _ in g_and_v] for g_and_v in grads_and_vars]
- # `grads_and_vars` is returned as a list of (grad, var) tuples.
- else:
- grads_and_vars = grads_and_vars[0]
- grads = [g for g, _ in grads_and_vars]
- stats = self._stats(self, samples, grads)
- return grads_and_vars, grads, stats
- def _apply_gradients_helper(self, grads_and_vars):
- # Increase the tracing counter to make sure we don't re-trace too
- # often. If eager_tracing=True, this counter should only get
- # incremented during the @tf.function trace operations, never when
- # calling the already traced function after that.
- self._re_trace_counter += 1
- if apply_gradients_fn:
- if self.config["_tf_policy_handles_more_than_one_loss"]:
- apply_gradients_fn(self, self._optimizers, grads_and_vars)
- else:
- apply_gradients_fn(self, self._optimizer, grads_and_vars)
- else:
- if self.config["_tf_policy_handles_more_than_one_loss"]:
- for i, o in enumerate(self._optimizers):
- o.apply_gradients(
- [(g, v) for g, v in grads_and_vars[i] if g is not None]
- )
- else:
- self._optimizer.apply_gradients(
- [(g, v) for g, v in grads_and_vars if g is not None]
- )
- def _stats(self, outputs, samples, grads):
- fetches = {}
- if stats_fn:
- fetches[LEARNER_STATS_KEY] = dict(stats_fn(outputs, samples))
- else:
- fetches[LEARNER_STATS_KEY] = {}
- if extra_learn_fetches_fn:
- fetches.update(dict(extra_learn_fetches_fn(self)))
- if grad_stats_fn:
- fetches.update(dict(grad_stats_fn(self, samples, grads)))
- return fetches
- def _lazy_tensor_dict(self, postprocessed_batch: SampleBatch):
- # TODO: (sven): Keep for a while to ensure backward compatibility.
- if not isinstance(postprocessed_batch, SampleBatch):
- postprocessed_batch = SampleBatch(postprocessed_batch)
- postprocessed_batch.set_get_interceptor(_convert_to_tf)
- return postprocessed_batch
- @classmethod
- def with_tracing(cls):
- return _traced_eager_policy(cls)
- eager_policy_cls.__name__ = name + "_eager"
- eager_policy_cls.__qualname__ = name + "_eager"
- return eager_policy_cls
|