distributed.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. """ Distributed training/validation utils
  2. Hacked together by / Copyright 2020 Ross Wightman
  3. """
  4. import logging
  5. import os
  6. from typing import Optional
  7. import torch
  8. from torch import distributed as dist
  9. from .model import unwrap_model
  10. _logger = logging.getLogger(__name__)
  11. def reduce_tensor(tensor, n):
  12. rt = tensor.clone()
  13. dist.all_reduce(rt, op=dist.ReduceOp.SUM)
  14. rt /= n
  15. return rt
  16. def distribute_bn(model, world_size, reduce=False):
  17. # ensure every node has the same running bn stats
  18. for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True):
  19. if ('running_mean' in bn_name) or ('running_var' in bn_name):
  20. if reduce:
  21. # average bn stats across whole group
  22. torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM)
  23. bn_buf /= float(world_size)
  24. else:
  25. # broadcast bn stats from rank 0 to whole group
  26. torch.distributed.broadcast(bn_buf, 0)
  27. def is_global_primary(args):
  28. return args.rank == 0
  29. def is_local_primary(args):
  30. return args.local_rank == 0
  31. def is_primary(args, local=False):
  32. return is_local_primary(args) if local else is_global_primary(args)
  33. def is_distributed_env():
  34. if 'WORLD_SIZE' in os.environ:
  35. return int(os.environ['WORLD_SIZE']) > 1
  36. if 'SLURM_NTASKS' in os.environ:
  37. return int(os.environ['SLURM_NTASKS']) > 1
  38. return False
  39. def world_info_from_env():
  40. local_rank = 0
  41. for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'):
  42. if v in os.environ:
  43. local_rank = int(os.environ[v])
  44. break
  45. global_rank = 0
  46. for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'):
  47. if v in os.environ:
  48. global_rank = int(os.environ[v])
  49. break
  50. world_size = 1
  51. for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'):
  52. if v in os.environ:
  53. world_size = int(os.environ[v])
  54. break
  55. return local_rank, global_rank, world_size
  56. def init_distributed_device(args):
  57. # Distributed training = training on more than one GPU.
  58. # Works in both single and multi-node scenarios.
  59. args.distributed = False
  60. args.world_size = 1
  61. args.rank = 0 # global rank
  62. args.local_rank = 0
  63. result = init_distributed_device_so(
  64. device=getattr(args, 'device', 'cuda'),
  65. dist_backend=getattr(args, 'dist_backend', None),
  66. dist_url=getattr(args, 'dist_url', None),
  67. )
  68. args.device = result['device']
  69. args.world_size = result['world_size']
  70. args.rank = result['global_rank']
  71. args.local_rank = result['local_rank']
  72. args.distributed = result['distributed']
  73. device = torch.device(args.device)
  74. return device
  75. def init_distributed_device_so(
  76. device: str = 'cuda',
  77. dist_backend: Optional[str] = None,
  78. dist_url: Optional[str] = None,
  79. ):
  80. # Distributed training = training on more than one GPU.
  81. # Works in both single and multi-node scenarios.
  82. distributed = False
  83. world_size = 1
  84. global_rank = 0
  85. local_rank = 0
  86. device_type, *device_idx = device.split(':', maxsplit=1)
  87. if dist_backend is None:
  88. # FIXME: verify that ROCm transform nccl to rccl
  89. dist_backends = {
  90. "xpu": "ccl",
  91. "hpu": "hccl",
  92. "cuda": "nccl",
  93. "npu": "hccl",
  94. }
  95. dist_backend = dist_backends.get(device_type, 'gloo')
  96. dist_url = dist_url or 'env://'
  97. # TBD, support horovod?
  98. # if args.horovod:
  99. # import horovod.torch as hvd
  100. # assert hvd is not None, "Horovod is not installed"
  101. # hvd.init()
  102. # args.local_rank = int(hvd.local_rank())
  103. # args.rank = hvd.rank()
  104. # args.world_size = hvd.size()
  105. # args.distributed = True
  106. # os.environ['LOCAL_RANK'] = str(args.local_rank)
  107. # os.environ['RANK'] = str(args.rank)
  108. # os.environ['WORLD_SIZE'] = str(args.world_size)
  109. if is_distributed_env():
  110. if 'SLURM_PROCID' in os.environ:
  111. # DDP via SLURM
  112. local_rank, global_rank, world_size = world_info_from_env()
  113. # SLURM var -> torch.distributed vars in case needed
  114. os.environ['LOCAL_RANK'] = str(local_rank)
  115. os.environ['RANK'] = str(global_rank)
  116. os.environ['WORLD_SIZE'] = str(world_size)
  117. torch.distributed.init_process_group(
  118. backend=dist_backend,
  119. init_method=dist_url,
  120. world_size=world_size,
  121. rank=global_rank,
  122. )
  123. else:
  124. # DDP via torchrun, torch.distributed.launch
  125. local_rank, _, _ = world_info_from_env()
  126. torch.distributed.init_process_group(
  127. backend=dist_backend,
  128. init_method=dist_url,
  129. )
  130. world_size = torch.distributed.get_world_size()
  131. global_rank = torch.distributed.get_rank()
  132. distributed = True
  133. if device_type == 'cuda':
  134. assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.'
  135. if device_type == 'npu':
  136. assert torch.npu.is_available(), f'Ascend NPU is not available but {device} was specified.'
  137. if distributed and device != 'cpu':
  138. # Ignore manually specified device index in distributed mode and
  139. # override with resolved local rank, fewer headaches in most setups.
  140. if device_idx:
  141. _logger.warning(f'device index {device_idx[0]} removed from specified ({device}).')
  142. device = f'{device_type}:{local_rank}'
  143. if device.startswith('cuda:'):
  144. torch.cuda.set_device(device)
  145. return dict(
  146. device=device,
  147. global_rank=global_rank,
  148. local_rank=local_rank,
  149. world_size=world_size,
  150. distributed=distributed,
  151. )