| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176 |
- import warnings
- from itertools import chain
- import torch
- from torch._utils import _get_device_index
- from torch.autograd import Function
- from torch.nn.parallel import comm
- class Broadcast(Function):
- @staticmethod
- def forward(ctx, target_gpus, *inputs):
- if not all(i.device.type != "cpu" for i in inputs):
- raise AssertionError("Broadcast function not implemented for CPU tensors")
- target_gpus = [_get_device_index(x, True) for x in target_gpus]
- ctx.target_gpus = target_gpus
- if len(inputs) == 0:
- return ()
- ctx.num_inputs = len(inputs)
- ctx.input_device = inputs[0].get_device()
- ctx.complex_mask = [inp.is_complex() for inp in inputs]
- outputs = comm.broadcast_coalesced(inputs, ctx.target_gpus)
- for device_outputs in outputs:
- for i, is_complex in enumerate(ctx.complex_mask):
- if is_complex:
- device_outputs[i] = torch.view_as_complex(device_outputs[i])
- non_differentiables = []
- for idx, input_requires_grad in enumerate(ctx.needs_input_grad[1:]):
- if not input_requires_grad:
- non_differentiables.extend(output[idx] for output in outputs)
- ctx.mark_non_differentiable(*non_differentiables)
- return tuple(chain.from_iterable(outputs))
- @staticmethod
- def backward(ctx, *grad_outputs):
- grads = ReduceAddCoalesced.apply(
- ctx.input_device, ctx.num_inputs, *grad_outputs
- )
- return (None,) + grads
- class ReduceAddCoalesced(Function):
- @staticmethod
- def forward(ctx, destination, num_inputs, *grads):
- ctx.target_gpus = [
- grads[i].get_device() for i in range(0, len(grads), num_inputs)
- ]
- complex_mask = [grads[i].is_complex() for i in range(num_inputs)]
- ctx.complex_mask = complex_mask
- grads_converted = tuple(
- torch.view_as_real(g) if g.is_complex() else g for g in grads
- )
- grads_ = [
- grads_converted[i : i + num_inputs]
- for i in range(0, len(grads_converted), num_inputs)
- ]
- results = comm.reduce_add_coalesced(grads_, destination)
- results = tuple(
- torch.view_as_complex(r) if is_complex else r
- for r, is_complex in zip(results, complex_mask)
- )
- return results
- @staticmethod
- def backward(ctx, *grad_outputs):
- return (
- None,
- None,
- ) + Broadcast.apply(ctx.target_gpus, *grad_outputs)
- class Gather(Function):
- @staticmethod
- def forward(ctx, target_device, dim, *inputs):
- if not all(i.device.type != "cpu" for i in inputs):
- raise AssertionError("Gather function not implemented for CPU tensors")
- if target_device == "cpu":
- ctx.target_device = "cpu"
- else:
- target_device = _get_device_index(target_device, True)
- ctx.target_device = target_device
- ctx.dim = dim
- ctx.input_gpus = tuple(i.get_device() for i in inputs)
- if all(t.dim() == 0 for t in inputs) and dim == 0:
- inputs = tuple(t.view(1) for t in inputs)
- warnings.warn(
- "Was asked to gather along dimension 0, but all "
- "input tensors were scalars; will instead unsqueeze "
- "and return a vector.",
- stacklevel=2,
- )
- ctx.unsqueezed_scalar = True
- else:
- ctx.unsqueezed_scalar = False
- ctx.input_sizes = tuple(i.size(ctx.dim) for i in inputs)
- is_complex = len(inputs) > 0 and inputs[0].is_complex()
- output = comm.gather(inputs, ctx.dim, ctx.target_device)
- if is_complex:
- output = torch.view_as_complex(output)
- return output
- @staticmethod
- def backward(ctx, grad_output):
- scattered_grads = Scatter.apply(
- ctx.input_gpus, ctx.input_sizes, ctx.dim, grad_output
- )
- if ctx.unsqueezed_scalar:
- scattered_grads = tuple(g[0] for g in scattered_grads)
- return (None, None) + scattered_grads
- class Scatter(Function):
- @staticmethod
- def forward(ctx, target_gpus, chunk_sizes, dim, input):
- target_gpus = [_get_device_index(x, True) for x in target_gpus]
- ctx.dim = dim
- ctx.input_device = input.get_device() if input.device.type != "cpu" else -1
- streams = None
- if torch.accelerator.is_available() and ctx.input_device == -1:
- # Perform CPU to GPU copies in a background stream
- streams = [_get_stream(torch.device(device)) for device in target_gpus]
- is_complex = input.is_complex()
- outputs = comm.scatter(input, target_gpus, chunk_sizes, ctx.dim, streams)
- if is_complex:
- outputs = tuple(torch.view_as_complex(o) for o in outputs)
- # Synchronize with the copy stream
- if streams is not None:
- for i, output in enumerate(outputs):
- with torch.accelerator.device_index(target_gpus[i]):
- main_stream = torch.accelerator.current_stream()
- main_stream.wait_stream(streams[i])
- output.record_stream(main_stream)
- return outputs
- @staticmethod
- def backward(ctx, *grad_output):
- return None, None, None, Gather.apply(ctx.input_device, ctx.dim, *grad_output)
- # background streams used for copying
- _streams: list[torch.Stream | None] | None = None
- def _get_stream(device: torch.device):
- """Get a background stream for copying between CPU and target device."""
- global _streams
- if device.type == "cpu" or not torch.accelerator.is_available():
- return None
- if torch.accelerator.current_accelerator().type != device.type:
- raise AssertionError(
- f"Expected current accelerator type {torch.accelerator.current_accelerator().type} "
- f"to match device type {device.type}"
- )
- if _streams is None:
- _streams = [None] * torch.accelerator.device_count()
- if _streams[device.index] is None:
- _streams[device.index] = torch.Stream(device.index)
- return _streams[device.index]
|