placement_groups.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import warnings
  2. from typing import Dict, Optional
  3. from ray.air.execution.resources.request import ResourceRequest
  4. from ray.util.annotations import DeveloperAPI, PublicAPI
  5. from ray.util.placement_group import placement_group
  6. @PublicAPI(stability="beta")
  7. class PlacementGroupFactory(ResourceRequest):
  8. """Wrapper class that creates placement groups for trials.
  9. This function should be used to define resource requests for Ray Tune
  10. trials. It holds the parameters to create
  11. :ref:`placement groups <ray-placement-group-doc-ref>`.
  12. At a minimum, this will hold at least one bundle specifying the
  13. resource requirements for each trial:
  14. .. code-block:: python
  15. from ray import tune
  16. tuner = tune.Tuner(
  17. tune.with_resources(
  18. train,
  19. resources=tune.PlacementGroupFactory([
  20. {"CPU": 1, "GPU": 0.5, "custom_resource": 2}
  21. ])
  22. )
  23. )
  24. tuner.fit()
  25. If the trial itself schedules further remote workers, the resource
  26. requirements should be specified in additional bundles. You can also
  27. pass the placement strategy for these bundles, e.g. to enforce
  28. co-located placement:
  29. .. code-block:: python
  30. from ray import tune
  31. tuner = tune.Tuner(
  32. tune.with_resources(
  33. train,
  34. resources=tune.PlacementGroupFactory([
  35. {"CPU": 1, "GPU": 0.5, "custom_resource": 2},
  36. {"CPU": 2},
  37. {"CPU": 2},
  38. ], strategy="PACK")
  39. )
  40. )
  41. tuner.fit()
  42. The example above will reserve 1 CPU, 0.5 GPUs and 2 custom_resources
  43. for the trainable itself, and reserve another 2 bundles of 2 CPUs each.
  44. The trial will only start when all these resources are available. This
  45. could be used e.g. if you had one learner running in the main trainable
  46. that schedules two remote workers that need access to 2 CPUs each.
  47. If the trainable itself doesn't require resources.
  48. You can specify it as:
  49. .. code-block:: python
  50. from ray import tune
  51. tuner = tune.Tuner(
  52. tune.with_resources(
  53. train,
  54. resources=tune.PlacementGroupFactory([
  55. {},
  56. {"CPU": 2},
  57. {"CPU": 2},
  58. ], strategy="PACK")
  59. )
  60. )
  61. tuner.fit()
  62. Args:
  63. bundles: A list of bundles which
  64. represent the resources requirements.
  65. strategy: The strategy to create the placement group.
  66. - "PACK": Packs Bundles into as few nodes as possible.
  67. - "SPREAD": Places Bundles across distinct nodes as even as possible.
  68. - "STRICT_PACK": Packs Bundles into one node. The group is
  69. not allowed to span multiple nodes.
  70. - "STRICT_SPREAD": Packs Bundles across distinct nodes.
  71. *args: Passed to the call of ``placement_group()``
  72. **kwargs: Passed to the call of ``placement_group()``
  73. """
  74. def __call__(self, *args, **kwargs):
  75. warnings.warn(
  76. "Calling PlacementGroupFactory objects is deprecated. Use "
  77. "`to_placement_group()` instead.",
  78. DeprecationWarning,
  79. )
  80. kwargs.update(self._bound.kwargs)
  81. # Call with bounded *args and **kwargs
  82. return placement_group(*self._bound.args, **kwargs)
  83. @DeveloperAPI
  84. def resource_dict_to_pg_factory(spec: Optional[Dict[str, float]] = None):
  85. """Translates resource dict into PlacementGroupFactory."""
  86. spec = spec or {"cpu": 1}
  87. spec = spec.copy()
  88. cpus = spec.pop("cpu", spec.pop("CPU", 0.0))
  89. gpus = spec.pop("gpu", spec.pop("GPU", 0.0))
  90. memory = spec.pop("memory", 0.0)
  91. # If there is a custom_resources key, use as base for bundle
  92. bundle = dict(spec.pop("custom_resources", {}))
  93. # Otherwise, consider all other keys as custom resources
  94. if not bundle:
  95. bundle = spec
  96. bundle.update(
  97. {
  98. "CPU": cpus,
  99. "GPU": gpus,
  100. "memory": memory,
  101. }
  102. )
  103. return PlacementGroupFactory([bundle])