schedules.py 139 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447
  1. # mypy: allow-untyped-defs
  2. # Copyright (c) Meta Platforms, Inc. and affiliates
  3. import copy
  4. import csv
  5. import itertools
  6. import logging
  7. import re
  8. from abc import ABC, abstractmethod
  9. from collections import Counter, defaultdict
  10. from collections.abc import Callable
  11. from dataclasses import dataclass
  12. from enum import Enum
  13. from functools import lru_cache
  14. from typing import Any, cast, NamedTuple, Protocol
  15. import torch
  16. import torch.distributed as dist
  17. from torch._dynamo import OptimizedModule
  18. from torch.distributed.fsdp import FSDPModule, UnshardHandle
  19. from torch.nn.modules.loss import _Loss
  20. from torch.profiler import record_function
  21. from ._utils import generate_rank_to_stage_mapping, generate_stage_to_rank_mapping
  22. from .microbatch import merge_chunks, split_args_kwargs_into_chunks, TensorChunkSpec
  23. from .stage import _PipelineStageBase
  24. __all__ = [
  25. "get_schedule_class",
  26. "PipelineScheduleSingle",
  27. "PipelineScheduleMulti",
  28. "Schedule1F1B",
  29. "ScheduleGPipe",
  30. "ScheduleInterleaved1F1B",
  31. "ScheduleLoopedBFS",
  32. "ScheduleInterleavedZeroBubble",
  33. "ScheduleZBVZeroBubble",
  34. "ScheduleDualPipeV",
  35. ]
  36. logger = logging.getLogger(__name__)
  37. class _ComputationType(str, Enum):
  38. # TODO(whc) rename to _ActType?
  39. FORWARD = "F"
  40. BACKWARD_INPUT = "I"
  41. BACKWARD_WEIGHT = "W"
  42. UNSHARD = "UNSHARD"
  43. RESHARD = "RESHARD"
  44. SEND_F = "SEND_F"
  45. RECV_F = "RECV_F"
  46. SEND_B = "SEND_B"
  47. RECV_B = "RECV_B"
  48. FULL_BACKWARD = "B"
  49. OVERLAP_F_B = "OVERLAP_F_B"
  50. REDUCE_GRAD = "REDUCE_GRAD"
  51. @staticmethod
  52. def from_str(action: str) -> "_ComputationType":
  53. try:
  54. return _ComputationType(action)
  55. except ValueError as exc:
  56. raise RuntimeError(f"Invalid computation type {action}") from exc
  57. FORWARD = _ComputationType.FORWARD
  58. BACKWARD_INPUT = _ComputationType.BACKWARD_INPUT
  59. BACKWARD_WEIGHT = _ComputationType.BACKWARD_WEIGHT
  60. UNSHARD = _ComputationType.UNSHARD
  61. RESHARD = _ComputationType.RESHARD
  62. SEND_F = _ComputationType.SEND_F
  63. RECV_F = _ComputationType.RECV_F
  64. SEND_B = _ComputationType.SEND_B
  65. RECV_B = _ComputationType.RECV_B
  66. FULL_BACKWARD = _ComputationType.FULL_BACKWARD
  67. OVERLAP_F_B = _ComputationType.OVERLAP_F_B
  68. REDUCE_GRAD = _ComputationType.REDUCE_GRAD
  69. # Convenience shorthand for compute actions only since they are used in 'simple schedule format'
  70. F = FORWARD
  71. I = BACKWARD_INPUT
  72. W = BACKWARD_WEIGHT
  73. B = FULL_BACKWARD
  74. # Helper to parse an action string like 1F0 into a tuple of (stage_index, computation_type, microbatch_index)
  75. _action_regex = re.compile(
  76. r"(\d+)(F|I|B|W|UNSHARD|RESHARD|REDUCE_GRAD|SEND_F|RECV_F|SEND_B|RECV_B)(\d*)"
  77. )
  78. class _Action(NamedTuple):
  79. stage_index: int
  80. computation_type: _ComputationType
  81. microbatch_index: int | None = None
  82. sub_actions: tuple["_Action", ...] | None = None
  83. def __str__(self):
  84. return self.__repr__()
  85. def __repr__(self):
  86. if self.sub_actions is not None:
  87. # Use recursive repr for sub_actions
  88. sub_action_reprs = [repr(sub_action) for sub_action in self.sub_actions]
  89. return f"({';'.join(sub_action_reprs)}){self.computation_type.value}"
  90. else:
  91. repr_str = str(self.stage_index)
  92. # Use .value to get the short string (e.g., "F", "B") instead of the full enum name
  93. repr_str += self.computation_type.value
  94. if self.microbatch_index is not None:
  95. repr_str += str(self.microbatch_index)
  96. return repr_str
  97. @property
  98. def is_compute_op(self) -> bool:
  99. return self.computation_type in (
  100. FORWARD,
  101. FULL_BACKWARD,
  102. BACKWARD_INPUT,
  103. BACKWARD_WEIGHT,
  104. OVERLAP_F_B,
  105. )
  106. @staticmethod
  107. def from_str(action_string: str):
  108. """
  109. Reverse of __repr__
  110. String should be formatted as [stage][action type][(microbatch)]
  111. e.g. `2F0`, `1UNSHARD`, `3SEND_F1`
  112. """
  113. action_string = action_string.strip()
  114. if action_string == "":
  115. return None
  116. # Check for sub_actions format: [sub_action1;sub_action2;...]ComputationType
  117. if action_string.startswith("(") and ")" in action_string:
  118. # Find the closing bracket to separate sub_actions from computation type
  119. bracket_end = action_string.find(")")
  120. sub_part = action_string[
  121. 1:bracket_end
  122. ] # Remove '[' and get content before ']'
  123. computation_type_part = action_string[
  124. bracket_end + 1 :
  125. ] # Get part after ']'
  126. # Parse sub_actions
  127. sub_actions = []
  128. if sub_part.strip():
  129. for sub_str in sub_part.split(";"):
  130. sub_action = _Action.from_str(sub_str.strip())
  131. if sub_action is not None:
  132. sub_actions.append(sub_action)
  133. # For sub_actions format, we create an action with just the computation type
  134. # The stage_index and microbatch_index are not meaningful for the container action
  135. return _Action(
  136. stage_index=-1, # Placeholder, not meaningful for sub_actions container
  137. computation_type=_ComputationType.from_str(computation_type_part),
  138. microbatch_index=None,
  139. sub_actions=tuple(sub_actions) if sub_actions else None,
  140. )
  141. # Handle regular single action format
  142. if match := _action_regex.match(action_string):
  143. stage_index, computation_type, microbatch_index = match.groups()
  144. return _Action(
  145. int(stage_index),
  146. _ComputationType.from_str(computation_type),
  147. int(microbatch_index) if len(microbatch_index) else None,
  148. )
  149. elif action_string == "":
  150. return None
  151. raise RuntimeError(
  152. f"Invalid action string: {action_string}, should be formatted as [stage][action type][(microbatch)] e.g. 2F0"
  153. )
  154. @lru_cache
  155. def _get_profiler_function_name(action: _Action) -> str:
  156. return f"PP:{str(action)}"
  157. def _format_pipeline_order(
  158. pipeline_order: dict[int, list[_Action | None]],
  159. error_step_number: int | None = None,
  160. ) -> str:
  161. """
  162. Formats the pipeline order in a timestep (row) x rank (column) grid of actions
  163. and returns the formatted string.
  164. If `error_step_number` is passed in, an additional label will be added to signify which step
  165. that it is erroring on.
  166. """
  167. # don't mutate the original
  168. pipeline_order = copy.deepcopy(pipeline_order)
  169. # Replace None with ""
  170. for rank in pipeline_order:
  171. for i in range(len(pipeline_order[rank])):
  172. if pipeline_order[rank][i] is None:
  173. # TODO make a real 'None action' that prints as empty string and make mypy happy
  174. pipeline_order[rank][i] = "" # type: ignore[call-overload]
  175. # Calculate the maximum number of steps across all ranks
  176. num_steps = max(len(actions) for actions in pipeline_order.values())
  177. step_labels = [
  178. "Step " + str(i).zfill(len(str(num_steps - 1))) for i in range(num_steps)
  179. ]
  180. # Sorting the dictionary by keys and retrieving values in that order
  181. rank_actions = [
  182. pipeline_order.get(key, [""] * num_steps) for key in sorted(pipeline_order)
  183. ]
  184. # Transpose the list of lists (rows to columns)
  185. # pyrefly: ignore [no-matching-overload]
  186. transposed_actions = list(itertools.zip_longest(*rank_actions, fillvalue=""))
  187. # Generate column labels for ranks
  188. num_ranks = len(pipeline_order)
  189. rank_labels = ["Rank " + str(i) for i in range(num_ranks)]
  190. # Calculate the maximum length of each column, considering labels
  191. max_lengths = [
  192. max(len(str(item)) if item is not None else 0 for item in col)
  193. for col in zip(step_labels, *transposed_actions)
  194. ]
  195. # Format the header row with rank labels
  196. header_row = " " * (len(step_labels[0]) + 2) + " ".join(
  197. f"{label:<{max_lengths[i]}}" for i, label in enumerate(rank_labels)
  198. )
  199. # Format each row with its corresponding label
  200. formatted_rows = [
  201. f"{label}: "
  202. + " ".join(f"{str(item):<{max_lengths[i]}}" for i, item in enumerate(row))
  203. + (
  204. " <-- ERROR HERE"
  205. if error_step_number is not None
  206. and int(label.split()[1]) == error_step_number
  207. else ""
  208. )
  209. for label, row in zip(step_labels, transposed_actions)
  210. ]
  211. # Join the rows into a single string
  212. formatted_table = header_row + "\n" + "\n".join(formatted_rows) + "\n"
  213. return formatted_table
  214. class _PipelineSchedule(ABC):
  215. def __init__(
  216. self,
  217. n_microbatches: int,
  218. loss_fn: Callable[..., torch.Tensor] | None = None,
  219. args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None,
  220. kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None,
  221. output_merge_spec: dict[str, Any] | tuple[Any] | None = None,
  222. scale_grads: bool = True,
  223. ):
  224. # From arguments
  225. self._n_microbatches = n_microbatches
  226. self._loss_fn = loss_fn
  227. # See documentation in `PipelineScheduleSingle` / `PipelineScheduleMulti`
  228. self.scale_grads = scale_grads
  229. # Chunking specification for positional inputs. (default: `None`)
  230. self._args_chunk_spec = args_chunk_spec
  231. # Chunking specification for keyword inputs. (default: `None`)
  232. self._kwargs_chunk_spec = kwargs_chunk_spec
  233. self._output_merge_spec = output_merge_spec
  234. """
  235. # args_chunk_spec and kwargs_chunk_spec specify how to chunk inputs.
  236. # They are used to convert batch to microbatches in `step(x)`. See
  237. # `TensorChunkSpec` for helper methods for creating them.
  238. """
  239. # Derived
  240. self._has_backward = self._loss_fn is not None
  241. # Holds the losses for each microbatch.
  242. self._internal_losses: list[torch.Tensor] = []
  243. logger.info("Using %s", self.__class__.__name__)
  244. def _maybe_compute_loss(self, stage, output, target_mbs, mb_index):
  245. if stage.is_last and self._loss_fn is not None:
  246. loss = self._compute_loss(output, target_mbs[mb_index]) # type: ignore[index]
  247. self._internal_losses.append(loss)
  248. def _maybe_get_loss(self, stage, mb_index):
  249. valid_index = 0 <= mb_index < len(self._internal_losses)
  250. if stage.is_last and self._loss_fn is not None and valid_index:
  251. return self._internal_losses[mb_index]
  252. elif len(self._internal_losses) != 0 and not valid_index:
  253. raise RuntimeError(
  254. f"Loss for microbatch {mb_index} is not available. "
  255. f"Available losses for microbatches: {self._internal_losses}"
  256. )
  257. else:
  258. return None
  259. def _update_losses(self, stages, losses):
  260. """
  261. Update the losses to those in the internal state
  262. """
  263. # if stages not a list turn into a list
  264. if not isinstance(stages, list):
  265. stages = [stages]
  266. contains_last_stage = any(stage.is_last for stage in stages)
  267. # Return losses if there is a container passed in
  268. if contains_last_stage and losses is not None:
  269. if len(self._internal_losses) != self._n_microbatches:
  270. raise RuntimeError(
  271. f"Expecting {self._n_microbatches} losses but got {len(self._internal_losses)}"
  272. )
  273. # Clean external container first
  274. losses.clear()
  275. # Copy internal losses to external container
  276. losses.extend(self._internal_losses)
  277. self._internal_losses.clear()
  278. @abstractmethod
  279. def _step_microbatches(
  280. self,
  281. arg_mbs: list | None = None,
  282. kwarg_mbs: list | None = None,
  283. target_mbs: list | None = None,
  284. losses: list | None = None,
  285. return_outputs: bool = True,
  286. ):
  287. """
  288. Run one iteration of the pipeline schedule with list of microbatches.
  289. Will go through all the microbatches according to the schedule
  290. implementation.
  291. Args:
  292. microbatches: list of microbatch args.
  293. return_outputs: whether to return the outputs from the last stage.
  294. """
  295. raise NotImplementedError
  296. @abstractmethod
  297. def step(
  298. self,
  299. *args,
  300. target=None,
  301. losses: list | None = None,
  302. return_outputs=True,
  303. **kwargs,
  304. ):
  305. """
  306. Run one iteration of the pipeline schedule with *whole-batch* input.
  307. Will chunk the input into microbatches automatically, and go through the
  308. microbatches according to the schedule implementation.
  309. args: positional arguments to the model (as in non-pipeline case).
  310. kwargs: keyword arguments to the model (as in non-pipeline case).
  311. target: target for the loss function.
  312. losses: a list to store the losses for each microbatch.
  313. return_outputs: whether to return the outputs from the last stage.
  314. """
  315. raise NotImplementedError
  316. def eval(self, *args, target=None, losses: list | None = None, **kwargs):
  317. """
  318. Run one iteration of the pipeline schedule with *whole-batch* input.
  319. Will chunk the input into microbatches automatically, and go through the
  320. microbatches, calling forward only.
  321. args: positional arguments to the model (as in non-pipeline case).
  322. kwargs: keyword arguments to the model (as in non-pipeline case).
  323. target: target values for the loss function.
  324. losses: a list to store the losses for each microbatch.
  325. """
  326. # Save the original has_backward state
  327. original_has_backward = self._has_backward
  328. try:
  329. self._has_backward = False
  330. return self.step(*args, target=target, losses=losses, **kwargs)
  331. finally:
  332. # Restore the original state
  333. self._has_backward = original_has_backward
  334. def _check_inputs(
  335. self,
  336. arg_mbs: list | None = None,
  337. kwarg_mbs: list | None = None,
  338. target_mbs: list | None = None,
  339. losses: list | None = None,
  340. ) -> tuple[list, list]:
  341. """
  342. Pre-process/check inputs
  343. """
  344. def check_type_and_len(mbs, name: str):
  345. if not isinstance(mbs, list):
  346. raise TypeError(f"{name} must be a list but got a {type(mbs)}")
  347. if len(mbs) != self._n_microbatches:
  348. raise ValueError(
  349. f"Expecting {self._n_microbatches} {name} but got {len(mbs)}"
  350. )
  351. if arg_mbs is not None:
  352. check_type_and_len(arg_mbs, "arg_mbs")
  353. else:
  354. arg_mbs = [()] * self._n_microbatches
  355. if kwarg_mbs is not None:
  356. check_type_and_len(kwarg_mbs, "kwarg_mbs")
  357. else:
  358. kwarg_mbs = [{}] * self._n_microbatches
  359. if target_mbs is not None:
  360. check_type_and_len(target_mbs, "target_mbs")
  361. if losses is not None:
  362. if not isinstance(losses, list):
  363. raise TypeError(f"losses must be a list but got a {type(losses)}")
  364. return arg_mbs, kwarg_mbs
  365. def _compute_loss(self, output, target):
  366. return self._loss_fn(output, target) # type: ignore[misc]
  367. def _split_inputs(
  368. self,
  369. args: tuple[Any, ...],
  370. kwargs: dict[str, Any] | None = None,
  371. ):
  372. """
  373. Splits a full-batch input into chunks (i.e. microbatches) and returns
  374. the chunks
  375. """
  376. if args or kwargs:
  377. args_split, kwargs_split = split_args_kwargs_into_chunks(
  378. args,
  379. kwargs,
  380. self._n_microbatches,
  381. self._args_chunk_spec,
  382. self._kwargs_chunk_spec,
  383. )
  384. return args_split, kwargs_split
  385. else:
  386. # Empty inputs (e.g. when called on middle stages)
  387. # Return a list of empty tuples/dicts with matching length as chunks
  388. return [()] * self._n_microbatches, [{}] * self._n_microbatches
  389. def _merge_outputs(self, output_chunks: list[Any]) -> Any:
  390. """
  391. Merge output chunks back to a batch state.
  392. If output_merge_spec is None, the utility will merge output chunks by dimension 0 (batch dim).
  393. """
  394. return merge_chunks(
  395. output_chunks,
  396. self._output_merge_spec,
  397. )
  398. def _batch_p2p(p2p_ops: list[dist.P2POp], desc: str | None = None) -> list[dist.Work]:
  399. """
  400. Simple wrapper over batch_isend_irecv from torch.distributed, which just adds a descriptive logger on top.
  401. """
  402. if len(p2p_ops) == 0:
  403. return []
  404. desc_str = f"{desc}, " if desc else ""
  405. logger.debug("batch_p2p %s%s", desc_str, p2p_ops)
  406. return dist.batch_isend_irecv(p2p_ops)
  407. def _sorted_batch_p2p(
  408. p2p_ops: list[dist.P2POp], desc: str | None = None
  409. ) -> dict[int, list[dist.Work]]:
  410. """
  411. Sorts the list of P2P ops by the peer rank, and then calls
  412. batch_isend_irecv. Return a dictionary of works by peer rank. This function
  413. helps us avoid hangs in case of skip connections.
  414. """
  415. # Arrange p2p_ops by peer rank:
  416. # int is the peer rank;
  417. # List is the list of ops towards the peer
  418. ops_by_peer: dict[int, list[dist.P2POp]] = defaultdict(list)
  419. work_by_peer: dict[int, list[dist.Work]] = {}
  420. if len(p2p_ops) == 0:
  421. return work_by_peer
  422. # Classify the ops by peer rank
  423. for op in p2p_ops:
  424. ops_by_peer[op.peer].append(op)
  425. # Call batch_isend_irecv per peer, in sorted order of the peers (to avoid hangs)
  426. for peer, ops in sorted(ops_by_peer.items()):
  427. work_by_peer[peer] = _batch_p2p(ops, desc=desc)
  428. return work_by_peer
  429. def _wait_batch_p2p(work: list[dist.Work]):
  430. """
  431. Waits for a list of dist.Work (typically from _batch_p2p / _sorted_batch_p2p).
  432. """
  433. for w in work:
  434. w.wait()
  435. class PipelineScheduleSingle(_PipelineSchedule):
  436. """
  437. Base class for single-stage schedules.
  438. Implements the `step` method.
  439. Derived classes should implement `_step_microbatches`.
  440. Gradients are scaled by num_microbatches depending on the `scale_grads` argument, defaulting to True. This setting
  441. should match the configuration of your loss_fn, which may either average losses (scale_grads=True)
  442. or sum losses (scale_grads=False).
  443. """
  444. def __init__(
  445. self,
  446. stage: _PipelineStageBase,
  447. n_microbatches: int,
  448. loss_fn: Callable | None = None,
  449. args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None,
  450. kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None,
  451. output_merge_spec: dict[str, Any] | tuple[Any] | None = None,
  452. scale_grads: bool = True,
  453. ):
  454. # Init parent
  455. super().__init__(
  456. n_microbatches=n_microbatches,
  457. loss_fn=loss_fn,
  458. args_chunk_spec=args_chunk_spec,
  459. kwargs_chunk_spec=kwargs_chunk_spec,
  460. output_merge_spec=output_merge_spec,
  461. scale_grads=scale_grads,
  462. )
  463. # Self attributes
  464. self._stage = stage
  465. self._num_stages = stage.num_stages
  466. self._stage_forward_initialized = False
  467. self._stage_backward_initialized = False
  468. self.pipeline_order: dict[int, list[_Action | None]] | None = (
  469. self._get_pipeline_order()
  470. )
  471. def _initialize_stage(self, args, kwargs):
  472. if not self._stage_forward_initialized:
  473. # Prepare the communication needed for the pipeline schedule execution
  474. # This is needed because during execution we always perform a series of batch P2P ops
  475. # The first call of the batched P2P needs to involve the global group
  476. all_ops: list[dist.P2POp] = []
  477. all_ops.extend(self._stage._get_init_p2p_neighbors_ops())
  478. _wait_batch_p2p(_batch_p2p(all_ops))
  479. self._stage._prepare_forward_infra(self._n_microbatches, args, kwargs)
  480. self._stage_forward_initialized = True
  481. if self._has_backward and not self._stage_backward_initialized:
  482. self._stage._prepare_backward_infra(self._n_microbatches)
  483. self._stage_backward_initialized = True
  484. def step(
  485. self,
  486. *args,
  487. target=None,
  488. losses: list | None = None,
  489. return_outputs: bool = True,
  490. **kwargs,
  491. ):
  492. """
  493. Run one iteration of the pipeline schedule with *whole-batch* input.
  494. Will chunk the input into microbatches automatically, and go through the
  495. microbatches according to the schedule implementation.
  496. args: positional arguments to the model (as in non-pipeline case).
  497. kwargs: keyword arguments to the model (as in non-pipeline case).
  498. target: target for the loss function.
  499. losses: a list to store the losses for each microbatch.
  500. return_outputs: whether to return the outputs from the last stage.
  501. """
  502. if self._has_backward and not torch.is_grad_enabled():
  503. raise RuntimeError(
  504. "step() requires gradients to be enabled for backward computation; "
  505. "it should not be used under torch.no_grad() context. "
  506. "Please call eval() instead."
  507. )
  508. # Set the same has_backward flag for stage object
  509. self._stage.has_backward = self._has_backward
  510. # Clean per iteration
  511. self._stage.clear_runtime_states()
  512. # Split inputs into microbatches
  513. args_split, kwargs_split = self._split_inputs(args, kwargs)
  514. # Split target into microbatches
  515. if target is not None:
  516. targets_split = list(torch.tensor_split(target, self._n_microbatches))
  517. else:
  518. targets_split = None
  519. # Run microbatches
  520. self._step_microbatches(
  521. args_split, kwargs_split, targets_split, losses, return_outputs
  522. )
  523. # Return merged results per original format
  524. if self._stage.is_last and return_outputs:
  525. return self._merge_outputs(self._stage.output_chunks)
  526. else:
  527. return None
  528. def _get_pipeline_order(self) -> dict[int, list[_Action | None]] | None:
  529. """
  530. Returns the pipeline execution order as a schedule IR.
  531. The returned IR is a dictionary mapping rank IDs to lists of actions.
  532. Each action is either an _Action object representing computation to perform,
  533. or None representing a deliberate idle step.
  534. The None values are used to represent pipeline bubbles where a rank
  535. must wait for dependencies from other ranks before proceeding. However
  536. during execution, with the _PipelineScheduleRuntime, these Nones are
  537. skipped since the relevant communication (send/recv) will be scheduled and waited on.
  538. Returns:
  539. A dictionary mapping rank -> list of actions
  540. """
  541. return None
  542. class _ScheduleForwardOnly(PipelineScheduleSingle):
  543. """
  544. The forward-only schedule.
  545. Will go through all the microbatches and perform only the forward pass
  546. """
  547. def _step_microbatches(
  548. self,
  549. arg_mbs: list | None = None,
  550. kwarg_mbs: list | None = None,
  551. target_mbs: list | None = None,
  552. losses: list | None = None,
  553. return_outputs: bool = True,
  554. ):
  555. """
  556. Run one iteration of the pipeline schedule
  557. """
  558. if target_mbs is not None or losses is not None:
  559. raise RuntimeError(
  560. "Forward-only schedule does not support loss computation"
  561. )
  562. arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
  563. self._initialize_stage(arg_mbs[0], kwarg_mbs[0])
  564. # Delay send waits
  565. fwd_sends_to_wait: list[list[dist.Work]] = []
  566. # Run microbatches
  567. for i in range(self._n_microbatches):
  568. with record_function(f"Forward {i}"):
  569. ops = self._stage.get_fwd_recv_ops(i)
  570. works = _sorted_batch_p2p(ops, desc="fwd_recv")
  571. for work in works.values():
  572. _wait_batch_p2p(work)
  573. self._stage.forward_one_chunk(i, arg_mbs[i], kwarg_mbs[i]) # type: ignore[index]
  574. ops = self._stage.get_fwd_send_ops(i)
  575. works = _sorted_batch_p2p(ops, desc="fwd_send")
  576. fwd_sends_to_wait.extend(works.values())
  577. logger.debug("[%s] Forwarded microbatch %s", self._stage.stage_index, i)
  578. # Wait for all forward sends to finish
  579. # This should not have performance impact because by the time the first
  580. # backward arrives all the forward sends should have been finished.
  581. for work in fwd_sends_to_wait:
  582. _wait_batch_p2p(work)
  583. class ScheduleGPipe(PipelineScheduleSingle):
  584. """
  585. The GPipe schedule.
  586. Will go through all the microbatches in a fill-drain manner.
  587. """
  588. def _step_microbatches(
  589. self,
  590. arg_mbs: list | None = None,
  591. kwarg_mbs: list | None = None,
  592. target_mbs: list | None = None,
  593. losses: list | None = None,
  594. return_outputs: bool = True,
  595. ):
  596. """
  597. Run one iteration of the pipeline schedule with list of microbatches.
  598. Will go through all the microbatches according to the GPipe schedule.
  599. Args:
  600. microbatches: list of microbatch args.
  601. return_outputs: whether to return the outputs from the last stage.
  602. """
  603. arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
  604. self._initialize_stage(arg_mbs[0], kwarg_mbs[0])
  605. # Delay send waits
  606. fwd_sends_to_wait: list[list[dist.Work]] = []
  607. # Run microbatches
  608. for i in range(self._n_microbatches):
  609. with record_function(f"Forward {i}"):
  610. ops = self._stage.get_fwd_recv_ops(i)
  611. works = _sorted_batch_p2p(ops, desc="fwd_recv")
  612. for work in works.values():
  613. _wait_batch_p2p(work)
  614. output = self._stage.forward_one_chunk(
  615. i, arg_mbs[i], kwarg_mbs[i], save_forward_output=return_outputs
  616. ) # type: ignore[index]
  617. ops = self._stage.get_fwd_send_ops(i)
  618. works = _sorted_batch_p2p(ops, desc="fwd_send")
  619. fwd_sends_to_wait.extend(works.values())
  620. logger.debug("[%s] Forwarded microbatch %s", self._stage.stage_index, i)
  621. self._maybe_compute_loss(self._stage, output, target_mbs, i)
  622. # Wait for all forward sends to finish
  623. # This should not have performance impact because by the time the first
  624. # backward arrives all the forward sends should have been finished.
  625. for work in fwd_sends_to_wait:
  626. _wait_batch_p2p(work)
  627. # Run backward
  628. # Delay send waits
  629. bwd_sends_to_wait: list[list[dist.Work]] = []
  630. for i in range(self._n_microbatches):
  631. with record_function(f"Backward {i}"):
  632. ops = self._stage.get_bwd_recv_ops(i)
  633. works = _sorted_batch_p2p(ops, desc="bwd_recv")
  634. for work in works.values():
  635. _wait_batch_p2p(work)
  636. loss = self._maybe_get_loss(self._stage, i)
  637. self._stage.backward_one_chunk(
  638. i,
  639. loss=loss,
  640. last_backward=i == self._n_microbatches - 1,
  641. )
  642. ops = self._stage.get_bwd_send_ops(i)
  643. works = _sorted_batch_p2p(ops, desc="bwd_send")
  644. bwd_sends_to_wait.extend(works.values())
  645. logger.debug("[%s] Backwarded microbatch %s", self._stage.stage_index, i)
  646. # Wait for all backward sends to finish
  647. for work in bwd_sends_to_wait:
  648. _wait_batch_p2p(work)
  649. # Update losses if there is a container passed in
  650. self._update_losses(self._stage, losses)
  651. self._stage.perform_reduce_grad(self._n_microbatches if self.scale_grads else 1)
  652. def _get_pipeline_order(self) -> dict[int, list[_Action | None]] | None:
  653. """
  654. Returns the pipeline order for GPipe schedule.
  655. See base method in PipelineScheduleSingle for details on the schedule IR format.
  656. """
  657. pipeline_order = {}
  658. pp_group_size = self._num_stages
  659. for rank in range(pp_group_size):
  660. actions: list[_Action | None] = []
  661. # 1. Initial delay based on rank position
  662. warmup_delay = rank
  663. actions.extend([None] * warmup_delay)
  664. # 2. Forward passes for all microbatches
  665. for mb_idx in range(self._n_microbatches):
  666. actions.append(_Action(rank, _ComputationType.FORWARD, mb_idx))
  667. # 3. Wait period before backward passes can begin
  668. backward_delay = 3 * (pp_group_size - 1 - rank)
  669. actions.extend([None] * backward_delay)
  670. # 4. Backward passes for all microbatches
  671. for mb_idx in range(self._n_microbatches):
  672. actions.append(_Action(rank, _ComputationType.FULL_BACKWARD, mb_idx))
  673. pipeline_order[rank] = _add_reduce_grad(actions, self._n_microbatches)
  674. return pipeline_order # type: ignore[return-value]
  675. class Schedule1F1B(PipelineScheduleSingle):
  676. """
  677. The 1F1B schedule.
  678. Will perform one forward and one backward on the microbatches in steady state.
  679. """
  680. def __init__(
  681. self,
  682. stage: _PipelineStageBase,
  683. n_microbatches: int,
  684. loss_fn: Callable | None = None,
  685. args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None,
  686. kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None,
  687. output_merge_spec: dict[str, Any] | tuple[Any] | None = None,
  688. scale_grads: bool = True,
  689. ):
  690. super().__init__(
  691. stage=stage,
  692. n_microbatches=n_microbatches,
  693. loss_fn=loss_fn,
  694. args_chunk_spec=args_chunk_spec,
  695. kwargs_chunk_spec=kwargs_chunk_spec,
  696. output_merge_spec=output_merge_spec,
  697. scale_grads=scale_grads,
  698. )
  699. if n_microbatches < self._num_stages:
  700. raise ValueError(
  701. f"Number of microbatches ({n_microbatches}) must be greater than \
  702. or equal to the number of stages ({self._num_stages})."
  703. )
  704. def _step_microbatches(
  705. self,
  706. arg_mbs: list | None = None,
  707. kwarg_mbs: list | None = None,
  708. target_mbs: list | None = None,
  709. losses: list | None = None,
  710. return_outputs: bool = True,
  711. ):
  712. """
  713. Run one iteration of the pipeline schedule with list of microbatches.
  714. Will go through all the microbatches according to the 1F1B schedule.
  715. Args:
  716. microbatches: list of microbatch args.
  717. return_outputs: whether to return the outputs from the last stage.
  718. """
  719. arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
  720. self._initialize_stage(arg_mbs[0], kwarg_mbs[0])
  721. # Last stage has 1 warmup, second-to-last 2 warmups, ...
  722. # first stage `num_stages` warmups
  723. warmup_chunks = min(
  724. self._n_microbatches,
  725. self._num_stages - self._stage.stage_index,
  726. )
  727. # Chunk counters
  728. fwd_mb_index = 0
  729. bwd_mb_index = 0
  730. # Warmup phase
  731. send_work: list[dist.Work] = []
  732. fwd_sends = []
  733. for _ in range(warmup_chunks):
  734. # Receive activations
  735. fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index)
  736. _wait_batch_p2p(_batch_p2p(fwd_recvs, desc="fwd_recv"))
  737. # Compute
  738. output = self._stage.forward_one_chunk(
  739. fwd_mb_index,
  740. arg_mbs[fwd_mb_index],
  741. kwarg_mbs[fwd_mb_index],
  742. save_forward_output=return_outputs,
  743. ) # type: ignore[index]
  744. # Clear previous chunk's forward sends (hopefully they have well
  745. # finished, otherwise, we are heavily communication bound, in which
  746. # case it doesn't create a lot of benefit to compute next chunk
  747. # eagerly either)
  748. _wait_batch_p2p(send_work)
  749. # Send activations
  750. fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index)
  751. if fwd_mb_index != warmup_chunks - 1:
  752. # Safe to fire
  753. send_work = _batch_p2p(fwd_sends, desc="fwd_send")
  754. # otherwise:
  755. # The last forward send is left for fuse with first 1B in 1B1F below
  756. # Compute loss
  757. self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index)
  758. fwd_mb_index += 1
  759. # Now we should have send ops left over, to be fused with first 1B of 1B1F phase below.
  760. # 1B1F phase
  761. while True: # Don't worry, we have a break inside
  762. # We actually do 1B first as the `1B1F` name indicates, so prepare its recv ops
  763. bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index)
  764. # Now, we need to fire the fwd_sends and bwd_recvs together
  765. _wait_batch_p2p(_batch_p2p(fwd_sends + bwd_recvs, desc="fwd_send_bwd_recv"))
  766. # Backward one chunk
  767. loss = self._maybe_get_loss(self._stage, bwd_mb_index)
  768. self._stage.backward_one_chunk(
  769. bwd_mb_index,
  770. loss=loss,
  771. last_backward=bwd_mb_index == self._n_microbatches - 1,
  772. )
  773. # Get the bwd send ops, but don't fire, to be fused with the 1F below
  774. bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index)
  775. bwd_mb_index += 1
  776. if fwd_mb_index == self._n_microbatches:
  777. # We are done with 1B1F, so break with some left-over bwd_sends
  778. break
  779. # We prepare 1F of the `1B1F`
  780. fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index)
  781. # Fuse it with bwd_sends above
  782. _wait_batch_p2p(_batch_p2p(bwd_sends + fwd_recvs, desc="bwd_send_fwd_recv"))
  783. # Now do the fwd
  784. output = self._stage.forward_one_chunk(
  785. fwd_mb_index,
  786. arg_mbs[fwd_mb_index],
  787. kwarg_mbs[fwd_mb_index],
  788. save_forward_output=return_outputs,
  789. ) # type: ignore[index]
  790. # Compute loss
  791. self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index)
  792. # Get the fwd send ops, but don't fire, leave it for the next iter (wrap-around)
  793. fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index)
  794. fwd_mb_index += 1
  795. # Remember we still have some bwd_sends left over after the break? Now it is time to fire it
  796. send_work = _batch_p2p(bwd_sends, desc="bwd_send")
  797. # Cooldown
  798. while bwd_mb_index < self._n_microbatches:
  799. # prepare bwd recv ops
  800. bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index)
  801. _wait_batch_p2p(_batch_p2p(bwd_recvs, desc="bwd_recv"))
  802. # Backward one chunk
  803. loss = self._maybe_get_loss(self._stage, bwd_mb_index)
  804. self._stage.backward_one_chunk(
  805. bwd_mb_index,
  806. loss=loss,
  807. last_backward=bwd_mb_index == self._n_microbatches - 1,
  808. )
  809. # Clear previous chunk's backward sends (hopefully they have well finished)
  810. _wait_batch_p2p(send_work)
  811. # Get the bwd send ops, fire it
  812. bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index)
  813. send_work = _batch_p2p(bwd_sends, desc="bwd_send")
  814. bwd_mb_index += 1
  815. # Wait for the last backward send to finish
  816. _wait_batch_p2p(send_work)
  817. # Return losses if there is a container passed in
  818. self._update_losses(self._stage, losses)
  819. self._stage.perform_reduce_grad(self._n_microbatches if self.scale_grads else 1)
  820. def _get_pipeline_order(self) -> dict[int, list[_Action | None]] | None:
  821. """
  822. Returns the pipeline order for 1F1B schedule.
  823. See base method in PipelineScheduleSingle for details on the schedule IR format.
  824. """
  825. pipeline_order = {}
  826. pp_group_size = self._num_stages
  827. for rank in range(pp_group_size):
  828. actions: list[_Action | None] = []
  829. # 1. Warmup phase: initial delay based on rank
  830. actions.extend([None] * rank)
  831. # 2. Initial forward passes before 1F1B phase
  832. num_forward = (pp_group_size - 1) - rank
  833. forward_mb = 0
  834. for i in range(num_forward):
  835. actions.append(_Action(rank, _ComputationType.FORWARD, i))
  836. forward_mb = i
  837. # 3. Wait for backward to be ready
  838. wait_for_1f1b = max(0, 2 * (pp_group_size - 1 - rank))
  839. actions.extend([None] * wait_for_1f1b)
  840. # 4. 1F1B steady state phase
  841. backward_mb = 0
  842. remaining_forward = self._n_microbatches - num_forward
  843. while remaining_forward > 0:
  844. # One forward
  845. forward_mb += 1
  846. actions.append(_Action(rank, _ComputationType.FORWARD, forward_mb))
  847. remaining_forward -= 1
  848. # One backward
  849. actions.append(
  850. _Action(rank, _ComputationType.FULL_BACKWARD, backward_mb)
  851. )
  852. backward_mb += 1
  853. # 5. Cooldown phase: remaining backward passes
  854. remaining_backward = self._n_microbatches - backward_mb
  855. while remaining_backward > 0:
  856. # Add None and backward actions in alternating pattern
  857. # based on distance from the last stage
  858. if (pp_group_size - rank) > 0:
  859. actions.append(None)
  860. # Decrement the wait counter only if we still have backward passes to do
  861. if remaining_backward > 0:
  862. actions.append(
  863. _Action(rank, _ComputationType.FULL_BACKWARD, backward_mb)
  864. )
  865. backward_mb += 1
  866. remaining_backward -= 1
  867. else:
  868. # If we're at the last stage, just add backward actions without None
  869. actions.append(
  870. _Action(rank, _ComputationType.FULL_BACKWARD, backward_mb)
  871. )
  872. backward_mb += 1
  873. remaining_backward -= 1
  874. pipeline_order[rank] = _add_reduce_grad(actions, self._n_microbatches)
  875. return pipeline_order
  876. def _requires_reduce_grad(action_type: _ComputationType) -> bool:
  877. return action_type in (W, B)
  878. def _add_reduce_grad(
  879. actions: list[_Action | None], n_microbatches: int
  880. ) -> list[_Action | None]:
  881. """
  882. REDUCE_GRAD refers to joint across minibatches grad reduction.
  883. reduce_grad frees memory and we want to schedule it just after the last "backward"-like stage.
  884. """
  885. actions_with_reduce_grad: list[_Action | None] = []
  886. cnt: dict[int, int] = defaultdict(int)
  887. def _leaf_action(a, to_schedule):
  888. if _requires_reduce_grad(a.computation_type):
  889. stage_index = a.stage_index
  890. cnt[stage_index] += 1
  891. if cnt[stage_index] == n_microbatches:
  892. to_schedule.append(stage_index)
  893. for a in actions:
  894. if a is None:
  895. continue
  896. actions_with_reduce_grad.append(a)
  897. schedule_reduce_grad_stage_idxs: list[int] = []
  898. if a.computation_type == OVERLAP_F_B and a.sub_actions is not None:
  899. for sub_action in a.sub_actions:
  900. _leaf_action(sub_action, schedule_reduce_grad_stage_idxs)
  901. else:
  902. _leaf_action(a, schedule_reduce_grad_stage_idxs)
  903. for stage_idx in schedule_reduce_grad_stage_idxs:
  904. actions_with_reduce_grad.append(_Action(stage_idx, REDUCE_GRAD, None))
  905. return actions_with_reduce_grad
  906. def _add_unshard_reshard(
  907. compute_actions: list[_Action | None],
  908. max_active_stages: int = 3,
  909. ) -> list[_Action]:
  910. """Given a basic schedule involving only compute actions (F,B,W,OVERLAP_F_B), add UNSHARD/RESHARD actions for FSDP.
  911. UNSHARD refers to fetching the full contents of an FSDP-sharded layer, requiring an all-gather operation.
  912. RESHARD does the opposite, releasing memory (but doing no communication)
  913. We abandon the "timestep lock" during lowering
  914. max_active_stages controls how many prefetches we allow. It should be measured in mb and tuneable but in practice
  915. 3 stages is probably the thing we want?
  916. (to account for having one f and one b active, and something else prefetching?)
  917. """
  918. def next_stage_indices(count: int, next_actions: list[_Action | None]) -> list[int]:
  919. """Remove duplicates (same stage, different microbatch), find next 'count' stages that will do compute."""
  920. seen: set[int] = set()
  921. ret: list[int] = []
  922. for a in next_actions:
  923. if a is not None:
  924. # Handle OVERLAP_F_B actions by checking their sub_actions
  925. if a.computation_type == OVERLAP_F_B and a.sub_actions is not None:
  926. for sub_action in a.sub_actions:
  927. if sub_action.stage_index not in seen:
  928. seen.add(sub_action.stage_index)
  929. ret.append(sub_action.stage_index)
  930. if len(ret) >= count:
  931. break
  932. else:
  933. # Regular action
  934. if a.stage_index not in seen:
  935. seen.add(a.stage_index)
  936. ret.append(a.stage_index)
  937. if len(ret) == count:
  938. break
  939. return ret
  940. active_stages: set[int] = set()
  941. fsdp_aware_actions: list[_Action] = []
  942. def _unshard(stage_index: int):
  943. active_stages.add(stage_index)
  944. fsdp_aware_actions.append(_Action(stage_index, UNSHARD, None))
  945. def _reshard(stage_index: int):
  946. active_stages.remove(stage_index)
  947. fsdp_aware_actions.append(_Action(stage_index, RESHARD, None))
  948. for i, action in enumerate(compute_actions):
  949. if action is None:
  950. continue
  951. # We prefetch the next N stages we'll see, dropping existing stages to make room
  952. next_n = next_stage_indices(max_active_stages, compute_actions[i:])
  953. # Fetch needs to be ordered correctly, so don't use a set
  954. fetch = list(filter(lambda s: s not in active_stages, next_n))
  955. # Unclear what the best policy is for eviction, but we can maintain order so we do
  956. evict = list(filter(lambda s: s not in next_n, active_stages))
  957. # logger.debug(
  958. # "_add_unshard_reshard Step %d active: %s fetch %s, evict %s",
  959. # i,
  960. # active_stages,
  961. # fetch,
  962. # evict,
  963. # )
  964. for stage in evict:
  965. _reshard(stage)
  966. for stage in fetch:
  967. _unshard(stage)
  968. fsdp_aware_actions.append(action)
  969. # Reshard all remaining active stages after processing all operations
  970. for stage in list(active_stages):
  971. _reshard(stage)
  972. return fsdp_aware_actions
  973. def _merge_bw(
  974. compute_actions: list[_Action | None],
  975. ) -> list[_Action]:
  976. """Given a basic schedule involving only compute actions (F,I,W), merge adjacent I and W ops into B ops.
  977. (note: I = BACKWARD_INPUT, W = BACKWARD_WEIGHT, B = FULL_BACKWARD)
  978. B refers to running the whole backward (not separating grad_input and grad_weight), which can be more efficient
  979. in some cases.
  980. """
  981. merged_actions = []
  982. while compute_actions:
  983. action = compute_actions.pop(0)
  984. if action is None:
  985. continue
  986. # Remove any None actions and find the next non-None action
  987. while len(compute_actions) and compute_actions[0] is None:
  988. compute_actions.pop(0)
  989. # Get the next action if it exists
  990. next_action = compute_actions[0] if len(compute_actions) > 0 else None
  991. if (
  992. action.computation_type == BACKWARD_INPUT
  993. and next_action is not None
  994. and next_action.computation_type == BACKWARD_WEIGHT
  995. and action.stage_index == next_action.stage_index
  996. and action.microbatch_index == next_action.microbatch_index
  997. ):
  998. merged_actions.append(
  999. _Action(action.stage_index, FULL_BACKWARD, action.microbatch_index)
  1000. )
  1001. compute_actions.pop(0)
  1002. else:
  1003. merged_actions.append(action)
  1004. return merged_actions
  1005. def _add_send_recv(
  1006. compute_actions: dict[int, list[_Action]],
  1007. stage_to_rank: Callable[[int], int],
  1008. num_stages: int,
  1009. ) -> dict[int, list[_Action]]:
  1010. """
  1011. Transforms a compute-only schedule into a complete schedule with communication actions.
  1012. For actions with sub-actions (OVERLAP_F_B) we ensure that all the subactions have been
  1013. computed and the communication is ready
  1014. """
  1015. comm_actions: dict[int, list[_Action]] = {rank: [] for rank in compute_actions}
  1016. prev_actions: dict[int, set[_Action]] = {rank: set() for rank in compute_actions}
  1017. def _has_comms(action: _Action) -> bool:
  1018. if action.computation_type == F:
  1019. return action.stage_index != num_stages - 1 and stage_to_rank(
  1020. action.stage_index + 1
  1021. ) != stage_to_rank(action.stage_index)
  1022. elif action.computation_type in (BACKWARD_INPUT, FULL_BACKWARD):
  1023. return action.stage_index != 0 and stage_to_rank(
  1024. action.stage_index - 1
  1025. ) != stage_to_rank(action.stage_index)
  1026. return False
  1027. def _get_comms(action: _Action) -> tuple[_Action, _Action]:
  1028. if not _has_comms(action):
  1029. raise AssertionError(f"{action} is not a valid comm action")
  1030. stage_idx = action.stage_index
  1031. ctype = action.computation_type
  1032. mb_idx = action.microbatch_index
  1033. send = _Action(stage_idx, SEND_F if ctype == F else SEND_B, mb_idx)
  1034. recv_stage_idx = stage_idx + 1 if ctype == F else stage_idx - 1
  1035. recv = _Action(recv_stage_idx, RECV_F if ctype == F else RECV_B, mb_idx)
  1036. return send, recv
  1037. def _ready_to_schedule(action: _Action | None, prev_actions: set[_Action]) -> bool:
  1038. """We don't put our own recv ops in the schedule, we let a sender on another rank put our recv ops in place.
  1039. This helps ensure a sane (non-hanging) ordering of sends and recvs.
  1040. But it also means we might not be able to schedule our next compute action yet.
  1041. """
  1042. if action is None:
  1043. return True
  1044. elif action.computation_type == F and action.stage_index != 0:
  1045. if (
  1046. _Action(action.stage_index, RECV_F, action.microbatch_index)
  1047. in prev_actions
  1048. ):
  1049. return True
  1050. elif (
  1051. _Action(action.stage_index - 1, F, action.microbatch_index)
  1052. in prev_actions
  1053. ):
  1054. return True
  1055. return False
  1056. elif (
  1057. action.computation_type in (BACKWARD_INPUT, FULL_BACKWARD)
  1058. and action.stage_index != num_stages - 1
  1059. ):
  1060. if (
  1061. _Action(action.stage_index, RECV_B, action.microbatch_index)
  1062. in prev_actions
  1063. ):
  1064. return True
  1065. elif (
  1066. _Action(action.stage_index + 1, BACKWARD_INPUT, action.microbatch_index)
  1067. in prev_actions
  1068. ):
  1069. return True
  1070. elif (
  1071. _Action(action.stage_index + 1, FULL_BACKWARD, action.microbatch_index)
  1072. in prev_actions
  1073. ):
  1074. return True
  1075. return False
  1076. else:
  1077. return True
  1078. while compute_actions:
  1079. progress = False
  1080. # go in order of ranks even if dict keys aren't ordered
  1081. for rank in sorted(compute_actions):
  1082. if not (len(compute_actions[rank]) > 0):
  1083. raise AssertionError(f"{rank=}, {len(compute_actions[rank])=}")
  1084. action = compute_actions[rank][0]
  1085. # handle case where parent action (e.g. OVERLAP_F_B) can be comprised of subactions
  1086. if action is not None and action.sub_actions is not None:
  1087. all_actions = action.sub_actions
  1088. else:
  1089. all_actions = (action,)
  1090. if not all(_ready_to_schedule(a, prev_actions[rank]) for a in all_actions):
  1091. continue
  1092. # The action's dependencies are satisfied, so add to schedule
  1093. if action is not None:
  1094. comm_actions[rank].append(action)
  1095. for a in all_actions:
  1096. prev_actions[rank].add(a)
  1097. if _has_comms(a):
  1098. send, recv = _get_comms(a)
  1099. # TODO we can avoid send/recv if the 2 stages are on the same rank.
  1100. # should we avoid that in the runtime or here?
  1101. comm_actions[rank].append(send)
  1102. prev_actions[rank].add(send)
  1103. comm_actions[stage_to_rank(recv.stage_index)].append(recv)
  1104. prev_actions[stage_to_rank(recv.stage_index)].add(recv)
  1105. compute_actions[rank].pop(0)
  1106. if len(compute_actions[rank]) == 0:
  1107. del compute_actions[rank]
  1108. progress = True
  1109. if not progress:
  1110. raise AssertionError(
  1111. "Malformed compute schedule, can't schedule sends/recvs"
  1112. )
  1113. return comm_actions
  1114. def _validate_schedule(
  1115. actions: dict[int, list[_Action | None]],
  1116. pp_group_size: int,
  1117. num_stages: int,
  1118. num_microbatches: int,
  1119. ) -> dict[int, int]:
  1120. if not (len(actions) == pp_group_size):
  1121. raise AssertionError(
  1122. f"Schedule has incorrect number of ranks - expected {pp_group_size}, actual {len(actions)}"
  1123. )
  1124. for rank in range(pp_group_size):
  1125. if rank not in actions:
  1126. raise AssertionError(f"Schedule is missing actions for rank {rank}")
  1127. # We will count all the actions per stage and ensure they happen in a valid order
  1128. # (e.g. F before (B, I) before W for a given microbatch)
  1129. stage_actions: dict[int, dict[_ComputationType, set]] = {
  1130. stage_id: {
  1131. F: set(),
  1132. B: set(),
  1133. I: set(),
  1134. W: set(),
  1135. }
  1136. for stage_id in range(num_stages)
  1137. }
  1138. stage_index_to_rank_mapping = {}
  1139. def _process_action(action: _Action, rank: int, step: int):
  1140. """Process a single action and update stage_actions and stage_index_to_rank_mapping"""
  1141. s_id = action.stage_index
  1142. ctype = action.computation_type
  1143. mb_id = action.microbatch_index
  1144. if ctype == F:
  1145. stage_actions[s_id][F].add(mb_id)
  1146. elif ctype == B:
  1147. if mb_id not in stage_actions[s_id][F]:
  1148. error_msg = (
  1149. f"Rank {rank}, step {step}: Running Full Backward for stage {s_id}, "
  1150. f"microbatch {mb_id} without first running Forward"
  1151. )
  1152. formatted_schedule = _format_pipeline_order(
  1153. actions, error_step_number=step
  1154. )
  1155. full_error_msg = (
  1156. f"{error_msg}\n\nFull pipeline schedule:\n{formatted_schedule}"
  1157. )
  1158. raise AssertionError(full_error_msg)
  1159. stage_actions[s_id][B].add(mb_id)
  1160. elif ctype == I:
  1161. if mb_id not in stage_actions[s_id][F]:
  1162. error_msg = (
  1163. f"Rank {rank}, step {step}: Running Backward Input for stage {s_id}, "
  1164. f"microbatch {mb_id} without first running Forward"
  1165. )
  1166. formatted_schedule = _format_pipeline_order(
  1167. actions, error_step_number=step
  1168. )
  1169. full_error_msg = (
  1170. f"{error_msg}\n\nFull pipeline schedule:\n{formatted_schedule}"
  1171. )
  1172. raise AssertionError(full_error_msg)
  1173. stage_actions[s_id][I].add(mb_id)
  1174. elif ctype == W:
  1175. if mb_id not in stage_actions[s_id][I]:
  1176. error_msg = (
  1177. f"Rank {rank}, step {step}: Running Backward Weight for stage {s_id}, "
  1178. f"microbatch {mb_id} without first running Backward Input"
  1179. )
  1180. formatted_schedule = _format_pipeline_order(
  1181. actions, error_step_number=step
  1182. )
  1183. full_error_msg = (
  1184. f"{error_msg}\n\nFull pipeline schedule:\n{formatted_schedule}"
  1185. )
  1186. raise AssertionError(full_error_msg)
  1187. stage_actions[s_id][W].add(mb_id)
  1188. if s_id not in stage_index_to_rank_mapping:
  1189. stage_index_to_rank_mapping[s_id] = rank
  1190. else:
  1191. existing_rank = stage_index_to_rank_mapping[s_id]
  1192. if not (rank == existing_rank):
  1193. raise AssertionError(
  1194. f"Rank {rank}, step {step}: Stage {s_id} is assigned to both rank {rank} and rank {existing_rank}"
  1195. )
  1196. for rank in actions:
  1197. for step, action in enumerate(actions[rank]):
  1198. if action is None:
  1199. continue
  1200. if not isinstance(action, _Action):
  1201. raise AssertionError(
  1202. f"Rank {rank}, step {step}: Got an invalid action: {action}, expected instance of _Action"
  1203. )
  1204. # Check if action has sub_actions
  1205. if action.sub_actions is not None:
  1206. # Process each sub_action instead of the main action
  1207. for sub_action in action.sub_actions:
  1208. _process_action(sub_action, rank, step)
  1209. else:
  1210. # Process the main action normally
  1211. _process_action(action, rank, step)
  1212. for s_id in stage_actions:
  1213. f_mb = len(stage_actions[s_id][F])
  1214. b_mb = len(stage_actions[s_id][B])
  1215. i_mb = len(stage_actions[s_id][I])
  1216. w_mb = len(stage_actions[s_id][W])
  1217. if not (f_mb == num_microbatches):
  1218. raise AssertionError(
  1219. f"Got {f_mb} {F} microbatches for stage {s_id}, expected {num_microbatches}"
  1220. )
  1221. if not (i_mb == w_mb):
  1222. raise AssertionError(
  1223. f"Invalid backward microbatches for stage {s_id}: I and W must have equal counts, \
  1224. but got I={i_mb}, W={w_mb}"
  1225. )
  1226. if not (b_mb + (i_mb + w_mb) // 2 == num_microbatches):
  1227. raise AssertionError(
  1228. f"Invalid backward microbatches for stage {s_id}: expected {num_microbatches} total backwards, \
  1229. but got B={b_mb}, I={i_mb}, W={w_mb}"
  1230. )
  1231. return stage_index_to_rank_mapping
  1232. class PipelineScheduleMulti(_PipelineSchedule):
  1233. """
  1234. Base class for multi-stage schedules.
  1235. Implements the `step` method.
  1236. Gradients are scaled by num_microbatches depending on the `scale_grads` argument, defaulting to True. This setting
  1237. should match the configuration of your loss_fn, which may either average losses (scale_grads=True)
  1238. or sum losses (scale_grads=False).
  1239. """
  1240. def __init__(
  1241. self,
  1242. stages: list[_PipelineStageBase],
  1243. n_microbatches: int,
  1244. loss_fn: Callable | None = None,
  1245. args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None,
  1246. kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None,
  1247. output_merge_spec: dict[str, Any] | tuple[Any] | None = None,
  1248. use_full_backward: bool | None = None,
  1249. scale_grads: bool = True,
  1250. backward_requires_autograd: bool = True,
  1251. ):
  1252. # Init parent
  1253. super().__init__(
  1254. n_microbatches=n_microbatches,
  1255. loss_fn=loss_fn,
  1256. args_chunk_spec=args_chunk_spec,
  1257. kwargs_chunk_spec=kwargs_chunk_spec,
  1258. output_merge_spec=output_merge_spec,
  1259. scale_grads=scale_grads,
  1260. )
  1261. # Self attributes
  1262. self._stages = stages
  1263. self._num_stages = stages[0].num_stages
  1264. self.pp_group_size = stages[0].group_size
  1265. self.rank = stages[0].group_rank
  1266. # Set the pipeline stage states
  1267. self.stage_index_to_group_rank = generate_stage_to_rank_mapping(
  1268. self.pp_group_size, self._num_stages
  1269. )
  1270. for stage in self._stages:
  1271. stage.stage_index_to_group_rank = self.stage_index_to_group_rank
  1272. self._stages_forward_initialized = False
  1273. self._stages_backward_initialized = False
  1274. # avoid putting a reference to 'self' inside the lambda, it creates a ref cycle
  1275. has_loss: bool = self._loss_fn is not None
  1276. self._should_compute_loss = lambda stage: stage.is_last and has_loss
  1277. # This will be set during init of derived schedules
  1278. self.pipeline_order: dict[int, list[_Action | None]] = {}
  1279. # When using a custom backward function, we may or may not need autograd to be used
  1280. # for the backward pass. This flag is used to determine whether or torch.is_grad_enabled()
  1281. # check should be performed before the step function.
  1282. self._backward_requires_autograd = backward_requires_autograd
  1283. if use_full_backward is not None:
  1284. logger.warning(
  1285. "Deprecation warning: 'use_full_backward' is no longer supported. "
  1286. "Simply stop passing it, and everything should still work fine."
  1287. )
  1288. def _initialize_stages(self, args: tuple[Any, ...], kwargs):
  1289. if not self._stages_forward_initialized:
  1290. # Prepare the communication needed for the pipeline schedule execution
  1291. # This is needed because during execution we always perform a series of batch P2P ops
  1292. # The first call of the batched P2P needs to involve the global group
  1293. all_ops: list[dist.P2POp] = []
  1294. for stage in self._stages:
  1295. all_ops.extend(stage._get_init_p2p_neighbors_ops())
  1296. _wait_batch_p2p(_batch_p2p(all_ops))
  1297. # may be 'none' value (if this stage sends its output shapes to the next stage via P2P)
  1298. # or real value (if this stage and next stage are on the same device)
  1299. next_stage_args: tuple[Any, ...] = tuple()
  1300. for stage in self._stages:
  1301. if stage.is_first:
  1302. next_stage_args = stage._prepare_forward_infra(
  1303. self._n_microbatches, args, kwargs
  1304. )
  1305. else:
  1306. next_stage_args = stage._prepare_forward_infra(
  1307. self._n_microbatches, next_stage_args, kwargs
  1308. )
  1309. self._stages_forward_initialized = True
  1310. if self._has_backward and not self._stages_backward_initialized:
  1311. for stage in self._stages:
  1312. stage._prepare_backward_infra(self._n_microbatches)
  1313. self._stages_backward_initialized = True
  1314. def _validate_and_set_stage_mapping(
  1315. self, actions: dict[int, list[_Action | None]]
  1316. ) -> None:
  1317. """
  1318. Allocates the stage index to rank mapping which is needed for communication
  1319. """
  1320. self.stage_index_to_group_rank = _validate_schedule(
  1321. actions,
  1322. self.pp_group_size,
  1323. self._num_stages,
  1324. self._n_microbatches,
  1325. )
  1326. for stage in self._stages:
  1327. stage.stage_index_to_group_rank = self.stage_index_to_group_rank
  1328. def _dump_csv(self, filename):
  1329. """Dump a CSV representation of the schedule into a file with the provided filename."""
  1330. with open(filename, "w", newline="") as csvfile:
  1331. writer = csv.writer(csvfile)
  1332. for rank in self.pipeline_order:
  1333. writer.writerow(self.pipeline_order[rank])
  1334. def _load_csv(self, filename, format="compute_only"):
  1335. """Load a CSV representation of the schedule from a file with the provided filename.
  1336. This API will most likely get renamed/refactored so is marked as internal for now.
  1337. format must be "compute_only" for PipelineScheduleMulti.
  1338. """
  1339. if format != "compute_only":
  1340. raise AssertionError(f'format must be "compute_only", got {format}')
  1341. with open(filename, newline="") as csvfile:
  1342. reader = csv.reader(csvfile)
  1343. for rank, row in enumerate(reader):
  1344. self.pipeline_order[rank] = [_Action.from_str(s) for s in row]
  1345. # Validates the order of the pipeline actions and infers the stage_to_rank_mapping.
  1346. # This will overwrite the default stage_to_rank_mapping created in the constructor
  1347. self._validate_and_set_stage_mapping(self.pipeline_order)
  1348. def step(
  1349. self,
  1350. *args,
  1351. target=None,
  1352. losses: list | None = None,
  1353. return_outputs: bool = True,
  1354. **kwargs,
  1355. ):
  1356. """
  1357. Run one iteration of the pipeline schedule with *whole-batch* input.
  1358. Will chunk the input into microbatches automatically, and go through the
  1359. microbatches according to the schedule implementation.
  1360. args: positional arguments to the model (as in non-pipeline case).
  1361. kwargs: keyword arguments to the model (as in non-pipeline case).
  1362. target: target for the loss function.
  1363. losses: a list to store the losses for each microbatch.
  1364. return_outputs: whether to return the outputs from the last stage.
  1365. """
  1366. if (
  1367. self._has_backward
  1368. and self._backward_requires_autograd
  1369. and not torch.is_grad_enabled()
  1370. ):
  1371. raise RuntimeError(
  1372. "step() requires gradients to be enabled for backward computation; "
  1373. "it should not be used under torch.no_grad() context. "
  1374. "Please call eval() instead."
  1375. )
  1376. # Set the same has_backward flag for stage object
  1377. for stage in self._stages:
  1378. stage.has_backward = self._has_backward
  1379. # Clean per iteration
  1380. for stage in self._stages:
  1381. stage.clear_runtime_states()
  1382. # Split inputs into microbatches
  1383. args_split, kwargs_split = self._split_inputs(args, kwargs)
  1384. # Split target into microbatches
  1385. if target is not None:
  1386. targets_split = list(torch.tensor_split(target, self._n_microbatches))
  1387. else:
  1388. targets_split = None
  1389. # Run microbatches
  1390. self._step_microbatches(
  1391. args_split, kwargs_split, targets_split, losses, return_outputs
  1392. )
  1393. # Return merged results per original format
  1394. for stage in self._stages:
  1395. if stage.is_last and return_outputs:
  1396. return self._merge_outputs(stage.output_chunks)
  1397. # Does not contain the last stage or we do not return output chunks
  1398. return None
  1399. def _step_microbatches(
  1400. self,
  1401. arg_mbs: list | None = None,
  1402. kwarg_mbs: list | None = None,
  1403. target_mbs: list | None = None,
  1404. losses: list | None = None,
  1405. return_outputs: bool = True,
  1406. ):
  1407. """
  1408. Operate on the microbatches for looped schedules (multiple stages on each rank).
  1409. TODO: Does not use sorted_batch_isend_irecv(). As a result, this schedule does
  1410. not support models with skip connections.
  1411. """
  1412. arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
  1413. self._initialize_stages(arg_mbs[0], kwarg_mbs[0])
  1414. # Based on the plan in Step 1 created in __init__:
  1415. # 2. Perform communication based on the pipeline_order
  1416. stage_index_to_stage: dict[int, _PipelineStageBase] = {
  1417. stage.stage_index: stage for stage in self._stages
  1418. }
  1419. # determine prev_rank and next_rank based on which ranks are next to
  1420. # the stages in the pipeline_order
  1421. all_prev_ranks: set[int] = set()
  1422. all_next_ranks: set[int] = set()
  1423. for stage_index in stage_index_to_stage:
  1424. # TODO: assumption that stages only communicate from distances of +1/-1 (no skip connections)
  1425. if stage_index > 0:
  1426. all_prev_ranks.add(self.stage_index_to_group_rank[stage_index - 1])
  1427. if stage_index < self._num_stages - 1:
  1428. all_next_ranks.add(self.stage_index_to_group_rank[stage_index + 1])
  1429. # count either full_backward or backward_weight together, to determine when to sync DP grads
  1430. backward_counter: Counter[int] = Counter()
  1431. for time_step, action in enumerate(self.pipeline_order[self.rank]):
  1432. try:
  1433. ops: list[dist.P2POp] = []
  1434. if action is not None:
  1435. computation_type = action.computation_type
  1436. mb_index = action.microbatch_index
  1437. stage_index = action.stage_index
  1438. if mb_index is None:
  1439. raise AssertionError(
  1440. "All currently supported action types require valid microbatch_index"
  1441. )
  1442. if computation_type == _ComputationType.FORWARD:
  1443. # perform forward computation
  1444. stage = stage_index_to_stage[stage_index]
  1445. output = stage.forward_one_chunk(
  1446. mb_index,
  1447. arg_mbs[mb_index],
  1448. kwarg_mbs[mb_index],
  1449. save_forward_output=return_outputs,
  1450. )
  1451. self._maybe_compute_loss(stage, output, target_mbs, mb_index)
  1452. ops.extend(stage.get_fwd_send_ops(mb_index))
  1453. elif computation_type == _ComputationType.FULL_BACKWARD:
  1454. # perform backward computation
  1455. stage = stage_index_to_stage[stage_index]
  1456. loss = self._maybe_get_loss(stage, mb_index)
  1457. backward_counter[stage_index] += 1
  1458. last_backward = (
  1459. backward_counter[stage_index] == self._n_microbatches
  1460. )
  1461. grad_scale_factor = (
  1462. self._n_microbatches if self.scale_grads else 1
  1463. )
  1464. stage.backward_one_chunk(
  1465. mb_index,
  1466. loss=loss,
  1467. full_backward=True,
  1468. last_backward=last_backward,
  1469. )
  1470. if last_backward:
  1471. stage.scale_grads(grad_scale_factor)
  1472. ops.extend(stage.get_bwd_send_ops(mb_index))
  1473. elif computation_type == _ComputationType.BACKWARD_INPUT:
  1474. # perform backward computation
  1475. stage = stage_index_to_stage[stage_index]
  1476. loss = self._maybe_get_loss(stage, mb_index)
  1477. stage.backward_one_chunk(
  1478. mb_index,
  1479. loss=loss,
  1480. full_backward=False,
  1481. last_backward=False,
  1482. )
  1483. ops.extend(stage.get_bwd_send_ops(mb_index))
  1484. elif computation_type == _ComputationType.BACKWARD_WEIGHT:
  1485. # perform weight update
  1486. stage = stage_index_to_stage[stage_index]
  1487. backward_counter[stage_index] += 1
  1488. last_backward = (
  1489. backward_counter[stage_index] == self._n_microbatches
  1490. )
  1491. grad_scale_factor = (
  1492. self._n_microbatches if self.scale_grads else 1
  1493. )
  1494. stage.backward_weight_one_chunk(
  1495. mb_index,
  1496. last_backward=last_backward,
  1497. )
  1498. if last_backward:
  1499. stage.scale_grads(grad_scale_factor)
  1500. else:
  1501. raise ValueError(f"Unknown computation type {computation_type}")
  1502. # Look at the neighboring ranks for this current timestep and determine whether
  1503. # this current rank needs to do any recv communication
  1504. for prev_rank in all_prev_ranks:
  1505. prev_rank_ops = self.pipeline_order[prev_rank]
  1506. prev_rank_action = None
  1507. if time_step < len(prev_rank_ops):
  1508. prev_rank_action = prev_rank_ops[time_step]
  1509. if prev_rank_action is not None:
  1510. computation_type = prev_rank_action.computation_type
  1511. mb_index = prev_rank_action.microbatch_index
  1512. stage_index = prev_rank_action.stage_index
  1513. if mb_index is None:
  1514. raise AssertionError(
  1515. "All currently supported action types require valid microbatch_index"
  1516. )
  1517. # Only handle sends for the forward from a previous rank
  1518. if computation_type == _ComputationType.FORWARD:
  1519. # If not the last stage, then receive fwd activations
  1520. if stage_index + 1 in stage_index_to_stage:
  1521. # TODO: We are assuming that stage will always receive from stage-1
  1522. # however that is not necessarily true of get_fwd_recv_ops
  1523. stage = stage_index_to_stage[stage_index + 1]
  1524. ops.extend(stage.get_fwd_recv_ops(mb_index))
  1525. elif computation_type in (
  1526. FULL_BACKWARD,
  1527. BACKWARD_INPUT,
  1528. BACKWARD_WEIGHT,
  1529. ):
  1530. # Previous rank doing backward has no influence for the current rank forward recv
  1531. pass
  1532. else:
  1533. raise ValueError(
  1534. f"Unknown computation type {computation_type}"
  1535. )
  1536. for next_rank in all_next_ranks:
  1537. next_rank_ops = self.pipeline_order[next_rank]
  1538. next_rank_action = None
  1539. if time_step < len(next_rank_ops):
  1540. next_rank_action = next_rank_ops[time_step]
  1541. if next_rank_action is not None:
  1542. computation_type = next_rank_action.computation_type
  1543. mb_index = next_rank_action.microbatch_index
  1544. stage_index = next_rank_action.stage_index
  1545. if not (mb_index is not None):
  1546. raise AssertionError(
  1547. "All currently supported action types require valid microbatch_index"
  1548. )
  1549. # Only handle receives for the backwards from a next rank
  1550. if computation_type in (FORWARD, BACKWARD_WEIGHT):
  1551. # Next rank doing forward or weight update has no influence for the current rank backward recv
  1552. pass
  1553. elif computation_type in (BACKWARD_INPUT, FULL_BACKWARD):
  1554. # If not the first stage, then receive bwd gradients
  1555. if stage_index - 1 in stage_index_to_stage:
  1556. # TODO: We are assuming that stage will always receive from stage+1
  1557. # however that is not necessarily true of get_bwd_recv_ops
  1558. stage = stage_index_to_stage[stage_index - 1]
  1559. ops.extend(stage.get_bwd_recv_ops(mb_index))
  1560. else:
  1561. raise ValueError(
  1562. f"Unknown computation type {computation_type}"
  1563. )
  1564. # do the communication
  1565. _wait_batch_p2p(_batch_p2p(ops))
  1566. except Exception as e:
  1567. logger.error( # noqa: G200
  1568. "[Rank %s] pipeline schedule %s caught the following exception '%s' \
  1569. at time_step %s when running action %s",
  1570. self.rank,
  1571. self.__class__.__name__,
  1572. str(e),
  1573. time_step,
  1574. action,
  1575. )
  1576. logger.error(
  1577. "%s",
  1578. _format_pipeline_order(
  1579. self.pipeline_order, error_step_number=time_step
  1580. ),
  1581. )
  1582. raise e
  1583. # Return losses if there is a container passed in
  1584. self._update_losses(self._stages, losses)
  1585. @dataclass
  1586. class _PipelineContext:
  1587. """Context passed to custom functions during pipeline execution."""
  1588. schedule_ref: _PipelineSchedule
  1589. arg_mbs: list[tuple] | None = None
  1590. kwarg_mbs: list[dict] | None = None
  1591. target_mbs: list | None = None
  1592. losses: list | None = None
  1593. class _CustomFunctionProtocol(Protocol):
  1594. def __call__(self, action: _Action, ctx: _PipelineContext) -> None: ...
  1595. class _PipelineScheduleRuntime(PipelineScheduleMulti):
  1596. """
  1597. Provides a simple runtime that requires a 'schedule IR' including specified communication operations.
  1598. Can be instantiated directly by creating _PipelineScheduleRuntime and calling load_csv, or can be
  1599. subclassed and the subclass can be responsible for creating a schedule IR.
  1600. """
  1601. def __init__(self, *args, **kwargs):
  1602. super().__init__(*args, **kwargs)
  1603. # Action to custom function mapping
  1604. self._comp_type_to_function_map: dict[_ComputationType, Callable] = {}
  1605. # count either full_backward or backward_weight together, to determine when to sync DP grads
  1606. self.backward_counter: Counter[int] = Counter()
  1607. # recv ops indexed by (stage_idx, mb_idx) need to be waited on before use
  1608. self.bwd_recv_ops: dict[tuple[int, int], list[dist.Work]] = {}
  1609. self.fwd_recv_ops: dict[tuple[int, int], list[dist.Work]] = {}
  1610. # we track which stages are 'active' when used with FSDP, and wait on unshard ops before computing on stages
  1611. self.unshard_ops: dict[int, list[UnshardHandle]] = defaultdict(list)
  1612. self.unsharded_stages = set()
  1613. def register_custom_function(
  1614. self,
  1615. computation_type: _ComputationType,
  1616. custom_function: _CustomFunctionProtocol,
  1617. ) -> None:
  1618. """
  1619. Register a custom function to be executed for a specific computation type.
  1620. Args:
  1621. computation_type: The computation type for which to register the custom function
  1622. custom_function: The function to execute when this computation type is encountered.
  1623. Must have signature: (action: _Action, ctx: _PipelineContext) -> None
  1624. """
  1625. # Ensure that the computation type is valid
  1626. if computation_type not in (
  1627. FORWARD,
  1628. FULL_BACKWARD,
  1629. BACKWARD_INPUT,
  1630. BACKWARD_WEIGHT,
  1631. OVERLAP_F_B,
  1632. UNSHARD,
  1633. RESHARD,
  1634. REDUCE_GRAD,
  1635. ):
  1636. raise ValueError(
  1637. f"Invalid computation type {computation_type}. Only FORWARD, FULL_BACKWARD, \
  1638. BACKWARD_INPUT, BACKWARD_WEIGHT, OVERLAP_F_B, UNSHARD, RESHARD and REDUCE_GRAD are supported."
  1639. )
  1640. # Check if computation_type is already registered
  1641. if computation_type in self._comp_type_to_function_map:
  1642. logger.warning(
  1643. "Computation type %s is already registered. "
  1644. "Overwriting the existing custom function.",
  1645. computation_type,
  1646. )
  1647. self._comp_type_to_function_map[computation_type] = custom_function
  1648. def _prepare_schedule_with_comms(
  1649. self,
  1650. actions: dict[int, list[_Action | None]],
  1651. format: str = "compute_only",
  1652. ):
  1653. """
  1654. Given an in-memory representation for a simple compute-only schedule, lower it to a complex schedule including
  1655. communication actions. Stores the schedule in self, and must be called before running step_mo()
  1656. """
  1657. # validate the provided actions are valid and overrides the default stage_index_to_group_rank
  1658. super()._validate_and_set_stage_mapping(actions)
  1659. self.pipeline_order_with_comms: dict[int, list[_Action]] = {}
  1660. if format == "compute_comms":
  1661. for rank in actions:
  1662. self.pipeline_order_with_comms[rank] = []
  1663. for action in actions[rank]:
  1664. if action is None:
  1665. raise AssertionError(
  1666. f"Expected action to be not None, got {type(action)}"
  1667. )
  1668. self.pipeline_order_with_comms[rank].append(action)
  1669. # TODO what level of validation should we offer for compute+comms schedule?
  1670. elif format == "compute_only":
  1671. # Validate that the schedule does not have comms already added to it
  1672. for rank, action_list in actions.items():
  1673. for i, action in enumerate(action_list):
  1674. if action is not None:
  1675. if not action.is_compute_op:
  1676. raise ValueError(
  1677. f"Expected compute-only schedule but found communication action "
  1678. f"'{action}' at rank {rank}, position {i}. "
  1679. f"Communication actions (e.g. SEND_F, RECV_F, etc.) "
  1680. f"should not be present when format='compute_only'."
  1681. )
  1682. # Perform schedule lowering
  1683. for rank in actions:
  1684. self.pipeline_order_with_comms[rank] = _add_unshard_reshard(
  1685. actions[rank]
  1686. )
  1687. self.pipeline_order_with_comms[rank] = _add_reduce_grad( # type: ignore[assignment]
  1688. self.pipeline_order_with_comms[rank], # type: ignore[arg-type]
  1689. self._n_microbatches,
  1690. )
  1691. self.pipeline_order_with_comms = _add_send_recv(
  1692. self.pipeline_order_with_comms,
  1693. stage_to_rank=lambda s: self.stage_index_to_group_rank[s],
  1694. num_stages=self._num_stages,
  1695. )
  1696. else:
  1697. raise NotImplementedError(f"{format=} is not implemented")
  1698. def _load_csv(self, filename: str, format: str = "compute_only"):
  1699. """Loads a csv in simple format and then lowers it to include communication actions
  1700. format must be either "compute_only" or "compute_comms". If compute_only, the lowering passes
  1701. will automatically be run to generate a compute_comms schedule.
  1702. """
  1703. if format == "compute_only":
  1704. # this will populate self.pipeline_order
  1705. super()._load_csv(filename)
  1706. # this will populate self.pipeline_order_with_comms
  1707. self._prepare_schedule_with_comms(self.pipeline_order)
  1708. elif format == "compute_comms":
  1709. actions = {}
  1710. with open(filename, newline="") as csvfile:
  1711. reader = csv.reader(csvfile)
  1712. for rank, row in enumerate(reader):
  1713. actions[rank] = [_Action.from_str(s) for s in row]
  1714. self._prepare_schedule_with_comms(actions, format=format)
  1715. else:
  1716. raise NotImplementedError(f"{format=} is not implemented")
  1717. def _dump_csv(self, filename: str, format: str = "compute_comms"):
  1718. """Dump a CSV representation of the schedule into a file with the provided filename."""
  1719. if format == "compute_only":
  1720. if self.pipeline_order is None:
  1721. raise AssertionError("Compute only schedule must be available")
  1722. with open(filename, "w", newline="") as csvfile:
  1723. writer = csv.writer(csvfile)
  1724. for rank in self.pipeline_order:
  1725. writer.writerow(self.pipeline_order[rank])
  1726. elif format == "compute_comms":
  1727. if self.pipeline_order_with_comms is None:
  1728. raise AssertionError(
  1729. "Must initialize compute_comms schedule before dump_csv"
  1730. )
  1731. with open(filename, "w", newline="") as csvfile:
  1732. writer = csv.writer(csvfile)
  1733. for rank in self.pipeline_order_with_comms:
  1734. writer.writerow(self.pipeline_order_with_comms[rank])
  1735. def _simulate(self):
  1736. return _simulate_comms_compute(
  1737. self.pipeline_order_with_comms,
  1738. lambda s: self.stage_index_to_group_rank[s],
  1739. self._num_stages,
  1740. )
  1741. def _assert_unsharded(self, stage: _PipelineStageBase):
  1742. """If an unshard is active for `stage_idx`, wait() it and mark `stage_idx` unshared."""
  1743. stage_uses_fsdp = isinstance(stage.submod, FSDPModule)
  1744. if stage_uses_fsdp:
  1745. stage_idx = stage.stage_index
  1746. if stage_idx in self.unshard_ops:
  1747. for op in self.unshard_ops[stage_idx]:
  1748. op.wait()
  1749. del self.unshard_ops[stage_idx]
  1750. self.unsharded_stages.add(stage_idx)
  1751. if stage_idx not in self.unsharded_stages:
  1752. raise AssertionError(f"Attempted to compute on sharded {stage_idx=}")
  1753. def _step_microbatches(
  1754. self,
  1755. arg_mbs: list | None = None,
  1756. kwarg_mbs: list | None = None,
  1757. target_mbs: list | None = None,
  1758. losses: list | None = None,
  1759. return_outputs: bool = True,
  1760. ):
  1761. """
  1762. Operate on the microbatches for looped schedules (multiple stages on each rank).
  1763. TODO: Does not use sorted_batch_isend_irecv(). As a result, this schedule does
  1764. not support models with skip connections.
  1765. """
  1766. arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
  1767. self._initialize_stages(arg_mbs[0], kwarg_mbs[0])
  1768. # Based on the plan in Step 1 created in __init__:
  1769. # 2. Perform communication based on the pipeline_order
  1770. stage_index_to_stage: dict[int, _PipelineStageBase] = {
  1771. stage.stage_index: stage for stage in self._stages
  1772. }
  1773. if self.pipeline_order_with_comms is None:
  1774. raise AssertionError(
  1775. "Must call _prepare_schedule_with_comms() before calling _step_microbatches()"
  1776. )
  1777. # send ops should be waited on before step() exists, mainly for hygiene
  1778. send_ops: list[list[dist.Work]] = []
  1779. def _perform_action(action: _Action) -> None:
  1780. comp_type = action.computation_type
  1781. mb_index: int = (
  1782. action.microbatch_index if action.microbatch_index is not None else -1
  1783. )
  1784. if not (
  1785. mb_index >= 0
  1786. or comp_type
  1787. in (
  1788. UNSHARD,
  1789. RESHARD,
  1790. REDUCE_GRAD,
  1791. )
  1792. ):
  1793. raise AssertionError(f"{action=} missing mb_index")
  1794. stage_idx = action.stage_index
  1795. stage = stage_index_to_stage[stage_idx]
  1796. stage_uses_fsdp = isinstance(stage.submod, FSDPModule)
  1797. # see [Note: V-schedule special case]
  1798. is_next_stage_on_this_rank = stage_idx + 1 in stage_index_to_stage
  1799. is_prev_stage_on_this_rank = stage_idx - 1 in stage_index_to_stage
  1800. # TODO(whc) it's not actually safe to use _batch_p2p here in the uncommon case the model has skip-connections,
  1801. # since we do not want to batch up ops between more than a pair of ranks. _sorted_batch_p2p would be
  1802. # safe to use instead.
  1803. # However, I was wondering if I should avoid calling batched operators at all in the case that there is
  1804. # only one operator per batch. I could iterate through the 'fwd_send_ops' one by one and run them.
  1805. if comp_type == SEND_F:
  1806. send_ops.append(_batch_p2p(stage.get_fwd_send_ops(mb_index)))
  1807. elif comp_type == SEND_B:
  1808. send_ops.append(_batch_p2p(stage.get_bwd_send_ops(mb_index)))
  1809. elif comp_type == RECV_F:
  1810. if (stage_idx, mb_index) in self.fwd_recv_ops:
  1811. raise AssertionError(
  1812. f"Recv twice for {stage_idx=} {mb_index=} without executing forward"
  1813. )
  1814. self.fwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p(
  1815. stage.get_fwd_recv_ops(mb_index)
  1816. )
  1817. elif comp_type == RECV_B:
  1818. if (stage_idx, mb_index) in self.bwd_recv_ops:
  1819. raise AssertionError(
  1820. f"Recv twice for {stage_idx=} {mb_index=} without executing backward"
  1821. )
  1822. self.bwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p(
  1823. stage.get_bwd_recv_ops(mb_index)
  1824. )
  1825. elif comp_type == UNSHARD:
  1826. if stage_uses_fsdp:
  1827. if not (
  1828. stage_idx not in self.unsharded_stages
  1829. and stage_idx not in self.unshard_ops
  1830. ):
  1831. raise AssertionError(f"Unsharding the same {stage_idx=} twice")
  1832. for submodule in stage.submod.modules():
  1833. if not isinstance(submodule, FSDPModule):
  1834. continue
  1835. handle = cast(UnshardHandle, submodule.unshard(async_op=True))
  1836. self.unshard_ops[stage_idx].append(handle)
  1837. elif comp_type == RESHARD:
  1838. if stage_uses_fsdp:
  1839. if stage_idx not in self.unsharded_stages:
  1840. raise AssertionError(
  1841. f"Resharding {stage_idx=} without unsharding"
  1842. )
  1843. if stage_idx in self.unshard_ops:
  1844. raise AssertionError(
  1845. f"Resharding {stage_idx=} before finishing unshard"
  1846. )
  1847. for submodule in stage.submod.modules():
  1848. if not isinstance(submodule, FSDPModule):
  1849. continue
  1850. submodule.reshard()
  1851. self.unsharded_stages.remove(stage_idx)
  1852. elif comp_type == FORWARD:
  1853. self._assert_unsharded(stage)
  1854. if (
  1855. not stage.is_first
  1856. # no recv op expected for V-schedule special case (see [Note: V-schedule special case])
  1857. and not is_prev_stage_on_this_rank
  1858. ):
  1859. if (stage_idx, mb_index) not in self.fwd_recv_ops:
  1860. raise AssertionError(
  1861. f"Computing {action=} before receiving input"
  1862. )
  1863. _wait_batch_p2p(self.fwd_recv_ops.pop((stage_idx, mb_index)))
  1864. output = stage.forward_one_chunk(
  1865. mb_index,
  1866. arg_mbs[mb_index], # type: ignore[index]
  1867. kwarg_mbs[mb_index], # type: ignore[index]
  1868. save_forward_output=return_outputs,
  1869. )
  1870. self._maybe_compute_loss(stage, output, target_mbs, mb_index)
  1871. # SEND/RECV op are avoided for special case with 2 adjacent stages on same rank
  1872. # see [Note: V-schedule special case]
  1873. if is_next_stage_on_this_rank:
  1874. stage_index_to_stage[stage_idx + 1].set_local_fwd_input(
  1875. output, mb_index
  1876. )
  1877. elif comp_type == FULL_BACKWARD:
  1878. self._assert_unsharded(stage)
  1879. if (
  1880. not stage.is_last
  1881. # no recv op expected for V-schedule special case (see [Note: V-schedule special case])
  1882. and not is_next_stage_on_this_rank
  1883. ):
  1884. if (stage_idx, mb_index) not in self.bwd_recv_ops:
  1885. raise AssertionError(
  1886. f"Attempted to run compute {action=} before receiving input"
  1887. )
  1888. _wait_batch_p2p(self.bwd_recv_ops.pop((stage_idx, mb_index)))
  1889. loss = self._maybe_get_loss(stage, mb_index)
  1890. self.backward_counter[stage_idx] += 1
  1891. last_backward = self.backward_counter[stage_idx] == self._n_microbatches
  1892. stage.backward_one_chunk(
  1893. mb_index,
  1894. loss=loss,
  1895. full_backward=True,
  1896. last_backward=last_backward,
  1897. )
  1898. # SEND/RECV op are avoided for special case with 2 adjacent stages on same rank
  1899. # see [Note: V-schedule special case]
  1900. if is_prev_stage_on_this_rank:
  1901. stage_index_to_stage[stage_idx - 1].set_local_bwd_input(
  1902. stage.get_local_bwd_output(mb_index), mb_index
  1903. )
  1904. elif comp_type == BACKWARD_INPUT:
  1905. self._assert_unsharded(stage)
  1906. if not stage.is_last and not is_next_stage_on_this_rank:
  1907. if (stage_idx, mb_index) not in self.bwd_recv_ops:
  1908. raise AssertionError(
  1909. f"Attempted to run compute {action=} before receiving input"
  1910. )
  1911. _wait_batch_p2p(self.bwd_recv_ops.pop((stage_idx, mb_index)))
  1912. loss = self._maybe_get_loss(stage, mb_index)
  1913. stage.backward_one_chunk(
  1914. mb_index,
  1915. loss=loss,
  1916. full_backward=False,
  1917. last_backward=False,
  1918. )
  1919. # SEND/RECV op are avoided for special case with 2 adjacent stages on same rank
  1920. # see [Note: V-schedule special case]
  1921. if is_prev_stage_on_this_rank:
  1922. stage_index_to_stage[stage_idx - 1].set_local_bwd_input(
  1923. stage.get_local_bwd_output(mb_index), mb_index
  1924. )
  1925. elif comp_type == BACKWARD_WEIGHT:
  1926. self._assert_unsharded(stage)
  1927. self.backward_counter[stage_idx] += 1
  1928. last_backward = self.backward_counter[stage_idx] == self._n_microbatches
  1929. stage.backward_weight_one_chunk(
  1930. mb_index,
  1931. last_backward=last_backward,
  1932. )
  1933. elif comp_type == REDUCE_GRAD:
  1934. grad_scale_factor = self._n_microbatches if self.scale_grads else 1
  1935. stage.perform_reduce_grad(grad_scale_factor)
  1936. else:
  1937. raise ValueError(f"{action=} is unknown or unsupported")
  1938. # count either full_backward or backward_weight together, to determine when to sync DP grads
  1939. self.backward_counter.clear()
  1940. for time_step, action in enumerate(self.pipeline_order_with_comms[self.rank]):
  1941. logger.debug(
  1942. "_PipelineScheduleRuntime running time_step %d, action %s",
  1943. time_step,
  1944. action,
  1945. )
  1946. try:
  1947. with record_function(_get_profiler_function_name(action)):
  1948. if action.computation_type in self._comp_type_to_function_map:
  1949. ctx = _PipelineContext(
  1950. self,
  1951. arg_mbs,
  1952. kwarg_mbs,
  1953. target_mbs,
  1954. losses,
  1955. )
  1956. self._comp_type_to_function_map[action.computation_type](
  1957. action, ctx
  1958. )
  1959. elif action.computation_type == OVERLAP_F_B:
  1960. if action.sub_actions is None:
  1961. raise AssertionError("sub_actions must be set")
  1962. for sub_a in action.sub_actions:
  1963. _perform_action(sub_a)
  1964. else:
  1965. _perform_action(action)
  1966. except Exception as e:
  1967. logger.error(
  1968. "_PipelineScheduleRuntime caught exception at step %s when running action %s. Full Schedule:",
  1969. time_step,
  1970. action,
  1971. )
  1972. logger.error(
  1973. _format_pipeline_order(
  1974. self.pipeline_order_with_comms, # type: ignore[arg-type]
  1975. error_step_number=time_step,
  1976. )
  1977. )
  1978. raise e
  1979. # Mostly these operations should have finished long ago, but there isn't an obvious time when to wait for them
  1980. while send_ops:
  1981. _wait_batch_p2p(send_ops.pop())
  1982. if len(self.unshard_ops) != 0:
  1983. raise AssertionError("Unused unshard operations")
  1984. # Return losses if there is a container passed in
  1985. self._update_losses(self._stages, losses)
  1986. class ScheduleLoopedBFS(_PipelineScheduleRuntime):
  1987. """
  1988. Breadth-First Pipeline Parallelism.
  1989. See https://arxiv.org/abs/2211.05953 for details.
  1990. Similar to Interleaved 1F1B, Looped BFS supports multiple stages per rank.
  1991. What is different is that when microbatches are ready for multiple local
  1992. stages, Loops BFS will prioritizes the earlier stage, running all available
  1993. microbatches at once.
  1994. """
  1995. def __init__(
  1996. self,
  1997. stages: list[_PipelineStageBase],
  1998. n_microbatches: int,
  1999. loss_fn: Callable | _Loss | None = None,
  2000. output_merge_spec: dict[str, Any] | tuple[Any] | None = None,
  2001. scale_grads: bool = True,
  2002. backward_requires_autograd: bool = True,
  2003. ):
  2004. super().__init__(
  2005. stages=stages,
  2006. n_microbatches=n_microbatches,
  2007. loss_fn=loss_fn,
  2008. output_merge_spec=output_merge_spec,
  2009. scale_grads=scale_grads,
  2010. backward_requires_autograd=backward_requires_autograd,
  2011. )
  2012. # 1. Create the pipeline_order (all ranks do this calculation)
  2013. # This will be used to keep track of the current state of the entire pipeline
  2014. # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
  2015. self.pipeline_order: dict[int, list[_Action | None]] = {}
  2016. # ========================================================================
  2017. for rank in range(self.pp_group_size):
  2018. rank_ops = self._calculate_single_rank_operations(rank)
  2019. self.pipeline_order[rank] = rank_ops
  2020. # Initialize the pipeline order with communication necessary to run with _PipelineScheduleRuntime
  2021. self._prepare_schedule_with_comms(self.pipeline_order)
  2022. def _calculate_single_rank_operations(self, rank):
  2023. n_local_stages = len(self._stages)
  2024. stage_indices = range(
  2025. rank, self.pp_group_size * n_local_stages, self.pp_group_size
  2026. )
  2027. # Store the list of operations used for that rank
  2028. # Pre-padding, rank starts with no-ops based on the warmup.
  2029. rank_ops: list[_Action | None] = [None for _ in range(rank)]
  2030. for stage_index in stage_indices:
  2031. rank_ops.extend(
  2032. _Action(stage_index, _ComputationType.FORWARD, mb_index)
  2033. for mb_index in range(self._n_microbatches)
  2034. )
  2035. # wait for the first backward to trickle up
  2036. # which is 2 for every hop away
  2037. post_warmup_ops = 2 * (self.pp_group_size - 1 - rank)
  2038. rank_ops.extend([None] * post_warmup_ops)
  2039. for stage_index in reversed(stage_indices):
  2040. rank_ops.extend(
  2041. _Action(stage_index, _ComputationType.FULL_BACKWARD, mb_index)
  2042. for mb_index in reversed(range(self._n_microbatches))
  2043. )
  2044. return rank_ops
  2045. def _get_1f1b_rank_ops(
  2046. n_local_stages,
  2047. pp_group_size,
  2048. warmup_ops,
  2049. fwd_bwd_ops,
  2050. cooldown_ops,
  2051. rank,
  2052. forward_stage_index,
  2053. backward_stage_index,
  2054. num_1f1b_microbatches=0,
  2055. enable_zero_bubble=False,
  2056. ):
  2057. # All stages start with handling microbatch 0
  2058. fwd_stage_mb_index: dict[int, int] = defaultdict(int)
  2059. bwd_stage_mb_index: dict[int, int] = defaultdict(int)
  2060. weight_stage_mb_index: dict[int, int] = defaultdict(int)
  2061. # Store the list of operations used for that rank
  2062. # Pre-padding, rank starts with no-ops based on the warmup.
  2063. rank_ops: list[_Action | None] = [None for _ in range(rank)]
  2064. # These are used to calculate the number of slots to fill with no-ops, to account for the delay in warmup
  2065. # when we want to wait for the backward to trickle back up and start 1f1b to align all ranks.
  2066. # Formula:
  2067. # pre-padding + warmup_ops + post_warmup_ops = earliest time step of first backward
  2068. # post_warmup_ops = [earliest time step of first backward] - (warmup_ops + pre-padding)
  2069. # earliest time step of first backward = [local_stages * group_size + 2 * (group_size - 1 - rank)]
  2070. # warmup_ops = calculated above
  2071. post_warmup_ops = (
  2072. n_local_stages * pp_group_size + 2 * (pp_group_size - 1 - rank)
  2073. ) - (warmup_ops + rank)
  2074. if enable_zero_bubble:
  2075. post_warmup_ops = pp_group_size - rank - 1
  2076. total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops
  2077. backward_op_ids = []
  2078. weight_op_count = 0
  2079. FULL_BACKWARD_OR_BACKWARD_INPUT = (
  2080. BACKWARD_INPUT if enable_zero_bubble else FULL_BACKWARD
  2081. )
  2082. for op in range(total_ops):
  2083. # Warmup phase
  2084. if op < warmup_ops:
  2085. fwd_stage_index = forward_stage_index(op)
  2086. # This will assign the current microbatch index and update it as well
  2087. fwd_stage_mb_index[fwd_stage_index] = (
  2088. mb_index := fwd_stage_mb_index[fwd_stage_index]
  2089. ) + 1
  2090. rank_ops.append(
  2091. _Action(fwd_stage_index, _ComputationType.FORWARD, mb_index)
  2092. )
  2093. if op == warmup_ops - 1:
  2094. # This is the last step in the warmup phase, so we need to wait for the backward to trickle back up
  2095. rank_ops.extend([None] * post_warmup_ops)
  2096. # 1F1B Phase (forward and backward)
  2097. elif warmup_ops <= op < warmup_ops + fwd_bwd_ops:
  2098. fwd_stage_index = forward_stage_index(op)
  2099. fwd_stage_mb_index[fwd_stage_index] = (
  2100. fwd_mb_index := fwd_stage_mb_index[fwd_stage_index]
  2101. ) + 1
  2102. rank_ops.append(
  2103. _Action(fwd_stage_index, _ComputationType.FORWARD, fwd_mb_index)
  2104. )
  2105. bwd_stage_index = backward_stage_index(op)
  2106. bwd_stage_mb_index[bwd_stage_index] = (
  2107. bwd_mb_index := bwd_stage_mb_index[bwd_stage_index]
  2108. ) + 1
  2109. rank_ops.append(
  2110. _Action(bwd_stage_index, FULL_BACKWARD_OR_BACKWARD_INPUT, bwd_mb_index)
  2111. )
  2112. backward_op_ids.append(op)
  2113. if enable_zero_bubble and op - warmup_ops >= num_1f1b_microbatches:
  2114. weight_stage_index = backward_stage_index(
  2115. backward_op_ids[weight_op_count]
  2116. )
  2117. weight_stage_mb_index[weight_stage_index] = (
  2118. weight_mb_index := weight_stage_mb_index[weight_stage_index]
  2119. ) + 1
  2120. rank_ops.append(
  2121. _Action(
  2122. weight_stage_index,
  2123. _ComputationType.BACKWARD_WEIGHT,
  2124. weight_mb_index,
  2125. )
  2126. )
  2127. weight_op_count += 1
  2128. # Cooldown phase
  2129. else:
  2130. # During cooldown phase, we need steps to align with 1f1b happening in other ranks
  2131. # TODO: we don't need to always append, after all 1f1b are finished we can stop appending None
  2132. if not enable_zero_bubble:
  2133. rank_ops.append(None)
  2134. bwd_stage_index = backward_stage_index(op)
  2135. bwd_stage_mb_index[bwd_stage_index] = (
  2136. bwd_mb_index := bwd_stage_mb_index[bwd_stage_index]
  2137. ) + 1
  2138. rank_ops.append(
  2139. _Action(bwd_stage_index, FULL_BACKWARD_OR_BACKWARD_INPUT, bwd_mb_index)
  2140. )
  2141. backward_op_ids.append(op)
  2142. if enable_zero_bubble and op - warmup_ops >= num_1f1b_microbatches:
  2143. weight_stage_index = backward_stage_index(
  2144. backward_op_ids[weight_op_count]
  2145. )
  2146. weight_stage_mb_index[weight_stage_index] = (
  2147. weight_mb_index := weight_stage_mb_index[weight_stage_index]
  2148. ) + 1
  2149. rank_ops.append(
  2150. _Action(
  2151. weight_stage_index,
  2152. _ComputationType.BACKWARD_WEIGHT,
  2153. weight_mb_index,
  2154. )
  2155. )
  2156. weight_op_count += 1
  2157. while enable_zero_bubble and weight_op_count < len(backward_op_ids):
  2158. weight_stage_index = backward_stage_index(backward_op_ids[weight_op_count])
  2159. weight_stage_mb_index[weight_stage_index] = (
  2160. weight_mb_index := weight_stage_mb_index[weight_stage_index]
  2161. ) + 1
  2162. rank_ops.append(
  2163. _Action(
  2164. weight_stage_index, _ComputationType.BACKWARD_WEIGHT, weight_mb_index
  2165. )
  2166. )
  2167. weight_op_count += 1
  2168. return rank_ops
  2169. def _get_warmup_ops(
  2170. rank: int,
  2171. n_local_stages: int,
  2172. microbatches_per_round: int,
  2173. pp_group_size: int,
  2174. n_microbatches: int,
  2175. multiply_factor: int = 2,
  2176. ) -> int:
  2177. """
  2178. Calculate the number of warmup operations for interleaved schedules.
  2179. """
  2180. # Warmup operations for last stage
  2181. warmups_ops_last_stage = (n_local_stages - 1) * microbatches_per_round
  2182. # Increment warmup operations by multiply_factor for each hop away from the last stage
  2183. warmup_ops = warmups_ops_last_stage + multiply_factor * ((pp_group_size - 1) - rank)
  2184. # We cannot have more warmup operations than there are number of microbatches, so cap it there
  2185. return min(warmup_ops, n_microbatches * n_local_stages)
  2186. class ScheduleInterleaved1F1B(_PipelineScheduleRuntime):
  2187. """
  2188. The Interleaved 1F1B schedule.
  2189. See https://arxiv.org/pdf/2104.04473 for details.
  2190. Will perform one forward and one backward on the microbatches in steady
  2191. state and supports multiple stages per rank. When microbatches are ready for
  2192. multiple local stages, Interleaved 1F1B prioritizes the earlier microbatch
  2193. (also called "depth first").
  2194. This schedule is mostly similar to the original paper.
  2195. It differs by being relaxing the requirement of num_microbatch % pp_size == 0.
  2196. Using the flex_pp schedule, we will have num_rounds = max(1, n_microbatches // pp_group_size) and
  2197. it works as long as n_microbatches % num_rounds is 0. As a few examples, support
  2198. 1. pp_group_size = 4, n_microbatches = 10. We will have num_rounds = 2 and n_microbatches % 2 is 0.
  2199. 2. pp_group_size = 4, n_microbatches = 3. We will have num_rounds = 1 and n_microbatches % 1 is 0.
  2200. """
  2201. def __init__(
  2202. self,
  2203. stages: list[_PipelineStageBase],
  2204. n_microbatches: int,
  2205. loss_fn: Callable | None = None,
  2206. args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None,
  2207. kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None,
  2208. output_merge_spec: dict[str, Any] | tuple[Any] | None = None,
  2209. scale_grads: bool = True,
  2210. backward_requires_autograd: bool = True,
  2211. ):
  2212. self.pp_group_size = stages[0].group_size
  2213. super().__init__(
  2214. stages=stages,
  2215. n_microbatches=n_microbatches,
  2216. loss_fn=loss_fn,
  2217. args_chunk_spec=args_chunk_spec,
  2218. kwargs_chunk_spec=kwargs_chunk_spec,
  2219. output_merge_spec=output_merge_spec,
  2220. scale_grads=scale_grads,
  2221. backward_requires_autograd=backward_requires_autograd,
  2222. )
  2223. self.n_local_stages = len(stages)
  2224. self.rank = stages[0].group_rank
  2225. self.number_of_rounds = max(1, n_microbatches // self.pp_group_size)
  2226. self.microbatches_per_round = n_microbatches // self.number_of_rounds
  2227. if n_microbatches % self.number_of_rounds != 0:
  2228. raise ValueError(
  2229. "Interleaved 1F1B requires the number of microbatches to be a "
  2230. f"multiple of the number of rounds ({self.number_of_rounds}), "
  2231. f"but got {n_microbatches}."
  2232. )
  2233. # 1. Create the pipeline_order (all ranks do this calculation)
  2234. # This will be used to keep track of the current state of the entire pipeline
  2235. # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
  2236. self.pipeline_order: dict[int, list[_Action | None]] = {}
  2237. for rank in range(self.pp_group_size):
  2238. rank_ops = self._calculate_single_rank_operations(rank)
  2239. self.pipeline_order[rank] = rank_ops
  2240. # Initialize the pipeline order with communication necessary to run with _PipelineScheduleRuntime
  2241. self._prepare_schedule_with_comms(self.pipeline_order)
  2242. def _calculate_single_rank_operations(self, rank) -> list[_Action | None]:
  2243. warmup_ops = _get_warmup_ops(
  2244. rank,
  2245. self.n_local_stages,
  2246. self.microbatches_per_round,
  2247. self.pp_group_size,
  2248. self._n_microbatches,
  2249. multiply_factor=2,
  2250. )
  2251. microbatch_ops = self.n_local_stages * self._n_microbatches
  2252. # fwd_bwd_ops should encompass the remaining forwards
  2253. fwd_bwd_ops = microbatch_ops - warmup_ops
  2254. # cooldown_ops should encompass the remaining backwards
  2255. cooldown_ops = microbatch_ops - fwd_bwd_ops
  2256. # total ops encompass both forward and backward ops
  2257. total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops
  2258. # warmup_ops + fwd_bwd_ops * 2 + cooldown_ops == microbatch_ops * 2
  2259. logger.debug(
  2260. "rank %s, warmup_ops %s, 1f1b %s, cooldown_ops %s total_ops %s",
  2261. rank,
  2262. warmup_ops,
  2263. fwd_bwd_ops,
  2264. cooldown_ops,
  2265. total_ops,
  2266. )
  2267. # Calculates the stage index based on step and pp_group_size
  2268. def forward_stage_index(step):
  2269. # Get the local index from 0 to n_local_stages-1
  2270. local_index = (step // self.microbatches_per_round) % self.n_local_stages
  2271. return (local_index * self.pp_group_size) + rank
  2272. def backward_stage_index(step):
  2273. local_index = (
  2274. self.n_local_stages
  2275. - 1
  2276. - ((step - warmup_ops) // self.microbatches_per_round)
  2277. % self.n_local_stages
  2278. )
  2279. return (local_index * self.pp_group_size) + rank
  2280. return _get_1f1b_rank_ops(
  2281. self.n_local_stages,
  2282. self.pp_group_size,
  2283. warmup_ops,
  2284. fwd_bwd_ops,
  2285. cooldown_ops,
  2286. rank,
  2287. forward_stage_index,
  2288. backward_stage_index,
  2289. )
  2290. class ScheduleInterleavedZeroBubble(_PipelineScheduleRuntime):
  2291. """
  2292. The Interleaved Zero Bubble schedule.
  2293. See https://arxiv.org/pdf/2401.10241 for details.
  2294. Will perform one forward and one backward on inputs for the microbatches in steady
  2295. state and supports multiple stages per rank. Uses the backward for weights to fill in
  2296. the pipeline bubble.
  2297. In particular this is implementing the ZB1P schedule in the paper.
  2298. """
  2299. def __init__(
  2300. self,
  2301. stages: list[_PipelineStageBase],
  2302. n_microbatches: int,
  2303. loss_fn: Callable | None = None,
  2304. args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None,
  2305. kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None,
  2306. output_merge_spec: dict[str, Any] | tuple[Any] | None = None,
  2307. scale_grads: bool = True,
  2308. backward_requires_autograd: bool = True,
  2309. ):
  2310. # TODO: we dont support input/weight backward split with torch.compile
  2311. _check_torch_compile_compatibility(stages, self.__class__.__name__)
  2312. self.pp_group_size = stages[0].group_size
  2313. super().__init__(
  2314. stages=stages,
  2315. n_microbatches=n_microbatches,
  2316. loss_fn=loss_fn,
  2317. args_chunk_spec=args_chunk_spec,
  2318. kwargs_chunk_spec=kwargs_chunk_spec,
  2319. output_merge_spec=output_merge_spec,
  2320. scale_grads=scale_grads,
  2321. backward_requires_autograd=backward_requires_autograd,
  2322. )
  2323. self.n_local_stages = len(stages)
  2324. self.rank = stages[0].group_rank
  2325. self.number_of_rounds = max(1, n_microbatches // self.pp_group_size)
  2326. self.microbatches_per_round = n_microbatches // self.number_of_rounds
  2327. if n_microbatches % self.number_of_rounds != 0:
  2328. raise ValueError(
  2329. "Zero bubble requires the number of microbatches to be a "
  2330. f"multiple of the number of rounds ({self.number_of_rounds}), "
  2331. f"but got {n_microbatches}."
  2332. )
  2333. # 1. Create the pipeline_order (all ranks do this calculation)
  2334. # This will be used to keep track of the current state of the entire pipeline
  2335. # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
  2336. self.pipeline_order: dict[int, list[_Action | None]] = {}
  2337. for rank in range(self.pp_group_size):
  2338. rank_ops = self._calculate_single_rank_operations(rank)
  2339. self.pipeline_order[rank] = rank_ops
  2340. # This function add bubbles to the generated schedule based on dependencies of actions
  2341. # Note that the ZB1P schedule will not require bubbles to be manually added and it is
  2342. # only useful when n_microbatches <= microbatches_per_round
  2343. self.pipeline_order = self._add_bubbles_to_actions(
  2344. self.n_local_stages * self.pp_group_size,
  2345. )
  2346. # Initialize the pipeline order with communication necessary to run with _PipelineScheduleRuntime
  2347. self._prepare_schedule_with_comms(self.pipeline_order)
  2348. def _calculate_single_rank_operations(self, rank) -> list[_Action | None]:
  2349. warmup_ops = _get_warmup_ops(
  2350. rank,
  2351. self.n_local_stages,
  2352. self.microbatches_per_round,
  2353. self.pp_group_size,
  2354. self._n_microbatches,
  2355. multiply_factor=1,
  2356. )
  2357. microbatch_ops = self.n_local_stages * self._n_microbatches
  2358. # fwd_bwd_ops should encompass the remaining forwards
  2359. fwd_bwd_ops = microbatch_ops - warmup_ops
  2360. # cooldown_ops should encompass the remaining backwards
  2361. cooldown_ops = microbatch_ops - fwd_bwd_ops
  2362. # total ops encompass both forward and backward ops
  2363. total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops
  2364. # warmup_ops + fwd_bwd_ops * 2 + cooldown_ops == microbatch_ops * 2
  2365. logger.debug(
  2366. "rank %s, warmup_ops %s, 1f1b %s, cooldown_ops %s total_ops %s",
  2367. rank,
  2368. warmup_ops,
  2369. fwd_bwd_ops,
  2370. cooldown_ops,
  2371. total_ops,
  2372. )
  2373. # Calculates the stage index based on step and pp_group_size
  2374. def forward_stage_index(step):
  2375. # Get the local index from 0 to n_local_stages-1
  2376. local_index = (step // self.microbatches_per_round) % self.n_local_stages
  2377. return (local_index * self.pp_group_size) + rank
  2378. def backward_stage_index(step):
  2379. local_index = (
  2380. self.n_local_stages
  2381. - 1
  2382. - ((step - warmup_ops) // self.microbatches_per_round)
  2383. % self.n_local_stages
  2384. )
  2385. return (local_index * self.pp_group_size) + rank
  2386. num_1f1b_microbatches = rank
  2387. return _get_1f1b_rank_ops(
  2388. self.n_local_stages,
  2389. self.pp_group_size,
  2390. warmup_ops,
  2391. fwd_bwd_ops,
  2392. cooldown_ops,
  2393. rank,
  2394. forward_stage_index,
  2395. backward_stage_index,
  2396. num_1f1b_microbatches,
  2397. enable_zero_bubble=True,
  2398. )
  2399. def _add_bubbles_to_actions(self, num_stages_global):
  2400. actions = self.pipeline_order
  2401. def need_bubble(stage, op, microbatch, num_stages_global, seen_ops):
  2402. if op == _ComputationType.FORWARD:
  2403. if stage != 0 and (stage - 1, op, microbatch) not in seen_ops:
  2404. return True
  2405. elif op == _ComputationType.FULL_BACKWARD:
  2406. if stage == num_stages_global - 1:
  2407. return (stage, _ComputationType.FORWARD, microbatch) not in seen_ops
  2408. return (stage + 1, op, microbatch) not in seen_ops
  2409. return False
  2410. seen_ops: set[tuple[int, _ComputationType, int]] = set()
  2411. result: dict[int, list[_Action | None]] = {}
  2412. next_pointer: dict[int, int] = {}
  2413. bubbles_added: dict[int, int] = {}
  2414. total_bubbles_added = 0
  2415. for rank in range(self.pp_group_size):
  2416. result[rank] = []
  2417. next_pointer[rank] = 0
  2418. bubbles_added[rank] = 0
  2419. while True:
  2420. should_stop = True
  2421. temp_seen_ops: set[tuple[int, _ComputationType, int]] = set()
  2422. for rank in range(self.pp_group_size):
  2423. timestamp = next_pointer[rank]
  2424. if timestamp >= len(actions[rank]):
  2425. continue
  2426. should_stop = False
  2427. if actions[rank][timestamp] is not None:
  2428. temp_action = actions[rank][timestamp]
  2429. if temp_action is None:
  2430. raise AssertionError(
  2431. f"Expected temp_action to be not None, got {type(temp_action)}"
  2432. )
  2433. stage_index, op, microbatch, _ = temp_action
  2434. if not need_bubble(
  2435. stage_index, op, microbatch, num_stages_global, seen_ops
  2436. ):
  2437. result[rank].append(actions[rank][timestamp])
  2438. if microbatch is not None:
  2439. temp_seen_ops.add((stage_index, op, microbatch))
  2440. next_pointer[rank] += 1
  2441. else:
  2442. result[rank].append(None)
  2443. bubbles_added[rank] += 1
  2444. else:
  2445. next_pointer[rank] += 1
  2446. result[rank].append(None)
  2447. seen_ops.update(temp_seen_ops)
  2448. if should_stop:
  2449. break
  2450. if total_bubbles_added > 0:
  2451. logger.warning(
  2452. "Non zero bubbles added: total_bubbles_added=%s bubbles_added=%s",
  2453. total_bubbles_added,
  2454. bubbles_added,
  2455. )
  2456. return result
  2457. class ScheduleZBVZeroBubble(_PipelineScheduleRuntime):
  2458. """
  2459. The Zero Bubble schedule (ZBV variant).
  2460. See https://arxiv.org/pdf/2401.10241 Section 6 for details.
  2461. This schedules requires exactly two stages per rank.
  2462. This schedule will perform one forward and one backward on inputs for the microbatches in steady
  2463. state and supports multiple stages per rank. Uses backward with respect to weights to fill in
  2464. the pipeline bubble.
  2465. This ZB-V schedule would have the "zero bubble" property only if time forward == time backward input == time backward weights.
  2466. In practice, this is not likely true for real models so alternatively
  2467. a greedy scheduler could be implemented for unequal/unbalanced time.
  2468. """
  2469. def __init__(
  2470. self,
  2471. stages: list[_PipelineStageBase],
  2472. n_microbatches: int,
  2473. loss_fn: Callable | None = None,
  2474. args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None,
  2475. kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None,
  2476. output_merge_spec: dict[str, Any] | tuple[Any] | None = None,
  2477. scale_grads: bool = True,
  2478. backward_requires_autograd: bool = True,
  2479. ):
  2480. # TODO: we dont support input/weight backward split with torch.compile
  2481. _check_torch_compile_compatibility(stages, self.__class__.__name__)
  2482. self.pp_group_size = stages[0].group_size
  2483. super().__init__(
  2484. stages=stages,
  2485. n_microbatches=n_microbatches,
  2486. loss_fn=loss_fn,
  2487. args_chunk_spec=args_chunk_spec,
  2488. kwargs_chunk_spec=kwargs_chunk_spec,
  2489. output_merge_spec=output_merge_spec,
  2490. scale_grads=scale_grads,
  2491. backward_requires_autograd=backward_requires_autograd,
  2492. )
  2493. self.stage_index_to_group_rank = generate_stage_to_rank_mapping(
  2494. self.pp_group_size, self._num_stages, style="v"
  2495. )
  2496. for stage in self._stages:
  2497. stage.stage_index_to_group_rank = self.stage_index_to_group_rank
  2498. self.n_local_stages = len(stages)
  2499. if self.n_local_stages != 2:
  2500. raise ValueError(
  2501. "ZBV requires exactly 2 stages per rank, but got "
  2502. f"{self.n_local_stages}."
  2503. )
  2504. self.rank = stages[0].group_rank
  2505. self.num_stages = stages[0].num_stages
  2506. # 1. Create the pipeline_order (all ranks do this calculation)
  2507. # This will be used to keep track of the current state of the entire pipeline
  2508. # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
  2509. self.pipeline_order: dict[int, list[_Action | None]] = {}
  2510. for rank in range(self.pp_group_size):
  2511. rank_ops = self._calculate_single_rank_operations(rank)
  2512. self.pipeline_order[rank] = rank_ops
  2513. # Initialize the pipeline order with communication necessary to run with _PipelineScheduleRuntime
  2514. self._prepare_schedule_with_comms(self.pipeline_order)
  2515. def _calculate_single_rank_operations(self, rank) -> list[_Action | None]:
  2516. # max(2 * self.pp_group_size - 1, ...) ensure the number of microbatches is at least
  2517. # as large of the number of microbatches needed to fully utilize the pipeline
  2518. n_micro = max(2 * self.pp_group_size - 1, self._n_microbatches)
  2519. rank_ops: list[_Action | None] = [None for _ in range(rank)]
  2520. # Forward and backward action counts for stage chunk 0 and chunk 1
  2521. f0_cnt, f1_cnt, b0_cnt, b1_cnt = 0, 0, 0, 0
  2522. # warm-up phase
  2523. warmup_n1 = 2 * (self.pp_group_size - rank) - 1
  2524. stage_id_chunk0 = rank
  2525. stage_id_chunk1 = self.num_stages - 1 - rank
  2526. for _ in range(warmup_n1):
  2527. rank_ops.append(
  2528. _Action(stage_id_chunk0, computation_type=F, microbatch_index=f0_cnt)
  2529. )
  2530. f0_cnt += 1
  2531. warmup_n2 = rank
  2532. for _ in range(warmup_n2):
  2533. rank_ops.append(
  2534. _Action(stage_id_chunk1, computation_type=F, microbatch_index=f1_cnt)
  2535. )
  2536. f1_cnt += 1
  2537. rank_ops.append(
  2538. _Action(stage_id_chunk0, computation_type=F, microbatch_index=f0_cnt)
  2539. )
  2540. f0_cnt += 1
  2541. warmup_n3 = self.pp_group_size - rank
  2542. for _ in range(warmup_n3):
  2543. rank_ops.append(
  2544. _Action(stage_id_chunk1, computation_type=F, microbatch_index=f1_cnt)
  2545. )
  2546. f1_cnt += 1
  2547. rank_ops.append(
  2548. _Action(stage_id_chunk1, computation_type=I, microbatch_index=b1_cnt)
  2549. )
  2550. rank_ops.append(
  2551. _Action(stage_id_chunk1, computation_type=W, microbatch_index=b1_cnt)
  2552. )
  2553. b1_cnt += 1
  2554. # stable phase
  2555. while f1_cnt < f0_cnt or f0_cnt < n_micro:
  2556. if f0_cnt < n_micro:
  2557. rank_ops.append(
  2558. _Action(
  2559. stage_id_chunk0, computation_type=F, microbatch_index=f0_cnt
  2560. )
  2561. )
  2562. f0_cnt += 1
  2563. rank_ops.append(
  2564. _Action(stage_id_chunk0, computation_type=I, microbatch_index=b0_cnt)
  2565. )
  2566. rank_ops.append(
  2567. _Action(stage_id_chunk0, computation_type=W, microbatch_index=b0_cnt)
  2568. )
  2569. b0_cnt += 1
  2570. rank_ops.append(
  2571. _Action(stage_id_chunk1, computation_type=F, microbatch_index=f1_cnt)
  2572. )
  2573. f1_cnt += 1
  2574. rank_ops.append(
  2575. _Action(stage_id_chunk1, computation_type=I, microbatch_index=b1_cnt)
  2576. )
  2577. rank_ops.append(
  2578. _Action(stage_id_chunk1, computation_type=W, microbatch_index=b1_cnt)
  2579. )
  2580. b1_cnt += 1
  2581. # cool-down phase
  2582. w0_cnt, w1_cnt = b0_cnt, b1_cnt
  2583. cooldown_n1 = rank
  2584. for _ in range(cooldown_n1):
  2585. rank_ops.append(
  2586. _Action(stage_id_chunk0, computation_type=I, microbatch_index=b0_cnt)
  2587. )
  2588. b0_cnt += 1
  2589. rank_ops.append(
  2590. _Action(stage_id_chunk1, computation_type=I, microbatch_index=b1_cnt)
  2591. )
  2592. b1_cnt += 1
  2593. cooldown_n2 = self.pp_group_size - rank
  2594. for _ in range(cooldown_n2):
  2595. rank_ops.append(
  2596. _Action(stage_id_chunk0, computation_type=I, microbatch_index=b0_cnt)
  2597. )
  2598. b0_cnt += 1
  2599. rank_ops.append(
  2600. _Action(stage_id_chunk0, computation_type=W, microbatch_index=w0_cnt)
  2601. )
  2602. w0_cnt += 1
  2603. while w1_cnt < b1_cnt:
  2604. rank_ops.append(
  2605. _Action(stage_id_chunk1, computation_type=W, microbatch_index=w1_cnt)
  2606. )
  2607. w1_cnt += 1
  2608. while w0_cnt < b0_cnt:
  2609. rank_ops.append(
  2610. _Action(stage_id_chunk0, computation_type=W, microbatch_index=w0_cnt)
  2611. )
  2612. w0_cnt += 1
  2613. if not (w0_cnt == b0_cnt and b0_cnt == f0_cnt):
  2614. raise AssertionError(
  2615. f"Expected w0_cnt == b0_cnt == f0_cnt, got w0_cnt={w0_cnt}, b0_cnt={b0_cnt}, f0_cnt={f0_cnt}"
  2616. )
  2617. if not (w1_cnt == b1_cnt and b1_cnt == f1_cnt):
  2618. raise AssertionError(
  2619. f"Expected w1_cnt == b1_cnt == f1_cnt, got w1_cnt={w1_cnt}, b1_cnt={b1_cnt}, f1_cnt={f1_cnt}"
  2620. )
  2621. # We use max() in the n_micro computation above, so we may need to
  2622. # remove redundant microbatches
  2623. rank_ops = [
  2624. (
  2625. action
  2626. if action is not None
  2627. and action.microbatch_index is not None
  2628. and action.microbatch_index < self._n_microbatches
  2629. else None
  2630. )
  2631. for action in rank_ops
  2632. ]
  2633. return rank_ops
  2634. class ScheduleDualPipeV(_PipelineScheduleRuntime):
  2635. """
  2636. The DualPipeV schedule. A more efficient schedule variant based on the
  2637. DualPipe schedule introduced by DeepSeek in https://arxiv.org/pdf/2412.19437
  2638. Based on the open sourced code from https://github.com/deepseek-ai/DualPipe
  2639. """
  2640. def __init__(
  2641. self,
  2642. stages: list[_PipelineStageBase],
  2643. n_microbatches: int,
  2644. loss_fn: Callable | None = None,
  2645. args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None,
  2646. kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None,
  2647. output_merge_spec: dict[str, Any] | tuple[Any] | None = None,
  2648. scale_grads: bool = True,
  2649. backward_requires_autograd: bool = True,
  2650. ):
  2651. # TODO: we dont support input/weight backward split with torch.compile
  2652. _check_torch_compile_compatibility(stages, self.__class__.__name__)
  2653. self.pp_group_size = stages[0].group_size
  2654. super().__init__(
  2655. stages=stages,
  2656. n_microbatches=n_microbatches,
  2657. loss_fn=loss_fn,
  2658. args_chunk_spec=args_chunk_spec,
  2659. kwargs_chunk_spec=kwargs_chunk_spec,
  2660. output_merge_spec=output_merge_spec,
  2661. scale_grads=scale_grads,
  2662. backward_requires_autograd=backward_requires_autograd,
  2663. )
  2664. self.stage_index_to_group_rank = generate_stage_to_rank_mapping(
  2665. self.pp_group_size, self._num_stages, style="v"
  2666. )
  2667. for stage in self._stages:
  2668. stage.stage_index_to_group_rank = self.stage_index_to_group_rank
  2669. self.n_local_stages = len(stages)
  2670. if self.n_local_stages != 2:
  2671. raise ValueError(
  2672. "ZBV requires exactly 2 stages per rank, but got "
  2673. f"{self.n_local_stages}."
  2674. )
  2675. if n_microbatches < self._num_stages:
  2676. raise ValueError(
  2677. "DualPipeV requires at least as many microbatches as stages, but got "
  2678. f"{n_microbatches} microbatches and {self._num_stages} stages."
  2679. )
  2680. self.rank = stages[0].group_rank
  2681. self.num_stages = stages[0].num_stages
  2682. # 1. Create the pipeline_order (all ranks do this calculation)
  2683. # This will be used to keep track of the current state of the entire pipeline
  2684. # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
  2685. self.pipeline_order: dict[int, list[_Action | None]] = {}
  2686. for rank in range(self.pp_group_size):
  2687. rank_ops = self._calculate_single_rank_operations(rank)
  2688. self.pipeline_order[rank] = rank_ops
  2689. # Initialize the pipeline order with communication necessary to run with _PipelineScheduleRuntime
  2690. self._prepare_schedule_with_comms(self.pipeline_order)
  2691. def _calculate_single_rank_operations(self, rank) -> list[_Action | None]:
  2692. actions: list[_Action | None] = []
  2693. counters: dict[
  2694. tuple[int, _ComputationType], int
  2695. ] = {} # (stage_index, computation_type) -> mb_index
  2696. weight_queue = [] # Queue of (stage_index, mb_index) for pending weight actions
  2697. num_ranks = self.pp_group_size
  2698. num_chunks = self._n_microbatches
  2699. rank_to_stages = generate_rank_to_stage_mapping(
  2700. num_ranks, num_ranks * 2, style="v"
  2701. )
  2702. stage0_index, stage1_index = rank_to_stages[rank]
  2703. def increment_backward_counts(stage_index: int):
  2704. """Helper method to increment BACKWARD_INPUT and BACKWARD_WEIGHT counters when FULL_BACKWARD is used."""
  2705. input_key = (stage_index, BACKWARD_INPUT)
  2706. weight_key = (stage_index, BACKWARD_WEIGHT)
  2707. counters[input_key] = counters.get(input_key, 0) + 1
  2708. counters[weight_key] = counters.get(weight_key, 0) + 1
  2709. def add_overlap_f_b(
  2710. actions: list,
  2711. forward_stage: int,
  2712. backward_stage: int,
  2713. ):
  2714. """Helper method to add an overlapped forward+backward action which tracks microbatch index."""
  2715. # Create new overlapped forward+backward action with sub_actions
  2716. forward_key = (forward_stage, FORWARD)
  2717. backward_key = (backward_stage, BACKWARD_INPUT)
  2718. forward_mb = counters.get(forward_key, 0)
  2719. backward_mb = counters.get(backward_key, 0)
  2720. sub_actions = (
  2721. _Action(forward_stage, FORWARD, forward_mb),
  2722. _Action(backward_stage, FULL_BACKWARD, backward_mb),
  2723. )
  2724. actions.append(_Action(-1, OVERLAP_F_B, None, sub_actions))
  2725. # Update counters for sub_actions
  2726. counters[forward_key] = forward_mb + 1
  2727. increment_backward_counts(backward_stage)
  2728. def add_action(
  2729. actions: list,
  2730. stage_index: int,
  2731. computation_type: _ComputationType,
  2732. ):
  2733. # Regular single action, for FULL_BACKWARD we only use the BACKWARD_INPUT counter
  2734. key = (
  2735. (stage_index, computation_type)
  2736. if computation_type != FULL_BACKWARD
  2737. else (stage_index, BACKWARD_INPUT)
  2738. )
  2739. mb_index = counters.get(key, 0)
  2740. actions.append(_Action(stage_index, computation_type, mb_index))
  2741. # If FULL_BACKWARD is used, just increment the separate BACKWARD_INPUT and BACKWARD_WEIGHT counters
  2742. if computation_type == FULL_BACKWARD:
  2743. increment_backward_counts(stage_index)
  2744. else:
  2745. # If BACKWARD_INPUT is updated, add corresponding weight action to queue
  2746. if computation_type == BACKWARD_INPUT:
  2747. # Add weight action to queue for later processing
  2748. weight_queue.append((stage_index, mb_index))
  2749. counters[key] = mb_index + 1
  2750. def add_weight_action_if_pending(actions: list):
  2751. """Helper method to add a weight action from the queue."""
  2752. if not weight_queue:
  2753. return # No pending weight actions, skip
  2754. # Pop the oldest weight action from the queue
  2755. actual_stage_index, weight_mb_index = weight_queue.pop(0)
  2756. actions.append(
  2757. _Action(
  2758. actual_stage_index,
  2759. BACKWARD_WEIGHT,
  2760. weight_mb_index,
  2761. )
  2762. )
  2763. # Update the counter for the actual stage that was processed
  2764. weight_key = (actual_stage_index, BACKWARD_WEIGHT)
  2765. counters[weight_key] = counters.get(weight_key, 0) + 1
  2766. # Step 1: F0
  2767. step_1 = (num_ranks - rank - 1) * 2
  2768. for _ in range(step_1):
  2769. add_action(actions, stage0_index, FORWARD)
  2770. # Step 2: F0F1
  2771. step_2 = rank + 1
  2772. for _ in range(step_2):
  2773. add_action(actions, stage0_index, FORWARD)
  2774. add_action(actions, stage1_index, FORWARD)
  2775. # Step 3: I1W1F1 (Use zero bubble)
  2776. step_3 = num_ranks - rank - 1
  2777. for _ in range(step_3):
  2778. add_action(actions, stage1_index, BACKWARD_INPUT)
  2779. add_weight_action_if_pending(actions)
  2780. add_action(actions, stage1_index, FORWARD)
  2781. # Step 4 (Main step): F0B1-F1B0 (combined, overlapped forward+backward)
  2782. step_4 = num_chunks - num_ranks * 2 + rank + 1
  2783. for i in range(step_4):
  2784. if i == 0 and rank == num_ranks - 1:
  2785. # NOTE: We don't overlap these two chunks to further reduce bubble size.
  2786. add_action(actions, stage0_index, FORWARD)
  2787. add_action(actions, stage1_index, FULL_BACKWARD)
  2788. else:
  2789. add_overlap_f_b(
  2790. actions,
  2791. forward_stage=stage0_index,
  2792. backward_stage=stage1_index,
  2793. )
  2794. add_overlap_f_b(
  2795. actions,
  2796. forward_stage=stage1_index,
  2797. backward_stage=stage0_index,
  2798. )
  2799. # Step 5: B1-F1B0
  2800. step_5 = num_ranks - rank - 1
  2801. for _ in range(step_5):
  2802. add_action(actions, stage1_index, FULL_BACKWARD)
  2803. add_overlap_f_b(
  2804. actions,
  2805. forward_stage=stage1_index,
  2806. backward_stage=stage0_index,
  2807. )
  2808. # Step 6: B1B0 (The second half of the chunks use zero bubble)
  2809. step_6 = rank + 1
  2810. enable_zb = False
  2811. for i in range(step_6):
  2812. if i == step_6 // 2 and rank % 2 == 1:
  2813. enable_zb = True
  2814. comp_type = BACKWARD_INPUT if enable_zb else FULL_BACKWARD
  2815. add_action(actions, stage1_index, comp_type)
  2816. if i == step_6 // 2 and rank % 2 == 0:
  2817. enable_zb = True
  2818. comp_type = BACKWARD_INPUT if enable_zb else FULL_BACKWARD
  2819. add_action(actions, stage0_index, comp_type)
  2820. # Step 7: W0B0
  2821. step_7 = num_ranks - rank - 1
  2822. for _ in range(step_7):
  2823. add_weight_action_if_pending(actions)
  2824. comp_type = BACKWARD_INPUT if enable_zb else FULL_BACKWARD
  2825. add_action(actions, stage0_index, comp_type)
  2826. # Step 8: W0
  2827. step_8 = rank + 1
  2828. for _ in range(step_8):
  2829. add_weight_action_if_pending(actions)
  2830. return actions
  2831. def get_schedule_class(schedule_name: str):
  2832. """
  2833. Maps a schedule name (case insensitive) to its corresponding class object.
  2834. Args:
  2835. schedule_name (str): The name of the schedule.
  2836. """
  2837. schedule_map = {
  2838. "1F1B": Schedule1F1B,
  2839. "Interleaved1F1B": ScheduleInterleaved1F1B,
  2840. "GPipe": ScheduleGPipe,
  2841. "LoopedBFS": ScheduleLoopedBFS,
  2842. "InterleavedZeroBubble": ScheduleInterleavedZeroBubble,
  2843. "PipelineScheduleSingle": PipelineScheduleSingle,
  2844. "PipelineScheduleMulti": PipelineScheduleMulti,
  2845. "ZBVZeroBubble": ScheduleZBVZeroBubble,
  2846. "DualPipeV": ScheduleDualPipeV,
  2847. }
  2848. lowercase_keys = {k.lower(): k for k in schedule_map}
  2849. lowercase_schedule_name = schedule_name.lower()
  2850. if lowercase_schedule_name not in lowercase_keys:
  2851. raise ValueError(
  2852. f"Unknown schedule name '{schedule_name}'. The valid options are {list(schedule_map.keys())}"
  2853. )
  2854. return schedule_map[lowercase_keys[lowercase_schedule_name]]
  2855. def _simulate_comms_compute(
  2856. pipeline_order, stage_to_rank: Callable[[int], int], num_stages: int
  2857. ):
  2858. """This function dry-run simulates the actions in the schedule from the perspective of all ranks, and flags
  2859. any deadlocks caused by missing or misordered communications. It also simulates any bubbles in time where a rank
  2860. can not execute any action due to waiting for unmet dependencies. The total number of simulator steps can be used
  2861. as a metric for unit tests involving IR optimization passes as reordering and merging of IR can reduce the number
  2862. of simulated steps.
  2863. The simulation is not high-fidelity and does not model overlapping of compute and communication, or cuda streams.
  2864. Future work may be to enhance this and model the compute time, comms overlap, and even memory.
  2865. """
  2866. pipeline_order = {
  2867. rank: [a for a in pipeline_order[rank] if a is not None]
  2868. for rank in sorted(pipeline_order)
  2869. }
  2870. _schedule: dict[int, list[_Action | None]] = {
  2871. rank: [] for rank in sorted(pipeline_order)
  2872. }
  2873. _prev_ops_rank: dict[int, set[_Action]] = {rank: set() for rank in _schedule}
  2874. def add_to_schedule(rank: int, action: _Action | None):
  2875. _schedule[rank].append(action)
  2876. if action is not None:
  2877. _prev_ops_rank[rank].add(action)
  2878. def _ready_to_schedule(action: _Action | None) -> bool:
  2879. if action is None:
  2880. return True
  2881. stage_idx = action.stage_index
  2882. prev_ops = _prev_ops_rank[stage_to_rank(stage_idx)]
  2883. if action.computation_type == F:
  2884. if action.stage_index == 0:
  2885. return True
  2886. elif (
  2887. _Action(action.stage_index, RECV_F, action.microbatch_index) in prev_ops
  2888. ):
  2889. return True
  2890. elif (
  2891. _Action(action.stage_index - 1, F, action.microbatch_index) in prev_ops
  2892. ):
  2893. return True
  2894. return False
  2895. elif action.computation_type in (BACKWARD_INPUT, FULL_BACKWARD):
  2896. if action.stage_index == num_stages - 1:
  2897. return True
  2898. if _Action(action.stage_index, RECV_B, action.microbatch_index) in prev_ops:
  2899. return True
  2900. if (
  2901. _Action(action.stage_index + 1, BACKWARD_INPUT, action.microbatch_index)
  2902. in prev_ops
  2903. ):
  2904. return True
  2905. if (
  2906. _Action(action.stage_index + 1, FULL_BACKWARD, action.microbatch_index)
  2907. in prev_ops
  2908. ):
  2909. return True
  2910. return False
  2911. elif action.computation_type == BACKWARD_WEIGHT:
  2912. return True
  2913. elif action.computation_type == SEND_F:
  2914. expected_f = _Action(action.stage_index, F, action.microbatch_index)
  2915. return expected_f in prev_ops
  2916. elif action.computation_type == RECV_F:
  2917. peer_stage_idx = stage_idx - 1
  2918. expected_send = _Action(peer_stage_idx, SEND_F, action.microbatch_index)
  2919. return expected_send in _prev_ops_rank[stage_to_rank(peer_stage_idx)]
  2920. elif action.computation_type == SEND_B:
  2921. expected_b = _Action(
  2922. action.stage_index, BACKWARD_INPUT, action.microbatch_index
  2923. )
  2924. expected_bw = _Action(
  2925. action.stage_index, FULL_BACKWARD, action.microbatch_index
  2926. )
  2927. return expected_b in prev_ops or expected_bw in prev_ops
  2928. elif action.computation_type == RECV_B:
  2929. peer_stage_idx = stage_idx + 1
  2930. expected_send = _Action(peer_stage_idx, SEND_B, action.microbatch_index)
  2931. return expected_send in _prev_ops_rank[stage_to_rank(peer_stage_idx)]
  2932. else:
  2933. raise ValueError(f"Unsupported action type {action}")
  2934. while pipeline_order:
  2935. progress = False
  2936. for rank in sorted(pipeline_order):
  2937. if len(pipeline_order[rank]) == 0:
  2938. continue
  2939. action = pipeline_order[rank][0]
  2940. if _ready_to_schedule(action):
  2941. if action is not None:
  2942. add_to_schedule(rank, action)
  2943. pipeline_order[rank].pop(0)
  2944. progress = True
  2945. else:
  2946. add_to_schedule(rank, None)
  2947. for i in sorted(pipeline_order, reverse=True):
  2948. if len(pipeline_order[i]) == 0:
  2949. del pipeline_order[i]
  2950. # hacky, but do a second pass to replace any 'none' at this timestep with a real action, if it got unblocked
  2951. # by one of the later ranks
  2952. for rank in sorted(pipeline_order):
  2953. if len(pipeline_order[rank]) == 0:
  2954. continue
  2955. if _schedule[rank][-1] is not None:
  2956. continue
  2957. action = pipeline_order[rank][0]
  2958. if _ready_to_schedule(action):
  2959. if action is not None:
  2960. _schedule[rank][-1] = action
  2961. _prev_ops_rank[rank].add(action)
  2962. pipeline_order[rank].pop(0)
  2963. for i in sorted(pipeline_order, reverse=True):
  2964. if len(pipeline_order[i]) == 0:
  2965. del pipeline_order[i]
  2966. if not progress:
  2967. print("WIP comms schedule:\n", _format_pipeline_order(_schedule))
  2968. for rank in pipeline_order:
  2969. print(f"{rank=} next action= {pipeline_order[rank][0]}")
  2970. raise ValueError("Schedule is not progressing")
  2971. return _schedule
  2972. def _dump_chrometrace(schedule, filename):
  2973. """
  2974. This function dumps a schedule IR into a chrometrace format so it can be visualized.
  2975. It is currently very basic and only serves as a graphical alternative to dumping the schedule IR as text.
  2976. As future work we may extend this to include more accurate heuristics for durations, or let users input durations,
  2977. add 'flow events' to let the UI show the connection between sends and recvs, and model cuda streams for comm/compute
  2978. as separate streams on the chrometrace view.
  2979. """
  2980. events = []
  2981. for rank in sorted(schedule):
  2982. for timestep, action in enumerate(schedule[rank]):
  2983. if action is None:
  2984. continue
  2985. events.append(
  2986. {
  2987. "name": str(action),
  2988. "cat": (
  2989. "computation"
  2990. if action.computation_type in (F, B, W)
  2991. else "communication"
  2992. ),
  2993. "ph": "X",
  2994. "pid": rank,
  2995. "tid": rank,
  2996. "ts": timestep,
  2997. "dur": 1,
  2998. }
  2999. )
  3000. import json
  3001. with open(filename, "w") as f:
  3002. json.dump({"traceEvents": events}, f)
  3003. def _check_torch_compile_compatibility(
  3004. stages: list[_PipelineStageBase], schedule_name: str
  3005. ):
  3006. """
  3007. Check if the schedule is compatible with torch.compile.
  3008. Args:
  3009. stages: List of pipeline stages to check
  3010. schedule_name: Name of the schedule for error message
  3011. Raises:
  3012. RuntimeError: If any stage uses torch.compile
  3013. """
  3014. for stage in stages:
  3015. if not isinstance(stage.submod, torch.nn.Module):
  3016. continue
  3017. for module in stage.submod.modules():
  3018. if isinstance(module, OptimizedModule):
  3019. raise RuntimeError(
  3020. f"The {schedule_name} schedule is not supported with "
  3021. "stage modules that have used torch.compile. "
  3022. f"Found OptimizedModule in {type(module).__name__}"
  3023. )