| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447 |
- # mypy: allow-untyped-defs
- # Copyright (c) Meta Platforms, Inc. and affiliates
- import copy
- import csv
- import itertools
- import logging
- import re
- from abc import ABC, abstractmethod
- from collections import Counter, defaultdict
- from collections.abc import Callable
- from dataclasses import dataclass
- from enum import Enum
- from functools import lru_cache
- from typing import Any, cast, NamedTuple, Protocol
- import torch
- import torch.distributed as dist
- from torch._dynamo import OptimizedModule
- from torch.distributed.fsdp import FSDPModule, UnshardHandle
- from torch.nn.modules.loss import _Loss
- from torch.profiler import record_function
- from ._utils import generate_rank_to_stage_mapping, generate_stage_to_rank_mapping
- from .microbatch import merge_chunks, split_args_kwargs_into_chunks, TensorChunkSpec
- from .stage import _PipelineStageBase
- __all__ = [
- "get_schedule_class",
- "PipelineScheduleSingle",
- "PipelineScheduleMulti",
- "Schedule1F1B",
- "ScheduleGPipe",
- "ScheduleInterleaved1F1B",
- "ScheduleLoopedBFS",
- "ScheduleInterleavedZeroBubble",
- "ScheduleZBVZeroBubble",
- "ScheduleDualPipeV",
- ]
- logger = logging.getLogger(__name__)
- class _ComputationType(str, Enum):
- # TODO(whc) rename to _ActType?
- FORWARD = "F"
- BACKWARD_INPUT = "I"
- BACKWARD_WEIGHT = "W"
- UNSHARD = "UNSHARD"
- RESHARD = "RESHARD"
- SEND_F = "SEND_F"
- RECV_F = "RECV_F"
- SEND_B = "SEND_B"
- RECV_B = "RECV_B"
- FULL_BACKWARD = "B"
- OVERLAP_F_B = "OVERLAP_F_B"
- REDUCE_GRAD = "REDUCE_GRAD"
- @staticmethod
- def from_str(action: str) -> "_ComputationType":
- try:
- return _ComputationType(action)
- except ValueError as exc:
- raise RuntimeError(f"Invalid computation type {action}") from exc
- FORWARD = _ComputationType.FORWARD
- BACKWARD_INPUT = _ComputationType.BACKWARD_INPUT
- BACKWARD_WEIGHT = _ComputationType.BACKWARD_WEIGHT
- UNSHARD = _ComputationType.UNSHARD
- RESHARD = _ComputationType.RESHARD
- SEND_F = _ComputationType.SEND_F
- RECV_F = _ComputationType.RECV_F
- SEND_B = _ComputationType.SEND_B
- RECV_B = _ComputationType.RECV_B
- FULL_BACKWARD = _ComputationType.FULL_BACKWARD
- OVERLAP_F_B = _ComputationType.OVERLAP_F_B
- REDUCE_GRAD = _ComputationType.REDUCE_GRAD
- # Convenience shorthand for compute actions only since they are used in 'simple schedule format'
- F = FORWARD
- I = BACKWARD_INPUT
- W = BACKWARD_WEIGHT
- B = FULL_BACKWARD
- # Helper to parse an action string like 1F0 into a tuple of (stage_index, computation_type, microbatch_index)
- _action_regex = re.compile(
- r"(\d+)(F|I|B|W|UNSHARD|RESHARD|REDUCE_GRAD|SEND_F|RECV_F|SEND_B|RECV_B)(\d*)"
- )
- class _Action(NamedTuple):
- stage_index: int
- computation_type: _ComputationType
- microbatch_index: int | None = None
- sub_actions: tuple["_Action", ...] | None = None
- def __str__(self):
- return self.__repr__()
- def __repr__(self):
- if self.sub_actions is not None:
- # Use recursive repr for sub_actions
- sub_action_reprs = [repr(sub_action) for sub_action in self.sub_actions]
- return f"({';'.join(sub_action_reprs)}){self.computation_type.value}"
- else:
- repr_str = str(self.stage_index)
- # Use .value to get the short string (e.g., "F", "B") instead of the full enum name
- repr_str += self.computation_type.value
- if self.microbatch_index is not None:
- repr_str += str(self.microbatch_index)
- return repr_str
- @property
- def is_compute_op(self) -> bool:
- return self.computation_type in (
- FORWARD,
- FULL_BACKWARD,
- BACKWARD_INPUT,
- BACKWARD_WEIGHT,
- OVERLAP_F_B,
- )
- @staticmethod
- def from_str(action_string: str):
- """
- Reverse of __repr__
- String should be formatted as [stage][action type][(microbatch)]
- e.g. `2F0`, `1UNSHARD`, `3SEND_F1`
- """
- action_string = action_string.strip()
- if action_string == "":
- return None
- # Check for sub_actions format: [sub_action1;sub_action2;...]ComputationType
- if action_string.startswith("(") and ")" in action_string:
- # Find the closing bracket to separate sub_actions from computation type
- bracket_end = action_string.find(")")
- sub_part = action_string[
- 1:bracket_end
- ] # Remove '[' and get content before ']'
- computation_type_part = action_string[
- bracket_end + 1 :
- ] # Get part after ']'
- # Parse sub_actions
- sub_actions = []
- if sub_part.strip():
- for sub_str in sub_part.split(";"):
- sub_action = _Action.from_str(sub_str.strip())
- if sub_action is not None:
- sub_actions.append(sub_action)
- # For sub_actions format, we create an action with just the computation type
- # The stage_index and microbatch_index are not meaningful for the container action
- return _Action(
- stage_index=-1, # Placeholder, not meaningful for sub_actions container
- computation_type=_ComputationType.from_str(computation_type_part),
- microbatch_index=None,
- sub_actions=tuple(sub_actions) if sub_actions else None,
- )
- # Handle regular single action format
- if match := _action_regex.match(action_string):
- stage_index, computation_type, microbatch_index = match.groups()
- return _Action(
- int(stage_index),
- _ComputationType.from_str(computation_type),
- int(microbatch_index) if len(microbatch_index) else None,
- )
- elif action_string == "":
- return None
- raise RuntimeError(
- f"Invalid action string: {action_string}, should be formatted as [stage][action type][(microbatch)] e.g. 2F0"
- )
- @lru_cache
- def _get_profiler_function_name(action: _Action) -> str:
- return f"PP:{str(action)}"
- def _format_pipeline_order(
- pipeline_order: dict[int, list[_Action | None]],
- error_step_number: int | None = None,
- ) -> str:
- """
- Formats the pipeline order in a timestep (row) x rank (column) grid of actions
- and returns the formatted string.
- If `error_step_number` is passed in, an additional label will be added to signify which step
- that it is erroring on.
- """
- # don't mutate the original
- pipeline_order = copy.deepcopy(pipeline_order)
- # Replace None with ""
- for rank in pipeline_order:
- for i in range(len(pipeline_order[rank])):
- if pipeline_order[rank][i] is None:
- # TODO make a real 'None action' that prints as empty string and make mypy happy
- pipeline_order[rank][i] = "" # type: ignore[call-overload]
- # Calculate the maximum number of steps across all ranks
- num_steps = max(len(actions) for actions in pipeline_order.values())
- step_labels = [
- "Step " + str(i).zfill(len(str(num_steps - 1))) for i in range(num_steps)
- ]
- # Sorting the dictionary by keys and retrieving values in that order
- rank_actions = [
- pipeline_order.get(key, [""] * num_steps) for key in sorted(pipeline_order)
- ]
- # Transpose the list of lists (rows to columns)
- # pyrefly: ignore [no-matching-overload]
- transposed_actions = list(itertools.zip_longest(*rank_actions, fillvalue=""))
- # Generate column labels for ranks
- num_ranks = len(pipeline_order)
- rank_labels = ["Rank " + str(i) for i in range(num_ranks)]
- # Calculate the maximum length of each column, considering labels
- max_lengths = [
- max(len(str(item)) if item is not None else 0 for item in col)
- for col in zip(step_labels, *transposed_actions)
- ]
- # Format the header row with rank labels
- header_row = " " * (len(step_labels[0]) + 2) + " ".join(
- f"{label:<{max_lengths[i]}}" for i, label in enumerate(rank_labels)
- )
- # Format each row with its corresponding label
- formatted_rows = [
- f"{label}: "
- + " ".join(f"{str(item):<{max_lengths[i]}}" for i, item in enumerate(row))
- + (
- " <-- ERROR HERE"
- if error_step_number is not None
- and int(label.split()[1]) == error_step_number
- else ""
- )
- for label, row in zip(step_labels, transposed_actions)
- ]
- # Join the rows into a single string
- formatted_table = header_row + "\n" + "\n".join(formatted_rows) + "\n"
- return formatted_table
- class _PipelineSchedule(ABC):
- def __init__(
- self,
- n_microbatches: int,
- loss_fn: Callable[..., torch.Tensor] | None = None,
- args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None,
- kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None,
- output_merge_spec: dict[str, Any] | tuple[Any] | None = None,
- scale_grads: bool = True,
- ):
- # From arguments
- self._n_microbatches = n_microbatches
- self._loss_fn = loss_fn
- # See documentation in `PipelineScheduleSingle` / `PipelineScheduleMulti`
- self.scale_grads = scale_grads
- # Chunking specification for positional inputs. (default: `None`)
- self._args_chunk_spec = args_chunk_spec
- # Chunking specification for keyword inputs. (default: `None`)
- self._kwargs_chunk_spec = kwargs_chunk_spec
- self._output_merge_spec = output_merge_spec
- """
- # args_chunk_spec and kwargs_chunk_spec specify how to chunk inputs.
- # They are used to convert batch to microbatches in `step(x)`. See
- # `TensorChunkSpec` for helper methods for creating them.
- """
- # Derived
- self._has_backward = self._loss_fn is not None
- # Holds the losses for each microbatch.
- self._internal_losses: list[torch.Tensor] = []
- logger.info("Using %s", self.__class__.__name__)
- def _maybe_compute_loss(self, stage, output, target_mbs, mb_index):
- if stage.is_last and self._loss_fn is not None:
- loss = self._compute_loss(output, target_mbs[mb_index]) # type: ignore[index]
- self._internal_losses.append(loss)
- def _maybe_get_loss(self, stage, mb_index):
- valid_index = 0 <= mb_index < len(self._internal_losses)
- if stage.is_last and self._loss_fn is not None and valid_index:
- return self._internal_losses[mb_index]
- elif len(self._internal_losses) != 0 and not valid_index:
- raise RuntimeError(
- f"Loss for microbatch {mb_index} is not available. "
- f"Available losses for microbatches: {self._internal_losses}"
- )
- else:
- return None
- def _update_losses(self, stages, losses):
- """
- Update the losses to those in the internal state
- """
- # if stages not a list turn into a list
- if not isinstance(stages, list):
- stages = [stages]
- contains_last_stage = any(stage.is_last for stage in stages)
- # Return losses if there is a container passed in
- if contains_last_stage and losses is not None:
- if len(self._internal_losses) != self._n_microbatches:
- raise RuntimeError(
- f"Expecting {self._n_microbatches} losses but got {len(self._internal_losses)}"
- )
- # Clean external container first
- losses.clear()
- # Copy internal losses to external container
- losses.extend(self._internal_losses)
- self._internal_losses.clear()
- @abstractmethod
- def _step_microbatches(
- self,
- arg_mbs: list | None = None,
- kwarg_mbs: list | None = None,
- target_mbs: list | None = None,
- losses: list | None = None,
- return_outputs: bool = True,
- ):
- """
- Run one iteration of the pipeline schedule with list of microbatches.
- Will go through all the microbatches according to the schedule
- implementation.
- Args:
- microbatches: list of microbatch args.
- return_outputs: whether to return the outputs from the last stage.
- """
- raise NotImplementedError
- @abstractmethod
- def step(
- self,
- *args,
- target=None,
- losses: list | None = None,
- return_outputs=True,
- **kwargs,
- ):
- """
- Run one iteration of the pipeline schedule with *whole-batch* input.
- Will chunk the input into microbatches automatically, and go through the
- microbatches according to the schedule implementation.
- args: positional arguments to the model (as in non-pipeline case).
- kwargs: keyword arguments to the model (as in non-pipeline case).
- target: target for the loss function.
- losses: a list to store the losses for each microbatch.
- return_outputs: whether to return the outputs from the last stage.
- """
- raise NotImplementedError
- def eval(self, *args, target=None, losses: list | None = None, **kwargs):
- """
- Run one iteration of the pipeline schedule with *whole-batch* input.
- Will chunk the input into microbatches automatically, and go through the
- microbatches, calling forward only.
- args: positional arguments to the model (as in non-pipeline case).
- kwargs: keyword arguments to the model (as in non-pipeline case).
- target: target values for the loss function.
- losses: a list to store the losses for each microbatch.
- """
- # Save the original has_backward state
- original_has_backward = self._has_backward
- try:
- self._has_backward = False
- return self.step(*args, target=target, losses=losses, **kwargs)
- finally:
- # Restore the original state
- self._has_backward = original_has_backward
- def _check_inputs(
- self,
- arg_mbs: list | None = None,
- kwarg_mbs: list | None = None,
- target_mbs: list | None = None,
- losses: list | None = None,
- ) -> tuple[list, list]:
- """
- Pre-process/check inputs
- """
- def check_type_and_len(mbs, name: str):
- if not isinstance(mbs, list):
- raise TypeError(f"{name} must be a list but got a {type(mbs)}")
- if len(mbs) != self._n_microbatches:
- raise ValueError(
- f"Expecting {self._n_microbatches} {name} but got {len(mbs)}"
- )
- if arg_mbs is not None:
- check_type_and_len(arg_mbs, "arg_mbs")
- else:
- arg_mbs = [()] * self._n_microbatches
- if kwarg_mbs is not None:
- check_type_and_len(kwarg_mbs, "kwarg_mbs")
- else:
- kwarg_mbs = [{}] * self._n_microbatches
- if target_mbs is not None:
- check_type_and_len(target_mbs, "target_mbs")
- if losses is not None:
- if not isinstance(losses, list):
- raise TypeError(f"losses must be a list but got a {type(losses)}")
- return arg_mbs, kwarg_mbs
- def _compute_loss(self, output, target):
- return self._loss_fn(output, target) # type: ignore[misc]
- def _split_inputs(
- self,
- args: tuple[Any, ...],
- kwargs: dict[str, Any] | None = None,
- ):
- """
- Splits a full-batch input into chunks (i.e. microbatches) and returns
- the chunks
- """
- if args or kwargs:
- args_split, kwargs_split = split_args_kwargs_into_chunks(
- args,
- kwargs,
- self._n_microbatches,
- self._args_chunk_spec,
- self._kwargs_chunk_spec,
- )
- return args_split, kwargs_split
- else:
- # Empty inputs (e.g. when called on middle stages)
- # Return a list of empty tuples/dicts with matching length as chunks
- return [()] * self._n_microbatches, [{}] * self._n_microbatches
- def _merge_outputs(self, output_chunks: list[Any]) -> Any:
- """
- Merge output chunks back to a batch state.
- If output_merge_spec is None, the utility will merge output chunks by dimension 0 (batch dim).
- """
- return merge_chunks(
- output_chunks,
- self._output_merge_spec,
- )
- def _batch_p2p(p2p_ops: list[dist.P2POp], desc: str | None = None) -> list[dist.Work]:
- """
- Simple wrapper over batch_isend_irecv from torch.distributed, which just adds a descriptive logger on top.
- """
- if len(p2p_ops) == 0:
- return []
- desc_str = f"{desc}, " if desc else ""
- logger.debug("batch_p2p %s%s", desc_str, p2p_ops)
- return dist.batch_isend_irecv(p2p_ops)
- def _sorted_batch_p2p(
- p2p_ops: list[dist.P2POp], desc: str | None = None
- ) -> dict[int, list[dist.Work]]:
- """
- Sorts the list of P2P ops by the peer rank, and then calls
- batch_isend_irecv. Return a dictionary of works by peer rank. This function
- helps us avoid hangs in case of skip connections.
- """
- # Arrange p2p_ops by peer rank:
- # int is the peer rank;
- # List is the list of ops towards the peer
- ops_by_peer: dict[int, list[dist.P2POp]] = defaultdict(list)
- work_by_peer: dict[int, list[dist.Work]] = {}
- if len(p2p_ops) == 0:
- return work_by_peer
- # Classify the ops by peer rank
- for op in p2p_ops:
- ops_by_peer[op.peer].append(op)
- # Call batch_isend_irecv per peer, in sorted order of the peers (to avoid hangs)
- for peer, ops in sorted(ops_by_peer.items()):
- work_by_peer[peer] = _batch_p2p(ops, desc=desc)
- return work_by_peer
- def _wait_batch_p2p(work: list[dist.Work]):
- """
- Waits for a list of dist.Work (typically from _batch_p2p / _sorted_batch_p2p).
- """
- for w in work:
- w.wait()
- class PipelineScheduleSingle(_PipelineSchedule):
- """
- Base class for single-stage schedules.
- Implements the `step` method.
- Derived classes should implement `_step_microbatches`.
- Gradients are scaled by num_microbatches depending on the `scale_grads` argument, defaulting to True. This setting
- should match the configuration of your loss_fn, which may either average losses (scale_grads=True)
- or sum losses (scale_grads=False).
- """
- def __init__(
- self,
- stage: _PipelineStageBase,
- n_microbatches: int,
- loss_fn: Callable | None = None,
- args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None,
- kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None,
- output_merge_spec: dict[str, Any] | tuple[Any] | None = None,
- scale_grads: bool = True,
- ):
- # Init parent
- super().__init__(
- n_microbatches=n_microbatches,
- loss_fn=loss_fn,
- args_chunk_spec=args_chunk_spec,
- kwargs_chunk_spec=kwargs_chunk_spec,
- output_merge_spec=output_merge_spec,
- scale_grads=scale_grads,
- )
- # Self attributes
- self._stage = stage
- self._num_stages = stage.num_stages
- self._stage_forward_initialized = False
- self._stage_backward_initialized = False
- self.pipeline_order: dict[int, list[_Action | None]] | None = (
- self._get_pipeline_order()
- )
- def _initialize_stage(self, args, kwargs):
- if not self._stage_forward_initialized:
- # Prepare the communication needed for the pipeline schedule execution
- # This is needed because during execution we always perform a series of batch P2P ops
- # The first call of the batched P2P needs to involve the global group
- all_ops: list[dist.P2POp] = []
- all_ops.extend(self._stage._get_init_p2p_neighbors_ops())
- _wait_batch_p2p(_batch_p2p(all_ops))
- self._stage._prepare_forward_infra(self._n_microbatches, args, kwargs)
- self._stage_forward_initialized = True
- if self._has_backward and not self._stage_backward_initialized:
- self._stage._prepare_backward_infra(self._n_microbatches)
- self._stage_backward_initialized = True
- def step(
- self,
- *args,
- target=None,
- losses: list | None = None,
- return_outputs: bool = True,
- **kwargs,
- ):
- """
- Run one iteration of the pipeline schedule with *whole-batch* input.
- Will chunk the input into microbatches automatically, and go through the
- microbatches according to the schedule implementation.
- args: positional arguments to the model (as in non-pipeline case).
- kwargs: keyword arguments to the model (as in non-pipeline case).
- target: target for the loss function.
- losses: a list to store the losses for each microbatch.
- return_outputs: whether to return the outputs from the last stage.
- """
- if self._has_backward and not torch.is_grad_enabled():
- raise RuntimeError(
- "step() requires gradients to be enabled for backward computation; "
- "it should not be used under torch.no_grad() context. "
- "Please call eval() instead."
- )
- # Set the same has_backward flag for stage object
- self._stage.has_backward = self._has_backward
- # Clean per iteration
- self._stage.clear_runtime_states()
- # Split inputs into microbatches
- args_split, kwargs_split = self._split_inputs(args, kwargs)
- # Split target into microbatches
- if target is not None:
- targets_split = list(torch.tensor_split(target, self._n_microbatches))
- else:
- targets_split = None
- # Run microbatches
- self._step_microbatches(
- args_split, kwargs_split, targets_split, losses, return_outputs
- )
- # Return merged results per original format
- if self._stage.is_last and return_outputs:
- return self._merge_outputs(self._stage.output_chunks)
- else:
- return None
- def _get_pipeline_order(self) -> dict[int, list[_Action | None]] | None:
- """
- Returns the pipeline execution order as a schedule IR.
- The returned IR is a dictionary mapping rank IDs to lists of actions.
- Each action is either an _Action object representing computation to perform,
- or None representing a deliberate idle step.
- The None values are used to represent pipeline bubbles where a rank
- must wait for dependencies from other ranks before proceeding. However
- during execution, with the _PipelineScheduleRuntime, these Nones are
- skipped since the relevant communication (send/recv) will be scheduled and waited on.
- Returns:
- A dictionary mapping rank -> list of actions
- """
- return None
- class _ScheduleForwardOnly(PipelineScheduleSingle):
- """
- The forward-only schedule.
- Will go through all the microbatches and perform only the forward pass
- """
- def _step_microbatches(
- self,
- arg_mbs: list | None = None,
- kwarg_mbs: list | None = None,
- target_mbs: list | None = None,
- losses: list | None = None,
- return_outputs: bool = True,
- ):
- """
- Run one iteration of the pipeline schedule
- """
- if target_mbs is not None or losses is not None:
- raise RuntimeError(
- "Forward-only schedule does not support loss computation"
- )
- arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
- self._initialize_stage(arg_mbs[0], kwarg_mbs[0])
- # Delay send waits
- fwd_sends_to_wait: list[list[dist.Work]] = []
- # Run microbatches
- for i in range(self._n_microbatches):
- with record_function(f"Forward {i}"):
- ops = self._stage.get_fwd_recv_ops(i)
- works = _sorted_batch_p2p(ops, desc="fwd_recv")
- for work in works.values():
- _wait_batch_p2p(work)
- self._stage.forward_one_chunk(i, arg_mbs[i], kwarg_mbs[i]) # type: ignore[index]
- ops = self._stage.get_fwd_send_ops(i)
- works = _sorted_batch_p2p(ops, desc="fwd_send")
- fwd_sends_to_wait.extend(works.values())
- logger.debug("[%s] Forwarded microbatch %s", self._stage.stage_index, i)
- # Wait for all forward sends to finish
- # This should not have performance impact because by the time the first
- # backward arrives all the forward sends should have been finished.
- for work in fwd_sends_to_wait:
- _wait_batch_p2p(work)
- class ScheduleGPipe(PipelineScheduleSingle):
- """
- The GPipe schedule.
- Will go through all the microbatches in a fill-drain manner.
- """
- def _step_microbatches(
- self,
- arg_mbs: list | None = None,
- kwarg_mbs: list | None = None,
- target_mbs: list | None = None,
- losses: list | None = None,
- return_outputs: bool = True,
- ):
- """
- Run one iteration of the pipeline schedule with list of microbatches.
- Will go through all the microbatches according to the GPipe schedule.
- Args:
- microbatches: list of microbatch args.
- return_outputs: whether to return the outputs from the last stage.
- """
- arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
- self._initialize_stage(arg_mbs[0], kwarg_mbs[0])
- # Delay send waits
- fwd_sends_to_wait: list[list[dist.Work]] = []
- # Run microbatches
- for i in range(self._n_microbatches):
- with record_function(f"Forward {i}"):
- ops = self._stage.get_fwd_recv_ops(i)
- works = _sorted_batch_p2p(ops, desc="fwd_recv")
- for work in works.values():
- _wait_batch_p2p(work)
- output = self._stage.forward_one_chunk(
- i, arg_mbs[i], kwarg_mbs[i], save_forward_output=return_outputs
- ) # type: ignore[index]
- ops = self._stage.get_fwd_send_ops(i)
- works = _sorted_batch_p2p(ops, desc="fwd_send")
- fwd_sends_to_wait.extend(works.values())
- logger.debug("[%s] Forwarded microbatch %s", self._stage.stage_index, i)
- self._maybe_compute_loss(self._stage, output, target_mbs, i)
- # Wait for all forward sends to finish
- # This should not have performance impact because by the time the first
- # backward arrives all the forward sends should have been finished.
- for work in fwd_sends_to_wait:
- _wait_batch_p2p(work)
- # Run backward
- # Delay send waits
- bwd_sends_to_wait: list[list[dist.Work]] = []
- for i in range(self._n_microbatches):
- with record_function(f"Backward {i}"):
- ops = self._stage.get_bwd_recv_ops(i)
- works = _sorted_batch_p2p(ops, desc="bwd_recv")
- for work in works.values():
- _wait_batch_p2p(work)
- loss = self._maybe_get_loss(self._stage, i)
- self._stage.backward_one_chunk(
- i,
- loss=loss,
- last_backward=i == self._n_microbatches - 1,
- )
- ops = self._stage.get_bwd_send_ops(i)
- works = _sorted_batch_p2p(ops, desc="bwd_send")
- bwd_sends_to_wait.extend(works.values())
- logger.debug("[%s] Backwarded microbatch %s", self._stage.stage_index, i)
- # Wait for all backward sends to finish
- for work in bwd_sends_to_wait:
- _wait_batch_p2p(work)
- # Update losses if there is a container passed in
- self._update_losses(self._stage, losses)
- self._stage.perform_reduce_grad(self._n_microbatches if self.scale_grads else 1)
- def _get_pipeline_order(self) -> dict[int, list[_Action | None]] | None:
- """
- Returns the pipeline order for GPipe schedule.
- See base method in PipelineScheduleSingle for details on the schedule IR format.
- """
- pipeline_order = {}
- pp_group_size = self._num_stages
- for rank in range(pp_group_size):
- actions: list[_Action | None] = []
- # 1. Initial delay based on rank position
- warmup_delay = rank
- actions.extend([None] * warmup_delay)
- # 2. Forward passes for all microbatches
- for mb_idx in range(self._n_microbatches):
- actions.append(_Action(rank, _ComputationType.FORWARD, mb_idx))
- # 3. Wait period before backward passes can begin
- backward_delay = 3 * (pp_group_size - 1 - rank)
- actions.extend([None] * backward_delay)
- # 4. Backward passes for all microbatches
- for mb_idx in range(self._n_microbatches):
- actions.append(_Action(rank, _ComputationType.FULL_BACKWARD, mb_idx))
- pipeline_order[rank] = _add_reduce_grad(actions, self._n_microbatches)
- return pipeline_order # type: ignore[return-value]
- class Schedule1F1B(PipelineScheduleSingle):
- """
- The 1F1B schedule.
- Will perform one forward and one backward on the microbatches in steady state.
- """
- def __init__(
- self,
- stage: _PipelineStageBase,
- n_microbatches: int,
- loss_fn: Callable | None = None,
- args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None,
- kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None,
- output_merge_spec: dict[str, Any] | tuple[Any] | None = None,
- scale_grads: bool = True,
- ):
- super().__init__(
- stage=stage,
- n_microbatches=n_microbatches,
- loss_fn=loss_fn,
- args_chunk_spec=args_chunk_spec,
- kwargs_chunk_spec=kwargs_chunk_spec,
- output_merge_spec=output_merge_spec,
- scale_grads=scale_grads,
- )
- if n_microbatches < self._num_stages:
- raise ValueError(
- f"Number of microbatches ({n_microbatches}) must be greater than \
- or equal to the number of stages ({self._num_stages})."
- )
- def _step_microbatches(
- self,
- arg_mbs: list | None = None,
- kwarg_mbs: list | None = None,
- target_mbs: list | None = None,
- losses: list | None = None,
- return_outputs: bool = True,
- ):
- """
- Run one iteration of the pipeline schedule with list of microbatches.
- Will go through all the microbatches according to the 1F1B schedule.
- Args:
- microbatches: list of microbatch args.
- return_outputs: whether to return the outputs from the last stage.
- """
- arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
- self._initialize_stage(arg_mbs[0], kwarg_mbs[0])
- # Last stage has 1 warmup, second-to-last 2 warmups, ...
- # first stage `num_stages` warmups
- warmup_chunks = min(
- self._n_microbatches,
- self._num_stages - self._stage.stage_index,
- )
- # Chunk counters
- fwd_mb_index = 0
- bwd_mb_index = 0
- # Warmup phase
- send_work: list[dist.Work] = []
- fwd_sends = []
- for _ in range(warmup_chunks):
- # Receive activations
- fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index)
- _wait_batch_p2p(_batch_p2p(fwd_recvs, desc="fwd_recv"))
- # Compute
- output = self._stage.forward_one_chunk(
- fwd_mb_index,
- arg_mbs[fwd_mb_index],
- kwarg_mbs[fwd_mb_index],
- save_forward_output=return_outputs,
- ) # type: ignore[index]
- # Clear previous chunk's forward sends (hopefully they have well
- # finished, otherwise, we are heavily communication bound, in which
- # case it doesn't create a lot of benefit to compute next chunk
- # eagerly either)
- _wait_batch_p2p(send_work)
- # Send activations
- fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index)
- if fwd_mb_index != warmup_chunks - 1:
- # Safe to fire
- send_work = _batch_p2p(fwd_sends, desc="fwd_send")
- # otherwise:
- # The last forward send is left for fuse with first 1B in 1B1F below
- # Compute loss
- self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index)
- fwd_mb_index += 1
- # Now we should have send ops left over, to be fused with first 1B of 1B1F phase below.
- # 1B1F phase
- while True: # Don't worry, we have a break inside
- # We actually do 1B first as the `1B1F` name indicates, so prepare its recv ops
- bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index)
- # Now, we need to fire the fwd_sends and bwd_recvs together
- _wait_batch_p2p(_batch_p2p(fwd_sends + bwd_recvs, desc="fwd_send_bwd_recv"))
- # Backward one chunk
- loss = self._maybe_get_loss(self._stage, bwd_mb_index)
- self._stage.backward_one_chunk(
- bwd_mb_index,
- loss=loss,
- last_backward=bwd_mb_index == self._n_microbatches - 1,
- )
- # Get the bwd send ops, but don't fire, to be fused with the 1F below
- bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index)
- bwd_mb_index += 1
- if fwd_mb_index == self._n_microbatches:
- # We are done with 1B1F, so break with some left-over bwd_sends
- break
- # We prepare 1F of the `1B1F`
- fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index)
- # Fuse it with bwd_sends above
- _wait_batch_p2p(_batch_p2p(bwd_sends + fwd_recvs, desc="bwd_send_fwd_recv"))
- # Now do the fwd
- output = self._stage.forward_one_chunk(
- fwd_mb_index,
- arg_mbs[fwd_mb_index],
- kwarg_mbs[fwd_mb_index],
- save_forward_output=return_outputs,
- ) # type: ignore[index]
- # Compute loss
- self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index)
- # Get the fwd send ops, but don't fire, leave it for the next iter (wrap-around)
- fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index)
- fwd_mb_index += 1
- # Remember we still have some bwd_sends left over after the break? Now it is time to fire it
- send_work = _batch_p2p(bwd_sends, desc="bwd_send")
- # Cooldown
- while bwd_mb_index < self._n_microbatches:
- # prepare bwd recv ops
- bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index)
- _wait_batch_p2p(_batch_p2p(bwd_recvs, desc="bwd_recv"))
- # Backward one chunk
- loss = self._maybe_get_loss(self._stage, bwd_mb_index)
- self._stage.backward_one_chunk(
- bwd_mb_index,
- loss=loss,
- last_backward=bwd_mb_index == self._n_microbatches - 1,
- )
- # Clear previous chunk's backward sends (hopefully they have well finished)
- _wait_batch_p2p(send_work)
- # Get the bwd send ops, fire it
- bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index)
- send_work = _batch_p2p(bwd_sends, desc="bwd_send")
- bwd_mb_index += 1
- # Wait for the last backward send to finish
- _wait_batch_p2p(send_work)
- # Return losses if there is a container passed in
- self._update_losses(self._stage, losses)
- self._stage.perform_reduce_grad(self._n_microbatches if self.scale_grads else 1)
- def _get_pipeline_order(self) -> dict[int, list[_Action | None]] | None:
- """
- Returns the pipeline order for 1F1B schedule.
- See base method in PipelineScheduleSingle for details on the schedule IR format.
- """
- pipeline_order = {}
- pp_group_size = self._num_stages
- for rank in range(pp_group_size):
- actions: list[_Action | None] = []
- # 1. Warmup phase: initial delay based on rank
- actions.extend([None] * rank)
- # 2. Initial forward passes before 1F1B phase
- num_forward = (pp_group_size - 1) - rank
- forward_mb = 0
- for i in range(num_forward):
- actions.append(_Action(rank, _ComputationType.FORWARD, i))
- forward_mb = i
- # 3. Wait for backward to be ready
- wait_for_1f1b = max(0, 2 * (pp_group_size - 1 - rank))
- actions.extend([None] * wait_for_1f1b)
- # 4. 1F1B steady state phase
- backward_mb = 0
- remaining_forward = self._n_microbatches - num_forward
- while remaining_forward > 0:
- # One forward
- forward_mb += 1
- actions.append(_Action(rank, _ComputationType.FORWARD, forward_mb))
- remaining_forward -= 1
- # One backward
- actions.append(
- _Action(rank, _ComputationType.FULL_BACKWARD, backward_mb)
- )
- backward_mb += 1
- # 5. Cooldown phase: remaining backward passes
- remaining_backward = self._n_microbatches - backward_mb
- while remaining_backward > 0:
- # Add None and backward actions in alternating pattern
- # based on distance from the last stage
- if (pp_group_size - rank) > 0:
- actions.append(None)
- # Decrement the wait counter only if we still have backward passes to do
- if remaining_backward > 0:
- actions.append(
- _Action(rank, _ComputationType.FULL_BACKWARD, backward_mb)
- )
- backward_mb += 1
- remaining_backward -= 1
- else:
- # If we're at the last stage, just add backward actions without None
- actions.append(
- _Action(rank, _ComputationType.FULL_BACKWARD, backward_mb)
- )
- backward_mb += 1
- remaining_backward -= 1
- pipeline_order[rank] = _add_reduce_grad(actions, self._n_microbatches)
- return pipeline_order
- def _requires_reduce_grad(action_type: _ComputationType) -> bool:
- return action_type in (W, B)
- def _add_reduce_grad(
- actions: list[_Action | None], n_microbatches: int
- ) -> list[_Action | None]:
- """
- REDUCE_GRAD refers to joint across minibatches grad reduction.
- reduce_grad frees memory and we want to schedule it just after the last "backward"-like stage.
- """
- actions_with_reduce_grad: list[_Action | None] = []
- cnt: dict[int, int] = defaultdict(int)
- def _leaf_action(a, to_schedule):
- if _requires_reduce_grad(a.computation_type):
- stage_index = a.stage_index
- cnt[stage_index] += 1
- if cnt[stage_index] == n_microbatches:
- to_schedule.append(stage_index)
- for a in actions:
- if a is None:
- continue
- actions_with_reduce_grad.append(a)
- schedule_reduce_grad_stage_idxs: list[int] = []
- if a.computation_type == OVERLAP_F_B and a.sub_actions is not None:
- for sub_action in a.sub_actions:
- _leaf_action(sub_action, schedule_reduce_grad_stage_idxs)
- else:
- _leaf_action(a, schedule_reduce_grad_stage_idxs)
- for stage_idx in schedule_reduce_grad_stage_idxs:
- actions_with_reduce_grad.append(_Action(stage_idx, REDUCE_GRAD, None))
- return actions_with_reduce_grad
- def _add_unshard_reshard(
- compute_actions: list[_Action | None],
- max_active_stages: int = 3,
- ) -> list[_Action]:
- """Given a basic schedule involving only compute actions (F,B,W,OVERLAP_F_B), add UNSHARD/RESHARD actions for FSDP.
- UNSHARD refers to fetching the full contents of an FSDP-sharded layer, requiring an all-gather operation.
- RESHARD does the opposite, releasing memory (but doing no communication)
- We abandon the "timestep lock" during lowering
- max_active_stages controls how many prefetches we allow. It should be measured in mb and tuneable but in practice
- 3 stages is probably the thing we want?
- (to account for having one f and one b active, and something else prefetching?)
- """
- def next_stage_indices(count: int, next_actions: list[_Action | None]) -> list[int]:
- """Remove duplicates (same stage, different microbatch), find next 'count' stages that will do compute."""
- seen: set[int] = set()
- ret: list[int] = []
- for a in next_actions:
- if a is not None:
- # Handle OVERLAP_F_B actions by checking their sub_actions
- if a.computation_type == OVERLAP_F_B and a.sub_actions is not None:
- for sub_action in a.sub_actions:
- if sub_action.stage_index not in seen:
- seen.add(sub_action.stage_index)
- ret.append(sub_action.stage_index)
- if len(ret) >= count:
- break
- else:
- # Regular action
- if a.stage_index not in seen:
- seen.add(a.stage_index)
- ret.append(a.stage_index)
- if len(ret) == count:
- break
- return ret
- active_stages: set[int] = set()
- fsdp_aware_actions: list[_Action] = []
- def _unshard(stage_index: int):
- active_stages.add(stage_index)
- fsdp_aware_actions.append(_Action(stage_index, UNSHARD, None))
- def _reshard(stage_index: int):
- active_stages.remove(stage_index)
- fsdp_aware_actions.append(_Action(stage_index, RESHARD, None))
- for i, action in enumerate(compute_actions):
- if action is None:
- continue
- # We prefetch the next N stages we'll see, dropping existing stages to make room
- next_n = next_stage_indices(max_active_stages, compute_actions[i:])
- # Fetch needs to be ordered correctly, so don't use a set
- fetch = list(filter(lambda s: s not in active_stages, next_n))
- # Unclear what the best policy is for eviction, but we can maintain order so we do
- evict = list(filter(lambda s: s not in next_n, active_stages))
- # logger.debug(
- # "_add_unshard_reshard Step %d active: %s fetch %s, evict %s",
- # i,
- # active_stages,
- # fetch,
- # evict,
- # )
- for stage in evict:
- _reshard(stage)
- for stage in fetch:
- _unshard(stage)
- fsdp_aware_actions.append(action)
- # Reshard all remaining active stages after processing all operations
- for stage in list(active_stages):
- _reshard(stage)
- return fsdp_aware_actions
- def _merge_bw(
- compute_actions: list[_Action | None],
- ) -> list[_Action]:
- """Given a basic schedule involving only compute actions (F,I,W), merge adjacent I and W ops into B ops.
- (note: I = BACKWARD_INPUT, W = BACKWARD_WEIGHT, B = FULL_BACKWARD)
- B refers to running the whole backward (not separating grad_input and grad_weight), which can be more efficient
- in some cases.
- """
- merged_actions = []
- while compute_actions:
- action = compute_actions.pop(0)
- if action is None:
- continue
- # Remove any None actions and find the next non-None action
- while len(compute_actions) and compute_actions[0] is None:
- compute_actions.pop(0)
- # Get the next action if it exists
- next_action = compute_actions[0] if len(compute_actions) > 0 else None
- if (
- action.computation_type == BACKWARD_INPUT
- and next_action is not None
- and next_action.computation_type == BACKWARD_WEIGHT
- and action.stage_index == next_action.stage_index
- and action.microbatch_index == next_action.microbatch_index
- ):
- merged_actions.append(
- _Action(action.stage_index, FULL_BACKWARD, action.microbatch_index)
- )
- compute_actions.pop(0)
- else:
- merged_actions.append(action)
- return merged_actions
- def _add_send_recv(
- compute_actions: dict[int, list[_Action]],
- stage_to_rank: Callable[[int], int],
- num_stages: int,
- ) -> dict[int, list[_Action]]:
- """
- Transforms a compute-only schedule into a complete schedule with communication actions.
- For actions with sub-actions (OVERLAP_F_B) we ensure that all the subactions have been
- computed and the communication is ready
- """
- comm_actions: dict[int, list[_Action]] = {rank: [] for rank in compute_actions}
- prev_actions: dict[int, set[_Action]] = {rank: set() for rank in compute_actions}
- def _has_comms(action: _Action) -> bool:
- if action.computation_type == F:
- return action.stage_index != num_stages - 1 and stage_to_rank(
- action.stage_index + 1
- ) != stage_to_rank(action.stage_index)
- elif action.computation_type in (BACKWARD_INPUT, FULL_BACKWARD):
- return action.stage_index != 0 and stage_to_rank(
- action.stage_index - 1
- ) != stage_to_rank(action.stage_index)
- return False
- def _get_comms(action: _Action) -> tuple[_Action, _Action]:
- if not _has_comms(action):
- raise AssertionError(f"{action} is not a valid comm action")
- stage_idx = action.stage_index
- ctype = action.computation_type
- mb_idx = action.microbatch_index
- send = _Action(stage_idx, SEND_F if ctype == F else SEND_B, mb_idx)
- recv_stage_idx = stage_idx + 1 if ctype == F else stage_idx - 1
- recv = _Action(recv_stage_idx, RECV_F if ctype == F else RECV_B, mb_idx)
- return send, recv
- def _ready_to_schedule(action: _Action | None, prev_actions: set[_Action]) -> bool:
- """We don't put our own recv ops in the schedule, we let a sender on another rank put our recv ops in place.
- This helps ensure a sane (non-hanging) ordering of sends and recvs.
- But it also means we might not be able to schedule our next compute action yet.
- """
- if action is None:
- return True
- elif action.computation_type == F and action.stage_index != 0:
- if (
- _Action(action.stage_index, RECV_F, action.microbatch_index)
- in prev_actions
- ):
- return True
- elif (
- _Action(action.stage_index - 1, F, action.microbatch_index)
- in prev_actions
- ):
- return True
- return False
- elif (
- action.computation_type in (BACKWARD_INPUT, FULL_BACKWARD)
- and action.stage_index != num_stages - 1
- ):
- if (
- _Action(action.stage_index, RECV_B, action.microbatch_index)
- in prev_actions
- ):
- return True
- elif (
- _Action(action.stage_index + 1, BACKWARD_INPUT, action.microbatch_index)
- in prev_actions
- ):
- return True
- elif (
- _Action(action.stage_index + 1, FULL_BACKWARD, action.microbatch_index)
- in prev_actions
- ):
- return True
- return False
- else:
- return True
- while compute_actions:
- progress = False
- # go in order of ranks even if dict keys aren't ordered
- for rank in sorted(compute_actions):
- if not (len(compute_actions[rank]) > 0):
- raise AssertionError(f"{rank=}, {len(compute_actions[rank])=}")
- action = compute_actions[rank][0]
- # handle case where parent action (e.g. OVERLAP_F_B) can be comprised of subactions
- if action is not None and action.sub_actions is not None:
- all_actions = action.sub_actions
- else:
- all_actions = (action,)
- if not all(_ready_to_schedule(a, prev_actions[rank]) for a in all_actions):
- continue
- # The action's dependencies are satisfied, so add to schedule
- if action is not None:
- comm_actions[rank].append(action)
- for a in all_actions:
- prev_actions[rank].add(a)
- if _has_comms(a):
- send, recv = _get_comms(a)
- # TODO we can avoid send/recv if the 2 stages are on the same rank.
- # should we avoid that in the runtime or here?
- comm_actions[rank].append(send)
- prev_actions[rank].add(send)
- comm_actions[stage_to_rank(recv.stage_index)].append(recv)
- prev_actions[stage_to_rank(recv.stage_index)].add(recv)
- compute_actions[rank].pop(0)
- if len(compute_actions[rank]) == 0:
- del compute_actions[rank]
- progress = True
- if not progress:
- raise AssertionError(
- "Malformed compute schedule, can't schedule sends/recvs"
- )
- return comm_actions
- def _validate_schedule(
- actions: dict[int, list[_Action | None]],
- pp_group_size: int,
- num_stages: int,
- num_microbatches: int,
- ) -> dict[int, int]:
- if not (len(actions) == pp_group_size):
- raise AssertionError(
- f"Schedule has incorrect number of ranks - expected {pp_group_size}, actual {len(actions)}"
- )
- for rank in range(pp_group_size):
- if rank not in actions:
- raise AssertionError(f"Schedule is missing actions for rank {rank}")
- # We will count all the actions per stage and ensure they happen in a valid order
- # (e.g. F before (B, I) before W for a given microbatch)
- stage_actions: dict[int, dict[_ComputationType, set]] = {
- stage_id: {
- F: set(),
- B: set(),
- I: set(),
- W: set(),
- }
- for stage_id in range(num_stages)
- }
- stage_index_to_rank_mapping = {}
- def _process_action(action: _Action, rank: int, step: int):
- """Process a single action and update stage_actions and stage_index_to_rank_mapping"""
- s_id = action.stage_index
- ctype = action.computation_type
- mb_id = action.microbatch_index
- if ctype == F:
- stage_actions[s_id][F].add(mb_id)
- elif ctype == B:
- if mb_id not in stage_actions[s_id][F]:
- error_msg = (
- f"Rank {rank}, step {step}: Running Full Backward for stage {s_id}, "
- f"microbatch {mb_id} without first running Forward"
- )
- formatted_schedule = _format_pipeline_order(
- actions, error_step_number=step
- )
- full_error_msg = (
- f"{error_msg}\n\nFull pipeline schedule:\n{formatted_schedule}"
- )
- raise AssertionError(full_error_msg)
- stage_actions[s_id][B].add(mb_id)
- elif ctype == I:
- if mb_id not in stage_actions[s_id][F]:
- error_msg = (
- f"Rank {rank}, step {step}: Running Backward Input for stage {s_id}, "
- f"microbatch {mb_id} without first running Forward"
- )
- formatted_schedule = _format_pipeline_order(
- actions, error_step_number=step
- )
- full_error_msg = (
- f"{error_msg}\n\nFull pipeline schedule:\n{formatted_schedule}"
- )
- raise AssertionError(full_error_msg)
- stage_actions[s_id][I].add(mb_id)
- elif ctype == W:
- if mb_id not in stage_actions[s_id][I]:
- error_msg = (
- f"Rank {rank}, step {step}: Running Backward Weight for stage {s_id}, "
- f"microbatch {mb_id} without first running Backward Input"
- )
- formatted_schedule = _format_pipeline_order(
- actions, error_step_number=step
- )
- full_error_msg = (
- f"{error_msg}\n\nFull pipeline schedule:\n{formatted_schedule}"
- )
- raise AssertionError(full_error_msg)
- stage_actions[s_id][W].add(mb_id)
- if s_id not in stage_index_to_rank_mapping:
- stage_index_to_rank_mapping[s_id] = rank
- else:
- existing_rank = stage_index_to_rank_mapping[s_id]
- if not (rank == existing_rank):
- raise AssertionError(
- f"Rank {rank}, step {step}: Stage {s_id} is assigned to both rank {rank} and rank {existing_rank}"
- )
- for rank in actions:
- for step, action in enumerate(actions[rank]):
- if action is None:
- continue
- if not isinstance(action, _Action):
- raise AssertionError(
- f"Rank {rank}, step {step}: Got an invalid action: {action}, expected instance of _Action"
- )
- # Check if action has sub_actions
- if action.sub_actions is not None:
- # Process each sub_action instead of the main action
- for sub_action in action.sub_actions:
- _process_action(sub_action, rank, step)
- else:
- # Process the main action normally
- _process_action(action, rank, step)
- for s_id in stage_actions:
- f_mb = len(stage_actions[s_id][F])
- b_mb = len(stage_actions[s_id][B])
- i_mb = len(stage_actions[s_id][I])
- w_mb = len(stage_actions[s_id][W])
- if not (f_mb == num_microbatches):
- raise AssertionError(
- f"Got {f_mb} {F} microbatches for stage {s_id}, expected {num_microbatches}"
- )
- if not (i_mb == w_mb):
- raise AssertionError(
- f"Invalid backward microbatches for stage {s_id}: I and W must have equal counts, \
- but got I={i_mb}, W={w_mb}"
- )
- if not (b_mb + (i_mb + w_mb) // 2 == num_microbatches):
- raise AssertionError(
- f"Invalid backward microbatches for stage {s_id}: expected {num_microbatches} total backwards, \
- but got B={b_mb}, I={i_mb}, W={w_mb}"
- )
- return stage_index_to_rank_mapping
- class PipelineScheduleMulti(_PipelineSchedule):
- """
- Base class for multi-stage schedules.
- Implements the `step` method.
- Gradients are scaled by num_microbatches depending on the `scale_grads` argument, defaulting to True. This setting
- should match the configuration of your loss_fn, which may either average losses (scale_grads=True)
- or sum losses (scale_grads=False).
- """
- def __init__(
- self,
- stages: list[_PipelineStageBase],
- n_microbatches: int,
- loss_fn: Callable | None = None,
- args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None,
- kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None,
- output_merge_spec: dict[str, Any] | tuple[Any] | None = None,
- use_full_backward: bool | None = None,
- scale_grads: bool = True,
- backward_requires_autograd: bool = True,
- ):
- # Init parent
- super().__init__(
- n_microbatches=n_microbatches,
- loss_fn=loss_fn,
- args_chunk_spec=args_chunk_spec,
- kwargs_chunk_spec=kwargs_chunk_spec,
- output_merge_spec=output_merge_spec,
- scale_grads=scale_grads,
- )
- # Self attributes
- self._stages = stages
- self._num_stages = stages[0].num_stages
- self.pp_group_size = stages[0].group_size
- self.rank = stages[0].group_rank
- # Set the pipeline stage states
- self.stage_index_to_group_rank = generate_stage_to_rank_mapping(
- self.pp_group_size, self._num_stages
- )
- for stage in self._stages:
- stage.stage_index_to_group_rank = self.stage_index_to_group_rank
- self._stages_forward_initialized = False
- self._stages_backward_initialized = False
- # avoid putting a reference to 'self' inside the lambda, it creates a ref cycle
- has_loss: bool = self._loss_fn is not None
- self._should_compute_loss = lambda stage: stage.is_last and has_loss
- # This will be set during init of derived schedules
- self.pipeline_order: dict[int, list[_Action | None]] = {}
- # When using a custom backward function, we may or may not need autograd to be used
- # for the backward pass. This flag is used to determine whether or torch.is_grad_enabled()
- # check should be performed before the step function.
- self._backward_requires_autograd = backward_requires_autograd
- if use_full_backward is not None:
- logger.warning(
- "Deprecation warning: 'use_full_backward' is no longer supported. "
- "Simply stop passing it, and everything should still work fine."
- )
- def _initialize_stages(self, args: tuple[Any, ...], kwargs):
- if not self._stages_forward_initialized:
- # Prepare the communication needed for the pipeline schedule execution
- # This is needed because during execution we always perform a series of batch P2P ops
- # The first call of the batched P2P needs to involve the global group
- all_ops: list[dist.P2POp] = []
- for stage in self._stages:
- all_ops.extend(stage._get_init_p2p_neighbors_ops())
- _wait_batch_p2p(_batch_p2p(all_ops))
- # may be 'none' value (if this stage sends its output shapes to the next stage via P2P)
- # or real value (if this stage and next stage are on the same device)
- next_stage_args: tuple[Any, ...] = tuple()
- for stage in self._stages:
- if stage.is_first:
- next_stage_args = stage._prepare_forward_infra(
- self._n_microbatches, args, kwargs
- )
- else:
- next_stage_args = stage._prepare_forward_infra(
- self._n_microbatches, next_stage_args, kwargs
- )
- self._stages_forward_initialized = True
- if self._has_backward and not self._stages_backward_initialized:
- for stage in self._stages:
- stage._prepare_backward_infra(self._n_microbatches)
- self._stages_backward_initialized = True
- def _validate_and_set_stage_mapping(
- self, actions: dict[int, list[_Action | None]]
- ) -> None:
- """
- Allocates the stage index to rank mapping which is needed for communication
- """
- self.stage_index_to_group_rank = _validate_schedule(
- actions,
- self.pp_group_size,
- self._num_stages,
- self._n_microbatches,
- )
- for stage in self._stages:
- stage.stage_index_to_group_rank = self.stage_index_to_group_rank
- def _dump_csv(self, filename):
- """Dump a CSV representation of the schedule into a file with the provided filename."""
- with open(filename, "w", newline="") as csvfile:
- writer = csv.writer(csvfile)
- for rank in self.pipeline_order:
- writer.writerow(self.pipeline_order[rank])
- def _load_csv(self, filename, format="compute_only"):
- """Load a CSV representation of the schedule from a file with the provided filename.
- This API will most likely get renamed/refactored so is marked as internal for now.
- format must be "compute_only" for PipelineScheduleMulti.
- """
- if format != "compute_only":
- raise AssertionError(f'format must be "compute_only", got {format}')
- with open(filename, newline="") as csvfile:
- reader = csv.reader(csvfile)
- for rank, row in enumerate(reader):
- self.pipeline_order[rank] = [_Action.from_str(s) for s in row]
- # Validates the order of the pipeline actions and infers the stage_to_rank_mapping.
- # This will overwrite the default stage_to_rank_mapping created in the constructor
- self._validate_and_set_stage_mapping(self.pipeline_order)
- def step(
- self,
- *args,
- target=None,
- losses: list | None = None,
- return_outputs: bool = True,
- **kwargs,
- ):
- """
- Run one iteration of the pipeline schedule with *whole-batch* input.
- Will chunk the input into microbatches automatically, and go through the
- microbatches according to the schedule implementation.
- args: positional arguments to the model (as in non-pipeline case).
- kwargs: keyword arguments to the model (as in non-pipeline case).
- target: target for the loss function.
- losses: a list to store the losses for each microbatch.
- return_outputs: whether to return the outputs from the last stage.
- """
- if (
- self._has_backward
- and self._backward_requires_autograd
- and not torch.is_grad_enabled()
- ):
- raise RuntimeError(
- "step() requires gradients to be enabled for backward computation; "
- "it should not be used under torch.no_grad() context. "
- "Please call eval() instead."
- )
- # Set the same has_backward flag for stage object
- for stage in self._stages:
- stage.has_backward = self._has_backward
- # Clean per iteration
- for stage in self._stages:
- stage.clear_runtime_states()
- # Split inputs into microbatches
- args_split, kwargs_split = self._split_inputs(args, kwargs)
- # Split target into microbatches
- if target is not None:
- targets_split = list(torch.tensor_split(target, self._n_microbatches))
- else:
- targets_split = None
- # Run microbatches
- self._step_microbatches(
- args_split, kwargs_split, targets_split, losses, return_outputs
- )
- # Return merged results per original format
- for stage in self._stages:
- if stage.is_last and return_outputs:
- return self._merge_outputs(stage.output_chunks)
- # Does not contain the last stage or we do not return output chunks
- return None
- def _step_microbatches(
- self,
- arg_mbs: list | None = None,
- kwarg_mbs: list | None = None,
- target_mbs: list | None = None,
- losses: list | None = None,
- return_outputs: bool = True,
- ):
- """
- Operate on the microbatches for looped schedules (multiple stages on each rank).
- TODO: Does not use sorted_batch_isend_irecv(). As a result, this schedule does
- not support models with skip connections.
- """
- arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
- self._initialize_stages(arg_mbs[0], kwarg_mbs[0])
- # Based on the plan in Step 1 created in __init__:
- # 2. Perform communication based on the pipeline_order
- stage_index_to_stage: dict[int, _PipelineStageBase] = {
- stage.stage_index: stage for stage in self._stages
- }
- # determine prev_rank and next_rank based on which ranks are next to
- # the stages in the pipeline_order
- all_prev_ranks: set[int] = set()
- all_next_ranks: set[int] = set()
- for stage_index in stage_index_to_stage:
- # TODO: assumption that stages only communicate from distances of +1/-1 (no skip connections)
- if stage_index > 0:
- all_prev_ranks.add(self.stage_index_to_group_rank[stage_index - 1])
- if stage_index < self._num_stages - 1:
- all_next_ranks.add(self.stage_index_to_group_rank[stage_index + 1])
- # count either full_backward or backward_weight together, to determine when to sync DP grads
- backward_counter: Counter[int] = Counter()
- for time_step, action in enumerate(self.pipeline_order[self.rank]):
- try:
- ops: list[dist.P2POp] = []
- if action is not None:
- computation_type = action.computation_type
- mb_index = action.microbatch_index
- stage_index = action.stage_index
- if mb_index is None:
- raise AssertionError(
- "All currently supported action types require valid microbatch_index"
- )
- if computation_type == _ComputationType.FORWARD:
- # perform forward computation
- stage = stage_index_to_stage[stage_index]
- output = stage.forward_one_chunk(
- mb_index,
- arg_mbs[mb_index],
- kwarg_mbs[mb_index],
- save_forward_output=return_outputs,
- )
- self._maybe_compute_loss(stage, output, target_mbs, mb_index)
- ops.extend(stage.get_fwd_send_ops(mb_index))
- elif computation_type == _ComputationType.FULL_BACKWARD:
- # perform backward computation
- stage = stage_index_to_stage[stage_index]
- loss = self._maybe_get_loss(stage, mb_index)
- backward_counter[stage_index] += 1
- last_backward = (
- backward_counter[stage_index] == self._n_microbatches
- )
- grad_scale_factor = (
- self._n_microbatches if self.scale_grads else 1
- )
- stage.backward_one_chunk(
- mb_index,
- loss=loss,
- full_backward=True,
- last_backward=last_backward,
- )
- if last_backward:
- stage.scale_grads(grad_scale_factor)
- ops.extend(stage.get_bwd_send_ops(mb_index))
- elif computation_type == _ComputationType.BACKWARD_INPUT:
- # perform backward computation
- stage = stage_index_to_stage[stage_index]
- loss = self._maybe_get_loss(stage, mb_index)
- stage.backward_one_chunk(
- mb_index,
- loss=loss,
- full_backward=False,
- last_backward=False,
- )
- ops.extend(stage.get_bwd_send_ops(mb_index))
- elif computation_type == _ComputationType.BACKWARD_WEIGHT:
- # perform weight update
- stage = stage_index_to_stage[stage_index]
- backward_counter[stage_index] += 1
- last_backward = (
- backward_counter[stage_index] == self._n_microbatches
- )
- grad_scale_factor = (
- self._n_microbatches if self.scale_grads else 1
- )
- stage.backward_weight_one_chunk(
- mb_index,
- last_backward=last_backward,
- )
- if last_backward:
- stage.scale_grads(grad_scale_factor)
- else:
- raise ValueError(f"Unknown computation type {computation_type}")
- # Look at the neighboring ranks for this current timestep and determine whether
- # this current rank needs to do any recv communication
- for prev_rank in all_prev_ranks:
- prev_rank_ops = self.pipeline_order[prev_rank]
- prev_rank_action = None
- if time_step < len(prev_rank_ops):
- prev_rank_action = prev_rank_ops[time_step]
- if prev_rank_action is not None:
- computation_type = prev_rank_action.computation_type
- mb_index = prev_rank_action.microbatch_index
- stage_index = prev_rank_action.stage_index
- if mb_index is None:
- raise AssertionError(
- "All currently supported action types require valid microbatch_index"
- )
- # Only handle sends for the forward from a previous rank
- if computation_type == _ComputationType.FORWARD:
- # If not the last stage, then receive fwd activations
- if stage_index + 1 in stage_index_to_stage:
- # TODO: We are assuming that stage will always receive from stage-1
- # however that is not necessarily true of get_fwd_recv_ops
- stage = stage_index_to_stage[stage_index + 1]
- ops.extend(stage.get_fwd_recv_ops(mb_index))
- elif computation_type in (
- FULL_BACKWARD,
- BACKWARD_INPUT,
- BACKWARD_WEIGHT,
- ):
- # Previous rank doing backward has no influence for the current rank forward recv
- pass
- else:
- raise ValueError(
- f"Unknown computation type {computation_type}"
- )
- for next_rank in all_next_ranks:
- next_rank_ops = self.pipeline_order[next_rank]
- next_rank_action = None
- if time_step < len(next_rank_ops):
- next_rank_action = next_rank_ops[time_step]
- if next_rank_action is not None:
- computation_type = next_rank_action.computation_type
- mb_index = next_rank_action.microbatch_index
- stage_index = next_rank_action.stage_index
- if not (mb_index is not None):
- raise AssertionError(
- "All currently supported action types require valid microbatch_index"
- )
- # Only handle receives for the backwards from a next rank
- if computation_type in (FORWARD, BACKWARD_WEIGHT):
- # Next rank doing forward or weight update has no influence for the current rank backward recv
- pass
- elif computation_type in (BACKWARD_INPUT, FULL_BACKWARD):
- # If not the first stage, then receive bwd gradients
- if stage_index - 1 in stage_index_to_stage:
- # TODO: We are assuming that stage will always receive from stage+1
- # however that is not necessarily true of get_bwd_recv_ops
- stage = stage_index_to_stage[stage_index - 1]
- ops.extend(stage.get_bwd_recv_ops(mb_index))
- else:
- raise ValueError(
- f"Unknown computation type {computation_type}"
- )
- # do the communication
- _wait_batch_p2p(_batch_p2p(ops))
- except Exception as e:
- logger.error( # noqa: G200
- "[Rank %s] pipeline schedule %s caught the following exception '%s' \
- at time_step %s when running action %s",
- self.rank,
- self.__class__.__name__,
- str(e),
- time_step,
- action,
- )
- logger.error(
- "%s",
- _format_pipeline_order(
- self.pipeline_order, error_step_number=time_step
- ),
- )
- raise e
- # Return losses if there is a container passed in
- self._update_losses(self._stages, losses)
- @dataclass
- class _PipelineContext:
- """Context passed to custom functions during pipeline execution."""
- schedule_ref: _PipelineSchedule
- arg_mbs: list[tuple] | None = None
- kwarg_mbs: list[dict] | None = None
- target_mbs: list | None = None
- losses: list | None = None
- class _CustomFunctionProtocol(Protocol):
- def __call__(self, action: _Action, ctx: _PipelineContext) -> None: ...
- class _PipelineScheduleRuntime(PipelineScheduleMulti):
- """
- Provides a simple runtime that requires a 'schedule IR' including specified communication operations.
- Can be instantiated directly by creating _PipelineScheduleRuntime and calling load_csv, or can be
- subclassed and the subclass can be responsible for creating a schedule IR.
- """
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- # Action to custom function mapping
- self._comp_type_to_function_map: dict[_ComputationType, Callable] = {}
- # count either full_backward or backward_weight together, to determine when to sync DP grads
- self.backward_counter: Counter[int] = Counter()
- # recv ops indexed by (stage_idx, mb_idx) need to be waited on before use
- self.bwd_recv_ops: dict[tuple[int, int], list[dist.Work]] = {}
- self.fwd_recv_ops: dict[tuple[int, int], list[dist.Work]] = {}
- # we track which stages are 'active' when used with FSDP, and wait on unshard ops before computing on stages
- self.unshard_ops: dict[int, list[UnshardHandle]] = defaultdict(list)
- self.unsharded_stages = set()
- def register_custom_function(
- self,
- computation_type: _ComputationType,
- custom_function: _CustomFunctionProtocol,
- ) -> None:
- """
- Register a custom function to be executed for a specific computation type.
- Args:
- computation_type: The computation type for which to register the custom function
- custom_function: The function to execute when this computation type is encountered.
- Must have signature: (action: _Action, ctx: _PipelineContext) -> None
- """
- # Ensure that the computation type is valid
- if computation_type not in (
- FORWARD,
- FULL_BACKWARD,
- BACKWARD_INPUT,
- BACKWARD_WEIGHT,
- OVERLAP_F_B,
- UNSHARD,
- RESHARD,
- REDUCE_GRAD,
- ):
- raise ValueError(
- f"Invalid computation type {computation_type}. Only FORWARD, FULL_BACKWARD, \
- BACKWARD_INPUT, BACKWARD_WEIGHT, OVERLAP_F_B, UNSHARD, RESHARD and REDUCE_GRAD are supported."
- )
- # Check if computation_type is already registered
- if computation_type in self._comp_type_to_function_map:
- logger.warning(
- "Computation type %s is already registered. "
- "Overwriting the existing custom function.",
- computation_type,
- )
- self._comp_type_to_function_map[computation_type] = custom_function
- def _prepare_schedule_with_comms(
- self,
- actions: dict[int, list[_Action | None]],
- format: str = "compute_only",
- ):
- """
- Given an in-memory representation for a simple compute-only schedule, lower it to a complex schedule including
- communication actions. Stores the schedule in self, and must be called before running step_mo()
- """
- # validate the provided actions are valid and overrides the default stage_index_to_group_rank
- super()._validate_and_set_stage_mapping(actions)
- self.pipeline_order_with_comms: dict[int, list[_Action]] = {}
- if format == "compute_comms":
- for rank in actions:
- self.pipeline_order_with_comms[rank] = []
- for action in actions[rank]:
- if action is None:
- raise AssertionError(
- f"Expected action to be not None, got {type(action)}"
- )
- self.pipeline_order_with_comms[rank].append(action)
- # TODO what level of validation should we offer for compute+comms schedule?
- elif format == "compute_only":
- # Validate that the schedule does not have comms already added to it
- for rank, action_list in actions.items():
- for i, action in enumerate(action_list):
- if action is not None:
- if not action.is_compute_op:
- raise ValueError(
- f"Expected compute-only schedule but found communication action "
- f"'{action}' at rank {rank}, position {i}. "
- f"Communication actions (e.g. SEND_F, RECV_F, etc.) "
- f"should not be present when format='compute_only'."
- )
- # Perform schedule lowering
- for rank in actions:
- self.pipeline_order_with_comms[rank] = _add_unshard_reshard(
- actions[rank]
- )
- self.pipeline_order_with_comms[rank] = _add_reduce_grad( # type: ignore[assignment]
- self.pipeline_order_with_comms[rank], # type: ignore[arg-type]
- self._n_microbatches,
- )
- self.pipeline_order_with_comms = _add_send_recv(
- self.pipeline_order_with_comms,
- stage_to_rank=lambda s: self.stage_index_to_group_rank[s],
- num_stages=self._num_stages,
- )
- else:
- raise NotImplementedError(f"{format=} is not implemented")
- def _load_csv(self, filename: str, format: str = "compute_only"):
- """Loads a csv in simple format and then lowers it to include communication actions
- format must be either "compute_only" or "compute_comms". If compute_only, the lowering passes
- will automatically be run to generate a compute_comms schedule.
- """
- if format == "compute_only":
- # this will populate self.pipeline_order
- super()._load_csv(filename)
- # this will populate self.pipeline_order_with_comms
- self._prepare_schedule_with_comms(self.pipeline_order)
- elif format == "compute_comms":
- actions = {}
- with open(filename, newline="") as csvfile:
- reader = csv.reader(csvfile)
- for rank, row in enumerate(reader):
- actions[rank] = [_Action.from_str(s) for s in row]
- self._prepare_schedule_with_comms(actions, format=format)
- else:
- raise NotImplementedError(f"{format=} is not implemented")
- def _dump_csv(self, filename: str, format: str = "compute_comms"):
- """Dump a CSV representation of the schedule into a file with the provided filename."""
- if format == "compute_only":
- if self.pipeline_order is None:
- raise AssertionError("Compute only schedule must be available")
- with open(filename, "w", newline="") as csvfile:
- writer = csv.writer(csvfile)
- for rank in self.pipeline_order:
- writer.writerow(self.pipeline_order[rank])
- elif format == "compute_comms":
- if self.pipeline_order_with_comms is None:
- raise AssertionError(
- "Must initialize compute_comms schedule before dump_csv"
- )
- with open(filename, "w", newline="") as csvfile:
- writer = csv.writer(csvfile)
- for rank in self.pipeline_order_with_comms:
- writer.writerow(self.pipeline_order_with_comms[rank])
- def _simulate(self):
- return _simulate_comms_compute(
- self.pipeline_order_with_comms,
- lambda s: self.stage_index_to_group_rank[s],
- self._num_stages,
- )
- def _assert_unsharded(self, stage: _PipelineStageBase):
- """If an unshard is active for `stage_idx`, wait() it and mark `stage_idx` unshared."""
- stage_uses_fsdp = isinstance(stage.submod, FSDPModule)
- if stage_uses_fsdp:
- stage_idx = stage.stage_index
- if stage_idx in self.unshard_ops:
- for op in self.unshard_ops[stage_idx]:
- op.wait()
- del self.unshard_ops[stage_idx]
- self.unsharded_stages.add(stage_idx)
- if stage_idx not in self.unsharded_stages:
- raise AssertionError(f"Attempted to compute on sharded {stage_idx=}")
- def _step_microbatches(
- self,
- arg_mbs: list | None = None,
- kwarg_mbs: list | None = None,
- target_mbs: list | None = None,
- losses: list | None = None,
- return_outputs: bool = True,
- ):
- """
- Operate on the microbatches for looped schedules (multiple stages on each rank).
- TODO: Does not use sorted_batch_isend_irecv(). As a result, this schedule does
- not support models with skip connections.
- """
- arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
- self._initialize_stages(arg_mbs[0], kwarg_mbs[0])
- # Based on the plan in Step 1 created in __init__:
- # 2. Perform communication based on the pipeline_order
- stage_index_to_stage: dict[int, _PipelineStageBase] = {
- stage.stage_index: stage for stage in self._stages
- }
- if self.pipeline_order_with_comms is None:
- raise AssertionError(
- "Must call _prepare_schedule_with_comms() before calling _step_microbatches()"
- )
- # send ops should be waited on before step() exists, mainly for hygiene
- send_ops: list[list[dist.Work]] = []
- def _perform_action(action: _Action) -> None:
- comp_type = action.computation_type
- mb_index: int = (
- action.microbatch_index if action.microbatch_index is not None else -1
- )
- if not (
- mb_index >= 0
- or comp_type
- in (
- UNSHARD,
- RESHARD,
- REDUCE_GRAD,
- )
- ):
- raise AssertionError(f"{action=} missing mb_index")
- stage_idx = action.stage_index
- stage = stage_index_to_stage[stage_idx]
- stage_uses_fsdp = isinstance(stage.submod, FSDPModule)
- # see [Note: V-schedule special case]
- is_next_stage_on_this_rank = stage_idx + 1 in stage_index_to_stage
- is_prev_stage_on_this_rank = stage_idx - 1 in stage_index_to_stage
- # TODO(whc) it's not actually safe to use _batch_p2p here in the uncommon case the model has skip-connections,
- # since we do not want to batch up ops between more than a pair of ranks. _sorted_batch_p2p would be
- # safe to use instead.
- # However, I was wondering if I should avoid calling batched operators at all in the case that there is
- # only one operator per batch. I could iterate through the 'fwd_send_ops' one by one and run them.
- if comp_type == SEND_F:
- send_ops.append(_batch_p2p(stage.get_fwd_send_ops(mb_index)))
- elif comp_type == SEND_B:
- send_ops.append(_batch_p2p(stage.get_bwd_send_ops(mb_index)))
- elif comp_type == RECV_F:
- if (stage_idx, mb_index) in self.fwd_recv_ops:
- raise AssertionError(
- f"Recv twice for {stage_idx=} {mb_index=} without executing forward"
- )
- self.fwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p(
- stage.get_fwd_recv_ops(mb_index)
- )
- elif comp_type == RECV_B:
- if (stage_idx, mb_index) in self.bwd_recv_ops:
- raise AssertionError(
- f"Recv twice for {stage_idx=} {mb_index=} without executing backward"
- )
- self.bwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p(
- stage.get_bwd_recv_ops(mb_index)
- )
- elif comp_type == UNSHARD:
- if stage_uses_fsdp:
- if not (
- stage_idx not in self.unsharded_stages
- and stage_idx not in self.unshard_ops
- ):
- raise AssertionError(f"Unsharding the same {stage_idx=} twice")
- for submodule in stage.submod.modules():
- if not isinstance(submodule, FSDPModule):
- continue
- handle = cast(UnshardHandle, submodule.unshard(async_op=True))
- self.unshard_ops[stage_idx].append(handle)
- elif comp_type == RESHARD:
- if stage_uses_fsdp:
- if stage_idx not in self.unsharded_stages:
- raise AssertionError(
- f"Resharding {stage_idx=} without unsharding"
- )
- if stage_idx in self.unshard_ops:
- raise AssertionError(
- f"Resharding {stage_idx=} before finishing unshard"
- )
- for submodule in stage.submod.modules():
- if not isinstance(submodule, FSDPModule):
- continue
- submodule.reshard()
- self.unsharded_stages.remove(stage_idx)
- elif comp_type == FORWARD:
- self._assert_unsharded(stage)
- if (
- not stage.is_first
- # no recv op expected for V-schedule special case (see [Note: V-schedule special case])
- and not is_prev_stage_on_this_rank
- ):
- if (stage_idx, mb_index) not in self.fwd_recv_ops:
- raise AssertionError(
- f"Computing {action=} before receiving input"
- )
- _wait_batch_p2p(self.fwd_recv_ops.pop((stage_idx, mb_index)))
- output = stage.forward_one_chunk(
- mb_index,
- arg_mbs[mb_index], # type: ignore[index]
- kwarg_mbs[mb_index], # type: ignore[index]
- save_forward_output=return_outputs,
- )
- self._maybe_compute_loss(stage, output, target_mbs, mb_index)
- # SEND/RECV op are avoided for special case with 2 adjacent stages on same rank
- # see [Note: V-schedule special case]
- if is_next_stage_on_this_rank:
- stage_index_to_stage[stage_idx + 1].set_local_fwd_input(
- output, mb_index
- )
- elif comp_type == FULL_BACKWARD:
- self._assert_unsharded(stage)
- if (
- not stage.is_last
- # no recv op expected for V-schedule special case (see [Note: V-schedule special case])
- and not is_next_stage_on_this_rank
- ):
- if (stage_idx, mb_index) not in self.bwd_recv_ops:
- raise AssertionError(
- f"Attempted to run compute {action=} before receiving input"
- )
- _wait_batch_p2p(self.bwd_recv_ops.pop((stage_idx, mb_index)))
- loss = self._maybe_get_loss(stage, mb_index)
- self.backward_counter[stage_idx] += 1
- last_backward = self.backward_counter[stage_idx] == self._n_microbatches
- stage.backward_one_chunk(
- mb_index,
- loss=loss,
- full_backward=True,
- last_backward=last_backward,
- )
- # SEND/RECV op are avoided for special case with 2 adjacent stages on same rank
- # see [Note: V-schedule special case]
- if is_prev_stage_on_this_rank:
- stage_index_to_stage[stage_idx - 1].set_local_bwd_input(
- stage.get_local_bwd_output(mb_index), mb_index
- )
- elif comp_type == BACKWARD_INPUT:
- self._assert_unsharded(stage)
- if not stage.is_last and not is_next_stage_on_this_rank:
- if (stage_idx, mb_index) not in self.bwd_recv_ops:
- raise AssertionError(
- f"Attempted to run compute {action=} before receiving input"
- )
- _wait_batch_p2p(self.bwd_recv_ops.pop((stage_idx, mb_index)))
- loss = self._maybe_get_loss(stage, mb_index)
- stage.backward_one_chunk(
- mb_index,
- loss=loss,
- full_backward=False,
- last_backward=False,
- )
- # SEND/RECV op are avoided for special case with 2 adjacent stages on same rank
- # see [Note: V-schedule special case]
- if is_prev_stage_on_this_rank:
- stage_index_to_stage[stage_idx - 1].set_local_bwd_input(
- stage.get_local_bwd_output(mb_index), mb_index
- )
- elif comp_type == BACKWARD_WEIGHT:
- self._assert_unsharded(stage)
- self.backward_counter[stage_idx] += 1
- last_backward = self.backward_counter[stage_idx] == self._n_microbatches
- stage.backward_weight_one_chunk(
- mb_index,
- last_backward=last_backward,
- )
- elif comp_type == REDUCE_GRAD:
- grad_scale_factor = self._n_microbatches if self.scale_grads else 1
- stage.perform_reduce_grad(grad_scale_factor)
- else:
- raise ValueError(f"{action=} is unknown or unsupported")
- # count either full_backward or backward_weight together, to determine when to sync DP grads
- self.backward_counter.clear()
- for time_step, action in enumerate(self.pipeline_order_with_comms[self.rank]):
- logger.debug(
- "_PipelineScheduleRuntime running time_step %d, action %s",
- time_step,
- action,
- )
- try:
- with record_function(_get_profiler_function_name(action)):
- if action.computation_type in self._comp_type_to_function_map:
- ctx = _PipelineContext(
- self,
- arg_mbs,
- kwarg_mbs,
- target_mbs,
- losses,
- )
- self._comp_type_to_function_map[action.computation_type](
- action, ctx
- )
- elif action.computation_type == OVERLAP_F_B:
- if action.sub_actions is None:
- raise AssertionError("sub_actions must be set")
- for sub_a in action.sub_actions:
- _perform_action(sub_a)
- else:
- _perform_action(action)
- except Exception as e:
- logger.error(
- "_PipelineScheduleRuntime caught exception at step %s when running action %s. Full Schedule:",
- time_step,
- action,
- )
- logger.error(
- _format_pipeline_order(
- self.pipeline_order_with_comms, # type: ignore[arg-type]
- error_step_number=time_step,
- )
- )
- raise e
- # Mostly these operations should have finished long ago, but there isn't an obvious time when to wait for them
- while send_ops:
- _wait_batch_p2p(send_ops.pop())
- if len(self.unshard_ops) != 0:
- raise AssertionError("Unused unshard operations")
- # Return losses if there is a container passed in
- self._update_losses(self._stages, losses)
- class ScheduleLoopedBFS(_PipelineScheduleRuntime):
- """
- Breadth-First Pipeline Parallelism.
- See https://arxiv.org/abs/2211.05953 for details.
- Similar to Interleaved 1F1B, Looped BFS supports multiple stages per rank.
- What is different is that when microbatches are ready for multiple local
- stages, Loops BFS will prioritizes the earlier stage, running all available
- microbatches at once.
- """
- def __init__(
- self,
- stages: list[_PipelineStageBase],
- n_microbatches: int,
- loss_fn: Callable | _Loss | None = None,
- output_merge_spec: dict[str, Any] | tuple[Any] | None = None,
- scale_grads: bool = True,
- backward_requires_autograd: bool = True,
- ):
- super().__init__(
- stages=stages,
- n_microbatches=n_microbatches,
- loss_fn=loss_fn,
- output_merge_spec=output_merge_spec,
- scale_grads=scale_grads,
- backward_requires_autograd=backward_requires_autograd,
- )
- # 1. Create the pipeline_order (all ranks do this calculation)
- # This will be used to keep track of the current state of the entire pipeline
- # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
- self.pipeline_order: dict[int, list[_Action | None]] = {}
- # ========================================================================
- for rank in range(self.pp_group_size):
- rank_ops = self._calculate_single_rank_operations(rank)
- self.pipeline_order[rank] = rank_ops
- # Initialize the pipeline order with communication necessary to run with _PipelineScheduleRuntime
- self._prepare_schedule_with_comms(self.pipeline_order)
- def _calculate_single_rank_operations(self, rank):
- n_local_stages = len(self._stages)
- stage_indices = range(
- rank, self.pp_group_size * n_local_stages, self.pp_group_size
- )
- # Store the list of operations used for that rank
- # Pre-padding, rank starts with no-ops based on the warmup.
- rank_ops: list[_Action | None] = [None for _ in range(rank)]
- for stage_index in stage_indices:
- rank_ops.extend(
- _Action(stage_index, _ComputationType.FORWARD, mb_index)
- for mb_index in range(self._n_microbatches)
- )
- # wait for the first backward to trickle up
- # which is 2 for every hop away
- post_warmup_ops = 2 * (self.pp_group_size - 1 - rank)
- rank_ops.extend([None] * post_warmup_ops)
- for stage_index in reversed(stage_indices):
- rank_ops.extend(
- _Action(stage_index, _ComputationType.FULL_BACKWARD, mb_index)
- for mb_index in reversed(range(self._n_microbatches))
- )
- return rank_ops
- def _get_1f1b_rank_ops(
- n_local_stages,
- pp_group_size,
- warmup_ops,
- fwd_bwd_ops,
- cooldown_ops,
- rank,
- forward_stage_index,
- backward_stage_index,
- num_1f1b_microbatches=0,
- enable_zero_bubble=False,
- ):
- # All stages start with handling microbatch 0
- fwd_stage_mb_index: dict[int, int] = defaultdict(int)
- bwd_stage_mb_index: dict[int, int] = defaultdict(int)
- weight_stage_mb_index: dict[int, int] = defaultdict(int)
- # Store the list of operations used for that rank
- # Pre-padding, rank starts with no-ops based on the warmup.
- rank_ops: list[_Action | None] = [None for _ in range(rank)]
- # These are used to calculate the number of slots to fill with no-ops, to account for the delay in warmup
- # when we want to wait for the backward to trickle back up and start 1f1b to align all ranks.
- # Formula:
- # pre-padding + warmup_ops + post_warmup_ops = earliest time step of first backward
- # post_warmup_ops = [earliest time step of first backward] - (warmup_ops + pre-padding)
- # earliest time step of first backward = [local_stages * group_size + 2 * (group_size - 1 - rank)]
- # warmup_ops = calculated above
- post_warmup_ops = (
- n_local_stages * pp_group_size + 2 * (pp_group_size - 1 - rank)
- ) - (warmup_ops + rank)
- if enable_zero_bubble:
- post_warmup_ops = pp_group_size - rank - 1
- total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops
- backward_op_ids = []
- weight_op_count = 0
- FULL_BACKWARD_OR_BACKWARD_INPUT = (
- BACKWARD_INPUT if enable_zero_bubble else FULL_BACKWARD
- )
- for op in range(total_ops):
- # Warmup phase
- if op < warmup_ops:
- fwd_stage_index = forward_stage_index(op)
- # This will assign the current microbatch index and update it as well
- fwd_stage_mb_index[fwd_stage_index] = (
- mb_index := fwd_stage_mb_index[fwd_stage_index]
- ) + 1
- rank_ops.append(
- _Action(fwd_stage_index, _ComputationType.FORWARD, mb_index)
- )
- if op == warmup_ops - 1:
- # This is the last step in the warmup phase, so we need to wait for the backward to trickle back up
- rank_ops.extend([None] * post_warmup_ops)
- # 1F1B Phase (forward and backward)
- elif warmup_ops <= op < warmup_ops + fwd_bwd_ops:
- fwd_stage_index = forward_stage_index(op)
- fwd_stage_mb_index[fwd_stage_index] = (
- fwd_mb_index := fwd_stage_mb_index[fwd_stage_index]
- ) + 1
- rank_ops.append(
- _Action(fwd_stage_index, _ComputationType.FORWARD, fwd_mb_index)
- )
- bwd_stage_index = backward_stage_index(op)
- bwd_stage_mb_index[bwd_stage_index] = (
- bwd_mb_index := bwd_stage_mb_index[bwd_stage_index]
- ) + 1
- rank_ops.append(
- _Action(bwd_stage_index, FULL_BACKWARD_OR_BACKWARD_INPUT, bwd_mb_index)
- )
- backward_op_ids.append(op)
- if enable_zero_bubble and op - warmup_ops >= num_1f1b_microbatches:
- weight_stage_index = backward_stage_index(
- backward_op_ids[weight_op_count]
- )
- weight_stage_mb_index[weight_stage_index] = (
- weight_mb_index := weight_stage_mb_index[weight_stage_index]
- ) + 1
- rank_ops.append(
- _Action(
- weight_stage_index,
- _ComputationType.BACKWARD_WEIGHT,
- weight_mb_index,
- )
- )
- weight_op_count += 1
- # Cooldown phase
- else:
- # During cooldown phase, we need steps to align with 1f1b happening in other ranks
- # TODO: we don't need to always append, after all 1f1b are finished we can stop appending None
- if not enable_zero_bubble:
- rank_ops.append(None)
- bwd_stage_index = backward_stage_index(op)
- bwd_stage_mb_index[bwd_stage_index] = (
- bwd_mb_index := bwd_stage_mb_index[bwd_stage_index]
- ) + 1
- rank_ops.append(
- _Action(bwd_stage_index, FULL_BACKWARD_OR_BACKWARD_INPUT, bwd_mb_index)
- )
- backward_op_ids.append(op)
- if enable_zero_bubble and op - warmup_ops >= num_1f1b_microbatches:
- weight_stage_index = backward_stage_index(
- backward_op_ids[weight_op_count]
- )
- weight_stage_mb_index[weight_stage_index] = (
- weight_mb_index := weight_stage_mb_index[weight_stage_index]
- ) + 1
- rank_ops.append(
- _Action(
- weight_stage_index,
- _ComputationType.BACKWARD_WEIGHT,
- weight_mb_index,
- )
- )
- weight_op_count += 1
- while enable_zero_bubble and weight_op_count < len(backward_op_ids):
- weight_stage_index = backward_stage_index(backward_op_ids[weight_op_count])
- weight_stage_mb_index[weight_stage_index] = (
- weight_mb_index := weight_stage_mb_index[weight_stage_index]
- ) + 1
- rank_ops.append(
- _Action(
- weight_stage_index, _ComputationType.BACKWARD_WEIGHT, weight_mb_index
- )
- )
- weight_op_count += 1
- return rank_ops
- def _get_warmup_ops(
- rank: int,
- n_local_stages: int,
- microbatches_per_round: int,
- pp_group_size: int,
- n_microbatches: int,
- multiply_factor: int = 2,
- ) -> int:
- """
- Calculate the number of warmup operations for interleaved schedules.
- """
- # Warmup operations for last stage
- warmups_ops_last_stage = (n_local_stages - 1) * microbatches_per_round
- # Increment warmup operations by multiply_factor for each hop away from the last stage
- warmup_ops = warmups_ops_last_stage + multiply_factor * ((pp_group_size - 1) - rank)
- # We cannot have more warmup operations than there are number of microbatches, so cap it there
- return min(warmup_ops, n_microbatches * n_local_stages)
- class ScheduleInterleaved1F1B(_PipelineScheduleRuntime):
- """
- The Interleaved 1F1B schedule.
- See https://arxiv.org/pdf/2104.04473 for details.
- Will perform one forward and one backward on the microbatches in steady
- state and supports multiple stages per rank. When microbatches are ready for
- multiple local stages, Interleaved 1F1B prioritizes the earlier microbatch
- (also called "depth first").
- This schedule is mostly similar to the original paper.
- It differs by being relaxing the requirement of num_microbatch % pp_size == 0.
- Using the flex_pp schedule, we will have num_rounds = max(1, n_microbatches // pp_group_size) and
- it works as long as n_microbatches % num_rounds is 0. As a few examples, support
- 1. pp_group_size = 4, n_microbatches = 10. We will have num_rounds = 2 and n_microbatches % 2 is 0.
- 2. pp_group_size = 4, n_microbatches = 3. We will have num_rounds = 1 and n_microbatches % 1 is 0.
- """
- def __init__(
- self,
- stages: list[_PipelineStageBase],
- n_microbatches: int,
- loss_fn: Callable | None = None,
- args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None,
- kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None,
- output_merge_spec: dict[str, Any] | tuple[Any] | None = None,
- scale_grads: bool = True,
- backward_requires_autograd: bool = True,
- ):
- self.pp_group_size = stages[0].group_size
- super().__init__(
- stages=stages,
- n_microbatches=n_microbatches,
- loss_fn=loss_fn,
- args_chunk_spec=args_chunk_spec,
- kwargs_chunk_spec=kwargs_chunk_spec,
- output_merge_spec=output_merge_spec,
- scale_grads=scale_grads,
- backward_requires_autograd=backward_requires_autograd,
- )
- self.n_local_stages = len(stages)
- self.rank = stages[0].group_rank
- self.number_of_rounds = max(1, n_microbatches // self.pp_group_size)
- self.microbatches_per_round = n_microbatches // self.number_of_rounds
- if n_microbatches % self.number_of_rounds != 0:
- raise ValueError(
- "Interleaved 1F1B requires the number of microbatches to be a "
- f"multiple of the number of rounds ({self.number_of_rounds}), "
- f"but got {n_microbatches}."
- )
- # 1. Create the pipeline_order (all ranks do this calculation)
- # This will be used to keep track of the current state of the entire pipeline
- # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
- self.pipeline_order: dict[int, list[_Action | None]] = {}
- for rank in range(self.pp_group_size):
- rank_ops = self._calculate_single_rank_operations(rank)
- self.pipeline_order[rank] = rank_ops
- # Initialize the pipeline order with communication necessary to run with _PipelineScheduleRuntime
- self._prepare_schedule_with_comms(self.pipeline_order)
- def _calculate_single_rank_operations(self, rank) -> list[_Action | None]:
- warmup_ops = _get_warmup_ops(
- rank,
- self.n_local_stages,
- self.microbatches_per_round,
- self.pp_group_size,
- self._n_microbatches,
- multiply_factor=2,
- )
- microbatch_ops = self.n_local_stages * self._n_microbatches
- # fwd_bwd_ops should encompass the remaining forwards
- fwd_bwd_ops = microbatch_ops - warmup_ops
- # cooldown_ops should encompass the remaining backwards
- cooldown_ops = microbatch_ops - fwd_bwd_ops
- # total ops encompass both forward and backward ops
- total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops
- # warmup_ops + fwd_bwd_ops * 2 + cooldown_ops == microbatch_ops * 2
- logger.debug(
- "rank %s, warmup_ops %s, 1f1b %s, cooldown_ops %s total_ops %s",
- rank,
- warmup_ops,
- fwd_bwd_ops,
- cooldown_ops,
- total_ops,
- )
- # Calculates the stage index based on step and pp_group_size
- def forward_stage_index(step):
- # Get the local index from 0 to n_local_stages-1
- local_index = (step // self.microbatches_per_round) % self.n_local_stages
- return (local_index * self.pp_group_size) + rank
- def backward_stage_index(step):
- local_index = (
- self.n_local_stages
- - 1
- - ((step - warmup_ops) // self.microbatches_per_round)
- % self.n_local_stages
- )
- return (local_index * self.pp_group_size) + rank
- return _get_1f1b_rank_ops(
- self.n_local_stages,
- self.pp_group_size,
- warmup_ops,
- fwd_bwd_ops,
- cooldown_ops,
- rank,
- forward_stage_index,
- backward_stage_index,
- )
- class ScheduleInterleavedZeroBubble(_PipelineScheduleRuntime):
- """
- The Interleaved Zero Bubble schedule.
- See https://arxiv.org/pdf/2401.10241 for details.
- Will perform one forward and one backward on inputs for the microbatches in steady
- state and supports multiple stages per rank. Uses the backward for weights to fill in
- the pipeline bubble.
- In particular this is implementing the ZB1P schedule in the paper.
- """
- def __init__(
- self,
- stages: list[_PipelineStageBase],
- n_microbatches: int,
- loss_fn: Callable | None = None,
- args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None,
- kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None,
- output_merge_spec: dict[str, Any] | tuple[Any] | None = None,
- scale_grads: bool = True,
- backward_requires_autograd: bool = True,
- ):
- # TODO: we dont support input/weight backward split with torch.compile
- _check_torch_compile_compatibility(stages, self.__class__.__name__)
- self.pp_group_size = stages[0].group_size
- super().__init__(
- stages=stages,
- n_microbatches=n_microbatches,
- loss_fn=loss_fn,
- args_chunk_spec=args_chunk_spec,
- kwargs_chunk_spec=kwargs_chunk_spec,
- output_merge_spec=output_merge_spec,
- scale_grads=scale_grads,
- backward_requires_autograd=backward_requires_autograd,
- )
- self.n_local_stages = len(stages)
- self.rank = stages[0].group_rank
- self.number_of_rounds = max(1, n_microbatches // self.pp_group_size)
- self.microbatches_per_round = n_microbatches // self.number_of_rounds
- if n_microbatches % self.number_of_rounds != 0:
- raise ValueError(
- "Zero bubble requires the number of microbatches to be a "
- f"multiple of the number of rounds ({self.number_of_rounds}), "
- f"but got {n_microbatches}."
- )
- # 1. Create the pipeline_order (all ranks do this calculation)
- # This will be used to keep track of the current state of the entire pipeline
- # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
- self.pipeline_order: dict[int, list[_Action | None]] = {}
- for rank in range(self.pp_group_size):
- rank_ops = self._calculate_single_rank_operations(rank)
- self.pipeline_order[rank] = rank_ops
- # This function add bubbles to the generated schedule based on dependencies of actions
- # Note that the ZB1P schedule will not require bubbles to be manually added and it is
- # only useful when n_microbatches <= microbatches_per_round
- self.pipeline_order = self._add_bubbles_to_actions(
- self.n_local_stages * self.pp_group_size,
- )
- # Initialize the pipeline order with communication necessary to run with _PipelineScheduleRuntime
- self._prepare_schedule_with_comms(self.pipeline_order)
- def _calculate_single_rank_operations(self, rank) -> list[_Action | None]:
- warmup_ops = _get_warmup_ops(
- rank,
- self.n_local_stages,
- self.microbatches_per_round,
- self.pp_group_size,
- self._n_microbatches,
- multiply_factor=1,
- )
- microbatch_ops = self.n_local_stages * self._n_microbatches
- # fwd_bwd_ops should encompass the remaining forwards
- fwd_bwd_ops = microbatch_ops - warmup_ops
- # cooldown_ops should encompass the remaining backwards
- cooldown_ops = microbatch_ops - fwd_bwd_ops
- # total ops encompass both forward and backward ops
- total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops
- # warmup_ops + fwd_bwd_ops * 2 + cooldown_ops == microbatch_ops * 2
- logger.debug(
- "rank %s, warmup_ops %s, 1f1b %s, cooldown_ops %s total_ops %s",
- rank,
- warmup_ops,
- fwd_bwd_ops,
- cooldown_ops,
- total_ops,
- )
- # Calculates the stage index based on step and pp_group_size
- def forward_stage_index(step):
- # Get the local index from 0 to n_local_stages-1
- local_index = (step // self.microbatches_per_round) % self.n_local_stages
- return (local_index * self.pp_group_size) + rank
- def backward_stage_index(step):
- local_index = (
- self.n_local_stages
- - 1
- - ((step - warmup_ops) // self.microbatches_per_round)
- % self.n_local_stages
- )
- return (local_index * self.pp_group_size) + rank
- num_1f1b_microbatches = rank
- return _get_1f1b_rank_ops(
- self.n_local_stages,
- self.pp_group_size,
- warmup_ops,
- fwd_bwd_ops,
- cooldown_ops,
- rank,
- forward_stage_index,
- backward_stage_index,
- num_1f1b_microbatches,
- enable_zero_bubble=True,
- )
- def _add_bubbles_to_actions(self, num_stages_global):
- actions = self.pipeline_order
- def need_bubble(stage, op, microbatch, num_stages_global, seen_ops):
- if op == _ComputationType.FORWARD:
- if stage != 0 and (stage - 1, op, microbatch) not in seen_ops:
- return True
- elif op == _ComputationType.FULL_BACKWARD:
- if stage == num_stages_global - 1:
- return (stage, _ComputationType.FORWARD, microbatch) not in seen_ops
- return (stage + 1, op, microbatch) not in seen_ops
- return False
- seen_ops: set[tuple[int, _ComputationType, int]] = set()
- result: dict[int, list[_Action | None]] = {}
- next_pointer: dict[int, int] = {}
- bubbles_added: dict[int, int] = {}
- total_bubbles_added = 0
- for rank in range(self.pp_group_size):
- result[rank] = []
- next_pointer[rank] = 0
- bubbles_added[rank] = 0
- while True:
- should_stop = True
- temp_seen_ops: set[tuple[int, _ComputationType, int]] = set()
- for rank in range(self.pp_group_size):
- timestamp = next_pointer[rank]
- if timestamp >= len(actions[rank]):
- continue
- should_stop = False
- if actions[rank][timestamp] is not None:
- temp_action = actions[rank][timestamp]
- if temp_action is None:
- raise AssertionError(
- f"Expected temp_action to be not None, got {type(temp_action)}"
- )
- stage_index, op, microbatch, _ = temp_action
- if not need_bubble(
- stage_index, op, microbatch, num_stages_global, seen_ops
- ):
- result[rank].append(actions[rank][timestamp])
- if microbatch is not None:
- temp_seen_ops.add((stage_index, op, microbatch))
- next_pointer[rank] += 1
- else:
- result[rank].append(None)
- bubbles_added[rank] += 1
- else:
- next_pointer[rank] += 1
- result[rank].append(None)
- seen_ops.update(temp_seen_ops)
- if should_stop:
- break
- if total_bubbles_added > 0:
- logger.warning(
- "Non zero bubbles added: total_bubbles_added=%s bubbles_added=%s",
- total_bubbles_added,
- bubbles_added,
- )
- return result
- class ScheduleZBVZeroBubble(_PipelineScheduleRuntime):
- """
- The Zero Bubble schedule (ZBV variant).
- See https://arxiv.org/pdf/2401.10241 Section 6 for details.
- This schedules requires exactly two stages per rank.
- This schedule will perform one forward and one backward on inputs for the microbatches in steady
- state and supports multiple stages per rank. Uses backward with respect to weights to fill in
- the pipeline bubble.
- This ZB-V schedule would have the "zero bubble" property only if time forward == time backward input == time backward weights.
- In practice, this is not likely true for real models so alternatively
- a greedy scheduler could be implemented for unequal/unbalanced time.
- """
- def __init__(
- self,
- stages: list[_PipelineStageBase],
- n_microbatches: int,
- loss_fn: Callable | None = None,
- args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None,
- kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None,
- output_merge_spec: dict[str, Any] | tuple[Any] | None = None,
- scale_grads: bool = True,
- backward_requires_autograd: bool = True,
- ):
- # TODO: we dont support input/weight backward split with torch.compile
- _check_torch_compile_compatibility(stages, self.__class__.__name__)
- self.pp_group_size = stages[0].group_size
- super().__init__(
- stages=stages,
- n_microbatches=n_microbatches,
- loss_fn=loss_fn,
- args_chunk_spec=args_chunk_spec,
- kwargs_chunk_spec=kwargs_chunk_spec,
- output_merge_spec=output_merge_spec,
- scale_grads=scale_grads,
- backward_requires_autograd=backward_requires_autograd,
- )
- self.stage_index_to_group_rank = generate_stage_to_rank_mapping(
- self.pp_group_size, self._num_stages, style="v"
- )
- for stage in self._stages:
- stage.stage_index_to_group_rank = self.stage_index_to_group_rank
- self.n_local_stages = len(stages)
- if self.n_local_stages != 2:
- raise ValueError(
- "ZBV requires exactly 2 stages per rank, but got "
- f"{self.n_local_stages}."
- )
- self.rank = stages[0].group_rank
- self.num_stages = stages[0].num_stages
- # 1. Create the pipeline_order (all ranks do this calculation)
- # This will be used to keep track of the current state of the entire pipeline
- # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
- self.pipeline_order: dict[int, list[_Action | None]] = {}
- for rank in range(self.pp_group_size):
- rank_ops = self._calculate_single_rank_operations(rank)
- self.pipeline_order[rank] = rank_ops
- # Initialize the pipeline order with communication necessary to run with _PipelineScheduleRuntime
- self._prepare_schedule_with_comms(self.pipeline_order)
- def _calculate_single_rank_operations(self, rank) -> list[_Action | None]:
- # max(2 * self.pp_group_size - 1, ...) ensure the number of microbatches is at least
- # as large of the number of microbatches needed to fully utilize the pipeline
- n_micro = max(2 * self.pp_group_size - 1, self._n_microbatches)
- rank_ops: list[_Action | None] = [None for _ in range(rank)]
- # Forward and backward action counts for stage chunk 0 and chunk 1
- f0_cnt, f1_cnt, b0_cnt, b1_cnt = 0, 0, 0, 0
- # warm-up phase
- warmup_n1 = 2 * (self.pp_group_size - rank) - 1
- stage_id_chunk0 = rank
- stage_id_chunk1 = self.num_stages - 1 - rank
- for _ in range(warmup_n1):
- rank_ops.append(
- _Action(stage_id_chunk0, computation_type=F, microbatch_index=f0_cnt)
- )
- f0_cnt += 1
- warmup_n2 = rank
- for _ in range(warmup_n2):
- rank_ops.append(
- _Action(stage_id_chunk1, computation_type=F, microbatch_index=f1_cnt)
- )
- f1_cnt += 1
- rank_ops.append(
- _Action(stage_id_chunk0, computation_type=F, microbatch_index=f0_cnt)
- )
- f0_cnt += 1
- warmup_n3 = self.pp_group_size - rank
- for _ in range(warmup_n3):
- rank_ops.append(
- _Action(stage_id_chunk1, computation_type=F, microbatch_index=f1_cnt)
- )
- f1_cnt += 1
- rank_ops.append(
- _Action(stage_id_chunk1, computation_type=I, microbatch_index=b1_cnt)
- )
- rank_ops.append(
- _Action(stage_id_chunk1, computation_type=W, microbatch_index=b1_cnt)
- )
- b1_cnt += 1
- # stable phase
- while f1_cnt < f0_cnt or f0_cnt < n_micro:
- if f0_cnt < n_micro:
- rank_ops.append(
- _Action(
- stage_id_chunk0, computation_type=F, microbatch_index=f0_cnt
- )
- )
- f0_cnt += 1
- rank_ops.append(
- _Action(stage_id_chunk0, computation_type=I, microbatch_index=b0_cnt)
- )
- rank_ops.append(
- _Action(stage_id_chunk0, computation_type=W, microbatch_index=b0_cnt)
- )
- b0_cnt += 1
- rank_ops.append(
- _Action(stage_id_chunk1, computation_type=F, microbatch_index=f1_cnt)
- )
- f1_cnt += 1
- rank_ops.append(
- _Action(stage_id_chunk1, computation_type=I, microbatch_index=b1_cnt)
- )
- rank_ops.append(
- _Action(stage_id_chunk1, computation_type=W, microbatch_index=b1_cnt)
- )
- b1_cnt += 1
- # cool-down phase
- w0_cnt, w1_cnt = b0_cnt, b1_cnt
- cooldown_n1 = rank
- for _ in range(cooldown_n1):
- rank_ops.append(
- _Action(stage_id_chunk0, computation_type=I, microbatch_index=b0_cnt)
- )
- b0_cnt += 1
- rank_ops.append(
- _Action(stage_id_chunk1, computation_type=I, microbatch_index=b1_cnt)
- )
- b1_cnt += 1
- cooldown_n2 = self.pp_group_size - rank
- for _ in range(cooldown_n2):
- rank_ops.append(
- _Action(stage_id_chunk0, computation_type=I, microbatch_index=b0_cnt)
- )
- b0_cnt += 1
- rank_ops.append(
- _Action(stage_id_chunk0, computation_type=W, microbatch_index=w0_cnt)
- )
- w0_cnt += 1
- while w1_cnt < b1_cnt:
- rank_ops.append(
- _Action(stage_id_chunk1, computation_type=W, microbatch_index=w1_cnt)
- )
- w1_cnt += 1
- while w0_cnt < b0_cnt:
- rank_ops.append(
- _Action(stage_id_chunk0, computation_type=W, microbatch_index=w0_cnt)
- )
- w0_cnt += 1
- if not (w0_cnt == b0_cnt and b0_cnt == f0_cnt):
- raise AssertionError(
- f"Expected w0_cnt == b0_cnt == f0_cnt, got w0_cnt={w0_cnt}, b0_cnt={b0_cnt}, f0_cnt={f0_cnt}"
- )
- if not (w1_cnt == b1_cnt and b1_cnt == f1_cnt):
- raise AssertionError(
- f"Expected w1_cnt == b1_cnt == f1_cnt, got w1_cnt={w1_cnt}, b1_cnt={b1_cnt}, f1_cnt={f1_cnt}"
- )
- # We use max() in the n_micro computation above, so we may need to
- # remove redundant microbatches
- rank_ops = [
- (
- action
- if action is not None
- and action.microbatch_index is not None
- and action.microbatch_index < self._n_microbatches
- else None
- )
- for action in rank_ops
- ]
- return rank_ops
- class ScheduleDualPipeV(_PipelineScheduleRuntime):
- """
- The DualPipeV schedule. A more efficient schedule variant based on the
- DualPipe schedule introduced by DeepSeek in https://arxiv.org/pdf/2412.19437
- Based on the open sourced code from https://github.com/deepseek-ai/DualPipe
- """
- def __init__(
- self,
- stages: list[_PipelineStageBase],
- n_microbatches: int,
- loss_fn: Callable | None = None,
- args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None,
- kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None,
- output_merge_spec: dict[str, Any] | tuple[Any] | None = None,
- scale_grads: bool = True,
- backward_requires_autograd: bool = True,
- ):
- # TODO: we dont support input/weight backward split with torch.compile
- _check_torch_compile_compatibility(stages, self.__class__.__name__)
- self.pp_group_size = stages[0].group_size
- super().__init__(
- stages=stages,
- n_microbatches=n_microbatches,
- loss_fn=loss_fn,
- args_chunk_spec=args_chunk_spec,
- kwargs_chunk_spec=kwargs_chunk_spec,
- output_merge_spec=output_merge_spec,
- scale_grads=scale_grads,
- backward_requires_autograd=backward_requires_autograd,
- )
- self.stage_index_to_group_rank = generate_stage_to_rank_mapping(
- self.pp_group_size, self._num_stages, style="v"
- )
- for stage in self._stages:
- stage.stage_index_to_group_rank = self.stage_index_to_group_rank
- self.n_local_stages = len(stages)
- if self.n_local_stages != 2:
- raise ValueError(
- "ZBV requires exactly 2 stages per rank, but got "
- f"{self.n_local_stages}."
- )
- if n_microbatches < self._num_stages:
- raise ValueError(
- "DualPipeV requires at least as many microbatches as stages, but got "
- f"{n_microbatches} microbatches and {self._num_stages} stages."
- )
- self.rank = stages[0].group_rank
- self.num_stages = stages[0].num_stages
- # 1. Create the pipeline_order (all ranks do this calculation)
- # This will be used to keep track of the current state of the entire pipeline
- # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
- self.pipeline_order: dict[int, list[_Action | None]] = {}
- for rank in range(self.pp_group_size):
- rank_ops = self._calculate_single_rank_operations(rank)
- self.pipeline_order[rank] = rank_ops
- # Initialize the pipeline order with communication necessary to run with _PipelineScheduleRuntime
- self._prepare_schedule_with_comms(self.pipeline_order)
- def _calculate_single_rank_operations(self, rank) -> list[_Action | None]:
- actions: list[_Action | None] = []
- counters: dict[
- tuple[int, _ComputationType], int
- ] = {} # (stage_index, computation_type) -> mb_index
- weight_queue = [] # Queue of (stage_index, mb_index) for pending weight actions
- num_ranks = self.pp_group_size
- num_chunks = self._n_microbatches
- rank_to_stages = generate_rank_to_stage_mapping(
- num_ranks, num_ranks * 2, style="v"
- )
- stage0_index, stage1_index = rank_to_stages[rank]
- def increment_backward_counts(stage_index: int):
- """Helper method to increment BACKWARD_INPUT and BACKWARD_WEIGHT counters when FULL_BACKWARD is used."""
- input_key = (stage_index, BACKWARD_INPUT)
- weight_key = (stage_index, BACKWARD_WEIGHT)
- counters[input_key] = counters.get(input_key, 0) + 1
- counters[weight_key] = counters.get(weight_key, 0) + 1
- def add_overlap_f_b(
- actions: list,
- forward_stage: int,
- backward_stage: int,
- ):
- """Helper method to add an overlapped forward+backward action which tracks microbatch index."""
- # Create new overlapped forward+backward action with sub_actions
- forward_key = (forward_stage, FORWARD)
- backward_key = (backward_stage, BACKWARD_INPUT)
- forward_mb = counters.get(forward_key, 0)
- backward_mb = counters.get(backward_key, 0)
- sub_actions = (
- _Action(forward_stage, FORWARD, forward_mb),
- _Action(backward_stage, FULL_BACKWARD, backward_mb),
- )
- actions.append(_Action(-1, OVERLAP_F_B, None, sub_actions))
- # Update counters for sub_actions
- counters[forward_key] = forward_mb + 1
- increment_backward_counts(backward_stage)
- def add_action(
- actions: list,
- stage_index: int,
- computation_type: _ComputationType,
- ):
- # Regular single action, for FULL_BACKWARD we only use the BACKWARD_INPUT counter
- key = (
- (stage_index, computation_type)
- if computation_type != FULL_BACKWARD
- else (stage_index, BACKWARD_INPUT)
- )
- mb_index = counters.get(key, 0)
- actions.append(_Action(stage_index, computation_type, mb_index))
- # If FULL_BACKWARD is used, just increment the separate BACKWARD_INPUT and BACKWARD_WEIGHT counters
- if computation_type == FULL_BACKWARD:
- increment_backward_counts(stage_index)
- else:
- # If BACKWARD_INPUT is updated, add corresponding weight action to queue
- if computation_type == BACKWARD_INPUT:
- # Add weight action to queue for later processing
- weight_queue.append((stage_index, mb_index))
- counters[key] = mb_index + 1
- def add_weight_action_if_pending(actions: list):
- """Helper method to add a weight action from the queue."""
- if not weight_queue:
- return # No pending weight actions, skip
- # Pop the oldest weight action from the queue
- actual_stage_index, weight_mb_index = weight_queue.pop(0)
- actions.append(
- _Action(
- actual_stage_index,
- BACKWARD_WEIGHT,
- weight_mb_index,
- )
- )
- # Update the counter for the actual stage that was processed
- weight_key = (actual_stage_index, BACKWARD_WEIGHT)
- counters[weight_key] = counters.get(weight_key, 0) + 1
- # Step 1: F0
- step_1 = (num_ranks - rank - 1) * 2
- for _ in range(step_1):
- add_action(actions, stage0_index, FORWARD)
- # Step 2: F0F1
- step_2 = rank + 1
- for _ in range(step_2):
- add_action(actions, stage0_index, FORWARD)
- add_action(actions, stage1_index, FORWARD)
- # Step 3: I1W1F1 (Use zero bubble)
- step_3 = num_ranks - rank - 1
- for _ in range(step_3):
- add_action(actions, stage1_index, BACKWARD_INPUT)
- add_weight_action_if_pending(actions)
- add_action(actions, stage1_index, FORWARD)
- # Step 4 (Main step): F0B1-F1B0 (combined, overlapped forward+backward)
- step_4 = num_chunks - num_ranks * 2 + rank + 1
- for i in range(step_4):
- if i == 0 and rank == num_ranks - 1:
- # NOTE: We don't overlap these two chunks to further reduce bubble size.
- add_action(actions, stage0_index, FORWARD)
- add_action(actions, stage1_index, FULL_BACKWARD)
- else:
- add_overlap_f_b(
- actions,
- forward_stage=stage0_index,
- backward_stage=stage1_index,
- )
- add_overlap_f_b(
- actions,
- forward_stage=stage1_index,
- backward_stage=stage0_index,
- )
- # Step 5: B1-F1B0
- step_5 = num_ranks - rank - 1
- for _ in range(step_5):
- add_action(actions, stage1_index, FULL_BACKWARD)
- add_overlap_f_b(
- actions,
- forward_stage=stage1_index,
- backward_stage=stage0_index,
- )
- # Step 6: B1B0 (The second half of the chunks use zero bubble)
- step_6 = rank + 1
- enable_zb = False
- for i in range(step_6):
- if i == step_6 // 2 and rank % 2 == 1:
- enable_zb = True
- comp_type = BACKWARD_INPUT if enable_zb else FULL_BACKWARD
- add_action(actions, stage1_index, comp_type)
- if i == step_6 // 2 and rank % 2 == 0:
- enable_zb = True
- comp_type = BACKWARD_INPUT if enable_zb else FULL_BACKWARD
- add_action(actions, stage0_index, comp_type)
- # Step 7: W0B0
- step_7 = num_ranks - rank - 1
- for _ in range(step_7):
- add_weight_action_if_pending(actions)
- comp_type = BACKWARD_INPUT if enable_zb else FULL_BACKWARD
- add_action(actions, stage0_index, comp_type)
- # Step 8: W0
- step_8 = rank + 1
- for _ in range(step_8):
- add_weight_action_if_pending(actions)
- return actions
- def get_schedule_class(schedule_name: str):
- """
- Maps a schedule name (case insensitive) to its corresponding class object.
- Args:
- schedule_name (str): The name of the schedule.
- """
- schedule_map = {
- "1F1B": Schedule1F1B,
- "Interleaved1F1B": ScheduleInterleaved1F1B,
- "GPipe": ScheduleGPipe,
- "LoopedBFS": ScheduleLoopedBFS,
- "InterleavedZeroBubble": ScheduleInterleavedZeroBubble,
- "PipelineScheduleSingle": PipelineScheduleSingle,
- "PipelineScheduleMulti": PipelineScheduleMulti,
- "ZBVZeroBubble": ScheduleZBVZeroBubble,
- "DualPipeV": ScheduleDualPipeV,
- }
- lowercase_keys = {k.lower(): k for k in schedule_map}
- lowercase_schedule_name = schedule_name.lower()
- if lowercase_schedule_name not in lowercase_keys:
- raise ValueError(
- f"Unknown schedule name '{schedule_name}'. The valid options are {list(schedule_map.keys())}"
- )
- return schedule_map[lowercase_keys[lowercase_schedule_name]]
- def _simulate_comms_compute(
- pipeline_order, stage_to_rank: Callable[[int], int], num_stages: int
- ):
- """This function dry-run simulates the actions in the schedule from the perspective of all ranks, and flags
- any deadlocks caused by missing or misordered communications. It also simulates any bubbles in time where a rank
- can not execute any action due to waiting for unmet dependencies. The total number of simulator steps can be used
- as a metric for unit tests involving IR optimization passes as reordering and merging of IR can reduce the number
- of simulated steps.
- The simulation is not high-fidelity and does not model overlapping of compute and communication, or cuda streams.
- Future work may be to enhance this and model the compute time, comms overlap, and even memory.
- """
- pipeline_order = {
- rank: [a for a in pipeline_order[rank] if a is not None]
- for rank in sorted(pipeline_order)
- }
- _schedule: dict[int, list[_Action | None]] = {
- rank: [] for rank in sorted(pipeline_order)
- }
- _prev_ops_rank: dict[int, set[_Action]] = {rank: set() for rank in _schedule}
- def add_to_schedule(rank: int, action: _Action | None):
- _schedule[rank].append(action)
- if action is not None:
- _prev_ops_rank[rank].add(action)
- def _ready_to_schedule(action: _Action | None) -> bool:
- if action is None:
- return True
- stage_idx = action.stage_index
- prev_ops = _prev_ops_rank[stage_to_rank(stage_idx)]
- if action.computation_type == F:
- if action.stage_index == 0:
- return True
- elif (
- _Action(action.stage_index, RECV_F, action.microbatch_index) in prev_ops
- ):
- return True
- elif (
- _Action(action.stage_index - 1, F, action.microbatch_index) in prev_ops
- ):
- return True
- return False
- elif action.computation_type in (BACKWARD_INPUT, FULL_BACKWARD):
- if action.stage_index == num_stages - 1:
- return True
- if _Action(action.stage_index, RECV_B, action.microbatch_index) in prev_ops:
- return True
- if (
- _Action(action.stage_index + 1, BACKWARD_INPUT, action.microbatch_index)
- in prev_ops
- ):
- return True
- if (
- _Action(action.stage_index + 1, FULL_BACKWARD, action.microbatch_index)
- in prev_ops
- ):
- return True
- return False
- elif action.computation_type == BACKWARD_WEIGHT:
- return True
- elif action.computation_type == SEND_F:
- expected_f = _Action(action.stage_index, F, action.microbatch_index)
- return expected_f in prev_ops
- elif action.computation_type == RECV_F:
- peer_stage_idx = stage_idx - 1
- expected_send = _Action(peer_stage_idx, SEND_F, action.microbatch_index)
- return expected_send in _prev_ops_rank[stage_to_rank(peer_stage_idx)]
- elif action.computation_type == SEND_B:
- expected_b = _Action(
- action.stage_index, BACKWARD_INPUT, action.microbatch_index
- )
- expected_bw = _Action(
- action.stage_index, FULL_BACKWARD, action.microbatch_index
- )
- return expected_b in prev_ops or expected_bw in prev_ops
- elif action.computation_type == RECV_B:
- peer_stage_idx = stage_idx + 1
- expected_send = _Action(peer_stage_idx, SEND_B, action.microbatch_index)
- return expected_send in _prev_ops_rank[stage_to_rank(peer_stage_idx)]
- else:
- raise ValueError(f"Unsupported action type {action}")
- while pipeline_order:
- progress = False
- for rank in sorted(pipeline_order):
- if len(pipeline_order[rank]) == 0:
- continue
- action = pipeline_order[rank][0]
- if _ready_to_schedule(action):
- if action is not None:
- add_to_schedule(rank, action)
- pipeline_order[rank].pop(0)
- progress = True
- else:
- add_to_schedule(rank, None)
- for i in sorted(pipeline_order, reverse=True):
- if len(pipeline_order[i]) == 0:
- del pipeline_order[i]
- # hacky, but do a second pass to replace any 'none' at this timestep with a real action, if it got unblocked
- # by one of the later ranks
- for rank in sorted(pipeline_order):
- if len(pipeline_order[rank]) == 0:
- continue
- if _schedule[rank][-1] is not None:
- continue
- action = pipeline_order[rank][0]
- if _ready_to_schedule(action):
- if action is not None:
- _schedule[rank][-1] = action
- _prev_ops_rank[rank].add(action)
- pipeline_order[rank].pop(0)
- for i in sorted(pipeline_order, reverse=True):
- if len(pipeline_order[i]) == 0:
- del pipeline_order[i]
- if not progress:
- print("WIP comms schedule:\n", _format_pipeline_order(_schedule))
- for rank in pipeline_order:
- print(f"{rank=} next action= {pipeline_order[rank][0]}")
- raise ValueError("Schedule is not progressing")
- return _schedule
- def _dump_chrometrace(schedule, filename):
- """
- This function dumps a schedule IR into a chrometrace format so it can be visualized.
- It is currently very basic and only serves as a graphical alternative to dumping the schedule IR as text.
- As future work we may extend this to include more accurate heuristics for durations, or let users input durations,
- add 'flow events' to let the UI show the connection between sends and recvs, and model cuda streams for comm/compute
- as separate streams on the chrometrace view.
- """
- events = []
- for rank in sorted(schedule):
- for timestep, action in enumerate(schedule[rank]):
- if action is None:
- continue
- events.append(
- {
- "name": str(action),
- "cat": (
- "computation"
- if action.computation_type in (F, B, W)
- else "communication"
- ),
- "ph": "X",
- "pid": rank,
- "tid": rank,
- "ts": timestep,
- "dur": 1,
- }
- )
- import json
- with open(filename, "w") as f:
- json.dump({"traceEvents": events}, f)
- def _check_torch_compile_compatibility(
- stages: list[_PipelineStageBase], schedule_name: str
- ):
- """
- Check if the schedule is compatible with torch.compile.
- Args:
- stages: List of pipeline stages to check
- schedule_name: Name of the schedule for error message
- Raises:
- RuntimeError: If any stage uses torch.compile
- """
- for stage in stages:
- if not isinstance(stage.submod, torch.nn.Module):
- continue
- for module in stage.submod.modules():
- if isinstance(module, OptimizedModule):
- raise RuntimeError(
- f"The {schedule_name} schedule is not supported with "
- "stage modules that have used torch.compile. "
- f"Found OptimizedModule in {type(module).__name__}"
- )
|