config.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import logging
  2. import threading
  3. from dataclasses import dataclass
  4. from typing import Any, Dict, Optional
  5. import ray
  6. from ray._common.network_utils import build_address
  7. from ray.train._internal.base_worker_group import BaseWorkerGroup
  8. from ray.train._internal.utils import get_address_and_port
  9. from ray.train.backend import Backend, BackendConfig
  10. logger = logging.getLogger(__name__)
  11. # Global LightGBM distributed network configuration for each worker process.
  12. _lightgbm_network_params: Optional[Dict[str, Any]] = None
  13. _lightgbm_network_params_lock = threading.Lock()
  14. def get_network_params() -> Dict[str, Any]:
  15. """Returns the network parameters to enable LightGBM distributed training."""
  16. global _lightgbm_network_params
  17. with _lightgbm_network_params_lock:
  18. if not _lightgbm_network_params:
  19. logger.warning(
  20. "`ray.train.lightgbm.get_network_params` was called outside "
  21. "the context of a `ray.train.lightgbm.LightGBMTrainer`. "
  22. "The current process has no knowledge of the distributed training "
  23. "worker group, so this method will return an empty dict. "
  24. "Please call this within the training loop of a "
  25. "`ray.train.lightgbm.LightGBMTrainer`. "
  26. "If you are in fact calling this within a `LightGBMTrainer`, "
  27. "this is unexpected: please file a bug report to the Ray Team."
  28. )
  29. return {}
  30. return _lightgbm_network_params.copy()
  31. def _set_network_params(
  32. num_machines: int,
  33. local_listen_port: int,
  34. machines: str,
  35. ):
  36. global _lightgbm_network_params
  37. with _lightgbm_network_params_lock:
  38. assert (
  39. _lightgbm_network_params is None
  40. ), "LightGBM network params are already initialized."
  41. _lightgbm_network_params = dict(
  42. num_machines=num_machines,
  43. local_listen_port=local_listen_port,
  44. machines=machines,
  45. )
  46. @dataclass
  47. class LightGBMConfig(BackendConfig):
  48. """Configuration for LightGBM distributed data-parallel training setup.
  49. See the LightGBM docs for more information on the "network parameters"
  50. that Ray Train sets up for you:
  51. https://lightgbm.readthedocs.io/en/latest/Parameters.html#network-parameters
  52. """
  53. @property
  54. def backend_cls(self):
  55. return _LightGBMBackend
  56. class _LightGBMBackend(Backend):
  57. def on_training_start(
  58. self, worker_group: BaseWorkerGroup, backend_config: LightGBMConfig
  59. ):
  60. node_ips_and_ports = worker_group.execute(get_address_and_port)
  61. ports = [port for _, port in node_ips_and_ports]
  62. machines = ",".join(
  63. [build_address(node_ip, port) for node_ip, port in node_ips_and_ports]
  64. )
  65. num_machines = len(worker_group)
  66. ray.get(
  67. [
  68. worker_group.execute_single_async(
  69. rank, _set_network_params, num_machines, ports[rank], machines
  70. )
  71. for rank in range(len(worker_group))
  72. ]
  73. )