annotations.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. from typing import Any, Callable, TypeVar
  2. from ray._common.deprecation import Deprecated
  3. from ray.util.annotations import _mark_annotated
  4. # TypeVar for preserving function/class signatures through decorators
  5. F = TypeVar("F", bound=Callable[..., Any])
  6. def override(parent_cls: type) -> Callable[[F], F]:
  7. """Decorator for documenting method overrides.
  8. Args:
  9. parent_cls: The superclass that provides the overridden method. If
  10. `parent_class` does not actually have the method or the class, in which
  11. method is defined is not a subclass of `parent_class`, an error is raised.
  12. .. testcode::
  13. :skipif: True
  14. from ray.rllib.policy import Policy
  15. class TorchPolicy(Policy):
  16. ...
  17. # Indicates that `TorchPolicy.loss()` overrides the parent
  18. # Policy class' own `loss method. Leads to an error if Policy
  19. # does not have a `loss` method.
  20. @override(Policy)
  21. def loss(self, model, action_dist, train_batch):
  22. ...
  23. """
  24. class OverrideCheck:
  25. def __init__(self, func, expected_parent_cls):
  26. self.func = func
  27. self.expected_parent_cls = expected_parent_cls
  28. def __set_name__(self, owner, name):
  29. # Check if the owner (the class) is a subclass of the expected base class
  30. if not issubclass(owner, self.expected_parent_cls):
  31. raise TypeError(
  32. f"When using the @override decorator, {owner.__name__} must be a "
  33. f"subclass of {parent_cls.__name__}!"
  34. )
  35. # Set the function as a regular method on the class.
  36. setattr(owner, name, self.func)
  37. def decorator(method: F) -> F:
  38. # Check, whether `method` is actually defined by the parent class.
  39. if method.__name__ not in dir(parent_cls):
  40. raise NameError(
  41. f"When using the @override decorator, {method.__name__} must override "
  42. f"the respective method (with the same name) of {parent_cls.__name__}!"
  43. )
  44. # Check if the class is a subclass of the expected base class
  45. OverrideCheck(method, parent_cls)
  46. return method
  47. return decorator
  48. def PublicAPI(obj: F) -> F:
  49. """Decorator for documenting public APIs.
  50. Public APIs are classes and methods exposed to end users of RLlib. You
  51. can expect these APIs to remain stable across RLlib releases.
  52. Subclasses that inherit from a ``@PublicAPI`` base class can be
  53. assumed part of the RLlib public API as well (e.g., all Algorithm classes
  54. are in public API because Algorithm is ``@PublicAPI``).
  55. In addition, you can assume all algo configurations are part of their
  56. public API as well.
  57. .. testcode::
  58. :skipif: True
  59. # Indicates that the `Algorithm` class is exposed to end users
  60. # of RLlib and will remain stable across RLlib releases.
  61. from ray import tune
  62. @PublicAPI
  63. class Algorithm(tune.Trainable):
  64. ...
  65. """
  66. _mark_annotated(obj)
  67. return obj
  68. def DeveloperAPI(obj: F) -> F:
  69. """Decorator for documenting developer APIs.
  70. Developer APIs are classes and methods explicitly exposed to developers
  71. for the purposes of building custom algorithms or advanced training
  72. strategies on top of RLlib internals. You can generally expect these APIs
  73. to be stable sans minor changes (but less stable than public APIs).
  74. Subclasses that inherit from a ``@DeveloperAPI`` base class can be
  75. assumed part of the RLlib developer API as well.
  76. .. testcode::
  77. :skipif: True
  78. # Indicates that the `TorchPolicy` class is exposed to end users
  79. # of RLlib and will remain (relatively) stable across RLlib
  80. # releases.
  81. from ray.rllib.policy import Policy
  82. @DeveloperAPI
  83. class TorchPolicy(Policy):
  84. ...
  85. """
  86. _mark_annotated(obj)
  87. return obj
  88. def ExperimentalAPI(obj: F) -> F:
  89. """Decorator for documenting experimental APIs.
  90. Experimental APIs are classes and methods that are in development and may
  91. change at any time in their development process. You should not expect
  92. these APIs to be stable until their tag is changed to `DeveloperAPI` or
  93. `PublicAPI`.
  94. Subclasses that inherit from a ``@ExperimentalAPI`` base class can be
  95. assumed experimental as well.
  96. .. testcode::
  97. :skipif: True
  98. from ray.rllib.policy import Policy
  99. class TorchPolicy(Policy):
  100. ...
  101. # Indicates that the `TorchPolicy.loss` method is a new and
  102. # experimental API and may change frequently in future
  103. # releases.
  104. @ExperimentalAPI
  105. def loss(self, model, action_dist, train_batch):
  106. ...
  107. """
  108. _mark_annotated(obj)
  109. return obj
  110. def OldAPIStack(obj: F) -> F:
  111. """Decorator for classes/methods/functions belonging to the old API stack.
  112. These should be deprecated at some point after Ray 3.0 (RLlib GA).
  113. It is recommended for users to start exploring (and coding against) the new API
  114. stack instead.
  115. """
  116. # No effect yet.
  117. _mark_annotated(obj)
  118. return obj
  119. def OverrideToImplementCustomLogic(obj: F) -> F:
  120. """Users should override this in their sub-classes to implement custom logic.
  121. Used in Algorithm and Policy to tag methods that need overriding, e.g.
  122. `Policy.loss()`.
  123. .. testcode::
  124. :skipif: True
  125. from ray.rllib.policy.torch_policy import TorchPolicy
  126. @overrides(TorchPolicy)
  127. @OverrideToImplementCustomLogic
  128. def loss(self, ...):
  129. # implement custom loss function here ...
  130. # ... w/o calling the corresponding `super().loss()` method.
  131. ...
  132. """
  133. obj.__is_overridden__ = False # type: ignore[attr-defined]
  134. return obj
  135. def OverrideToImplementCustomLogic_CallToSuperRecommended(obj: F) -> F:
  136. """Users should override this in their sub-classes to implement custom logic.
  137. Thereby, it is recommended (but not required) to call the super-class'
  138. corresponding method.
  139. Used in Algorithm and Policy to tag methods that need overriding, but the
  140. super class' method should still be called, e.g.
  141. `Algorithm.setup()`.
  142. .. testcode::
  143. :skipif: True
  144. from ray import tune
  145. @overrides(tune.Trainable)
  146. @OverrideToImplementCustomLogic_CallToSuperRecommended
  147. def setup(self, config):
  148. # implement custom setup logic here ...
  149. super().setup(config)
  150. # ... or here (after having called super()'s setup method.
  151. """
  152. obj.__is_overridden__ = False # type: ignore[attr-defined]
  153. return obj
  154. def is_overridden(obj: Callable[..., Any]) -> bool:
  155. """Check whether a function has been overridden.
  156. Note, this only works for API calls decorated with OverrideToImplementCustomLogic
  157. or OverrideToImplementCustomLogic_CallToSuperRecommended.
  158. """
  159. return getattr(obj, "__is_overridden__", True)
  160. # Backward compatibility.
  161. Deprecated = Deprecated