util.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. import inspect
  2. import logging
  3. import types
  4. from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, Union
  5. import ray
  6. from ray.tune.execution.placement_groups import (
  7. PlacementGroupFactory,
  8. resource_dict_to_pg_factory,
  9. )
  10. from ray.tune.registry import _ParameterRegistry
  11. from ray.util.annotations import PublicAPI
  12. if TYPE_CHECKING:
  13. from ray.tune.trainable import Trainable
  14. logger = logging.getLogger(__name__)
  15. @PublicAPI(stability="beta")
  16. def with_parameters(trainable: Union[Type["Trainable"], Callable], **kwargs):
  17. """Wrapper for trainables to pass arbitrary large data objects.
  18. This wrapper function will store all passed parameters in the Ray
  19. object store and retrieve them when calling the function. It can thus
  20. be used to pass arbitrary data, even datasets, to Tune trainables.
  21. This can also be used as an alternative to ``functools.partial`` to pass
  22. default arguments to trainables.
  23. When used with the function API, the trainable function is called with
  24. the passed parameters as keyword arguments. When used with the class API,
  25. the ``Trainable.setup()`` method is called with the respective kwargs.
  26. If the data already exists in the object store (are instances of
  27. ObjectRef), using ``tune.with_parameters()`` is not necessary. You can
  28. instead pass the object refs to the training function via the ``config``
  29. or use Python partials.
  30. Args:
  31. trainable: Trainable to wrap.
  32. **kwargs: parameters to store in object store.
  33. Function API example:
  34. .. code-block:: python
  35. from ray import tune
  36. def train_fn(config, data=None):
  37. for sample in data:
  38. loss = update_model(sample)
  39. tune.report(dict(loss=loss))
  40. data = HugeDataset(download=True)
  41. tuner = Tuner(
  42. tune.with_parameters(train_fn, data=data),
  43. # ...
  44. )
  45. tuner.fit()
  46. Class API example:
  47. .. code-block:: python
  48. from ray import tune
  49. class MyTrainable(tune.Trainable):
  50. def setup(self, config, data=None):
  51. self.data = data
  52. self.iter = iter(self.data)
  53. self.next_sample = next(self.iter)
  54. def step(self):
  55. loss = update_model(self.next_sample)
  56. try:
  57. self.next_sample = next(self.iter)
  58. except StopIteration:
  59. return {"loss": loss, done: True}
  60. return {"loss": loss}
  61. data = HugeDataset(download=True)
  62. tuner = Tuner(
  63. tune.with_parameters(MyTrainable, data=data),
  64. # ...
  65. )
  66. """
  67. from ray.tune.trainable import Trainable
  68. if not callable(trainable) or (
  69. inspect.isclass(trainable) and not issubclass(trainable, Trainable)
  70. ):
  71. raise ValueError(
  72. f"`tune.with_parameters() only works with function trainables "
  73. f"or classes that inherit from `tune.Trainable()`. Got type: "
  74. f"{type(trainable)}."
  75. )
  76. parameter_registry = _ParameterRegistry()
  77. ray._private.worker._post_init_hooks.append(parameter_registry.flush)
  78. # Objects are moved into the object store
  79. prefix = f"{str(trainable)}_"
  80. for k, v in kwargs.items():
  81. parameter_registry.put(prefix + k, v)
  82. trainable_name = getattr(trainable, "__name__", "tune_with_parameters")
  83. keys = set(kwargs.keys())
  84. if inspect.isclass(trainable):
  85. # Class trainable
  86. class _Inner(trainable):
  87. def setup(self, config):
  88. setup_kwargs = {}
  89. for k in keys:
  90. setup_kwargs[k] = parameter_registry.get(prefix + k)
  91. super(_Inner, self).setup(config, **setup_kwargs)
  92. trainable_with_params = _Inner
  93. else:
  94. # Function trainable
  95. def inner(config):
  96. fn_kwargs = {}
  97. for k in keys:
  98. fn_kwargs[k] = parameter_registry.get(prefix + k)
  99. return trainable(config, **fn_kwargs)
  100. trainable_with_params = inner
  101. if hasattr(trainable, "__mixins__"):
  102. trainable_with_params.__mixins__ = trainable.__mixins__
  103. # If the trainable has been wrapped with `tune.with_resources`, we should
  104. # keep the `_resources` attribute around
  105. if hasattr(trainable, "_resources"):
  106. trainable_with_params._resources = trainable._resources
  107. trainable_with_params.__name__ = trainable_name
  108. return trainable_with_params
  109. @PublicAPI(stability="beta")
  110. def with_resources(
  111. trainable: Union[Type["Trainable"], Callable],
  112. resources: Union[
  113. Dict[str, float],
  114. PlacementGroupFactory,
  115. Callable[[dict], PlacementGroupFactory],
  116. ],
  117. ):
  118. """Wrapper for trainables to specify resource requests.
  119. This wrapper allows specification of resource requirements for a specific
  120. trainable. It will override potential existing resource requests (use
  121. with caution!).
  122. The main use case is to request resources for function trainables when used
  123. with the Tuner() API.
  124. Class trainables should usually just implement the ``default_resource_request()``
  125. method.
  126. Args:
  127. trainable: Trainable to wrap.
  128. resources: Resource dict, placement group factory, or callable that takes
  129. in a config dict and returns a placement group factory.
  130. Example:
  131. .. code-block:: python
  132. from ray import tune
  133. from ray.tune.tuner import Tuner
  134. def train_fn(config):
  135. return len(ray.get_gpu_ids()) # Returns 2
  136. tuner = Tuner(
  137. tune.with_resources(train_fn, resources={"gpu": 2}),
  138. # ...
  139. )
  140. results = tuner.fit()
  141. """
  142. from ray.tune.trainable import Trainable
  143. if not callable(trainable) or (
  144. inspect.isclass(trainable) and not issubclass(trainable, Trainable)
  145. ):
  146. raise ValueError(
  147. f"`tune.with_resources() only works with function trainables "
  148. f"or classes that inherit from `tune.Trainable()`. Got type: "
  149. f"{type(trainable)}."
  150. )
  151. if isinstance(resources, PlacementGroupFactory):
  152. pgf = resources
  153. elif isinstance(resources, dict):
  154. pgf = resource_dict_to_pg_factory(resources)
  155. elif callable(resources):
  156. pgf = resources
  157. else:
  158. raise ValueError(
  159. f"Invalid resource type for `with_resources()`: {type(resources)}"
  160. )
  161. if not inspect.isclass(trainable):
  162. if isinstance(trainable, types.MethodType):
  163. # Methods cannot set arbitrary attributes, so we have to wrap them
  164. def _trainable(config):
  165. return trainable(config)
  166. _trainable._resources = pgf
  167. return _trainable
  168. # Just set an attribute. This will be resolved later in `wrap_function()`.
  169. try:
  170. trainable._resources = pgf
  171. except AttributeError as e:
  172. raise RuntimeError(
  173. "Could not use `tune.with_resources()` on the supplied trainable. "
  174. "Wrap your trainable in a regular function before passing it "
  175. "to Ray Tune."
  176. ) from e
  177. else:
  178. class ResourceTrainable(trainable):
  179. @classmethod
  180. def default_resource_request(
  181. cls, config: Dict[str, Any]
  182. ) -> Optional[PlacementGroupFactory]:
  183. if not isinstance(pgf, PlacementGroupFactory) and callable(pgf):
  184. return pgf(config)
  185. return pgf
  186. ResourceTrainable.__name__ = trainable.__name__
  187. trainable = ResourceTrainable
  188. return trainable