streams.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549
  1. import collections
  2. from collections.abc import Callable
  3. from typing import Any, Optional
  4. import torch
  5. from torch._dynamo.variables.dicts import ConstDictVariable
  6. from torch._dynamo.variables.lists import TupleVariable
  7. from torch.fx import has_side_effect, Proxy
  8. from .. import graph_break_hints
  9. from ..bytecode_transformation import create_call_function
  10. from ..exc import TYPE_CHECKING, unimplemented
  11. from ..graph_bytecode_inputs import (
  12. get_external_object_by_index,
  13. register_graph_created_object,
  14. )
  15. from ..source import CurrentStreamSource
  16. from .base import VariableTracker
  17. from .constant import CONSTANT_VARIABLE_NONE, ConstantVariable
  18. from .ctx_manager import FxTracebackAnnotateVariable
  19. from .lazy import LazyVariableTracker
  20. if TYPE_CHECKING:
  21. from torch._dynamo.symbolic_convert import InstructionTranslator
  22. from ..codegen import PyCodegen
  23. from torch._library.custom_ops import custom_op
  24. Tensor = torch.Tensor
  25. def new_event(*args: Any, **kwargs: Any) -> int:
  26. event = torch.Event(*args, **kwargs)
  27. return register_graph_created_object(
  28. event,
  29. EventVariable.make_construct_in_graph_event_fn(
  30. TupleVariable([]), ConstDictVariable({})
  31. ),
  32. )
  33. def new_stream(*args: tuple[Any], **kwargs: Any) -> int:
  34. stream = torch.Stream(*args, **kwargs) # type: ignore[no-matching-overload,call-overload]
  35. return register_graph_created_object(
  36. stream,
  37. StreamVariable.make_construct_in_graph_stream_fn(
  38. TupleVariable([]), ConstDictVariable({})
  39. ),
  40. )
  41. def _codegen_current_stream(device: torch.device, cg: "PyCodegen") -> None:
  42. cg.add_push_null(
  43. lambda: cg.load_import_from(
  44. torch._dynamo.graph_bytecode_inputs.__name__, # type: ignore[implicit-imports]
  45. "stash_graph_created_object",
  46. )
  47. )
  48. cg(CurrentStreamSource(device))
  49. cg.extend_output(create_call_function(1, False))
  50. def get_current_stream(device: torch.device) -> int:
  51. stream = torch.accelerator.current_stream(device)
  52. return register_graph_created_object(
  53. stream, lambda _, cg: _codegen_current_stream(device, cg)
  54. )
  55. def _get_stream_by_index(index: int) -> torch.Stream:
  56. stream = get_external_object_by_index(index)
  57. assert isinstance(stream, torch.Stream), (
  58. f"Fork/join stream expected a stream object at index {index}"
  59. )
  60. return stream
  61. def _get_event_by_index(index: int) -> torch.Event:
  62. event = get_external_object_by_index(index)
  63. assert isinstance(event, torch.Event), (
  64. f"Record/wait event expected an event object at index {index}"
  65. )
  66. return event
  67. @custom_op("streams::fork", mutates_args=())
  68. def fork_stream(
  69. from_index: int, # kept to make stream transitions clearer
  70. to_index: int,
  71. ) -> None:
  72. torch.accelerator.set_stream(_get_stream_by_index(to_index))
  73. @fork_stream.register_fake
  74. def _(
  75. from_index: int, # kept to make stream transitions clearer
  76. to_index: int,
  77. ) -> None:
  78. pass
  79. has_side_effect(torch.ops.streams.fork.default)
  80. @custom_op("streams::join", mutates_args=())
  81. def join_stream(from_index: int, to_index: int) -> None:
  82. torch.accelerator.set_stream(_get_stream_by_index(to_index))
  83. @join_stream.register_fake
  84. def _(
  85. from_index: int,
  86. to_index: int,
  87. ) -> None:
  88. pass
  89. has_side_effect(torch.ops.streams.join.default)
  90. @custom_op("streams::record_event", mutates_args=())
  91. def record_event(event_index: int, stream_index: int) -> None:
  92. event = _get_event_by_index(event_index)
  93. stream = _get_stream_by_index(stream_index)
  94. stream.record_event(event)
  95. @record_event.register_fake
  96. def _(
  97. event_index: int,
  98. stream_index: int,
  99. ) -> None:
  100. pass
  101. has_side_effect(torch.ops.streams.record_event.default)
  102. @custom_op("streams::wait_event", mutates_args=())
  103. def wait_event(event_index: int, stream_index: int) -> None:
  104. event = _get_event_by_index(event_index)
  105. stream = _get_stream_by_index(stream_index)
  106. stream.wait_event(event)
  107. @wait_event.register_fake
  108. def _(
  109. event_index: int,
  110. stream_index: int,
  111. ) -> None:
  112. pass
  113. has_side_effect(torch.ops.streams.wait_event.default)
  114. @custom_op("streams::wait_stream", mutates_args=())
  115. def wait_stream(waiting_stream_index: int, waited_on_stream_index: int) -> None:
  116. waiting = _get_stream_by_index(waiting_stream_index)
  117. waited_on = _get_stream_by_index(waited_on_stream_index)
  118. waiting.wait_stream(waited_on)
  119. @wait_stream.register_fake
  120. def _(
  121. event_index: int,
  122. stream_index: int,
  123. ) -> None:
  124. pass
  125. has_side_effect(torch.ops.streams.wait_stream.default)
  126. @custom_op("streams::sync_dealloc", mutates_args=())
  127. def sync_dealloc(
  128. wait_event_index: int, src_stream_index: int, to_dealloc: torch.Tensor
  129. ) -> None:
  130. """An op which waits on an event and moves the last usage of to_dealloc
  131. after the wait, so that after the sync occurs, the deallocation or
  132. subsequent reuse of the tensor's memory will be guaranteed to happen
  133. after a side stream is finished using it.
  134. See https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html#torch.Tensor.record_stream
  135. for more details"""
  136. torch.ops.streams.wait_event.default(wait_event_index, src_stream_index)
  137. has_side_effect(torch.ops.streams.sync_dealloc.default)
  138. @custom_op("streams::record_stream", mutates_args=())
  139. def record_stream(tensor: torch.Tensor, stream_index: int) -> None:
  140. tensor.record_stream(_get_stream_by_index(stream_index))
  141. @record_stream.register_fake
  142. def _(
  143. src_stream_index: int,
  144. wait_event_index: int,
  145. to_dealloc: torch.Tensor,
  146. ) -> None:
  147. pass
  148. class SymbolicStreamState:
  149. """Track the currently entered stream if any"""
  150. def __init__(self) -> None:
  151. from ..source import CurrentStreamSource
  152. cur_stack: list[StreamVariable] = []
  153. if torch.accelerator.is_available():
  154. stream_var = LazyVariableTracker.create(
  155. torch.accelerator.current_stream(),
  156. source=CurrentStreamSource(torch.accelerator.current_stream().device),
  157. )
  158. cur_stack = [stream_var] # type: ignore[list-item]
  159. self.cur_stream_stack: collections.deque[StreamVariable] = collections.deque(
  160. cur_stack
  161. )
  162. def enter_stream(self, stream: "StreamVariable") -> None:
  163. self.cur_stream_stack.append(stream)
  164. def exit_stream(self) -> None:
  165. self.cur_stream_stack.pop()
  166. def cur_stream(self, device: Optional[torch.device] = None) -> "StreamVariable":
  167. if device is not None:
  168. for stream in reversed(self.cur_stream_stack):
  169. if stream.device == device:
  170. return stream
  171. return self.cur_stream_stack[-1]
  172. def in_stream_context(self) -> bool:
  173. return len(self.cur_stream_stack) > 0
  174. class StreamContextVariable(FxTracebackAnnotateVariable):
  175. """This represents torch.cuda.StreamContext"""
  176. @staticmethod
  177. def create(
  178. tx: "InstructionTranslator",
  179. stream_to_enter: "StreamVariable",
  180. **kwargs: dict[str, Any],
  181. ) -> "StreamContextVariable":
  182. return StreamContextVariable(
  183. stream_to_enter,
  184. **kwargs,
  185. )
  186. def __init__(self, stream: Optional["StreamVariable"], **kwargs: Any) -> None:
  187. self.stream = stream
  188. super().__init__(
  189. target_values={"stream": self.get_stream().user_object_index},
  190. initial_values=None,
  191. **kwargs,
  192. )
  193. def enter(
  194. self, tx: "InstructionTranslator", *args: VariableTracker
  195. ) -> VariableTracker:
  196. # to stream, from stream is the order of the arguments
  197. # we are entering the target, and leaving the initial stream
  198. tx.symbolic_stream_state.enter_stream(self.get_stream())
  199. return super().enter(tx)
  200. def exit(
  201. self, tx: "InstructionTranslator", *args: VariableTracker
  202. ) -> VariableTracker:
  203. # to stream, from stream is the order of the arguments
  204. # we are leaving the target, and entering the initial stream
  205. tx.symbolic_stream_state.exit_stream()
  206. return super().exit(tx, *args)
  207. def supports_graph_breaks(self) -> bool:
  208. return True
  209. def get_stream(self) -> "StreamVariable":
  210. assert self.stream, "Stream context should have a separate stream"
  211. return self.stream
  212. class StreamVariable(StreamContextVariable):
  213. """Represents the device-agnostic torch.Stream class"""
  214. def __init__(
  215. self,
  216. proxy: Proxy,
  217. value: torch.Stream,
  218. user_object_index: Optional[int] = None,
  219. **kwargs: Any,
  220. ) -> None:
  221. # Index into the user object table
  222. # used to pass arbitrary objects to the graph
  223. if proxy is not None and "example_value" in proxy.node.meta:
  224. assert proxy.node.meta["example_value"] == value
  225. self.proxy = proxy
  226. self.value = value
  227. # pyrefly: ignore [read-only]
  228. self.device = value.device
  229. self.user_object_index = user_object_index
  230. super().__init__(None, **kwargs)
  231. def python_type(self) -> type:
  232. return torch.Stream
  233. def call_method(
  234. self,
  235. tx: "InstructionTranslator",
  236. name: str,
  237. args: list[VariableTracker],
  238. kwargs: dict[str, VariableTracker],
  239. ) -> VariableTracker:
  240. assert hasattr(self.value, name), f"no stream method found named {name}"
  241. from ..utils import cmp_name_to_op_mapping, proxy_args_kwargs
  242. from .builder import wrap_fx_proxy_cls
  243. if name in ("wait_stream", "synchronize", "wait_event"):
  244. tx.output.create_proxy(
  245. "call_method", name, *proxy_args_kwargs([self] + args, kwargs)
  246. )
  247. return CONSTANT_VARIABLE_NONE
  248. elif name == "query":
  249. return wrap_fx_proxy_cls(
  250. target_cls=ConstantVariable,
  251. tx=tx,
  252. proxy=tx.output.create_proxy(
  253. "call_method", name, *proxy_args_kwargs([self] + args, kwargs)
  254. ),
  255. )
  256. elif name == "record_event":
  257. return wrap_fx_proxy_cls(
  258. target_cls=EventVariable,
  259. tx=tx,
  260. proxy=tx.output.create_proxy(
  261. "call_method", name, *proxy_args_kwargs([self] + args, kwargs)
  262. ),
  263. )
  264. elif name in cmp_name_to_op_mapping and len(args) == 1 and not kwargs:
  265. from ..guards import GuardBuilder, install_guard
  266. if self.source:
  267. install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH))
  268. # NB : Checking for mutation is necessary because we compare
  269. # constant values
  270. other = args[0]
  271. if not isinstance(other, StreamVariable):
  272. return ConstantVariable.create(NotImplemented)
  273. if other.source:
  274. assert self.source is not None
  275. install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH))
  276. return ConstantVariable.create(
  277. cmp_name_to_op_mapping[name](self.value, other.value) # type: ignore[arg-type]
  278. )
  279. return super().call_method(tx, name, args, kwargs)
  280. def as_proxy(self) -> Proxy:
  281. return self.proxy
  282. def module_name(self) -> str:
  283. return "torch._C"
  284. def fn_name(self) -> str:
  285. return "Stream"
  286. def reconstruct(self, codegen: "PyCodegen") -> None:
  287. # If we got here, this stream is fully subsumed by the graph - this means it is
  288. # not an input or global
  289. assert not self.source
  290. if self.user_object_index is not None:
  291. codegen.add_push_null(
  292. lambda: codegen.load_import_from(
  293. torch._dynamo.graph_bytecode_inputs.__name__,
  294. "get_external_object_by_index",
  295. )
  296. )
  297. codegen.append_output(codegen.create_load_const(self.user_object_index))
  298. codegen.extend_output(create_call_function(1, False))
  299. else:
  300. # This will support the legacy behavior
  301. prefix = f"_stream_{self.device}"
  302. name = codegen.tx.output.install_global_by_id(prefix, self.value)
  303. codegen.append_output(codegen.create_load_global(name, add=True))
  304. def get_stream(self) -> "StreamVariable":
  305. return self
  306. @staticmethod
  307. def make_construct_in_graph_stream_fn(
  308. args: TupleVariable, kwargs: ConstDictVariable
  309. ) -> Callable[[int, "PyCodegen"], None]:
  310. def fn(index: int, codegen: "PyCodegen") -> None:
  311. codegen.add_push_null(
  312. lambda: codegen.load_import_from(
  313. torch._dynamo.graph_bytecode_inputs.__name__, # type: ignore[implicit-imports]
  314. "stash_graph_created_object",
  315. )
  316. )
  317. codegen.add_push_null(
  318. lambda: codegen.load_import_from(
  319. torch._dynamo.utils.__name__, "build_stream"
  320. )
  321. )
  322. codegen(args)
  323. codegen(kwargs)
  324. codegen.extend_output(create_call_function(2, False))
  325. codegen.extend_output(create_call_function(1, False))
  326. return fn
  327. class EventVariable(VariableTracker):
  328. def __init__(
  329. self,
  330. proxy: Proxy,
  331. value: torch.Event,
  332. user_object_index: Optional[int],
  333. **kwargs: Any,
  334. ) -> None:
  335. if proxy is not None and "example_value" in proxy.node.meta:
  336. assert proxy.node.meta["example_value"] == value
  337. super().__init__(**kwargs)
  338. self.proxy = proxy
  339. self.value = value
  340. self.user_object_index = user_object_index
  341. def call_method(
  342. self,
  343. tx: "InstructionTranslator",
  344. name: str,
  345. args: list[VariableTracker],
  346. kwargs: dict[str, VariableTracker],
  347. ) -> VariableTracker:
  348. from ..utils import proxy_args_kwargs
  349. from .builder import wrap_fx_proxy_cls
  350. if name == "wait":
  351. tx.output.create_proxy(
  352. "call_function",
  353. torch.ops.streams.wait_event,
  354. (
  355. self.user_object_index,
  356. EventVariable._get_stream_arg(tx, args, kwargs).user_object_index,
  357. ),
  358. {},
  359. )
  360. return CONSTANT_VARIABLE_NONE
  361. elif name == "record":
  362. tx.output.create_proxy(
  363. "call_function",
  364. torch.ops.streams.record_event,
  365. (
  366. self.user_object_index,
  367. EventVariable._get_stream_arg(tx, args, kwargs).user_object_index,
  368. ),
  369. {},
  370. )
  371. return CONSTANT_VARIABLE_NONE
  372. elif name == "synchronize":
  373. tx.output.create_proxy(
  374. "call_method", name, *proxy_args_kwargs([self] + args, kwargs)
  375. )
  376. return CONSTANT_VARIABLE_NONE
  377. elif name == "query":
  378. return wrap_fx_proxy_cls(
  379. target_cls=ConstantVariable,
  380. tx=tx,
  381. proxy=tx.output.create_proxy(
  382. "call_method", name, *proxy_args_kwargs([self] + args, kwargs)
  383. ),
  384. )
  385. else:
  386. method_name = (
  387. f"{type(self.value).__module__}.{type(self.value).__qualname__}.{name}"
  388. )
  389. unimplemented(
  390. gb_type="Unsupported event method",
  391. context=str(name),
  392. explanation=f"Dynamo doesn't support tracing the {method_name} method. "
  393. f"We currently support wait, record, synchronize, and query.",
  394. hints=[
  395. *graph_break_hints.SUPPORTABLE,
  396. ],
  397. )
  398. def as_proxy(self) -> Proxy:
  399. return self.proxy
  400. @staticmethod
  401. def _get_stream_arg(
  402. tx: "InstructionTranslator",
  403. args: list[VariableTracker],
  404. kwargs: dict[str, VariableTracker],
  405. ) -> "StreamVariable":
  406. stream_arg = None
  407. if args:
  408. stream_arg = args[0]
  409. elif kwargs:
  410. stream_arg = kwargs.get("stream")
  411. if not stream_arg:
  412. stream_arg = tx.symbolic_stream_state.cur_stream()
  413. return stream_arg # type: ignore[return-value]
  414. @staticmethod
  415. def make_construct_in_graph_event_fn(
  416. args: TupleVariable, kwargs: ConstDictVariable
  417. ) -> Callable[[int, "PyCodegen"], None]:
  418. def fn(index: int, codegen: "PyCodegen") -> None:
  419. codegen.add_push_null(
  420. lambda: codegen.load_import_from(
  421. torch._dynamo.graph_bytecode_inputs.__name__, # type: ignore[implicit-imports]
  422. "stash_graph_created_object",
  423. )
  424. )
  425. codegen.add_push_null(
  426. lambda: codegen.load_import_from(
  427. torch._dynamo.utils.__name__, "build_event"
  428. )
  429. )
  430. codegen(args)
  431. codegen(kwargs)
  432. codegen.extend_output(create_call_function(2, False))
  433. codegen.extend_output(create_call_function(1, False))
  434. return fn
  435. def reconstruct(self, codegen: "PyCodegen") -> None:
  436. # If we got here, this event is fully subsumed by the graph - this means it is
  437. # not an input or global
  438. assert not self.source
  439. # Similar to stream handling, we lift the event into a global and then codegen bytecode to load it from there.
  440. prefix = "_event"
  441. name = codegen.tx.output.install_global_by_id(prefix, self.value)
  442. codegen.append_output(codegen.create_load_global(name, add=True))