distributed.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, List, Optional
  15. import torch
  16. from torch import Tensor
  17. from torch.nn import functional as F # noqa: N812
  18. from typing_extensions import Literal
  19. def reduce(x: Tensor, reduction: Optional[Literal["elementwise_mean", "sum", "none"]]) -> Tensor:
  20. """Reduces a given tensor by a given reduction method.
  21. Args:
  22. x: the tensor, which shall be reduced
  23. reduction: a string specifying the reduction method ('elementwise_mean', 'none', 'sum')
  24. Return:
  25. reduced Tensor
  26. Raise:
  27. ValueError if an invalid reduction parameter was given
  28. """
  29. if reduction == "elementwise_mean":
  30. return torch.mean(x)
  31. if reduction == "none" or reduction is None:
  32. return x
  33. if reduction == "sum":
  34. return torch.sum(x)
  35. raise ValueError("Reduction parameter unknown.")
  36. def class_reduce(
  37. num: Tensor,
  38. denom: Tensor,
  39. weights: Tensor,
  40. class_reduction: Optional[Literal["micro", "macro", "weighted", "none"]] = "none",
  41. ) -> Tensor:
  42. """Reduce classification metrics of the form ``num / denom * weights``.
  43. For example for calculating standard accuracy the num would be number of true positives per class, denom would be
  44. the support per class, and weights would be a tensor of 1s.
  45. Args:
  46. num: numerator tensor
  47. denom: denominator tensor
  48. weights: weights for each class
  49. class_reduction: reduction method for multiclass problems:
  50. - ``'micro'``: calculate metrics globally (default)
  51. - ``'macro'``: calculate metrics for each label, and find their unweighted mean.
  52. - ``'weighted'``: calculate metrics for each label, and find their weighted mean.
  53. - ``'none'`` or ``None``: returns calculated metric per class
  54. Raises:
  55. ValueError:
  56. If ``class_reduction`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"none"`` or ``None``.
  57. """
  58. valid_reduction = ("micro", "macro", "weighted", "none", None)
  59. fraction = torch.sum(num) / torch.sum(denom) if class_reduction == "micro" else num / denom
  60. # We need to take care of instances where the denom can be 0
  61. # for some (or all) classes which will produce nans
  62. fraction[fraction != fraction] = 0
  63. if class_reduction == "micro":
  64. return fraction
  65. if class_reduction == "macro":
  66. return torch.mean(fraction)
  67. if class_reduction == "weighted":
  68. return torch.sum(fraction * (weights.float() / torch.sum(weights)))
  69. if class_reduction == "none" or class_reduction is None:
  70. return fraction
  71. raise ValueError(f"Reduction parameter {class_reduction} unknown. Choose between one of these: {valid_reduction}")
  72. def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> List[Tensor]:
  73. with torch.no_grad():
  74. gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
  75. torch.distributed.all_gather(gathered_result, result, group)
  76. # to propagate autograd graph from local rank
  77. gathered_result[torch.distributed.get_rank(group)] = result
  78. return gathered_result
  79. def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tensor]:
  80. """Gather all tensors from several ddp processes onto a list that is broadcast to all processes.
  81. Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case
  82. tensors are padded, gathered and then trimmed to secure equal workload for all processes.
  83. Args:
  84. result: the value to sync
  85. group: the process group to gather results from. Defaults to all processes (world)
  86. Return:
  87. list with size equal to the process group where element i corresponds to result tensor from process i
  88. """
  89. if group is None:
  90. group = torch.distributed.group.WORLD
  91. # convert tensors to contiguous format
  92. result = result.contiguous()
  93. world_size = torch.distributed.get_world_size(group)
  94. torch.distributed.barrier(group=group)
  95. # if the tensor is scalar, things are easy
  96. if result.ndim == 0:
  97. return _simple_gather_all_tensors(result, group, world_size)
  98. # 1. Gather sizes of all tensors
  99. local_size = torch.tensor(result.shape, device=result.device)
  100. local_sizes = [torch.zeros_like(local_size) for _ in range(world_size)]
  101. torch.distributed.all_gather(local_sizes, local_size, group=group)
  102. max_size = torch.stack(local_sizes).max(dim=0).values
  103. all_sizes_equal = all(all(ls == max_size) for ls in local_sizes)
  104. # 2. If shapes are all the same, then do a simple gather:
  105. if all_sizes_equal:
  106. return _simple_gather_all_tensors(result, group, world_size)
  107. # 3. If not, we need to pad each local tensor to maximum size, gather and then truncate
  108. with torch.no_grad():
  109. pad_dims = []
  110. pad_by = (max_size - local_size).detach().cpu()
  111. for val in reversed(pad_by):
  112. pad_dims.append(0)
  113. pad_dims.append(val.item())
  114. result_padded = F.pad(result, pad_dims)
  115. gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)]
  116. torch.distributed.all_gather(gathered_result, result_padded, group)
  117. for idx, item_size in enumerate(local_sizes):
  118. slice_param = [slice(dim_size) for dim_size in item_size]
  119. gathered_result[idx] = gathered_result[idx][tuple(slice_param)]
  120. # to propagate autograd graph from local rank
  121. gathered_result[torch.distributed.get_rank(group)] = result
  122. return gathered_result