| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409 |
- import logging
- import os
- import sys
- from typing import TYPE_CHECKING, Any, Optional
- import numpy as np
- import tree # pip install dm_tree
- import ray
- from ray._common.deprecation import Deprecated
- from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
- from ray.rllib.utils.typing import (
- TensorShape,
- TensorStructType,
- TensorType,
- )
- if TYPE_CHECKING:
- from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
- logger = logging.getLogger(__name__)
- @PublicAPI
- def convert_to_tensor(
- data: TensorStructType,
- framework: str,
- device: Optional[str] = None,
- ):
- """Converts any nested numpy struct into framework-specific tensors.
- Args:
- data: The input data (numpy) to convert to framework-specific tensors.
- framework: The framework to convert to. Only "torch" and "tf2" allowed.
- device: An optional device name (for torch only).
- Returns:
- The converted tensor struct matching the input data.
- """
- if framework == "torch":
- from ray.rllib.utils.torch_utils import convert_to_torch_tensor
- return convert_to_torch_tensor(data, device=device)
- elif framework == "tf2":
- _, tf, _ = try_import_tf()
- return tree.map_structure(lambda s: tf.convert_to_tensor(s), data)
- raise NotImplementedError(
- f"framework={framework} not supported in `convert_to_tensor()`!"
- )
- @PublicAPI
- def get_device(config: "AlgorithmConfig", num_gpus_requested: int = 1):
- """Returns a single device (CPU or some GPU) depending on a config.
- Args:
- config: An AlgorithmConfig to extract information from about the device to use.
- num_gpus_requested: The number of GPUs actually requested. This may be the value
- of `config.num_gpus_per_env_runner` when for example calling this function
- from an EnvRunner.
- Returns:
- A single device (or name) given `config` and `num_gpus_requested`.
- """
- if config.framework_str == "torch":
- torch, _ = try_import_torch()
- # TODO (Kourosh): How do we handle model parallelism?
- # TODO (Kourosh): Instead of using _TorchAccelerator, we should use the public
- # API in ray.train but allow for session to be None without any errors raised.
- if num_gpus_requested > 0:
- from ray.air._internal.torch_utils import get_devices
- # `get_devices()` returns a list that contains the 0th device if
- # it is called from outside a Ray Train session. It's necessary to give
- # the user the option to run on the gpu of their choice, so we enable that
- # option here through `config.local_gpu_idx`.
- devices = get_devices()
- # Note, if we have a single learner and we do not run on Ray Tune, the local
- # learner is not an Ray actor and Ray does not manage devices for it.
- if (
- len(devices) == 1
- and ray._private.worker._mode() == ray._private.worker.WORKER_MODE
- ):
- return devices[0]
- else:
- assert config.local_gpu_idx < torch.cuda.device_count(), (
- f"local_gpu_idx {config.local_gpu_idx} is not a valid GPU ID "
- "or is not available."
- )
- # This is an index into the available CUDA devices. For example, if
- # `os.environ["CUDA_VISIBLE_DEVICES"] = "1"` then
- # `torch.cuda.device_count() = 1` and torch.device(0) maps to that GPU
- # with ID=1 on the node.
- return torch.device(config.local_gpu_idx)
- else:
- return torch.device("cpu")
- else:
- raise NotImplementedError(
- f"`framework_str` {config.framework_str} not supported!"
- )
- @PublicAPI
- def try_import_jax(error: bool = False):
- """Tries importing JAX and FLAX and returns both modules (or Nones).
- Args:
- error: Whether to raise an error if JAX/FLAX cannot be imported.
- Returns:
- Tuple containing the jax- and the flax modules.
- Raises:
- ImportError: If error=True and JAX is not installed.
- """
- if "RLLIB_TEST_NO_JAX_IMPORT" in os.environ:
- logger.warning("Not importing JAX for test purposes.")
- return None, None
- try:
- import flax
- import jax
- except ImportError:
- if error:
- raise ImportError(
- "Could not import JAX! RLlib requires you to "
- "install at least one deep-learning framework: "
- "`pip install [torch|tensorflow|jax]`."
- )
- return None, None
- return jax, flax
- @PublicAPI
- def try_import_tf(error: bool = False):
- """Tries importing tf and returns the module (or None).
- Args:
- error: Whether to raise an error if tf cannot be imported.
- Returns:
- Tuple containing
- 1) tf1.x module (either from tf2.x.compat.v1 OR as tf1.x).
- 2) tf module (resulting from `import tensorflow`). Either tf1.x or
- 2.x. 3) The actually installed tf version as int: 1 or 2.
- Raises:
- ImportError: If error=True and tf is not installed.
- """
- tf_stub = _TFStub()
- # Make sure, these are reset after each test case
- # that uses them: del os.environ["RLLIB_TEST_NO_TF_IMPORT"]
- if "RLLIB_TEST_NO_TF_IMPORT" in os.environ:
- logger.warning("Not importing TensorFlow for test purposes")
- return None, tf_stub, None
- if "TF_CPP_MIN_LOG_LEVEL" not in os.environ:
- os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
- # Try to reuse already imported tf module. This will avoid going through
- # the initial import steps below and thereby switching off v2_behavior
- # (switching off v2 behavior twice breaks all-framework tests for eager).
- was_imported = False
- if "tensorflow" in sys.modules:
- tf_module = sys.modules["tensorflow"]
- was_imported = True
- else:
- try:
- import tensorflow as tf_module
- except ImportError:
- if error:
- raise ImportError(
- "Could not import TensorFlow! RLlib requires you to "
- "install at least one deep-learning framework: "
- "`pip install [torch|tensorflow|jax]`."
- )
- return None, tf_stub, None
- # Try "reducing" tf to tf.compat.v1.
- try:
- tf1_module = tf_module.compat.v1
- tf1_module.logging.set_verbosity(tf1_module.logging.ERROR)
- if not was_imported:
- tf1_module.disable_v2_behavior()
- tf1_module.enable_resource_variables()
- tf1_module.logging.set_verbosity(tf1_module.logging.WARN)
- # No compat.v1 -> return tf as is.
- except AttributeError:
- tf1_module = tf_module
- if not hasattr(tf_module, "__version__"):
- version = 1 # sphinx doc gen
- else:
- version = 2 if "2." in tf_module.__version__[:2] else 1
- return tf1_module, tf_module, version
- # Fake module for tf.
- class _TFStub:
- def __init__(self) -> None:
- self.keras = _KerasStub()
- def __bool__(self):
- # if tf should return False
- return False
- # Fake module for tf.keras.
- class _KerasStub:
- def __init__(self) -> None:
- self.Model = _FakeTfClassStub
- # Fake classes under keras (e.g for tf.keras.Model)
- class _FakeTfClassStub:
- def __init__(self, *a, **kw):
- raise ImportError("Could not import `tensorflow`. Try pip install tensorflow.")
- @DeveloperAPI
- def tf_function(tf_module):
- """Conditional decorator for @tf.function.
- Use @tf_function(tf) instead to avoid errors if tf is not installed."""
- # The actual decorator to use (pass in `tf` (which could be None)).
- def decorator(func):
- # If tf not installed -> return function as is (won't be used anyways).
- if tf_module is None or tf_module.executing_eagerly():
- return func
- # If tf installed, return @tf.function-decorated function.
- return tf_module.function(func)
- return decorator
- @PublicAPI
- def try_import_tfp(error: bool = False):
- """Tries importing tfp and returns the module (or None).
- Args:
- error: Whether to raise an error if tfp cannot be imported.
- Returns:
- The tfp module.
- Raises:
- ImportError: If error=True and tfp is not installed.
- """
- if "RLLIB_TEST_NO_TF_IMPORT" in os.environ:
- logger.warning("Not importing TensorFlow Probability for test purposes.")
- return None
- try:
- import tensorflow_probability as tfp
- return tfp
- except ImportError as e:
- if error:
- raise e
- return None
- # Fake module for torch.nn.
- class _NNStub:
- def __init__(self, *a, **kw):
- # Fake nn.functional module within torch.nn.
- self.functional = None
- self.Module = _FakeTorchClassStub
- self.parallel = _ParallelStub()
- # Fake class for e.g. torch.nn.Module to allow it to be inherited from.
- class _FakeTorchClassStub:
- def __init__(self, *a, **kw):
- raise ImportError("Could not import `torch`. Try pip install torch.")
- class _ParallelStub:
- def __init__(self, *a, **kw):
- self.DataParallel = _FakeTorchClassStub
- self.DistributedDataParallel = _FakeTorchClassStub
- @PublicAPI
- def try_import_torch(error: bool = False):
- """Tries importing torch and returns the module (or None).
- Args:
- error: Whether to raise an error if torch cannot be imported.
- Returns:
- Tuple consisting of the torch- AND torch.nn modules.
- Raises:
- ImportError: If error=True and PyTorch is not installed.
- """
- if "RLLIB_TEST_NO_TORCH_IMPORT" in os.environ:
- logger.warning("Not importing PyTorch for test purposes.")
- return _torch_stubs()
- try:
- import torch
- import torch.nn as nn
- return torch, nn
- except ImportError:
- if error:
- raise ImportError(
- "Could not import PyTorch! RLlib requires you to "
- "install at least one deep-learning framework: "
- "`pip install [torch|tensorflow|jax]`."
- )
- return _torch_stubs()
- def _torch_stubs():
- nn = _NNStub()
- return None, nn
- @DeveloperAPI
- def get_variable(
- value: Any,
- framework: str = "tf",
- trainable: bool = False,
- tf_name: str = "unnamed-variable",
- torch_tensor: bool = False,
- device: Optional[str] = None,
- shape: Optional[TensorShape] = None,
- dtype: Optional[TensorType] = None,
- ) -> Any:
- """Creates a tf variable, a torch tensor, or a python primitive.
- Args:
- value: The initial value to use. In the non-tf case, this will
- be returned as is. In the tf case, this could be a tf-Initializer
- object.
- framework: One of "tf", "torch", or None.
- trainable: Whether the generated variable should be
- trainable (tf)/require_grad (torch) or not (default: False).
- tf_name: For framework="tf": An optional name for the
- tf.Variable.
- torch_tensor: For framework="torch": Whether to actually create
- a torch.tensor, or just a python value (default).
- device: An optional torch device to use for
- the created torch tensor.
- shape: An optional shape to use iff `value`
- does not have any (e.g. if it's an initializer w/o explicit value).
- dtype: An optional dtype to use iff `value` does
- not have any (e.g. if it's an initializer w/o explicit value).
- This should always be a numpy dtype (e.g. np.float32, np.int64).
- Returns:
- A framework-specific variable (tf.Variable, torch.tensor, or
- python primitive).
- """
- if framework in ["tf2", "tf"]:
- import tensorflow as tf
- dtype = dtype or getattr(
- value,
- "dtype",
- tf.float32
- if isinstance(value, float)
- else tf.int32
- if isinstance(value, int)
- else None,
- )
- return tf.compat.v1.get_variable(
- tf_name,
- initializer=value,
- dtype=dtype,
- trainable=trainable,
- **({} if shape is None else {"shape": shape}),
- )
- elif framework == "torch" and torch_tensor is True:
- torch, _ = try_import_torch()
- if not isinstance(value, np.ndarray):
- value = np.array(value)
- var_ = torch.from_numpy(value)
- if dtype in [torch.float32, np.float32]:
- var_ = var_.float()
- elif dtype in [torch.int32, np.int32]:
- var_ = var_.int()
- elif dtype in [torch.float64, np.float64]:
- var_ = var_.double()
- if device:
- var_ = var_.to(device)
- var_.requires_grad = trainable
- return var_
- # torch or None: Return python primitive.
- return value
- @DeveloperAPI
- @Deprecated(
- old="rllib/utils/framework.py::get_activation_fn",
- new="rllib/models/utils.py::get_activation_fn",
- error=True,
- )
- def get_activation_fn(name: Optional[str] = None, framework: str = "tf"):
- pass
|