| 123456789101112131415161718192021 |
- from contextlib import contextmanager
- from ray.train.v2._internal.execution.train_fn_utils import get_train_fn_utils
- from ray.train.xgboost.config import XGBoostConfig as XGBoostConfigV1
- class XGBoostConfig(XGBoostConfigV1):
- @property
- def train_func_context(self):
- distributed_context = super(XGBoostConfig, self).train_func_context
- @contextmanager
- def collective_communication_context():
- # The distributed_context is only needed in distributed mode
- if get_train_fn_utils().is_distributed():
- with distributed_context():
- yield
- else:
- yield
- return collective_communication_context
|