config.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import json
  2. import logging
  3. import os
  4. from dataclasses import dataclass
  5. from typing import List
  6. import ray
  7. from ray._common.network_utils import build_address
  8. from ray.train._internal.base_worker_group import BaseWorkerGroup
  9. from ray.train._internal.utils import get_address_and_port
  10. from ray.train.backend import Backend, BackendConfig
  11. from ray.util import PublicAPI
  12. logger = logging.getLogger(__name__)
  13. @PublicAPI(stability="beta")
  14. @dataclass
  15. class TensorflowConfig(BackendConfig):
  16. @property
  17. def backend_cls(self):
  18. return _TensorflowBackend
  19. def _setup_tensorflow_environment(worker_addresses: List[str], index: int):
  20. """Set up distributed Tensorflow training information.
  21. This function should be called on each worker.
  22. Args:
  23. worker_addresses: Addresses of all the workers.
  24. index: Index (i.e. world rank) of the current worker.
  25. """
  26. tf_config = {
  27. "cluster": {"worker": worker_addresses},
  28. "task": {"type": "worker", "index": index},
  29. }
  30. os.environ["TF_CONFIG"] = json.dumps(tf_config)
  31. class _TensorflowBackend(Backend):
  32. def on_start(self, worker_group: BaseWorkerGroup, backend_config: TensorflowConfig):
  33. # Compute URL for initializing distributed setup.
  34. def get_url():
  35. address, port = get_address_and_port()
  36. return build_address(address, port)
  37. urls = worker_group.execute(get_url)
  38. # Get setup tasks in order to throw errors on failure.
  39. setup_futures = []
  40. for i in range(len(worker_group)):
  41. setup_futures.append(
  42. worker_group.execute_single_async(
  43. i,
  44. _setup_tensorflow_environment,
  45. worker_addresses=urls,
  46. index=i,
  47. )
  48. )
  49. ray.get(setup_futures)