_utils.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. # mypy: allow-untyped-defs
  2. # Copyright (c) Meta Platforms, Inc. and affiliates
  3. import logging
  4. from dataclasses import dataclass
  5. import torch
  6. from torch import fx
  7. logger = logging.getLogger(__name__)
  8. def flatten_args_detach(args):
  9. """
  10. Flatten the args into a list form and detach the tensors from computational graph.
  11. """
  12. flat_detached_args = []
  13. def extract_tensor_args(a):
  14. nonlocal flat_detached_args
  15. if isinstance(a, torch.Tensor):
  16. val = a.detach().requires_grad_(a.requires_grad)
  17. flat_detached_args.append(val)
  18. return val
  19. else:
  20. flat_detached_args.append(a)
  21. return a
  22. new_args = fx.node.map_aggregate(
  23. args,
  24. extract_tensor_args,
  25. )
  26. return new_args, flat_detached_args
  27. def flatten_args(args):
  28. """
  29. Flatten the args into a list form.
  30. """
  31. flat_args = []
  32. def extract_tensor_args(a):
  33. nonlocal flat_args
  34. flat_args.append(a)
  35. return a
  36. fx.node.map_aggregate(
  37. args,
  38. extract_tensor_args,
  39. )
  40. return flat_args
  41. class PipeliningShapeError(RuntimeError):
  42. """Shape mismatch between configured and runtime values."""
  43. def validate_tensor_metadata(desc, expected, given):
  44. if not expected.shape == given.shape:
  45. raise PipeliningShapeError(
  46. f"{desc} has a shape mismatch: expected {expected.shape} actual {given.shape}"
  47. )
  48. if not expected.dtype == given.dtype:
  49. raise PipeliningShapeError(
  50. f"{desc} has a dtype mismatch: expected {expected.dtype} actual {given.dtype}"
  51. )
  52. if not expected.stride() == given.stride():
  53. raise PipeliningShapeError(
  54. f"{desc} has a stride mismatch: expected {expected.stride()} actual {given.stride()}"
  55. )
  56. def validate_tensors_metadata(
  57. desc,
  58. expected_tensors: list[torch.Tensor] | tuple[torch.Tensor, ...],
  59. actual_tensors: list[torch.Tensor] | tuple[torch.Tensor, ...],
  60. ):
  61. if len(expected_tensors) != len(actual_tensors):
  62. raise PipeliningShapeError(
  63. f"{desc}: Number of values ({len(actual_tensors)}) does not match expected number ({len(expected_tensors)})"
  64. )
  65. for i in range(len(expected_tensors)):
  66. validate_tensor_metadata(
  67. f"{desc}: value {i}", expected_tensors[i], actual_tensors[i]
  68. )
  69. def generate_stage_to_rank_mapping(
  70. pp_size: int, num_stages: int, style: str = "loop"
  71. ) -> dict[int, int]:
  72. """
  73. Compute the stage id to rank mapping for either a looped or V-style schedule.
  74. Most commonly num_stages == pp_size * 2, but this function can be used to
  75. compute the mapping for any number of stages per rank.
  76. """
  77. mapping = {}
  78. if style == "loop":
  79. for stage_index in range(num_stages):
  80. mapping[stage_index] = stage_index % pp_size
  81. elif style == "v":
  82. if num_stages % pp_size != 0:
  83. raise ValueError(
  84. f"num_stages {num_stages} must be evenly divisible by pp_size {pp_size} for V schedules"
  85. )
  86. rank_index = 0
  87. for stage_index in range(num_stages):
  88. mapping[stage_index] = rank_index
  89. # dont change rank if we are on the border (to keep v shape)
  90. if (stage_index + 1) % pp_size == 0:
  91. continue
  92. if (stage_index // pp_size) % 2 == 0:
  93. rank_index += 1
  94. else:
  95. rank_index -= 1
  96. else:
  97. raise ValueError(f"Style {style} is not supported.")
  98. return mapping
  99. def generate_rank_to_stage_mapping(
  100. pp_size: int, num_stages: int, style: str = "loop"
  101. ) -> dict[int, list[int]]:
  102. """
  103. Compute the rank to stage id mapping for either a looped or V-style schedule.
  104. This function inverts the stage_to_rank_mapping to get which stages are assigned to each rank.
  105. Returns a dictionary mapping rank -> list of stage indices assigned to that rank.
  106. """
  107. stage_to_rank = generate_stage_to_rank_mapping(pp_size, num_stages, style)
  108. # Invert the mapping: rank -> list of stages
  109. rank_to_stages: dict[int, list[int]] = {}
  110. for stage_id, rank in stage_to_rank.items():
  111. if rank not in rank_to_stages:
  112. rank_to_stages[rank] = []
  113. rank_to_stages[rank].append(stage_id)
  114. # Sort the stage lists for each rank to ensure consistent ordering
  115. for stages in rank_to_stages.values():
  116. stages.sort()
  117. return rank_to_stages
  118. @dataclass
  119. class PipeInfo:
  120. """
  121. Captures information for a pipeline (`Pipe` object).
  122. """
  123. graph: fx.Graph
  124. num_stages: int
  125. has_loss_and_backward: bool