config.py 711 B

123456789101112131415161718192021
  1. from contextlib import contextmanager
  2. from ray.train.v2._internal.execution.train_fn_utils import get_train_fn_utils
  3. from ray.train.xgboost.config import XGBoostConfig as XGBoostConfigV1
  4. class XGBoostConfig(XGBoostConfigV1):
  5. @property
  6. def train_func_context(self):
  7. distributed_context = super(XGBoostConfig, self).train_func_context
  8. @contextmanager
  9. def collective_communication_context():
  10. # The distributed_context is only needed in distributed mode
  11. if get_train_fn_utils().is_distributed():
  12. with distributed_context():
  13. yield
  14. else:
  15. yield
  16. return collective_communication_context