tf_modelv2.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import contextlib
  2. import re
  3. from typing import Dict, List, Union
  4. import gymnasium as gym
  5. from ray._common.deprecation import deprecation_warning
  6. from ray.rllib.models.modelv2 import ModelV2
  7. from ray.rllib.utils.annotations import OldAPIStack, override
  8. from ray.rllib.utils.framework import try_import_tf
  9. from ray.rllib.utils.typing import ModelConfigDict, TensorType
  10. from ray.util import log_once
  11. tf1, tf, tfv = try_import_tf()
  12. @OldAPIStack
  13. class TFModelV2(ModelV2):
  14. """TF version of ModelV2, which should contain a tf keras Model.
  15. Note that this class by itself is not a valid model unless you
  16. implement forward() in a subclass."""
  17. def __init__(
  18. self,
  19. obs_space: gym.spaces.Space,
  20. action_space: gym.spaces.Space,
  21. num_outputs: int,
  22. model_config: ModelConfigDict,
  23. name: str,
  24. ):
  25. """Initializes a TFModelV2 instance.
  26. Here is an example implementation for a subclass
  27. ``MyModelClass(TFModelV2)``::
  28. def __init__(self, *args, **kwargs):
  29. super(MyModelClass, self).__init__(*args, **kwargs)
  30. input_layer = tf.keras.layers.Input(...)
  31. hidden_layer = tf.keras.layers.Dense(...)(input_layer)
  32. output_layer = tf.keras.layers.Dense(...)(hidden_layer)
  33. value_layer = tf.keras.layers.Dense(...)(hidden_layer)
  34. self.base_model = tf.keras.Model(
  35. input_layer, [output_layer, value_layer])
  36. """
  37. super().__init__(
  38. obs_space, action_space, num_outputs, model_config, name, framework="tf"
  39. )
  40. # Deprecated: TFModelV2 now automatically track their variables.
  41. self.var_list = []
  42. if tf1.executing_eagerly():
  43. self.graph = None
  44. else:
  45. self.graph = tf1.get_default_graph()
  46. def context(self) -> contextlib.AbstractContextManager:
  47. """Returns a contextmanager for the current TF graph."""
  48. if self.graph:
  49. return self.graph.as_default()
  50. else:
  51. return ModelV2.context(self)
  52. def update_ops(self) -> List[TensorType]:
  53. """Return the list of update ops for this model.
  54. For example, this should include any BatchNorm update ops."""
  55. return []
  56. def register_variables(self, variables: List[TensorType]) -> None:
  57. """Register the given list of variables with this model."""
  58. if log_once("deprecated_tfmodelv2_register_variables"):
  59. deprecation_warning(old="TFModelV2.register_variables", error=False)
  60. self.var_list.extend(variables)
  61. @override(ModelV2)
  62. def variables(
  63. self, as_dict: bool = False
  64. ) -> Union[List[TensorType], Dict[str, TensorType]]:
  65. if as_dict:
  66. # Old way using `register_variables`.
  67. if self.var_list:
  68. return {v.name: v for v in self.var_list}
  69. # New way: Automatically determine the var tree.
  70. else:
  71. return self._find_sub_modules("", self.__dict__)
  72. # Old way using `register_variables`.
  73. if self.var_list:
  74. return list(self.var_list)
  75. # New way: Automatically determine the var tree.
  76. else:
  77. return list(self.variables(as_dict=True).values())
  78. @override(ModelV2)
  79. def trainable_variables(
  80. self, as_dict: bool = False
  81. ) -> Union[List[TensorType], Dict[str, TensorType]]:
  82. if as_dict:
  83. return {
  84. k: v for k, v in self.variables(as_dict=True).items() if v.trainable
  85. }
  86. return [v for v in self.variables() if v.trainable]
  87. @staticmethod
  88. def _find_sub_modules(current_key, struct):
  89. # Keras Model: key=k + "." + var-name (replace '/' by '.').
  90. if isinstance(struct, tf.keras.models.Model) or isinstance(struct, tf.Module):
  91. ret = {}
  92. for var in struct.variables:
  93. name = re.sub("/", ".", var.name)
  94. key = current_key + "." + name
  95. ret[key] = var
  96. return ret
  97. # Other TFModelV2: Include its vars into ours.
  98. elif isinstance(struct, TFModelV2):
  99. return {
  100. current_key + "." + key: var
  101. for key, var in struct.variables(as_dict=True).items()
  102. }
  103. # tf.Variable
  104. elif isinstance(struct, tf.Variable):
  105. return {current_key: struct}
  106. # List/Tuple.
  107. elif isinstance(struct, (tuple, list)):
  108. ret = {}
  109. for i, value in enumerate(struct):
  110. sub_vars = TFModelV2._find_sub_modules(
  111. current_key + "_{}".format(i), value
  112. )
  113. ret.update(sub_vars)
  114. return ret
  115. # Dict.
  116. elif isinstance(struct, dict):
  117. if current_key:
  118. current_key += "_"
  119. ret = {}
  120. for key, value in struct.items():
  121. sub_vars = TFModelV2._find_sub_modules(current_key + str(key), value)
  122. ret.update(sub_vars)
  123. return ret
  124. return {}