| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549 |
- import collections
- from collections.abc import Callable
- from typing import Any, Optional
- import torch
- from torch._dynamo.variables.dicts import ConstDictVariable
- from torch._dynamo.variables.lists import TupleVariable
- from torch.fx import has_side_effect, Proxy
- from .. import graph_break_hints
- from ..bytecode_transformation import create_call_function
- from ..exc import TYPE_CHECKING, unimplemented
- from ..graph_bytecode_inputs import (
- get_external_object_by_index,
- register_graph_created_object,
- )
- from ..source import CurrentStreamSource
- from .base import VariableTracker
- from .constant import CONSTANT_VARIABLE_NONE, ConstantVariable
- from .ctx_manager import FxTracebackAnnotateVariable
- from .lazy import LazyVariableTracker
- if TYPE_CHECKING:
- from torch._dynamo.symbolic_convert import InstructionTranslator
- from ..codegen import PyCodegen
- from torch._library.custom_ops import custom_op
- Tensor = torch.Tensor
- def new_event(*args: Any, **kwargs: Any) -> int:
- event = torch.Event(*args, **kwargs)
- return register_graph_created_object(
- event,
- EventVariable.make_construct_in_graph_event_fn(
- TupleVariable([]), ConstDictVariable({})
- ),
- )
- def new_stream(*args: tuple[Any], **kwargs: Any) -> int:
- stream = torch.Stream(*args, **kwargs) # type: ignore[no-matching-overload,call-overload]
- return register_graph_created_object(
- stream,
- StreamVariable.make_construct_in_graph_stream_fn(
- TupleVariable([]), ConstDictVariable({})
- ),
- )
- def _codegen_current_stream(device: torch.device, cg: "PyCodegen") -> None:
- cg.add_push_null(
- lambda: cg.load_import_from(
- torch._dynamo.graph_bytecode_inputs.__name__, # type: ignore[implicit-imports]
- "stash_graph_created_object",
- )
- )
- cg(CurrentStreamSource(device))
- cg.extend_output(create_call_function(1, False))
- def get_current_stream(device: torch.device) -> int:
- stream = torch.accelerator.current_stream(device)
- return register_graph_created_object(
- stream, lambda _, cg: _codegen_current_stream(device, cg)
- )
- def _get_stream_by_index(index: int) -> torch.Stream:
- stream = get_external_object_by_index(index)
- assert isinstance(stream, torch.Stream), (
- f"Fork/join stream expected a stream object at index {index}"
- )
- return stream
- def _get_event_by_index(index: int) -> torch.Event:
- event = get_external_object_by_index(index)
- assert isinstance(event, torch.Event), (
- f"Record/wait event expected an event object at index {index}"
- )
- return event
- @custom_op("streams::fork", mutates_args=())
- def fork_stream(
- from_index: int, # kept to make stream transitions clearer
- to_index: int,
- ) -> None:
- torch.accelerator.set_stream(_get_stream_by_index(to_index))
- @fork_stream.register_fake
- def _(
- from_index: int, # kept to make stream transitions clearer
- to_index: int,
- ) -> None:
- pass
- has_side_effect(torch.ops.streams.fork.default)
- @custom_op("streams::join", mutates_args=())
- def join_stream(from_index: int, to_index: int) -> None:
- torch.accelerator.set_stream(_get_stream_by_index(to_index))
- @join_stream.register_fake
- def _(
- from_index: int,
- to_index: int,
- ) -> None:
- pass
- has_side_effect(torch.ops.streams.join.default)
- @custom_op("streams::record_event", mutates_args=())
- def record_event(event_index: int, stream_index: int) -> None:
- event = _get_event_by_index(event_index)
- stream = _get_stream_by_index(stream_index)
- stream.record_event(event)
- @record_event.register_fake
- def _(
- event_index: int,
- stream_index: int,
- ) -> None:
- pass
- has_side_effect(torch.ops.streams.record_event.default)
- @custom_op("streams::wait_event", mutates_args=())
- def wait_event(event_index: int, stream_index: int) -> None:
- event = _get_event_by_index(event_index)
- stream = _get_stream_by_index(stream_index)
- stream.wait_event(event)
- @wait_event.register_fake
- def _(
- event_index: int,
- stream_index: int,
- ) -> None:
- pass
- has_side_effect(torch.ops.streams.wait_event.default)
- @custom_op("streams::wait_stream", mutates_args=())
- def wait_stream(waiting_stream_index: int, waited_on_stream_index: int) -> None:
- waiting = _get_stream_by_index(waiting_stream_index)
- waited_on = _get_stream_by_index(waited_on_stream_index)
- waiting.wait_stream(waited_on)
- @wait_stream.register_fake
- def _(
- event_index: int,
- stream_index: int,
- ) -> None:
- pass
- has_side_effect(torch.ops.streams.wait_stream.default)
- @custom_op("streams::sync_dealloc", mutates_args=())
- def sync_dealloc(
- wait_event_index: int, src_stream_index: int, to_dealloc: torch.Tensor
- ) -> None:
- """An op which waits on an event and moves the last usage of to_dealloc
- after the wait, so that after the sync occurs, the deallocation or
- subsequent reuse of the tensor's memory will be guaranteed to happen
- after a side stream is finished using it.
- See https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html#torch.Tensor.record_stream
- for more details"""
- torch.ops.streams.wait_event.default(wait_event_index, src_stream_index)
- has_side_effect(torch.ops.streams.sync_dealloc.default)
- @custom_op("streams::record_stream", mutates_args=())
- def record_stream(tensor: torch.Tensor, stream_index: int) -> None:
- tensor.record_stream(_get_stream_by_index(stream_index))
- @record_stream.register_fake
- def _(
- src_stream_index: int,
- wait_event_index: int,
- to_dealloc: torch.Tensor,
- ) -> None:
- pass
- class SymbolicStreamState:
- """Track the currently entered stream if any"""
- def __init__(self) -> None:
- from ..source import CurrentStreamSource
- cur_stack: list[StreamVariable] = []
- if torch.accelerator.is_available():
- stream_var = LazyVariableTracker.create(
- torch.accelerator.current_stream(),
- source=CurrentStreamSource(torch.accelerator.current_stream().device),
- )
- cur_stack = [stream_var] # type: ignore[list-item]
- self.cur_stream_stack: collections.deque[StreamVariable] = collections.deque(
- cur_stack
- )
- def enter_stream(self, stream: "StreamVariable") -> None:
- self.cur_stream_stack.append(stream)
- def exit_stream(self) -> None:
- self.cur_stream_stack.pop()
- def cur_stream(self, device: Optional[torch.device] = None) -> "StreamVariable":
- if device is not None:
- for stream in reversed(self.cur_stream_stack):
- if stream.device == device:
- return stream
- return self.cur_stream_stack[-1]
- def in_stream_context(self) -> bool:
- return len(self.cur_stream_stack) > 0
- class StreamContextVariable(FxTracebackAnnotateVariable):
- """This represents torch.cuda.StreamContext"""
- @staticmethod
- def create(
- tx: "InstructionTranslator",
- stream_to_enter: "StreamVariable",
- **kwargs: dict[str, Any],
- ) -> "StreamContextVariable":
- return StreamContextVariable(
- stream_to_enter,
- **kwargs,
- )
- def __init__(self, stream: Optional["StreamVariable"], **kwargs: Any) -> None:
- self.stream = stream
- super().__init__(
- target_values={"stream": self.get_stream().user_object_index},
- initial_values=None,
- **kwargs,
- )
- def enter(
- self, tx: "InstructionTranslator", *args: VariableTracker
- ) -> VariableTracker:
- # to stream, from stream is the order of the arguments
- # we are entering the target, and leaving the initial stream
- tx.symbolic_stream_state.enter_stream(self.get_stream())
- return super().enter(tx)
- def exit(
- self, tx: "InstructionTranslator", *args: VariableTracker
- ) -> VariableTracker:
- # to stream, from stream is the order of the arguments
- # we are leaving the target, and entering the initial stream
- tx.symbolic_stream_state.exit_stream()
- return super().exit(tx, *args)
- def supports_graph_breaks(self) -> bool:
- return True
- def get_stream(self) -> "StreamVariable":
- assert self.stream, "Stream context should have a separate stream"
- return self.stream
- class StreamVariable(StreamContextVariable):
- """Represents the device-agnostic torch.Stream class"""
- def __init__(
- self,
- proxy: Proxy,
- value: torch.Stream,
- user_object_index: Optional[int] = None,
- **kwargs: Any,
- ) -> None:
- # Index into the user object table
- # used to pass arbitrary objects to the graph
- if proxy is not None and "example_value" in proxy.node.meta:
- assert proxy.node.meta["example_value"] == value
- self.proxy = proxy
- self.value = value
- # pyrefly: ignore [read-only]
- self.device = value.device
- self.user_object_index = user_object_index
- super().__init__(None, **kwargs)
- def python_type(self) -> type:
- return torch.Stream
- def call_method(
- self,
- tx: "InstructionTranslator",
- name: str,
- args: list[VariableTracker],
- kwargs: dict[str, VariableTracker],
- ) -> VariableTracker:
- assert hasattr(self.value, name), f"no stream method found named {name}"
- from ..utils import cmp_name_to_op_mapping, proxy_args_kwargs
- from .builder import wrap_fx_proxy_cls
- if name in ("wait_stream", "synchronize", "wait_event"):
- tx.output.create_proxy(
- "call_method", name, *proxy_args_kwargs([self] + args, kwargs)
- )
- return CONSTANT_VARIABLE_NONE
- elif name == "query":
- return wrap_fx_proxy_cls(
- target_cls=ConstantVariable,
- tx=tx,
- proxy=tx.output.create_proxy(
- "call_method", name, *proxy_args_kwargs([self] + args, kwargs)
- ),
- )
- elif name == "record_event":
- return wrap_fx_proxy_cls(
- target_cls=EventVariable,
- tx=tx,
- proxy=tx.output.create_proxy(
- "call_method", name, *proxy_args_kwargs([self] + args, kwargs)
- ),
- )
- elif name in cmp_name_to_op_mapping and len(args) == 1 and not kwargs:
- from ..guards import GuardBuilder, install_guard
- if self.source:
- install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH))
- # NB : Checking for mutation is necessary because we compare
- # constant values
- other = args[0]
- if not isinstance(other, StreamVariable):
- return ConstantVariable.create(NotImplemented)
- if other.source:
- assert self.source is not None
- install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH))
- return ConstantVariable.create(
- cmp_name_to_op_mapping[name](self.value, other.value) # type: ignore[arg-type]
- )
- return super().call_method(tx, name, args, kwargs)
- def as_proxy(self) -> Proxy:
- return self.proxy
- def module_name(self) -> str:
- return "torch._C"
- def fn_name(self) -> str:
- return "Stream"
- def reconstruct(self, codegen: "PyCodegen") -> None:
- # If we got here, this stream is fully subsumed by the graph - this means it is
- # not an input or global
- assert not self.source
- if self.user_object_index is not None:
- codegen.add_push_null(
- lambda: codegen.load_import_from(
- torch._dynamo.graph_bytecode_inputs.__name__,
- "get_external_object_by_index",
- )
- )
- codegen.append_output(codegen.create_load_const(self.user_object_index))
- codegen.extend_output(create_call_function(1, False))
- else:
- # This will support the legacy behavior
- prefix = f"_stream_{self.device}"
- name = codegen.tx.output.install_global_by_id(prefix, self.value)
- codegen.append_output(codegen.create_load_global(name, add=True))
- def get_stream(self) -> "StreamVariable":
- return self
- @staticmethod
- def make_construct_in_graph_stream_fn(
- args: TupleVariable, kwargs: ConstDictVariable
- ) -> Callable[[int, "PyCodegen"], None]:
- def fn(index: int, codegen: "PyCodegen") -> None:
- codegen.add_push_null(
- lambda: codegen.load_import_from(
- torch._dynamo.graph_bytecode_inputs.__name__, # type: ignore[implicit-imports]
- "stash_graph_created_object",
- )
- )
- codegen.add_push_null(
- lambda: codegen.load_import_from(
- torch._dynamo.utils.__name__, "build_stream"
- )
- )
- codegen(args)
- codegen(kwargs)
- codegen.extend_output(create_call_function(2, False))
- codegen.extend_output(create_call_function(1, False))
- return fn
- class EventVariable(VariableTracker):
- def __init__(
- self,
- proxy: Proxy,
- value: torch.Event,
- user_object_index: Optional[int],
- **kwargs: Any,
- ) -> None:
- if proxy is not None and "example_value" in proxy.node.meta:
- assert proxy.node.meta["example_value"] == value
- super().__init__(**kwargs)
- self.proxy = proxy
- self.value = value
- self.user_object_index = user_object_index
- def call_method(
- self,
- tx: "InstructionTranslator",
- name: str,
- args: list[VariableTracker],
- kwargs: dict[str, VariableTracker],
- ) -> VariableTracker:
- from ..utils import proxy_args_kwargs
- from .builder import wrap_fx_proxy_cls
- if name == "wait":
- tx.output.create_proxy(
- "call_function",
- torch.ops.streams.wait_event,
- (
- self.user_object_index,
- EventVariable._get_stream_arg(tx, args, kwargs).user_object_index,
- ),
- {},
- )
- return CONSTANT_VARIABLE_NONE
- elif name == "record":
- tx.output.create_proxy(
- "call_function",
- torch.ops.streams.record_event,
- (
- self.user_object_index,
- EventVariable._get_stream_arg(tx, args, kwargs).user_object_index,
- ),
- {},
- )
- return CONSTANT_VARIABLE_NONE
- elif name == "synchronize":
- tx.output.create_proxy(
- "call_method", name, *proxy_args_kwargs([self] + args, kwargs)
- )
- return CONSTANT_VARIABLE_NONE
- elif name == "query":
- return wrap_fx_proxy_cls(
- target_cls=ConstantVariable,
- tx=tx,
- proxy=tx.output.create_proxy(
- "call_method", name, *proxy_args_kwargs([self] + args, kwargs)
- ),
- )
- else:
- method_name = (
- f"{type(self.value).__module__}.{type(self.value).__qualname__}.{name}"
- )
- unimplemented(
- gb_type="Unsupported event method",
- context=str(name),
- explanation=f"Dynamo doesn't support tracing the {method_name} method. "
- f"We currently support wait, record, synchronize, and query.",
- hints=[
- *graph_break_hints.SUPPORTABLE,
- ],
- )
- def as_proxy(self) -> Proxy:
- return self.proxy
- @staticmethod
- def _get_stream_arg(
- tx: "InstructionTranslator",
- args: list[VariableTracker],
- kwargs: dict[str, VariableTracker],
- ) -> "StreamVariable":
- stream_arg = None
- if args:
- stream_arg = args[0]
- elif kwargs:
- stream_arg = kwargs.get("stream")
- if not stream_arg:
- stream_arg = tx.symbolic_stream_state.cur_stream()
- return stream_arg # type: ignore[return-value]
- @staticmethod
- def make_construct_in_graph_event_fn(
- args: TupleVariable, kwargs: ConstDictVariable
- ) -> Callable[[int, "PyCodegen"], None]:
- def fn(index: int, codegen: "PyCodegen") -> None:
- codegen.add_push_null(
- lambda: codegen.load_import_from(
- torch._dynamo.graph_bytecode_inputs.__name__, # type: ignore[implicit-imports]
- "stash_graph_created_object",
- )
- )
- codegen.add_push_null(
- lambda: codegen.load_import_from(
- torch._dynamo.utils.__name__, "build_event"
- )
- )
- codegen(args)
- codegen(kwargs)
- codegen.extend_output(create_call_function(2, False))
- codegen.extend_output(create_call_function(1, False))
- return fn
- def reconstruct(self, codegen: "PyCodegen") -> None:
- # If we got here, this event is fully subsumed by the graph - this means it is
- # not an input or global
- assert not self.source
- # Similar to stream handling, we lift the event into a global and then codegen bytecode to load it from there.
- prefix = "_event"
- name = codegen.tx.output.install_global_by_id(prefix, self.value)
- codegen.append_output(codegen.create_load_global(name, add=True))
|