utils.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. from typing import Callable, Optional, Union
  2. from ray.rllib.utils.annotations import DeveloperAPI
  3. from ray.rllib.utils.framework import try_import_jax, try_import_tf, try_import_torch
  4. @DeveloperAPI
  5. def get_activation_fn(
  6. name: Optional[Union[Callable, str]] = None,
  7. framework: str = "tf",
  8. ):
  9. """Returns a framework specific activation function, given a name string.
  10. Args:
  11. name: One of "relu" (default), "tanh", "elu",
  12. "swish" (or "silu", which is the same), or "linear" (same as None).
  13. framework: One of "jax", "tf|tf2" or "torch".
  14. Returns:
  15. A framework-specific activtion function. e.g. tf.nn.tanh or
  16. torch.nn.ReLU. None if name in ["linear", None].
  17. Raises:
  18. ValueError: If name is an unknown activation function.
  19. """
  20. # Already a callable, return as-is.
  21. if callable(name):
  22. return name
  23. name_lower = name.lower() if isinstance(name, str) else name
  24. # Infer the correct activation function from the string specifier.
  25. if framework == "torch":
  26. if name_lower in ["linear", None]:
  27. return None
  28. _, nn = try_import_torch()
  29. # First try getting the correct activation function from nn directly.
  30. # Note that torch activation functions are not all lower case.
  31. fn = getattr(nn, name, None)
  32. if fn is not None:
  33. return fn
  34. if name_lower in ["swish", "silu"]:
  35. return nn.SiLU
  36. elif name_lower == "relu":
  37. return nn.ReLU
  38. elif name_lower == "tanh":
  39. return nn.Tanh
  40. elif name_lower == "elu":
  41. return nn.ELU
  42. elif name_lower == "softmax":
  43. return nn.Softmax
  44. elif framework == "jax":
  45. if name_lower in ["linear", None]:
  46. return None
  47. jax, _ = try_import_jax()
  48. if name_lower in ["swish", "silu"]:
  49. return jax.nn.swish
  50. if name_lower == "relu":
  51. return jax.nn.relu
  52. elif name_lower == "tanh":
  53. return jax.nn.hard_tanh
  54. elif name_lower == "elu":
  55. return jax.nn.elu
  56. else:
  57. assert framework in ["tf", "tf2"], "Unsupported framework `{}`!".format(
  58. framework
  59. )
  60. if name_lower in ["linear", None]:
  61. return None
  62. tf1, tf, tfv = try_import_tf()
  63. # Try getting the correct activation function from tf.nn directly.
  64. # Note that tf activation functions are all lower case, so this should always
  65. # work.
  66. fn = getattr(tf.nn, name_lower, None)
  67. if fn is not None:
  68. return fn
  69. raise ValueError(
  70. "Unknown activation ({}) for framework={}!".format(name, framework)
  71. )
  72. @DeveloperAPI
  73. def get_initializer_fn(name: Optional[Union[str, Callable]], framework: str = "torch"):
  74. """Returns the framework-specific initializer class or function.
  75. This function relies fully on the specified initializer classes and
  76. functions in the frameworks `torch` and `tf2` (see for `torch`
  77. https://pytorch.org/docs/stable/nn.init.html and for `tf2` see
  78. https://www.tensorflow.org/api_docs/python/tf/keras/initializers).
  79. Note, for framework `torch` the in-place initializers are needed, i.e. names
  80. should end with an underscore `_`, e.g. `glorot_uniform_`.
  81. Args:
  82. name: Name of the initializer class or function in one of the two
  83. supported frameworks, i.e. `torch` or `tf2`.
  84. framework: The framework string, either `torch or `tf2`.
  85. Returns:
  86. A framework-specific function or class defining an initializer to be used
  87. for network initialization,
  88. Raises:
  89. `ValueError` if the `name` is neither class or function in the specified
  90. `framework`. Raises also a `ValueError`, if `name` does not define an
  91. in-place initializer for framework `torch`.
  92. """
  93. # Already a callable or `None` return as is. If `None` we use the default
  94. # initializer defined in the framework-specific layers themselves.
  95. if callable(name) or name is None:
  96. return name
  97. if framework == "torch":
  98. name_lower = name.lower() if isinstance(name, str) else name
  99. _, nn = try_import_torch()
  100. # Check, if the name includes an underscore. We must use the
  101. # in-place initialization from Torch.
  102. if not name_lower.endswith("_"):
  103. raise ValueError(
  104. "Not an in-place initializer: Torch weight initializers "
  105. "need to be provided as their in-place version, i.e. "
  106. "<initializaer_name> + '_'. See "
  107. "https://pytorch.org/docs/stable/nn.init.html. "
  108. f"User provided {name}."
  109. )
  110. # First, try to get the initialization directly from `nn.init`.
  111. # Note, that all initialization methods in `nn.init` are lower
  112. # case and that `<method>_` defines the "in-place" method.
  113. fn = getattr(nn.init, name_lower, None)
  114. if fn is not None:
  115. # TODO (simon): Raise a warning if not "in-place" method.
  116. return fn
  117. # Unknown initializer.
  118. else:
  119. # Inform the user that this initializer does not exist.
  120. raise ValueError(
  121. f"Unknown initializer name: {name_lower} is not a method in "
  122. "`torch.nn.init`!"
  123. )
  124. elif framework == "tf2":
  125. # Note, as initializer classes in TensorFlow can be either given by their
  126. # name in camel toe typing or by their shortcut we use the `name` as it is.
  127. # See https://www.tensorflow.org/api_docs/python/tf/keras/initializers.
  128. _, tf, _ = try_import_tf()
  129. # Try to get the initialization function directly from `tf.keras.initializers`.
  130. fn = getattr(tf.keras.initializers, name, None)
  131. if fn is not None:
  132. return fn
  133. # Unknown initializer.
  134. else:
  135. # Inform the user that this initializer does not exist.
  136. raise ValueError(
  137. f"Unknown initializer: {name} is not a initializer in "
  138. "`tf.keras.initializers`!"
  139. )
  140. @DeveloperAPI
  141. def get_filter_config(shape):
  142. """Returns a default Conv2D filter config (list) for a given image shape.
  143. Args:
  144. shape (Tuple[int]): The input (image) shape, e.g. (84,84,3).
  145. Returns:
  146. List[list]: The Conv2D filter configuration usable as `conv_filters`
  147. inside a model config dict.
  148. """
  149. # 96x96x3 (e.g. CarRacing-v0).
  150. filters_96x96 = [
  151. [16, [8, 8], 4],
  152. [32, [4, 4], 2],
  153. [256, [11, 11], 2],
  154. ]
  155. # Atari.
  156. filters_84x84 = [
  157. [16, [8, 8], 4],
  158. [32, [4, 4], 2],
  159. [256, [11, 11], 1],
  160. ]
  161. # Dreamer-style (XS-sized model) Atari or DM Control Suite.
  162. filters_64x64 = [
  163. [16, [4, 4], 2],
  164. [32, [4, 4], 2],
  165. [64, [4, 4], 2],
  166. [128, [4, 4], 2],
  167. ]
  168. # Small (1/2) Atari.
  169. filters_42x42 = [
  170. [16, [4, 4], 2],
  171. [32, [4, 4], 2],
  172. [256, [11, 11], 1],
  173. ]
  174. # Test image (10x10).
  175. filters_10x10 = [
  176. [16, [5, 5], 2],
  177. [32, [5, 5], 2],
  178. ]
  179. shape = list(shape)
  180. if len(shape) in [2, 3] and (shape[:2] == [96, 96] or shape[1:] == [96, 96]):
  181. return filters_96x96
  182. elif len(shape) in [2, 3] and (shape[:2] == [84, 84] or shape[1:] == [84, 84]):
  183. return filters_84x84
  184. elif len(shape) in [2, 3] and (shape[:2] == [64, 64] or shape[1:] == [64, 64]):
  185. return filters_64x64
  186. elif len(shape) in [2, 3] and (shape[:2] == [42, 42] or shape[1:] == [42, 42]):
  187. return filters_42x42
  188. elif len(shape) in [2, 3] and (shape[:2] == [10, 10] or shape[1:] == [10, 10]):
  189. return filters_10x10
  190. else:
  191. if list(shape) == [210, 160, 3]:
  192. atari_help = (
  193. "This is the default atari obs shape. You may want to look at one of "
  194. "RLlib's Atari examples for an example of how to wrap an Atari env. "
  195. )
  196. else:
  197. atari_help = ""
  198. raise ValueError(
  199. "No default CNN configuration for obs shape {}. ".format(shape)
  200. + atari_help
  201. + "You can specify `conv_filters` manually through your "
  202. "AlgorithmConfig's model_config. "
  203. "Default configurations are only available for inputs of the following "
  204. "shapes: [42, 42, K], [84, 84, K], [64, 64, K], [10, 10, K]. You may "
  205. "want to use a custom RLModule or a ConnectorV2 for that."
  206. )
  207. @DeveloperAPI
  208. def get_initializer(name, framework="tf"):
  209. """Returns a framework specific initializer, given a name string.
  210. Args:
  211. name: One of "xavier_uniform" (default), "xavier_normal".
  212. framework: One of "jax", "tf|tf2" or "torch".
  213. Returns:
  214. A framework-specific initializer function, e.g.
  215. tf.keras.initializers.GlorotUniform or
  216. torch.nn.init.xavier_uniform_.
  217. Raises:
  218. ValueError: If name is an unknown initializer.
  219. """
  220. # Already a callable, return as-is.
  221. if callable(name):
  222. return name
  223. if framework == "jax":
  224. _, flax = try_import_jax()
  225. assert flax is not None, "`flax` not installed. Try `pip install jax flax`."
  226. import flax.linen as nn
  227. if name in [None, "default", "xavier_uniform"]:
  228. return nn.initializers.xavier_uniform()
  229. elif name == "xavier_normal":
  230. return nn.initializers.xavier_normal()
  231. if framework == "torch":
  232. _, nn = try_import_torch()
  233. assert nn is not None, "`torch` not installed. Try `pip install torch`."
  234. if name in [None, "default", "xavier_uniform"]:
  235. return nn.init.xavier_uniform_
  236. elif name == "xavier_normal":
  237. return nn.init.xavier_normal_
  238. else:
  239. assert framework in ["tf", "tf2"], "Unsupported framework `{}`!".format(
  240. framework
  241. )
  242. tf1, tf, tfv = try_import_tf()
  243. assert (
  244. tf is not None
  245. ), "`tensorflow` not installed. Try `pip install tensorflow`."
  246. if name in [None, "default", "xavier_uniform"]:
  247. return tf.keras.initializers.GlorotUniform
  248. elif name == "xavier_normal":
  249. return tf.keras.initializers.GlorotNormal
  250. raise ValueError(
  251. "Unknown activation ({}) for framework={}!".format(name, framework)
  252. )