_functions.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. import warnings
  2. from itertools import chain
  3. import torch
  4. from torch._utils import _get_device_index
  5. from torch.autograd import Function
  6. from torch.nn.parallel import comm
  7. class Broadcast(Function):
  8. @staticmethod
  9. def forward(ctx, target_gpus, *inputs):
  10. if not all(i.device.type != "cpu" for i in inputs):
  11. raise AssertionError("Broadcast function not implemented for CPU tensors")
  12. target_gpus = [_get_device_index(x, True) for x in target_gpus]
  13. ctx.target_gpus = target_gpus
  14. if len(inputs) == 0:
  15. return ()
  16. ctx.num_inputs = len(inputs)
  17. ctx.input_device = inputs[0].get_device()
  18. ctx.complex_mask = [inp.is_complex() for inp in inputs]
  19. outputs = comm.broadcast_coalesced(inputs, ctx.target_gpus)
  20. for device_outputs in outputs:
  21. for i, is_complex in enumerate(ctx.complex_mask):
  22. if is_complex:
  23. device_outputs[i] = torch.view_as_complex(device_outputs[i])
  24. non_differentiables = []
  25. for idx, input_requires_grad in enumerate(ctx.needs_input_grad[1:]):
  26. if not input_requires_grad:
  27. non_differentiables.extend(output[idx] for output in outputs)
  28. ctx.mark_non_differentiable(*non_differentiables)
  29. return tuple(chain.from_iterable(outputs))
  30. @staticmethod
  31. def backward(ctx, *grad_outputs):
  32. grads = ReduceAddCoalesced.apply(
  33. ctx.input_device, ctx.num_inputs, *grad_outputs
  34. )
  35. return (None,) + grads
  36. class ReduceAddCoalesced(Function):
  37. @staticmethod
  38. def forward(ctx, destination, num_inputs, *grads):
  39. ctx.target_gpus = [
  40. grads[i].get_device() for i in range(0, len(grads), num_inputs)
  41. ]
  42. complex_mask = [grads[i].is_complex() for i in range(num_inputs)]
  43. ctx.complex_mask = complex_mask
  44. grads_converted = tuple(
  45. torch.view_as_real(g) if g.is_complex() else g for g in grads
  46. )
  47. grads_ = [
  48. grads_converted[i : i + num_inputs]
  49. for i in range(0, len(grads_converted), num_inputs)
  50. ]
  51. results = comm.reduce_add_coalesced(grads_, destination)
  52. results = tuple(
  53. torch.view_as_complex(r) if is_complex else r
  54. for r, is_complex in zip(results, complex_mask)
  55. )
  56. return results
  57. @staticmethod
  58. def backward(ctx, *grad_outputs):
  59. return (
  60. None,
  61. None,
  62. ) + Broadcast.apply(ctx.target_gpus, *grad_outputs)
  63. class Gather(Function):
  64. @staticmethod
  65. def forward(ctx, target_device, dim, *inputs):
  66. if not all(i.device.type != "cpu" for i in inputs):
  67. raise AssertionError("Gather function not implemented for CPU tensors")
  68. if target_device == "cpu":
  69. ctx.target_device = "cpu"
  70. else:
  71. target_device = _get_device_index(target_device, True)
  72. ctx.target_device = target_device
  73. ctx.dim = dim
  74. ctx.input_gpus = tuple(i.get_device() for i in inputs)
  75. if all(t.dim() == 0 for t in inputs) and dim == 0:
  76. inputs = tuple(t.view(1) for t in inputs)
  77. warnings.warn(
  78. "Was asked to gather along dimension 0, but all "
  79. "input tensors were scalars; will instead unsqueeze "
  80. "and return a vector.",
  81. stacklevel=2,
  82. )
  83. ctx.unsqueezed_scalar = True
  84. else:
  85. ctx.unsqueezed_scalar = False
  86. ctx.input_sizes = tuple(i.size(ctx.dim) for i in inputs)
  87. is_complex = len(inputs) > 0 and inputs[0].is_complex()
  88. output = comm.gather(inputs, ctx.dim, ctx.target_device)
  89. if is_complex:
  90. output = torch.view_as_complex(output)
  91. return output
  92. @staticmethod
  93. def backward(ctx, grad_output):
  94. scattered_grads = Scatter.apply(
  95. ctx.input_gpus, ctx.input_sizes, ctx.dim, grad_output
  96. )
  97. if ctx.unsqueezed_scalar:
  98. scattered_grads = tuple(g[0] for g in scattered_grads)
  99. return (None, None) + scattered_grads
  100. class Scatter(Function):
  101. @staticmethod
  102. def forward(ctx, target_gpus, chunk_sizes, dim, input):
  103. target_gpus = [_get_device_index(x, True) for x in target_gpus]
  104. ctx.dim = dim
  105. ctx.input_device = input.get_device() if input.device.type != "cpu" else -1
  106. streams = None
  107. if torch.accelerator.is_available() and ctx.input_device == -1:
  108. # Perform CPU to GPU copies in a background stream
  109. streams = [_get_stream(torch.device(device)) for device in target_gpus]
  110. is_complex = input.is_complex()
  111. outputs = comm.scatter(input, target_gpus, chunk_sizes, ctx.dim, streams)
  112. if is_complex:
  113. outputs = tuple(torch.view_as_complex(o) for o in outputs)
  114. # Synchronize with the copy stream
  115. if streams is not None:
  116. for i, output in enumerate(outputs):
  117. with torch.accelerator.device_index(target_gpus[i]):
  118. main_stream = torch.accelerator.current_stream()
  119. main_stream.wait_stream(streams[i])
  120. output.record_stream(main_stream)
  121. return outputs
  122. @staticmethod
  123. def backward(ctx, *grad_output):
  124. return None, None, None, Gather.apply(ctx.input_device, ctx.dim, *grad_output)
  125. # background streams used for copying
  126. _streams: list[torch.Stream | None] | None = None
  127. def _get_stream(device: torch.device):
  128. """Get a background stream for copying between CPU and target device."""
  129. global _streams
  130. if device.type == "cpu" or not torch.accelerator.is_available():
  131. return None
  132. if torch.accelerator.current_accelerator().type != device.type:
  133. raise AssertionError(
  134. f"Expected current accelerator type {torch.accelerator.current_accelerator().type} "
  135. f"to match device type {device.type}"
  136. )
  137. if _streams is None:
  138. _streams = [None] * torch.accelerator.device_count()
  139. if _streams[device.index] is None:
  140. _streams[device.index] = torch.Stream(device.index)
  141. return _streams[device.index]