framework.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409
  1. import logging
  2. import os
  3. import sys
  4. from typing import TYPE_CHECKING, Any, Optional
  5. import numpy as np
  6. import tree # pip install dm_tree
  7. import ray
  8. from ray._common.deprecation import Deprecated
  9. from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
  10. from ray.rllib.utils.typing import (
  11. TensorShape,
  12. TensorStructType,
  13. TensorType,
  14. )
  15. if TYPE_CHECKING:
  16. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
  17. logger = logging.getLogger(__name__)
  18. @PublicAPI
  19. def convert_to_tensor(
  20. data: TensorStructType,
  21. framework: str,
  22. device: Optional[str] = None,
  23. ):
  24. """Converts any nested numpy struct into framework-specific tensors.
  25. Args:
  26. data: The input data (numpy) to convert to framework-specific tensors.
  27. framework: The framework to convert to. Only "torch" and "tf2" allowed.
  28. device: An optional device name (for torch only).
  29. Returns:
  30. The converted tensor struct matching the input data.
  31. """
  32. if framework == "torch":
  33. from ray.rllib.utils.torch_utils import convert_to_torch_tensor
  34. return convert_to_torch_tensor(data, device=device)
  35. elif framework == "tf2":
  36. _, tf, _ = try_import_tf()
  37. return tree.map_structure(lambda s: tf.convert_to_tensor(s), data)
  38. raise NotImplementedError(
  39. f"framework={framework} not supported in `convert_to_tensor()`!"
  40. )
  41. @PublicAPI
  42. def get_device(config: "AlgorithmConfig", num_gpus_requested: int = 1):
  43. """Returns a single device (CPU or some GPU) depending on a config.
  44. Args:
  45. config: An AlgorithmConfig to extract information from about the device to use.
  46. num_gpus_requested: The number of GPUs actually requested. This may be the value
  47. of `config.num_gpus_per_env_runner` when for example calling this function
  48. from an EnvRunner.
  49. Returns:
  50. A single device (or name) given `config` and `num_gpus_requested`.
  51. """
  52. if config.framework_str == "torch":
  53. torch, _ = try_import_torch()
  54. # TODO (Kourosh): How do we handle model parallelism?
  55. # TODO (Kourosh): Instead of using _TorchAccelerator, we should use the public
  56. # API in ray.train but allow for session to be None without any errors raised.
  57. if num_gpus_requested > 0:
  58. from ray.air._internal.torch_utils import get_devices
  59. # `get_devices()` returns a list that contains the 0th device if
  60. # it is called from outside a Ray Train session. It's necessary to give
  61. # the user the option to run on the gpu of their choice, so we enable that
  62. # option here through `config.local_gpu_idx`.
  63. devices = get_devices()
  64. # Note, if we have a single learner and we do not run on Ray Tune, the local
  65. # learner is not an Ray actor and Ray does not manage devices for it.
  66. if (
  67. len(devices) == 1
  68. and ray._private.worker._mode() == ray._private.worker.WORKER_MODE
  69. ):
  70. return devices[0]
  71. else:
  72. assert config.local_gpu_idx < torch.cuda.device_count(), (
  73. f"local_gpu_idx {config.local_gpu_idx} is not a valid GPU ID "
  74. "or is not available."
  75. )
  76. # This is an index into the available CUDA devices. For example, if
  77. # `os.environ["CUDA_VISIBLE_DEVICES"] = "1"` then
  78. # `torch.cuda.device_count() = 1` and torch.device(0) maps to that GPU
  79. # with ID=1 on the node.
  80. return torch.device(config.local_gpu_idx)
  81. else:
  82. return torch.device("cpu")
  83. else:
  84. raise NotImplementedError(
  85. f"`framework_str` {config.framework_str} not supported!"
  86. )
  87. @PublicAPI
  88. def try_import_jax(error: bool = False):
  89. """Tries importing JAX and FLAX and returns both modules (or Nones).
  90. Args:
  91. error: Whether to raise an error if JAX/FLAX cannot be imported.
  92. Returns:
  93. Tuple containing the jax- and the flax modules.
  94. Raises:
  95. ImportError: If error=True and JAX is not installed.
  96. """
  97. if "RLLIB_TEST_NO_JAX_IMPORT" in os.environ:
  98. logger.warning("Not importing JAX for test purposes.")
  99. return None, None
  100. try:
  101. import flax
  102. import jax
  103. except ImportError:
  104. if error:
  105. raise ImportError(
  106. "Could not import JAX! RLlib requires you to "
  107. "install at least one deep-learning framework: "
  108. "`pip install [torch|tensorflow|jax]`."
  109. )
  110. return None, None
  111. return jax, flax
  112. @PublicAPI
  113. def try_import_tf(error: bool = False):
  114. """Tries importing tf and returns the module (or None).
  115. Args:
  116. error: Whether to raise an error if tf cannot be imported.
  117. Returns:
  118. Tuple containing
  119. 1) tf1.x module (either from tf2.x.compat.v1 OR as tf1.x).
  120. 2) tf module (resulting from `import tensorflow`). Either tf1.x or
  121. 2.x. 3) The actually installed tf version as int: 1 or 2.
  122. Raises:
  123. ImportError: If error=True and tf is not installed.
  124. """
  125. tf_stub = _TFStub()
  126. # Make sure, these are reset after each test case
  127. # that uses them: del os.environ["RLLIB_TEST_NO_TF_IMPORT"]
  128. if "RLLIB_TEST_NO_TF_IMPORT" in os.environ:
  129. logger.warning("Not importing TensorFlow for test purposes")
  130. return None, tf_stub, None
  131. if "TF_CPP_MIN_LOG_LEVEL" not in os.environ:
  132. os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
  133. # Try to reuse already imported tf module. This will avoid going through
  134. # the initial import steps below and thereby switching off v2_behavior
  135. # (switching off v2 behavior twice breaks all-framework tests for eager).
  136. was_imported = False
  137. if "tensorflow" in sys.modules:
  138. tf_module = sys.modules["tensorflow"]
  139. was_imported = True
  140. else:
  141. try:
  142. import tensorflow as tf_module
  143. except ImportError:
  144. if error:
  145. raise ImportError(
  146. "Could not import TensorFlow! RLlib requires you to "
  147. "install at least one deep-learning framework: "
  148. "`pip install [torch|tensorflow|jax]`."
  149. )
  150. return None, tf_stub, None
  151. # Try "reducing" tf to tf.compat.v1.
  152. try:
  153. tf1_module = tf_module.compat.v1
  154. tf1_module.logging.set_verbosity(tf1_module.logging.ERROR)
  155. if not was_imported:
  156. tf1_module.disable_v2_behavior()
  157. tf1_module.enable_resource_variables()
  158. tf1_module.logging.set_verbosity(tf1_module.logging.WARN)
  159. # No compat.v1 -> return tf as is.
  160. except AttributeError:
  161. tf1_module = tf_module
  162. if not hasattr(tf_module, "__version__"):
  163. version = 1 # sphinx doc gen
  164. else:
  165. version = 2 if "2." in tf_module.__version__[:2] else 1
  166. return tf1_module, tf_module, version
  167. # Fake module for tf.
  168. class _TFStub:
  169. def __init__(self) -> None:
  170. self.keras = _KerasStub()
  171. def __bool__(self):
  172. # if tf should return False
  173. return False
  174. # Fake module for tf.keras.
  175. class _KerasStub:
  176. def __init__(self) -> None:
  177. self.Model = _FakeTfClassStub
  178. # Fake classes under keras (e.g for tf.keras.Model)
  179. class _FakeTfClassStub:
  180. def __init__(self, *a, **kw):
  181. raise ImportError("Could not import `tensorflow`. Try pip install tensorflow.")
  182. @DeveloperAPI
  183. def tf_function(tf_module):
  184. """Conditional decorator for @tf.function.
  185. Use @tf_function(tf) instead to avoid errors if tf is not installed."""
  186. # The actual decorator to use (pass in `tf` (which could be None)).
  187. def decorator(func):
  188. # If tf not installed -> return function as is (won't be used anyways).
  189. if tf_module is None or tf_module.executing_eagerly():
  190. return func
  191. # If tf installed, return @tf.function-decorated function.
  192. return tf_module.function(func)
  193. return decorator
  194. @PublicAPI
  195. def try_import_tfp(error: bool = False):
  196. """Tries importing tfp and returns the module (or None).
  197. Args:
  198. error: Whether to raise an error if tfp cannot be imported.
  199. Returns:
  200. The tfp module.
  201. Raises:
  202. ImportError: If error=True and tfp is not installed.
  203. """
  204. if "RLLIB_TEST_NO_TF_IMPORT" in os.environ:
  205. logger.warning("Not importing TensorFlow Probability for test purposes.")
  206. return None
  207. try:
  208. import tensorflow_probability as tfp
  209. return tfp
  210. except ImportError as e:
  211. if error:
  212. raise e
  213. return None
  214. # Fake module for torch.nn.
  215. class _NNStub:
  216. def __init__(self, *a, **kw):
  217. # Fake nn.functional module within torch.nn.
  218. self.functional = None
  219. self.Module = _FakeTorchClassStub
  220. self.parallel = _ParallelStub()
  221. # Fake class for e.g. torch.nn.Module to allow it to be inherited from.
  222. class _FakeTorchClassStub:
  223. def __init__(self, *a, **kw):
  224. raise ImportError("Could not import `torch`. Try pip install torch.")
  225. class _ParallelStub:
  226. def __init__(self, *a, **kw):
  227. self.DataParallel = _FakeTorchClassStub
  228. self.DistributedDataParallel = _FakeTorchClassStub
  229. @PublicAPI
  230. def try_import_torch(error: bool = False):
  231. """Tries importing torch and returns the module (or None).
  232. Args:
  233. error: Whether to raise an error if torch cannot be imported.
  234. Returns:
  235. Tuple consisting of the torch- AND torch.nn modules.
  236. Raises:
  237. ImportError: If error=True and PyTorch is not installed.
  238. """
  239. if "RLLIB_TEST_NO_TORCH_IMPORT" in os.environ:
  240. logger.warning("Not importing PyTorch for test purposes.")
  241. return _torch_stubs()
  242. try:
  243. import torch
  244. import torch.nn as nn
  245. return torch, nn
  246. except ImportError:
  247. if error:
  248. raise ImportError(
  249. "Could not import PyTorch! RLlib requires you to "
  250. "install at least one deep-learning framework: "
  251. "`pip install [torch|tensorflow|jax]`."
  252. )
  253. return _torch_stubs()
  254. def _torch_stubs():
  255. nn = _NNStub()
  256. return None, nn
  257. @DeveloperAPI
  258. def get_variable(
  259. value: Any,
  260. framework: str = "tf",
  261. trainable: bool = False,
  262. tf_name: str = "unnamed-variable",
  263. torch_tensor: bool = False,
  264. device: Optional[str] = None,
  265. shape: Optional[TensorShape] = None,
  266. dtype: Optional[TensorType] = None,
  267. ) -> Any:
  268. """Creates a tf variable, a torch tensor, or a python primitive.
  269. Args:
  270. value: The initial value to use. In the non-tf case, this will
  271. be returned as is. In the tf case, this could be a tf-Initializer
  272. object.
  273. framework: One of "tf", "torch", or None.
  274. trainable: Whether the generated variable should be
  275. trainable (tf)/require_grad (torch) or not (default: False).
  276. tf_name: For framework="tf": An optional name for the
  277. tf.Variable.
  278. torch_tensor: For framework="torch": Whether to actually create
  279. a torch.tensor, or just a python value (default).
  280. device: An optional torch device to use for
  281. the created torch tensor.
  282. shape: An optional shape to use iff `value`
  283. does not have any (e.g. if it's an initializer w/o explicit value).
  284. dtype: An optional dtype to use iff `value` does
  285. not have any (e.g. if it's an initializer w/o explicit value).
  286. This should always be a numpy dtype (e.g. np.float32, np.int64).
  287. Returns:
  288. A framework-specific variable (tf.Variable, torch.tensor, or
  289. python primitive).
  290. """
  291. if framework in ["tf2", "tf"]:
  292. import tensorflow as tf
  293. dtype = dtype or getattr(
  294. value,
  295. "dtype",
  296. tf.float32
  297. if isinstance(value, float)
  298. else tf.int32
  299. if isinstance(value, int)
  300. else None,
  301. )
  302. return tf.compat.v1.get_variable(
  303. tf_name,
  304. initializer=value,
  305. dtype=dtype,
  306. trainable=trainable,
  307. **({} if shape is None else {"shape": shape}),
  308. )
  309. elif framework == "torch" and torch_tensor is True:
  310. torch, _ = try_import_torch()
  311. if not isinstance(value, np.ndarray):
  312. value = np.array(value)
  313. var_ = torch.from_numpy(value)
  314. if dtype in [torch.float32, np.float32]:
  315. var_ = var_.float()
  316. elif dtype in [torch.int32, np.int32]:
  317. var_ = var_.int()
  318. elif dtype in [torch.float64, np.float64]:
  319. var_ = var_.double()
  320. if device:
  321. var_ = var_.to(device)
  322. var_.requires_grad = trainable
  323. return var_
  324. # torch or None: Return python primitive.
  325. return value
  326. @DeveloperAPI
  327. @Deprecated(
  328. old="rllib/utils/framework.py::get_activation_fn",
  329. new="rllib/models/utils.py::get_activation_fn",
  330. error=True,
  331. )
  332. def get_activation_fn(name: Optional[str] = None, framework: str = "tf"):
  333. pass