meta_utils.py 97 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152
  1. from __future__ import annotations
  2. import contextlib
  3. import dataclasses
  4. import functools
  5. import threading
  6. import typing
  7. import weakref
  8. from abc import abstractmethod
  9. from contextlib import AbstractContextManager, contextmanager
  10. from dataclasses import dataclass
  11. from typing import (
  12. Any,
  13. ClassVar,
  14. Generic,
  15. NewType,
  16. Optional,
  17. Protocol,
  18. TYPE_CHECKING,
  19. TypeGuard,
  20. TypeVar,
  21. Union,
  22. )
  23. from typing_extensions import override, TypedDict, TypeIs, Unpack
  24. import torch
  25. from torch._C._autograd import CreationMeta
  26. from torch._C._functorch import (
  27. _add_batch_dim,
  28. _unwrap_functional_tensor,
  29. _wrap_functional_tensor,
  30. get_unwrapped,
  31. is_batchedtensor,
  32. is_functorch_wrapped_tensor,
  33. is_gradtrackingtensor,
  34. is_legacy_batchedtensor,
  35. maybe_get_bdim,
  36. maybe_get_level,
  37. peek_interpreter_stack,
  38. )
  39. from torch._dispatch.python import enable_python_dispatcher
  40. from torch._logging import trace_structured
  41. from torch.utils._mode_utils import no_dispatch
  42. from torch.utils._python_dispatch import is_traceable_wrapper_subclass
  43. from torch.utils.weak import WeakIdKeyDictionary
  44. if TYPE_CHECKING:
  45. from collections.abc import Callable, Generator
  46. from torch._C._functorch import CInterpreter
  47. from torch._guards import Source
  48. from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
  49. # Import here to avoid cycle
  50. # Import the following modules during type checking to enable code intelligence features,
  51. # Do not import unconditionally, as they import sympy and importing sympy is very slow
  52. from torch.fx.experimental.symbolic_shapes import ShapeEnv, SymbolicContext
  53. def _is_fake_tensor(t: object) -> TypeIs[FakeTensor]:
  54. from torch._subclasses.fake_tensor import FakeTensor
  55. return isinstance(t, FakeTensor)
  56. DimList = list
  57. _TensorLikeT = TypeVar("_TensorLikeT", "MetaTensorDesc[Any]", torch.Tensor)
  58. _T = TypeVar("_T")
  59. _TensorT = TypeVar("_TensorT", bound=torch.Tensor)
  60. _TensorT_cov = TypeVar("_TensorT_cov", bound=torch.Tensor, covariant=True)
  61. def safe_is_leaf(t: Union[MetaTensorDesc[Any], torch.Tensor]) -> bool:
  62. try:
  63. return t.is_leaf
  64. except RuntimeError:
  65. # inference mode can trigger this
  66. return False
  67. def safe_grad(t: _TensorLikeT) -> Optional[_TensorLikeT]:
  68. with torch._logging.hide_warnings(torch._logging._internal.safe_grad_filter):
  69. # pyrefly: ignore [bad-return]
  70. return t.grad
  71. def _expect_safe_grad(t: _TensorLikeT) -> _TensorLikeT:
  72. grad = safe_grad(t)
  73. if grad is None:
  74. raise AssertionError("Expected tensor to have a gradient but grad is None")
  75. return grad
  76. def assert_eq(a: _T, b: _T) -> None:
  77. if a != b:
  78. raise AssertionError(f"{a} != {b}")
  79. tls = threading.local()
  80. # Turns off inference mode for fake tensor propagation. This is turned to True
  81. # only for `torch.compile`. Also look at
  82. # _dynamo.config.fake_tensor_disable_inference_mode
  83. tls.disable_inference_mode = False
  84. @contextmanager
  85. def disable_inference_mode_for_fake_prop() -> Generator[None, None, None]:
  86. prior = getattr(tls, "disable_inference_mode", False)
  87. tls.disable_inference_mode = True
  88. try:
  89. yield
  90. finally:
  91. tls.disable_inference_mode = prior
  92. def assert_metadata_eq(
  93. assert_eq: Callable[[object, object], None],
  94. m1: Union[MetaTensorDesc[Any], torch.Tensor],
  95. m2: torch.Tensor,
  96. *,
  97. skip_symbolic: bool = False,
  98. skip_leaf: bool = False,
  99. ) -> None:
  100. m1 = (
  101. MetaTensorDescriber().describe_tensor(m1)
  102. if isinstance(m1, torch.Tensor)
  103. else m1
  104. )
  105. def go(m1: MetaTensorDesc[Any], m2: torch.Tensor) -> None:
  106. assert_eq(m1.dtype, m2.dtype)
  107. if not skip_symbolic:
  108. assert_eq(m1.shape, m2.shape)
  109. assert_eq(m1.requires_grad, m2.requires_grad)
  110. if not skip_leaf:
  111. assert_eq(m1.is_leaf, m2.is_leaf)
  112. # MetaTensorDesc doesn't store grad_fn; inferred from leaf
  113. # assert_eq(m1.grad_fn is None, m2.grad_fn is None)
  114. assert_eq(m1.is_sparse, m2.is_sparse)
  115. if not getattr(tls, "disable_inference_mode", False):
  116. assert_eq(m1.is_inference, m2.is_inference())
  117. else:
  118. assert_eq(m1.is_inference, False)
  119. assert_eq(m1.is_conj, m2.is_conj())
  120. assert_eq(m1.is_neg, m2.is_neg())
  121. assert_eq(m1.grad is not None, safe_grad(m2) is not None)
  122. if m1.grad is not None:
  123. go(m1.grad, _expect_safe_grad(m2))
  124. # TODO: move "assert_eq(m1.layout, m2.layout)" out of sparse
  125. # branches (but not ready for prime time yet)...
  126. if m1.is_sparse:
  127. assert_eq(m1.layout, m2.layout)
  128. assert_eq(m1.dense_dim, m2.dense_dim())
  129. assert_eq(m1.sparse_dim, m2.sparse_dim())
  130. assert_eq(m1.is_coalesced, m2.is_coalesced())
  131. elif is_sparse_compressed(m1):
  132. assert_eq(m1.layout, m2.layout)
  133. assert_eq(m1.dense_dim, m2.dense_dim())
  134. assert_eq(m1.sparse_dim, m2.sparse_dim())
  135. else:
  136. if not skip_symbolic:
  137. assert_eq(m1.stride, m2.stride())
  138. assert_eq(m1.storage_offset, m2.storage_offset())
  139. assert_eq(m1.is_view, m2._is_view())
  140. if m1.is_view:
  141. if m1.base is None:
  142. raise AssertionError("m1.base must not be None for a view tensor")
  143. if m2._base is None:
  144. raise AssertionError("m2._base must not be None for a view tensor")
  145. go(m1.base, m2._base)
  146. # TODO: test if is resizable (no direct query for this atm)
  147. # TODO: audit AutogradMeta to see if it matches
  148. # TODO: test forward AD
  149. return go(m1, m2)
  150. # TypeGuard (not TypeIs): False does not imply !torch.Tensor
  151. def is_sparse_coo(t: object) -> TypeGuard[torch.Tensor]:
  152. return isinstance(t, torch.Tensor) and t.layout is torch.sparse_coo
  153. def is_sparse_compressed_layout(layout: torch.layout) -> bool:
  154. return layout in {
  155. torch.sparse_csr,
  156. torch.sparse_csc,
  157. torch.sparse_bsr,
  158. torch.sparse_bsc,
  159. }
  160. # TypeGuard (not TypeIs): False does not imply !torch.Tensor
  161. def is_sparse_compressed(t: object) -> TypeGuard[torch.Tensor]:
  162. return isinstance(t, torch.Tensor) and is_sparse_compressed_layout(t.layout)
  163. # TypeGuard (not TypeIs): False does not imply !torch.Tensor
  164. def is_sparse_any(t: object) -> TypeGuard[torch.Tensor]:
  165. return is_sparse_coo(t) or is_sparse_compressed(t)
  166. def _checked_cast(ty: type[_T], obj: object) -> _T:
  167. if not isinstance(obj, ty):
  168. raise AssertionError(f"expected {ty} but got {type(obj)}")
  169. return obj
  170. def _get_real_storage(base: torch.UntypedStorage) -> torch.UntypedStorage:
  171. return base.real_storage # type: ignore[attr-defined]
  172. def _set_real_storage(
  173. base: torch.UntypedStorage, real_storage: torch.UntypedStorage
  174. ) -> None:
  175. base.real_storage = real_storage # type: ignore[attr-defined]
  176. # Don't use id() directly, because those can get reallocated over time.
  177. MetaStorageId = NewType("MetaStorageId", int)
  178. MetaTensorId = NewType("MetaTensorId", int)
  179. _DescriberId = NewType("_DescriberId", int)
  180. DESCRIBER_NEXT_ID = _DescriberId(0)
  181. class MetaTensorDescriber:
  182. """
  183. Given a Tensor/Storage, generate a MetaTensorDesc/MetaStorageDesc
  184. for it, which is enough information to reconstruct a meta tensor/fake tensor
  185. corresponding to a Tensor as faithfully as possible.
  186. This is a stateful conversion object because we keep track of the IDs
  187. of the tensors/storages passed to us, so we can consistently give
  188. the same ID when we see the same tensor/storage.
  189. """
  190. def __init__(self, *, copy_data: bool = False) -> None:
  191. global DESCRIBER_NEXT_ID
  192. self.id = DESCRIBER_NEXT_ID
  193. DESCRIBER_NEXT_ID = _DescriberId(DESCRIBER_NEXT_ID + 1)
  194. self.next_tensor_id: MetaTensorId = MetaTensorId(0)
  195. self.next_storage_id: MetaStorageId = MetaStorageId(0)
  196. # Tensor -> int
  197. self.lookup_tensor = WeakIdKeyDictionary()
  198. # Storage -> int
  199. self.lookup_storage = WeakIdKeyDictionary()
  200. self.copy_data = copy_data
  201. self.traced_tensors: set[int] = set()
  202. self.traced_storages: set[int] = set()
  203. def get_tensor_id(self, t: torch.Tensor) -> MetaTensorId:
  204. if t not in self.lookup_tensor:
  205. self.lookup_tensor[t] = self.next_tensor_id
  206. self.next_tensor_id = MetaTensorId(self.next_tensor_id + 1)
  207. return self.lookup_tensor[t]
  208. def get_storage_id(self, s: torch.UntypedStorage) -> MetaStorageId:
  209. if s not in self.lookup_storage:
  210. self.lookup_storage[s] = self.next_storage_id
  211. self.next_storage_id = MetaStorageId(self.next_storage_id + 1)
  212. return self.lookup_storage[s]
  213. def describe_storage(
  214. self, s: torch.UntypedStorage, *, trace: bool = False
  215. ) -> MetaStorageDesc:
  216. r = MetaStorageDesc(
  217. id=self.get_storage_id(s),
  218. size=s.size(),
  219. # NB: We don't do the copy yet; copy happens when we start
  220. # creating the new storages
  221. data=s if self.copy_data else None,
  222. )
  223. if trace and r.id not in self.traced_storages:
  224. trace_structured(
  225. "describe_storage",
  226. metadata_fn=lambda: r.as_json(self.id),
  227. )
  228. self.traced_storages.add(r.id)
  229. return r
  230. def describe_tensor(
  231. self, t: torch.Tensor, *, recurse: bool = True, trace: bool = False
  232. ) -> MetaTensorDesc[Any]:
  233. is_leaf = safe_is_leaf(t)
  234. is_view = t._is_view()
  235. is_sparse = t.is_sparse
  236. layout = t.layout
  237. is_nested = t.is_nested
  238. is_traceable_wrapper_subclass_v = is_traceable_wrapper_subclass(t)
  239. is_functorch_wrapped = is_functorch_wrapped_tensor(t)
  240. is_mkldnn = t.is_mkldnn
  241. is_batchedtensor_v = is_batchedtensor(t)
  242. is_legacy_batchedtensor_v = is_legacy_batchedtensor(t)
  243. is_gradtrackingtensor_v = is_gradtrackingtensor(t)
  244. is_functional = torch._is_functional_tensor(t)
  245. storage = None
  246. # NB: For compatibility, I default this to zero, as sometimes people
  247. # still have stuffed zero into storage offset even though the tensor
  248. # doesn't meaningfully have an offset
  249. storage_offset = 0
  250. if not (
  251. is_sparse
  252. or is_sparse_compressed_layout(layout)
  253. or (is_nested and not is_traceable_wrapper_subclass_v)
  254. or is_mkldnn
  255. # TODO: TBH, functorch wrapped tensors probably should have
  256. # storage associated with them
  257. or is_functorch_wrapped
  258. or is_legacy_batchedtensor_v
  259. ):
  260. # NB: We actually don't use storage to do views, but might as well
  261. # put it in for accuracy
  262. storage = self.describe_storage(t.untyped_storage(), trace=trace)
  263. storage_offset = t.storage_offset() # type: ignore[assignment]
  264. stride = None
  265. if not (
  266. is_sparse
  267. or is_sparse_compressed_layout(layout)
  268. or (is_nested and not is_traceable_wrapper_subclass_v)
  269. ):
  270. # stride/storage_offset are called from is_functorch_wrapped,
  271. # view_from_base, empty_create_subclass,
  272. # sym_sizes_strides_storage_offset (empty_create)
  273. stride = t.stride()
  274. # NB: this technically should refer to functorch unwrapped tensor, but
  275. # I am (perhaps abusively) using it to store both the functorch and
  276. # non-functorch functional tensor
  277. unwrapped = None
  278. autograd_meta_from = None
  279. current_level = None
  280. if is_batchedtensor_v or is_gradtrackingtensor_v:
  281. unwrapped = self.describe_tensor(get_unwrapped(t), trace=trace)
  282. # xla and lazy tensors present as functional tensors, but we want them
  283. # to be handled specially
  284. elif is_functional and t.device.type not in ("xla", "lazy"):
  285. if t._is_view():
  286. raise RuntimeError(
  287. "Cannot safely fakify a view because this process drops the view information right now."
  288. )
  289. if not is_functorch_wrapped:
  290. torch._sync(t)
  291. unwrapped = self.describe_tensor(
  292. torch._from_functional_tensor(t), trace=trace
  293. )
  294. autograd_meta_from = t
  295. else:
  296. reapply_views = torch._C._functionalization_reapply_views_tls()
  297. # NB: has side effects!
  298. unwrapped = self.describe_tensor(
  299. _unwrap_functional_tensor(t, reapply_views), trace=trace
  300. )
  301. # TODO: It's pretty suspicious that functional tensors don't have
  302. # valid level and thus we just grab whatever the current level
  303. # is
  304. current_level = torch._C._functorch.current_level()
  305. maybe_functorch_stack = None
  306. if is_functorch_wrapped:
  307. with (
  308. torch._functorch.pyfunctorch.temporarily_clear_interpreter_stack()
  309. ) as maybe_functorch_stack:
  310. pass
  311. attrs = None
  312. ctx = None
  313. type_v = None
  314. if is_traceable_wrapper_subclass_v:
  315. if not hasattr(t, "__tensor_flatten__"):
  316. raise AssertionError(
  317. "Traceable wrapper subclass must have __tensor_flatten__ method"
  318. )
  319. raw_attrs, ctx = t.__tensor_flatten__()
  320. attrs = {
  321. attr: self.describe_tensor(getattr(t, attr), trace=trace)
  322. for attr in raw_attrs
  323. }
  324. type_v = type(t)
  325. from torch.nested._internal.nested_tensor import _tensor_symint_registry
  326. view_func = ViewFunc.from_tensor(t)
  327. # TODO: Is it important to enable torch.inference_mode before querying
  328. # these values?
  329. is_inference_mode_disabled = getattr(tls, "disable_inference_mode", False)
  330. r: MetaTensorDesc[Any] = MetaTensorDesc(
  331. id=self.get_tensor_id(t),
  332. storage=storage,
  333. is_inference=False if is_inference_mode_disabled else t.is_inference(),
  334. is_leaf=is_leaf,
  335. requires_grad=t.requires_grad,
  336. # NB: ndim should be OK too but there is a disaster at
  337. # python test/dynamo/test_subclasses.py -k test_user_overridden_property_unsupported
  338. # Actually, this means that we have a little bit of a problem
  339. # here, which is that there is some sensitivity to how exactly an
  340. # access is done if you have a __torch_function__ subclass. Maybe
  341. # should disable torch function before doing accesses?
  342. ndim=t.dim(),
  343. dtype=t.dtype,
  344. is_sparse=is_sparse,
  345. is_mkldnn=is_mkldnn,
  346. is_functorch_wrapped=is_functorch_wrapped,
  347. is_batchedtensor=is_batchedtensor_v,
  348. is_legacy_batchedtensor=is_legacy_batchedtensor_v,
  349. is_gradtrackingtensor=is_gradtrackingtensor_v,
  350. is_view=is_view,
  351. is_conj=t.is_conj(),
  352. is_neg=t.is_neg(),
  353. is_parameter=isinstance(t, torch.nn.Parameter),
  354. is_traceable_wrapper_subclass=is_traceable_wrapper_subclass_v,
  355. is_nested=is_nested,
  356. nested_int=(
  357. _tensor_symint_registry[t].node.nested_int()
  358. if t in _tensor_symint_registry
  359. else None
  360. ),
  361. is_functional=is_functional,
  362. layout=layout,
  363. device=t.device,
  364. size=t.size(),
  365. stride=stride,
  366. # pyrefly: ignore [bad-argument-type]
  367. storage_offset=storage_offset,
  368. dynamo_dynamic_indices=list(getattr(t, "_dynamo_dynamic_indices", set())),
  369. dynamo_hint_overrides=getattr(t, "_dynamo_hint_overrides", {}),
  370. sparse_dim=(
  371. t.sparse_dim() if t.is_sparse or is_sparse_compressed(t) else None
  372. ),
  373. dense_dim=t.dense_dim() if t.is_sparse or is_sparse_compressed(t) else None,
  374. is_coalesced=t.is_coalesced() if t.is_sparse else None,
  375. # TODO: I actually think recursing here is correct, but we have at
  376. # least an infinite cycle from base -> values -> base
  377. # https://github.com/pytorch/pytorch/issues/122089
  378. crow_indices=(
  379. self.describe_tensor(t.crow_indices(), recurse=False, trace=trace)
  380. if recurse and t.layout in {torch.sparse_csr, torch.sparse_bsr}
  381. else None
  382. ),
  383. col_indices=(
  384. self.describe_tensor(t.col_indices(), recurse=False, trace=trace)
  385. if recurse and t.layout in {torch.sparse_csr, torch.sparse_bsr}
  386. else None
  387. ),
  388. ccol_indices=(
  389. self.describe_tensor(t.ccol_indices(), recurse=False, trace=trace)
  390. if recurse and t.layout in {torch.sparse_csc, torch.sparse_bsc}
  391. else None
  392. ),
  393. row_indices=(
  394. self.describe_tensor(t.row_indices(), recurse=False, trace=trace)
  395. if recurse and t.layout in {torch.sparse_csc, torch.sparse_bsc}
  396. else None
  397. ),
  398. values=(
  399. self.describe_tensor(t.values(), recurse=False, trace=trace)
  400. if recurse and is_sparse_compressed(t)
  401. else None
  402. ),
  403. grad=(
  404. self.describe_tensor(grad, trace=trace)
  405. if (grad := safe_grad(t)) is not None
  406. else None
  407. ),
  408. creation_meta=(
  409. torch._C._autograd._get_creation_meta(t) if t._is_view() else None
  410. ),
  411. unwrapped=unwrapped,
  412. level=(
  413. maybe_get_level(t)
  414. if is_batchedtensor_v or is_gradtrackingtensor_v
  415. else None
  416. ),
  417. bdim=maybe_get_bdim(t) if is_batchedtensor_v else None,
  418. base=(
  419. self.describe_tensor(t._base, trace=trace)
  420. if recurse and t._is_view() and t._base is not None
  421. else None
  422. ),
  423. fake_mode=torch._subclasses.fake_tensor.maybe_get_fake_mode(t),
  424. view_func=view_func,
  425. # pyrefly: ignore [bad-argument-type]
  426. attrs=attrs,
  427. ctx=ctx,
  428. type=type_v,
  429. # NB: even if functorch is enabled, don't actually save the
  430. # interpreter stack here unless we are actually functorch wrapped;
  431. # it's irrelevant for non-functorch stuff
  432. functorch_stack=maybe_functorch_stack,
  433. autograd_meta_from=autograd_meta_from,
  434. current_level=current_level,
  435. data=t if self.copy_data else None,
  436. )
  437. if trace and r.id not in self.traced_tensors:
  438. trace_structured(
  439. "describe_tensor",
  440. metadata_fn=lambda: r.as_json(self.id),
  441. )
  442. self.traced_tensors.add(r.id)
  443. return r
  444. @dataclass(frozen=True)
  445. class MetaStorageDesc:
  446. id: MetaStorageId
  447. size: int
  448. # NB: this is only populated with copy_data True, it is not directly
  449. # serializable in JSON, you want to do something special here anyway
  450. data: Optional[torch.UntypedStorage]
  451. def as_json(self, describer_id: _DescriberId) -> dict[str, object]:
  452. return {
  453. "id": self.id,
  454. "describer_id": describer_id,
  455. "size": self.size if isinstance(self.size, int) else repr(self.size),
  456. }
  457. @dataclass(frozen=True)
  458. class ViewFunc(Generic[_TensorT]):
  459. @abstractmethod
  460. def apply(
  461. self,
  462. t: _TensorT,
  463. new_base: _TensorT,
  464. symint_visitor_fn: Optional[Callable[[int], int]] = None,
  465. tensor_visitor_fn: Optional[Callable[[torch.Tensor], _TensorT]] = None,
  466. ) -> _TensorT: ...
  467. @staticmethod
  468. def from_tensor(t: torch.Tensor) -> ViewFunc[Any]:
  469. if _is_fake_tensor(t):
  470. return _FakeTensorViewFunc()
  471. else:
  472. return _CustomViewFunc(t._view_func_unsafe)
  473. @dataclass(frozen=True)
  474. class _FakeTensorViewFunc(ViewFunc["FakeTensor"]):
  475. @override
  476. def apply(
  477. self,
  478. t: torch.Tensor,
  479. new_base: torch.Tensor,
  480. symint_visitor_fn: Optional[Callable[[int], int]] = None,
  481. tensor_visitor_fn: Optional[Callable[[torch.Tensor], FakeTensor]] = None,
  482. ) -> FakeTensor:
  483. return torch._subclasses.fake_tensor.FakeTensor._view_func_unsafe(
  484. # pyrefly: ignore [bad-argument-type]
  485. t,
  486. new_base,
  487. symint_visitor_fn,
  488. tensor_visitor_fn,
  489. )
  490. @dataclass(frozen=True)
  491. class _CustomViewFunc(ViewFunc[_TensorT], Generic[_TensorT]):
  492. func: Callable[
  493. [
  494. torch.Tensor,
  495. Optional[Callable[[int], int]],
  496. Optional[Callable[[torch.Tensor], _TensorT]],
  497. ],
  498. _TensorT,
  499. ]
  500. @override
  501. def apply(
  502. self,
  503. t: torch.Tensor,
  504. new_base: torch.Tensor,
  505. symint_visitor_fn: Optional[Callable[[int], int]] = None,
  506. tensor_visitor_fn: Optional[Callable[[torch.Tensor], _TensorT]] = None,
  507. ) -> _TensorT:
  508. # ignore `t`
  509. return self.func(new_base, symint_visitor_fn, tensor_visitor_fn)
  510. # A callback where the device is either optional or required.
  511. # All of these satisfy this protocol:
  512. # def mk(arg: Callable[[], torch.Tensor], device: Union[torch.device, str])
  513. # def mk(arg: Callable[[], torch.Tensor], device: Union[torch.device, str] = "meta")
  514. # def mk(arg: Callable[[], torch.Tensor], device: Optional[Union[torch.device, str]] = None)
  515. class _MetaTensorCallback(Protocol, Generic[_TensorT_cov]):
  516. def __call__(
  517. self, arg: Callable[[], torch.Tensor], /, *, device: Union[torch.device, str]
  518. ) -> _TensorT_cov: ...
  519. class _MetaTensorCallbackKwargs(TypedDict, total=False):
  520. device: Union[torch.device, str]
  521. # A callback where the device may not be provided (is optional).
  522. # All of these satisfy this protocol:
  523. # def mk(arg: Callable[[], torch.Tensor], device: Union[torch.device, str] = "meta")
  524. # def mk(arg: Callable[[], torch.Tensor], device: Optional[Union[torch.device, str]] = None)
  525. class _MetaTensorCallbackOptDevice(Protocol, Generic[_TensorT_cov]):
  526. def __call__(
  527. self,
  528. arg: Callable[[], torch.Tensor],
  529. /,
  530. **kwargs: Unpack[_MetaTensorCallbackKwargs],
  531. ) -> _TensorT_cov: ...
  532. @dataclass(frozen=True)
  533. class MetaTensorDesc(Generic[_TensorT]):
  534. id: MetaTensorId
  535. ndim: int
  536. dtype: torch.dtype
  537. device: torch.device
  538. # NB: Sometimes, size, stride and storage_offset contain SymInt, in which
  539. # case this is NOT serializable. That only happens when you're
  540. # re-fakeifying a fake tensor with an existing ShapeEnv... maybe we
  541. # can get rid of this use case entirely. Notably, even if we are
  542. # fakeifying a real tensor into a fake tensor with symbolic shapes, the
  543. # size here is NOT dynamic
  544. # NB: These also contain SymInt because wrap_meta_outputs_with_default_device_logic
  545. # goes through this codepath. But it really should not LOL.
  546. # NB: size could potentially be None as you can override it and make it
  547. # throw an error, but we don't currently have any subclasses that do this
  548. # except C++ nested tensor but we're going to have nested int to make this
  549. # defined on NJT
  550. size: tuple[int, ...]
  551. dynamo_dynamic_indices: list[int]
  552. dynamo_hint_overrides: dict[int, int]
  553. layout: torch.layout = torch.strided
  554. is_inference: bool = False
  555. is_leaf: bool = False
  556. requires_grad: bool = False
  557. is_sparse: bool = False
  558. is_mkldnn: bool = False
  559. is_functorch_wrapped: bool = False
  560. is_batchedtensor: bool = False
  561. is_legacy_batchedtensor: bool = False
  562. is_gradtrackingtensor: bool = False
  563. is_view: bool = False
  564. is_nested: bool = False
  565. # We eagerly symbolicize the associated nested int for e.g. offsets / lengths
  566. # metadata if that offsets is already associated with a nested int.
  567. # See test_construct_from_jagged_with_input_offsets_mixed_case.
  568. nested_int: Optional[int] = None
  569. is_traceable_wrapper_subclass: bool = False
  570. is_functional: bool = False
  571. is_conj: bool = False
  572. is_neg: bool = False
  573. is_parameter: bool = False
  574. stride: Optional[tuple[int, ...]] = None
  575. storage_offset: int = 0
  576. # NB: We have a choice whether or not to store the id or a direct pointer
  577. # to the data structure. For ease of use, we store the data structure,
  578. # but this means that when we serialize, we have to swizzle these pointers
  579. # back into ids (so we have accurate aliasing relationships)
  580. storage: Optional[MetaStorageDesc] = None
  581. sparse_dim: Optional[int] = None # is_sparse, is_sparse_compressed
  582. dense_dim: Optional[int] = None # is_sparse, is_sparse_compressed
  583. is_coalesced: Optional[bool] = None # is_sparse
  584. crow_indices: Optional[MetaTensorDesc[Any]] = None # is_sparse_compressed
  585. col_indices: Optional[MetaTensorDesc[Any]] = None # is_sparse_compressed
  586. ccol_indices: Optional[MetaTensorDesc[Any]] = None # is_sparse_compressed
  587. row_indices: Optional[MetaTensorDesc[Any]] = None # is_sparse_compressed
  588. values: Optional[MetaTensorDesc[Any]] = None # is_sparse_compressed
  589. unwrapped: Optional[MetaTensorDesc[Any]] = None # is_functorch_wrapped
  590. bdim: Optional[int] = None # is_functorch_wrapped
  591. base: Optional[MetaTensorDesc[Any]] = None # is_view
  592. attrs: Optional[dict[str, MetaTensorDesc[Any]]] = (
  593. None # is_traceable_wrapper_subclass
  594. )
  595. creation_meta: Optional[CreationMeta] = None
  596. grad: Optional[MetaTensorDesc[Any]] = None
  597. # Everything below is NOT serializable, need some more work
  598. _UNSERIALIZABLE: ClassVar[set[str]] = {
  599. "ctx",
  600. "type",
  601. "fake_mode",
  602. # view_func isn't serializable when it's a _CustomViewFunc
  603. "view_func",
  604. "level",
  605. "current_level",
  606. "functorch_stack",
  607. "autograd_meta_from",
  608. "data",
  609. "nested_int",
  610. }
  611. ctx: Optional[object] = None # is_traceable_wrapper_subclass
  612. type: Optional[type] = None # is_traceable_wrapper_subclass
  613. fake_mode: Optional[FakeTensorMode] = None
  614. view_func: Optional[ViewFunc[Any]] = None
  615. # level looks serializable, but actually it is meaningless without
  616. # the functorch_stack below
  617. level: Optional[int] = None # is_functorch_wrapped
  618. current_level: Optional[int] = None
  619. functorch_stack: Optional[list[CInterpreter]] = None
  620. autograd_meta_from: Optional[torch.Tensor] = None
  621. # This is only populated on copy_data, and typically is not used at all,
  622. # except for some of our meta-ification paths that don't properly use
  623. # storage (pro-tip: you should use storage)
  624. data: Optional[torch.Tensor] = None
  625. # Faithfully serializing functorch tensors will not be too difficult.
  626. # We only need to consider grad/vmap interpreters, and their internal
  627. # state is only bools (mostly what the grad enabled/disabled state
  628. # should be in the lower layer). Beyond that, tensors just need to
  629. # precisely indicate which particular interpreter they correspond
  630. # to (we then replace level with a pointer to the interpreter stack.)
  631. # However, this use of functorch is very "non-lexical" so it's not
  632. # entirely clear how to make it all lexical again, so we haven't done
  633. # it for now.
  634. # NB: This will reference numeric IDs, and it is assumed that you've
  635. # already serialized everything this recursively references
  636. def as_json(self, describer_id: _DescriberId) -> dict[str, object]:
  637. def json(k: str, v: object) -> object:
  638. # Some best-effort debugging serialization for unserializable
  639. # fields (feel free to add other special cases as appropriate)
  640. if k in ["data", "autograd_meta_from"]:
  641. return None # never repr these
  642. if k in MetaTensorDesc._UNSERIALIZABLE:
  643. return repr(v)
  644. if isinstance(v, (torch.device, torch.dtype, torch.layout)):
  645. return repr(v)
  646. if isinstance(v, torch.SymInt):
  647. return repr(v)
  648. if isinstance(v, (tuple, list)):
  649. return [json(k, v1) for v1 in v]
  650. if isinstance(v, (MetaStorageDesc, MetaTensorDesc)):
  651. return v.id
  652. if isinstance(v, CreationMeta):
  653. return str(v)
  654. if k == "attrs" and isinstance(v, dict):
  655. return {k1: v1.id for k1, v1 in v.items()}
  656. return v
  657. r = {
  658. field.name: json(field.name, getattr(self, field.name))
  659. for field in dataclasses.fields(self)
  660. if not (
  661. getattr(self, field.name) is field.default
  662. or (
  663. field.name == "dynamo_dynamic_indices"
  664. and not getattr(self, field.name)
  665. )
  666. )
  667. }
  668. r.update({"describer_id": describer_id})
  669. return r
  670. @property
  671. def shape(self) -> tuple[int, ...]:
  672. return self.size
  673. # A more faithful reproduction would do a copy on the entire
  674. # storage, but this needs to be done carefully because the
  675. # underlying storage could have larger extent than is implied
  676. # by size/stride. The real fix is to properly call
  677. # meta_storage recursively here.
  678. #
  679. # These "safe" functions are intended to be used under no_dispatch() mode.
  680. # The no_dispatch() here is intended to prevent ambient fake tensor mode from
  681. # fakeifying the operation. But if we are given an honest to goodness
  682. # FakeTensor as src, we MUST NOT run the copy/clone operation. A better way
  683. # to do this would be to not use no_dispatch and instead just disable fake
  684. # tensor mode only (allowing for subclass dispatch to occur)
  685. def _safe_copy(dst: torch.Tensor, src: Optional[torch.Tensor]) -> None:
  686. if type(src) is not torch.Tensor:
  687. return
  688. dst.copy_(src)
  689. def _safe_clone(src: torch.Tensor) -> Optional[torch.Tensor]:
  690. if type(src) is not torch.Tensor:
  691. return None
  692. return src.clone()
  693. # This is a class for converting multiple tensors into meta tensors which
  694. # share the same view/storage structure. The operation model is you allocate
  695. # one of these, and then call it repeatedly on all the tensors you want to
  696. # convert. It's important to use the same object for tensors you want to
  697. # share storage because this is how we correlate shared storages to the same
  698. # meta storages. This class will hold weak references to cached tenosrs
  699. # and tensor storages.
  700. class MetaConverter(Generic[_TensorT]):
  701. def __init__(self, *, copy_data: bool = False) -> None:
  702. # Maps MetaStorageId to UntypedStorage
  703. self.storage_memo: weakref.WeakValueDictionary[
  704. MetaStorageId, torch.UntypedStorage
  705. ] = weakref.WeakValueDictionary()
  706. # Maps MetaTensorId to torch.Tensor (typically a meta tensor or
  707. # FakeTensor)
  708. self.tensor_memo: weakref.WeakValueDictionary[MetaTensorId, _TensorT] = (
  709. weakref.WeakValueDictionary()
  710. )
  711. self.hit = 0
  712. self.miss = 0
  713. self.del_hook = None
  714. self.arg_cnt = 0
  715. # Ensures real_storage/real_tensor are populated on the resulting
  716. # metaified storage/tensor. The naming of this attribute is load
  717. # bearing: FakeTensor relies on real tensor being set to exactly this
  718. # value
  719. self.copy_data = copy_data
  720. self.describer = MetaTensorDescriber(copy_data=copy_data)
  721. def successful(self) -> bool:
  722. return self.hit > 0 and self.miss == 0
  723. def get_tensor_memo(self, t: MetaTensorDesc[Any]) -> Optional[torch.Tensor]:
  724. return self.tensor_memo.get(t.id, None)
  725. def _checked_get_tensor_memo(self, t: MetaTensorDesc[Any]) -> _TensorT:
  726. r = self.tensor_memo.get(t.id, None)
  727. if r is None:
  728. raise AssertionError(f"Tensor memo for id {t.id} is None")
  729. return r
  730. def set_tensor_memo(self, t: MetaTensorDesc[Any], v: _TensorT) -> None:
  731. self.tensor_memo[t.id] = v
  732. def get_storage_memo(self, s: MetaStorageDesc) -> Optional[torch.UntypedStorage]:
  733. return self.storage_memo.get(s.id, None)
  734. def set_storage_memo(self, s: MetaStorageDesc, v: torch.UntypedStorage) -> None:
  735. self.storage_memo[s.id] = v
  736. def meta_storage(
  737. self,
  738. s: MetaStorageDesc,
  739. callback: Callable[[Callable[[], torch.Tensor]], _TensorT],
  740. ) -> torch.UntypedStorage:
  741. # If we are fakeifying a tensor that has a secretly-zero-sized storage,
  742. # Need to make sure to resize the meta storage too.
  743. if (memo := self.get_storage_memo(s)) is None:
  744. r_s = callback(
  745. lambda: torch.empty(s.size, dtype=torch.uint8, device="meta"),
  746. ).untyped_storage()
  747. if self.copy_data:
  748. # NB: no_dispatch is needed because internally storage copy is
  749. # implemented as Tensor operations
  750. with torch.no_grad(), no_dispatch():
  751. if s.data is None:
  752. raise AssertionError(
  753. "s.data must not be None when copy_data is True"
  754. )
  755. _set_real_storage(r_s, s.data.clone())
  756. self.set_storage_memo(s, r_s)
  757. return r_s
  758. else:
  759. return memo
  760. @classmethod
  761. def _checked_cast_tensor_t(cls, t: torch.Tensor) -> _TensorT:
  762. # TODO: how to check _TensorT?
  763. return typing.cast(_TensorT, t)
  764. @classmethod
  765. def _identity_callable(
  766. cls,
  767. t: Callable[[], torch.Tensor],
  768. device: Optional[Union[torch.device, str]] = None,
  769. ) -> _TensorT:
  770. return cls._checked_cast_tensor_t(t())
  771. @classmethod
  772. def _backward_error(cls, t: _TensorT) -> _TensorT:
  773. errfn = torch._C._functions.DelayedError(
  774. "Internal error: Tried to backward() through example input",
  775. 1,
  776. )
  777. err = errfn(t)
  778. return typing.cast(_TensorT, err)
  779. # This function assumes that it's possible to do the conversion
  780. # NB: name here is used in a conventional way by Dynamo; it corresponds
  781. # precisely to the Source.name of the tensor we're fakeifying and
  782. # corresponds to a valid Python expression. When we construct sub-names
  783. # as part of this process, we will maintain this invariant! (Even though
  784. # other users of this may not need it this property to be upheld.)
  785. def meta_tensor(
  786. self,
  787. t: MetaTensorDesc[Any],
  788. shape_env: Optional[ShapeEnv],
  789. callback_: _MetaTensorCallback[_TensorT],
  790. source: Optional[Source],
  791. symbolic_context: Optional[SymbolicContext],
  792. ) -> _TensorT:
  793. callback: _MetaTensorCallbackOptDevice[_TensorT] = functools.partial(
  794. callback_, device=t.device
  795. )
  796. if source is None:
  797. from torch._dynamo.source import ConstantSource
  798. # TODO: make a dedicated UnknownSource for this?
  799. source = ConstantSource(
  800. f"__meta_utils_unknown_tensor{len(self.tensor_memo)}"
  801. )
  802. msg = (
  803. " This indicates you set no_dispatch() before calling into this"
  804. " function. This is an error: we may be creating fake tensors and"
  805. " will perform operations on them which need fake tensor mode to"
  806. " be active. You will segfault if you are in a no_dispatch() block."
  807. )
  808. if torch._C._dispatch_tls_local_exclude_set().has(torch._C.DispatchKey.Python):
  809. raise AssertionError(msg)
  810. self.arg_cnt += 1
  811. # When we make as_strided calls, we end up generating a guard
  812. # that the new as_strided tensor is in bounds for the old storage
  813. # for the base (since as_strided calls can "bust" out of their
  814. # bounding box.) This guard is unnecessary: if a user is able
  815. # to provide us a tensor with the view base setup this way, we
  816. # don't need to produce a guard, because the fact that they
  817. # were able to produce the view base means its in bounds.
  818. #
  819. # Now, ordinarily, this guard would be harmless. However, the
  820. # generated guard refers to variables bound on the base variable.
  821. # At the moment, Dynamo doesn't actually guard on x._base, because
  822. # according to Voz this results in a lot of spurious invalidations,
  823. # and also if the user doesn't directly make use of _base, its
  824. # pointless anyway (because programs should be parametric over
  825. # whether or not the input tensor is a view or not--unless you're
  826. # mutating the input, but that's a whole 'nother ballgame). So
  827. # for expediency, we suppress these guards so we don't have to
  828. # deal with this (yet, anyway.)
  829. #
  830. # NB: An old version of this code suppressed guards for ALL operations
  831. # happening during meta conversion, not just as_strided calls.
  832. # This is too aggressive: we do duck sizing and 0/1 simplification
  833. # as we allocate variables, and we do need to register guards for
  834. # these cases.
  835. maybe_suppress: Callable[[], Any] = contextlib.nullcontext
  836. if shape_env is not None:
  837. maybe_suppress = shape_env.suppress_guards
  838. def sym_sizes_strides_storage_offset(
  839. t: MetaTensorDesc[Any],
  840. src: torch._guards.Source,
  841. symbolic_context: Optional[
  842. torch.fx.experimental.symbolic_shapes.SymbolicContext
  843. ] = symbolic_context,
  844. ) -> tuple[tuple[int, ...], tuple[int, ...], int]:
  845. # local import to prevent circular import
  846. from torch.fx.experimental.symbolic_shapes import is_symbolic
  847. if t.stride is None:
  848. raise AssertionError("t.stride must not be None")
  849. if shape_env is not None:
  850. fake_mode = t.fake_mode
  851. has_symbolic = (
  852. any(is_symbolic(sz) for sz in t.size)
  853. or any(is_symbolic(sd) for sd in t.stride)
  854. or is_symbolic(t.storage_offset)
  855. )
  856. if fake_mode is not None and fake_mode.shape_env is shape_env:
  857. # Don't reallocate the sizes; the shape envs are the same,
  858. # so reuse the old sizes/strides/etc
  859. return (t.size, t.stride, t.storage_offset)
  860. elif (
  861. fake_mode is not None
  862. and not has_symbolic
  863. and symbolic_context is None
  864. ):
  865. return (t.size, t.stride, t.storage_offset)
  866. else:
  867. # TODO: deduplicate this
  868. t_size = tuple(
  869. shape_env._maybe_specialize_sym_int_with_hint(sz)
  870. for sz in t.size
  871. )
  872. t_stride = tuple(
  873. shape_env._maybe_specialize_sym_int_with_hint(sd)
  874. for sd in t.stride
  875. )
  876. t_storage_offset = shape_env._maybe_specialize_sym_int_with_hint(
  877. t.storage_offset
  878. )
  879. return shape_env._create_symbolic_sizes_strides_storage_offset(
  880. t_size,
  881. t_stride,
  882. t_storage_offset,
  883. [d in t.dynamo_dynamic_indices for d in range(t.ndim)],
  884. src,
  885. symbolic_context=symbolic_context,
  886. hint_overrides=t.dynamo_hint_overrides,
  887. )
  888. else:
  889. return (t.size, t.stride, t.storage_offset)
  890. def empty_create(
  891. inner_t: MetaTensorDesc[Any],
  892. inner_src: torch._guards.Source,
  893. symbolic_context: Optional[
  894. torch.fx.experimental.symbolic_shapes.SymbolicContext
  895. ] = symbolic_context,
  896. ) -> torch.Tensor:
  897. (
  898. inner_sizes,
  899. inner_strides,
  900. _inner_storage_offset,
  901. ) = sym_sizes_strides_storage_offset(inner_t, inner_src, symbolic_context)
  902. return torch.empty_strided(
  903. inner_sizes,
  904. inner_strides,
  905. dtype=inner_t.dtype,
  906. device="meta",
  907. )
  908. # Creates a subclass instance with empty inner tensors according to the specified
  909. # symbolic context.
  910. def empty_create_subclass(
  911. t: MetaTensorDesc[Any],
  912. outer_size: tuple[int, ...],
  913. outer_stride: tuple[int, ...],
  914. symbolic_context: Optional[
  915. torch.fx.experimental.symbolic_shapes.SymbolicContext
  916. ] = symbolic_context,
  917. source: Optional[torch._guards.Source] = source,
  918. ) -> _TensorT:
  919. from torch._dynamo.source import AttrSource
  920. from torch.fx.experimental.symbolic_shapes import SubclassSymbolicContext
  921. if t.attrs is None:
  922. raise AssertionError("t.attrs must not be None for subclass")
  923. if t.type is None:
  924. raise AssertionError("t.type must not be None for subclass")
  925. # NB: t.ctx could be None if the subclass in question has no
  926. # meaningful context
  927. # Note: transform_subclass will use __tensor_unflatten__ to generate
  928. # a fresh subclass wrapper with outer sizes / strides according to the
  929. # outer symbolic context (passed in to this function). Inner size / stride
  930. # / storage offset symbols are allocated according to the appropriate inner
  931. # symbolic contexts, after which the checks in transform_subclass() will
  932. # relate them to the outer metadata as possible.
  933. #
  934. # Morally, the code here is same as transform_subclass, but we've
  935. # written it from scratch to read EmptyCreateSubclass
  936. outer_size = outer_size if outer_size is not None else t.size
  937. # pyrefly: ignore [bad-assignment]
  938. outer_stride = outer_stride if outer_stride is not None else t.stride
  939. if symbolic_context is not None and not isinstance(
  940. symbolic_context, SubclassSymbolicContext
  941. ):
  942. raise AssertionError(
  943. f"Expected SubclassSymbolicContext or None, got {type(symbolic_context)}"
  944. )
  945. def _empty_create_subclass(
  946. t: MetaTensorDesc[Any],
  947. outer_size: Optional[tuple[int, ...]],
  948. outer_stride: Optional[tuple[int, ...]],
  949. symbolic_context: Optional[
  950. torch.fx.experimental.symbolic_shapes.SymbolicContext
  951. ],
  952. callback: _MetaTensorCallbackOptDevice[_TensorT],
  953. source: torch._guards.Source,
  954. ) -> _TensorT:
  955. # We are hitting plain meta_desc tensor so actually
  956. # create a tensor here.
  957. if t.attrs is None:
  958. return self.meta_tensor(
  959. t,
  960. shape_env,
  961. callback,
  962. source,
  963. symbolic_context,
  964. )
  965. inner_tensors = {}
  966. for attr, meta_tensor_desc in t.attrs.items():
  967. current_context = None
  968. if symbolic_context is not None:
  969. if not isinstance(symbolic_context, SubclassSymbolicContext):
  970. raise AssertionError(
  971. f"Expected SubclassSymbolicContext, got {type(symbolic_context)}"
  972. )
  973. if (
  974. current_context_ := symbolic_context.inner_contexts[attr]
  975. ) is not None:
  976. current_context = _checked_cast(
  977. torch.fx.experimental.symbolic_shapes.SymbolicContext,
  978. current_context_,
  979. )
  980. current_source = AttrSource(source, attr)
  981. inner_callback = functools.partial(
  982. callback, device=meta_tensor_desc.device
  983. )
  984. new_empty_tensor = _empty_create_subclass(
  985. meta_tensor_desc,
  986. meta_tensor_desc.size,
  987. meta_tensor_desc.stride,
  988. current_context,
  989. inner_callback,
  990. current_source,
  991. )
  992. inner_tensors[attr] = new_empty_tensor
  993. if t.type is None:
  994. raise AssertionError("t.type must not be None for subclass")
  995. return t.type.__tensor_unflatten__( # type: ignore[attr-defined]
  996. inner_tensors, t.ctx, outer_size, outer_stride
  997. )
  998. if source is None:
  999. raise AssertionError("source must not be None")
  1000. sub = _empty_create_subclass(
  1001. t, outer_size, outer_stride, symbolic_context, callback, source
  1002. )
  1003. # NB: Purposefully guard here to simplify the inner / outer symbols.
  1004. # Using sym_eq() for symbolic comparison can result in an expression that's too
  1005. # difficult to guard on, so we use == here.
  1006. if sub.shape != outer_size:
  1007. raise AssertionError(
  1008. f"Expected return value from {t.type}__tensor_unflatten__() to have "
  1009. f"shape equal to {outer_size}, but got: {sub.shape}"
  1010. )
  1011. if sub.stride() != outer_stride:
  1012. raise AssertionError(
  1013. f"Expected return value from {t.type}__tensor_unflatten__() to have "
  1014. f"stride equal to {outer_stride}, but got: {sub.stride()}"
  1015. )
  1016. return sub
  1017. # Returns an all-dynamic symbolic context used for metafying the given tensor with
  1018. # fully dynamic dims. This is useful when fake-ifying intermediate tensors in
  1019. # closed-over ViewFunc state, as we don't have symbolic contexts for them, but we
  1020. # don't want to over-specialize during view replay.
  1021. def all_dynamic_symbolic_context(
  1022. t: MetaTensorDesc[Any],
  1023. source: torch._guards.Source,
  1024. shape_env: Optional[torch.fx.experimental.symbolic_shapes.ShapeEnv],
  1025. callback: _MetaTensorCallback[_TensorT],
  1026. ) -> torch.fx.experimental.symbolic_shapes.SymbolicContext:
  1027. from torch._dynamo.source import AttrSource
  1028. from torch.fx.experimental.symbolic_shapes import (
  1029. DimDynamic,
  1030. StatelessSymbolicContext,
  1031. SubclassSymbolicContext,
  1032. )
  1033. view_base_context: Optional[
  1034. torch.fx.experimental.symbolic_shapes.SymbolicContext
  1035. ] = None
  1036. if t.is_view:
  1037. if t.base is None:
  1038. raise AssertionError("t.base must not be None for view tensor")
  1039. view_base_context = all_dynamic_symbolic_context(
  1040. t.base, AttrSource(source, "_base"), shape_env, callback
  1041. )
  1042. t_symbolic_context: torch.fx.experimental.symbolic_shapes.SymbolicContext
  1043. t_dynamic_sizes = [DimDynamic.DYNAMIC] * t.ndim
  1044. if t.is_traceable_wrapper_subclass:
  1045. if t.attrs is None:
  1046. raise AssertionError("t.attrs must not be None for subclass")
  1047. inner_contexts: dict[
  1048. str, torch.fx.experimental.symbolic_shapes.SymbolicContext
  1049. ] = {}
  1050. for attr, inner in t.attrs.items():
  1051. if not isinstance(attr, str):
  1052. raise AssertionError(
  1053. f"Expected attr to be str, got {type(attr)}"
  1054. )
  1055. inner_contexts[attr] = all_dynamic_symbolic_context(
  1056. inner, AttrSource(source, attr), shape_env, callback
  1057. )
  1058. t_symbolic_context = SubclassSymbolicContext(
  1059. dynamic_sizes=t_dynamic_sizes,
  1060. constraint_sizes=[None] * t.ndim,
  1061. inner_contexts=inner_contexts, # type: ignore[arg-type]
  1062. tensor_source=source,
  1063. view_base_context=view_base_context,
  1064. )
  1065. else:
  1066. t_symbolic_context = StatelessSymbolicContext(
  1067. dynamic_sizes=t_dynamic_sizes,
  1068. constraint_sizes=[None] * t.ndim,
  1069. view_base_context=view_base_context,
  1070. )
  1071. return t_symbolic_context
  1072. # Returns a fake-ified version of an input view tensor t, given an already fake-ified
  1073. # base. At a high level, we want two things:
  1074. # 1. fake_t should have the same view relationship to the given fake base as the
  1075. # input t has to its _base.
  1076. # 2. fake_t should have symbolic sizes / strides / storage offset according to the
  1077. # appropriate symbolic context (i.e. from the automatic dynamic algorithm).
  1078. #
  1079. # We currently take different strategies across view types:
  1080. # * For dense -> dense views, accomplish both (1) and (2) simultaneously via an
  1081. # as_strided() call on the fake-ified base, passing symbolic metadata.
  1082. # * For views involving subclasses, perform view replay using view funcs to
  1083. # achieve (1). It's necessary for (2) to swap out any closed-over state in
  1084. # the view funcs with symbolicized SymInts and fake-ified tensors. Doing this
  1085. # avoids specialization (and thus over-eager simplification of symbols) that
  1086. # could occur during view replay on the fake-ified base.
  1087. #
  1088. # Examples:
  1089. # * t.unsqueeze(-1) with dense t is a dense -> dense view. It can be modeled
  1090. # with an as_strided() call on the fake base passing symbolic metadata.
  1091. # * sub.select(dim=0, index=3) is a subclass -> subclass view. The index arg
  1092. # is made symbolic to avoid invalid specialization and view replay is then
  1093. # done to reconstruct the view.
  1094. # * _nested_from_jagged(values, offsets) is a dense -> subclass view
  1095. # that returns a subclass instance from a dense values tensor. The offsets
  1096. # tensor is closed over in the view func, as it can be considered view metadata.
  1097. # First, the offsets tensor is fake-ified according to the inner symbolic
  1098. # context and with the correct relationship to the outer size / stride metadata.
  1099. # Then view replay is done, swapping in the fake offsets so the view replay output
  1100. # is fully fake with no invalid specialization.
  1101. def view_from_base(
  1102. base: _TensorT,
  1103. t: MetaTensorDesc[Any],
  1104. shape_env: Optional[
  1105. torch.fx.experimental.symbolic_shapes.ShapeEnv
  1106. ] = shape_env,
  1107. ) -> _TensorT:
  1108. with enable_python_dispatcher():
  1109. # fake-ify t's metadata according to the outer symbolic context
  1110. (sizes, strides, storage_offset) = sym_sizes_strides_storage_offset(
  1111. t, source
  1112. )
  1113. if (
  1114. not t.is_traceable_wrapper_subclass
  1115. and not is_traceable_wrapper_subclass(base)
  1116. ):
  1117. # Dense -> Dense view case uses as_strided() to construct view relationship.
  1118. # TODO: Change this logic to use view replay for consistency?
  1119. # It's likely there is no view func available.
  1120. with maybe_suppress():
  1121. return self._checked_cast_tensor_t(
  1122. base.as_strided(sizes, strides, storage_offset)
  1123. )
  1124. from torch._dynamo.source import EphemeralSource
  1125. from torch.fx.experimental.symbolic_shapes import (
  1126. StatelessSymbolicContext,
  1127. sym_eq,
  1128. )
  1129. def symint_visitor_fn(s: int) -> int:
  1130. nonlocal symbolic_context
  1131. from torch.fx.experimental.symbolic_shapes import DimDynamic
  1132. all_static_sizes = (
  1133. symbolic_context is not None
  1134. and isinstance(symbolic_context, StatelessSymbolicContext)
  1135. and all(
  1136. x is DimDynamic.STATIC
  1137. for x in symbolic_context.dynamic_sizes
  1138. )
  1139. )
  1140. # Can't just rely on shape env being None - dynamo always initializes it
  1141. if all_static_sizes or shape_env is None:
  1142. return s
  1143. # NB: The symbol here is expected to be simplified out because we a priori
  1144. # allocate inner and outer symbols according to the appropriate symbolic
  1145. # contexts and prefer those over this symbol during symbol simplification
  1146. # (via usage of EphemeralSource below). This -shouldn't- happen, but if
  1147. # this symbol somehow leaks out beyond the view tensor's shape metadata, our
  1148. # assumption of it being simplified out will fail and it may be guarded on,
  1149. # which will hard error.
  1150. sym_source = EphemeralSource("symint_visitor_fn")
  1151. symbol = shape_env.create_symbol(s, sym_source, positive=None)
  1152. return shape_env.create_symintnode(
  1153. symbol, hint=s, source=sym_source
  1154. )
  1155. real_to_fake_mapping = {}
  1156. if t.is_traceable_wrapper_subclass:
  1157. if t.attrs is None:
  1158. raise AssertionError("t.attrs must not be None for subclass")
  1159. # NB: t.ctx could be None if the subclass in question has no
  1160. # meaningful context
  1161. if t.type is None:
  1162. raise AssertionError("t.type must not be None for subclass")
  1163. # Fake-ify t naively here; this is only done so we can get fake-ified inner
  1164. # tensors with the correct relationships to the outer sizes / strides for use
  1165. # in view replay. It's done beforehand here because it's not easy to do when
  1166. # visiting tensors one-by-one during view replay.
  1167. #
  1168. # Example:
  1169. # Consider a Dense -> NJT view. NJT has (values, offsets) components and we
  1170. # want a view of values with the offsets closed over. As the offsets component
  1171. # is needed to describe the output view, it's important that it's fakeified
  1172. # correctly.
  1173. fake_t: _TensorT = empty_create_subclass(
  1174. t, outer_size=sizes, outer_stride=strides
  1175. )
  1176. attrs, _ = fake_t.__tensor_flatten__() # type: ignore[attr-defined]
  1177. for attr in attrs:
  1178. real_to_fake_mapping[t.attrs[attr].id] = getattr(fake_t, attr)
  1179. def tensor_visitor_fn(
  1180. visited_t: torch.Tensor,
  1181. # These arguments are never passed, we just use them to close
  1182. # over these relevant values
  1183. shape_env: Optional[
  1184. torch.fx.experimental.symbolic_shapes.ShapeEnv
  1185. ] = shape_env,
  1186. callback: _MetaTensorCallbackOptDevice[_TensorT] = callback,
  1187. ) -> torch.Tensor:
  1188. # It's possible to close over an undefined tensor (e.g. NJT's lengths).
  1189. if visited_t is None:
  1190. # pyrefly: ignore [bad-return]
  1191. return None
  1192. # NB: visited_t being a Tensor here is very naughty! Should
  1193. # have already been described
  1194. # Fake inner tensors of view subclasses will come from the mapping built above.
  1195. visited_id = self.describer.get_tensor_id(visited_t)
  1196. fake_visited_t = real_to_fake_mapping.get(visited_id)
  1197. if fake_visited_t is not None:
  1198. return fake_visited_t
  1199. visited_desc = self.describer.describe_tensor(visited_t)
  1200. # For other closed-over tensor state, fake-ify it as all dynamic with an
  1201. # ephemeral source. This avoids invalid specialization during view replay.
  1202. # If we find that in practice the usage of ephemeral sources isn't enough
  1203. # to guarantee that we don't have guards on these symbols, we may need to
  1204. # explicitly suppress guards (as is done for _base in the dense -> dense
  1205. # view case).
  1206. temp_source = EphemeralSource("tensor_visitor_fn")
  1207. return self.meta_tensor(
  1208. visited_desc,
  1209. shape_env,
  1210. callback,
  1211. temp_source,
  1212. all_dynamic_symbolic_context(
  1213. visited_desc, temp_source, shape_env, callback
  1214. ),
  1215. )
  1216. # Replay the view, swapping out any non-symbolic SymInts or real tensors
  1217. # for symbolic SymInts or fake tensors.
  1218. if t.view_func is None:
  1219. raise AssertionError("t.view_func must not be None for view replay")
  1220. # NB: we do NOT suppress guards here, we need to remove ephemeral
  1221. # sources
  1222. fake_t = t.view_func.apply(
  1223. t, base, symint_visitor_fn, tensor_visitor_fn
  1224. )
  1225. # Ensure the output has symbolic shapes according to the outer symbolic context.
  1226. # These checks should simplify out any symbols created for closed-over view func
  1227. # SymInts.
  1228. torch._check(sym_eq(fake_t.size(), sizes))
  1229. torch._check(sym_eq(fake_t.stride(), strides))
  1230. torch._check(sym_eq(fake_t.storage_offset(), storage_offset))
  1231. return fake_t
  1232. if self.get_tensor_memo(t) is None:
  1233. GRAD_TENSOR_SENTINEL_VALUE = -2
  1234. with torch.inference_mode(t.is_inference):
  1235. if t.is_sparse:
  1236. is_leaf = t.is_leaf
  1237. # The lambda function below is similar to
  1238. # `t.to(device='meta')` except the latter
  1239. # preserves nnz value
  1240. r = callback(
  1241. lambda: torch.ops.aten._sparse_coo_tensor_with_dims(
  1242. t.sparse_dim,
  1243. t.dense_dim,
  1244. t.size,
  1245. dtype=t.dtype,
  1246. layout=torch.sparse_coo,
  1247. device="meta",
  1248. )
  1249. )
  1250. if self.copy_data:
  1251. # Pray that sparse clone doesn't lose information
  1252. if t.data is None:
  1253. raise AssertionError(
  1254. "t.data must not be None when copy_data is True"
  1255. )
  1256. with torch.no_grad(), no_dispatch():
  1257. if not _is_fake_tensor(r):
  1258. raise AssertionError("Expected r to be a FakeTensor")
  1259. # pyrefly: ignore[bad-assignment]
  1260. r.real_tensor = _safe_clone(t.data)
  1261. if not safe_is_leaf(r):
  1262. raise AssertionError(
  1263. "the callback you passed in doesn't detach"
  1264. )
  1265. # Note [is_coalesced is dispatched]
  1266. # Strangely enough, is_coalesced() is a dispatched operator,
  1267. # which means that it will get caught by fake tensor mode.
  1268. # Ordinarily this would error, but there's some logic in
  1269. # fake tensor ensure this doesn't happen.
  1270. r._coalesced_(bool(t.is_coalesced))
  1271. if t.requires_grad:
  1272. r.requires_grad = True
  1273. if t.requires_grad and not is_leaf:
  1274. # This should probably use DelayedError,
  1275. # but clone is fine for now for sparse tensors.
  1276. # (DelayedError does not work for sparse because it causes
  1277. # the Fake sparse tensor to "lose" its fakeness)
  1278. r = self._checked_cast_tensor_t(r.clone())
  1279. with torch.enable_grad():
  1280. r._coalesced_(bool(t.is_coalesced))
  1281. elif is_sparse_compressed_layout(t.layout):
  1282. is_leaf = t.is_leaf
  1283. if t.layout in {torch.sparse_bsr, torch.sparse_bsc}:
  1284. if t.sparse_dim is None:
  1285. raise AssertionError(
  1286. "t.sparse_dim must not be None for sparse block layout"
  1287. )
  1288. if t.dense_dim is None:
  1289. raise AssertionError(
  1290. "t.dense_dim must not be None for sparse block layout"
  1291. )
  1292. if t.values is None:
  1293. raise AssertionError(
  1294. "t.values must not be None for sparse block layout"
  1295. )
  1296. batch_dim = t.ndim - t.sparse_dim - t.dense_dim
  1297. blocksize = t.values.shape[batch_dim + 1 : batch_dim + 3]
  1298. else:
  1299. blocksize = ()
  1300. if t.layout in {torch.sparse_csr, torch.sparse_bsr}:
  1301. if t.crow_indices is None:
  1302. raise AssertionError(
  1303. "t.crow_indices must not be None for sparse csr/bsr layout"
  1304. )
  1305. index_dtype = t.crow_indices.dtype
  1306. else:
  1307. if t.ccol_indices is None:
  1308. raise AssertionError(
  1309. "t.ccol_indices must not be None for sparse csc/bsc layout"
  1310. )
  1311. index_dtype = t.ccol_indices.dtype
  1312. r = callback(
  1313. lambda: torch.ops.aten._sparse_compressed_tensor_with_dims(
  1314. 0,
  1315. t.dense_dim,
  1316. t.shape,
  1317. blocksize,
  1318. index_dtype,
  1319. layout=t.layout,
  1320. dtype=t.dtype,
  1321. device="meta",
  1322. )
  1323. )
  1324. if self.copy_data:
  1325. # Pray sparse clone doesn't lose information
  1326. if t.data is None:
  1327. raise AssertionError(
  1328. "t.data must not be None when copy_data is True"
  1329. )
  1330. with torch.no_grad(), no_dispatch():
  1331. if not _is_fake_tensor(r):
  1332. raise AssertionError("Expected r to be a FakeTensor")
  1333. # pyrefly: ignore[bad-assignment]
  1334. r.real_tensor = _safe_clone(t.data)
  1335. if not safe_is_leaf(r):
  1336. raise AssertionError(
  1337. "the callback you passed in doesn't detach"
  1338. )
  1339. if t.requires_grad:
  1340. r.requires_grad = True
  1341. if t.requires_grad and not is_leaf:
  1342. r = self._backward_error(r)
  1343. elif t.is_nested and not t.is_traceable_wrapper_subclass:
  1344. # TODO: Handle this better in Dynamo?
  1345. # There are checks there now, but this can still be triggered by a dense
  1346. # tensor graph input that is a view of a strided NT.
  1347. from torch._dynamo.exc import unimplemented
  1348. # NOTE this graph break will NOT be present in Dynamo's graph break registry
  1349. unimplemented(
  1350. gb_type="attempted to apply meta conversion to strided nested tensor",
  1351. context=str(t),
  1352. explanation="This is not supported.",
  1353. hints=[],
  1354. )
  1355. elif t.is_mkldnn:
  1356. is_leaf = t.is_leaf
  1357. (
  1358. sizes,
  1359. strides,
  1360. _storage_offset,
  1361. ) = sym_sizes_strides_storage_offset(t, source)
  1362. # TODO: This doesn't seem right, where's the MKLDNN'ness
  1363. # lol
  1364. r = callback(
  1365. lambda: torch.empty_strided(
  1366. sizes, strides, dtype=t.dtype, device="meta"
  1367. )
  1368. )
  1369. if self.copy_data:
  1370. with torch.no_grad(), no_dispatch():
  1371. if t.size is None:
  1372. raise AssertionError(
  1373. "t.size must not be None when copy_data is True"
  1374. )
  1375. if t.stride is None:
  1376. raise AssertionError(
  1377. "t.stride must not be None when copy_data is True"
  1378. )
  1379. if not _is_fake_tensor(r):
  1380. raise AssertionError("Expected r to be a FakeTensor")
  1381. # pyrefly: ignore[bad-assignment]
  1382. r.real_tensor = torch.empty_strided(
  1383. t.size, t.stride, dtype=t.dtype, device=t.device
  1384. )
  1385. if t.data is None:
  1386. raise AssertionError(
  1387. "t.data must not be None when copy_data is True"
  1388. )
  1389. _safe_copy(r.real_tensor, t.data)
  1390. if not safe_is_leaf(r):
  1391. raise AssertionError(
  1392. "the callback you passed in doesn't detach"
  1393. )
  1394. if t.requires_grad:
  1395. r.requires_grad = True
  1396. if t.requires_grad and not is_leaf:
  1397. r = self._backward_error(r)
  1398. elif t.is_functorch_wrapped:
  1399. if t.is_view:
  1400. from torch._dynamo.exc import unimplemented
  1401. unimplemented(
  1402. gb_type="attempted to apply meta conversion to view functorch tensor",
  1403. context=str(t),
  1404. explanation="This is not supported.",
  1405. hints=[],
  1406. )
  1407. # Wraps a functorch tensor class (BatchedTensor, GradTrackingTensor)
  1408. # in a FakeTensor
  1409. def _to_fake_tensor(t: MetaTensorDesc[Any]) -> _TensorT:
  1410. # TODO: why aren't the recursive calls going to
  1411. # meta_tensor
  1412. r: _TensorT
  1413. if t.is_batchedtensor:
  1414. if t.unwrapped is None:
  1415. raise AssertionError(
  1416. "t.unwrapped must not be None for batchedtensor"
  1417. )
  1418. if t.level is None:
  1419. raise AssertionError(
  1420. "t.level must not be None for batchedtensor"
  1421. )
  1422. if t.bdim is None:
  1423. raise AssertionError(
  1424. "t.bdim must not be None for batchedtensor"
  1425. )
  1426. ft = _to_fake_tensor(t.unwrapped)
  1427. lvl = t.level
  1428. bdim = t.bdim
  1429. # You cannot create functorch tensors without
  1430. # having the ambient funtorch interpreter stack
  1431. # available, as the level refers to things in the
  1432. # stack
  1433. with torch._functorch.pyfunctorch.temporarily_restore_interpreter_stack(
  1434. t.functorch_stack
  1435. ):
  1436. r = self._checked_cast_tensor_t(
  1437. _add_batch_dim(ft, bdim, lvl)
  1438. )
  1439. elif t.is_gradtrackingtensor:
  1440. if t.unwrapped is None:
  1441. raise AssertionError(
  1442. "t.unwrapped must not be None for gradtrackingtensor"
  1443. )
  1444. if t.level is None:
  1445. raise AssertionError(
  1446. "t.level must not be None for gradtrackingtensor"
  1447. )
  1448. disable_functorch = torch._C._DisableFuncTorch
  1449. with disable_functorch():
  1450. ft = _to_fake_tensor(t.unwrapped)
  1451. lvl = t.level
  1452. if lvl == GRAD_TENSOR_SENTINEL_VALUE:
  1453. r = ft
  1454. else:
  1455. with torch._functorch.pyfunctorch.temporarily_restore_interpreter_stack(
  1456. t.functorch_stack
  1457. ):
  1458. r = self._checked_cast_tensor_t(
  1459. torch._C._functorch._wrap_for_grad(ft, lvl),
  1460. )
  1461. is_leaf = t.is_leaf
  1462. if t.requires_grad and safe_is_leaf(r):
  1463. r.requires_grad = True
  1464. elif t.requires_grad and not is_leaf:
  1465. r = self._backward_error(r)
  1466. elif t.is_functional:
  1467. if t.unwrapped is None:
  1468. raise AssertionError(
  1469. "t.unwrapped must not be None for functional tensor"
  1470. )
  1471. if t.current_level is None:
  1472. raise AssertionError(
  1473. "t.current_level must not be None for functional tensor"
  1474. )
  1475. ft = self.meta_tensor(
  1476. t.unwrapped,
  1477. shape_env,
  1478. callback,
  1479. # NB: reuse these exactly, we treat the
  1480. # functional tensor as "invisible".
  1481. # TODO: Actually this all probably doesn't
  1482. # work, take a closer look.
  1483. source,
  1484. symbolic_context,
  1485. )
  1486. r = self._checked_cast_tensor_t(
  1487. _wrap_functional_tensor(ft, t.current_level),
  1488. )
  1489. # TODO: is_leaf/requires_grad?
  1490. else:
  1491. if t.stride is None:
  1492. raise AssertionError("t.stride must not be None")
  1493. sizes = t.size
  1494. strides = t.stride
  1495. r = callback(
  1496. lambda: torch.empty_strided(
  1497. sizes,
  1498. strides,
  1499. dtype=t.dtype,
  1500. device="meta",
  1501. ),
  1502. # device="meta",
  1503. )
  1504. if self.copy_data:
  1505. with torch.no_grad(), no_dispatch():
  1506. r.real_tensor = torch.empty_strided( # type: ignore[attr-defined]
  1507. t.size,
  1508. t.stride,
  1509. dtype=t.dtype,
  1510. device=t.device,
  1511. )
  1512. if t.data is None:
  1513. raise AssertionError(
  1514. "t.data must not be None when copy_data is True"
  1515. )
  1516. _safe_copy(r.real_tensor, t.data) # type: ignore[attr-defined]
  1517. # pyrefly: ignore [bad-return]
  1518. return r
  1519. r = _to_fake_tensor(t)
  1520. elif t.is_functional and t.device.type not in ["xla", "lazy"]:
  1521. if t.unwrapped is None:
  1522. raise AssertionError(
  1523. "t.unwrapped must not be None for functional tensor"
  1524. )
  1525. if t.is_functorch_wrapped: # handled above
  1526. raise AssertionError(
  1527. "Expected non-functorch wrapped functional tensor"
  1528. )
  1529. unwrapped = self.meta_tensor(
  1530. t.unwrapped,
  1531. shape_env,
  1532. callback,
  1533. source,
  1534. symbolic_context,
  1535. )
  1536. r = self._checked_cast_tensor_t(
  1537. torch._to_functional_tensor(unwrapped)
  1538. )
  1539. torch._mirror_autograd_meta_to(t.autograd_meta_from, r) # type: ignore[attr-defined]
  1540. elif t.is_view:
  1541. # Construct views in two steps: recursively meta-fy their
  1542. # base, and then create view(s) off that. NB: doing it
  1543. # directly from storage is WRONG because this won't cause
  1544. # version counters to get shared.
  1545. if t.base is None:
  1546. raise AssertionError("t.base must not be None for view tensor")
  1547. base_symbolic_context = None
  1548. if shape_env and symbolic_context is not None:
  1549. from torch.fx.experimental.symbolic_shapes import (
  1550. StatelessSymbolicContext,
  1551. )
  1552. if not isinstance(symbolic_context, StatelessSymbolicContext):
  1553. raise AssertionError(
  1554. f"Expected StatelessSymbolicContext, got {type(symbolic_context)}"
  1555. )
  1556. # NB: This should generally be set when the input is a view,
  1557. # but the exception right now is for fake-ifying grads, which is
  1558. # a work in progress.
  1559. if symbolic_context.view_base_context is not None:
  1560. base_symbolic_context = symbolic_context.view_base_context
  1561. base = self.meta_tensor(
  1562. t.base,
  1563. shape_env,
  1564. callback,
  1565. torch._dynamo.source.AttrSource(source, "_base"),
  1566. base_symbolic_context,
  1567. )
  1568. # If the base tensor has unbacked symbols (e.g., from mark_unbacked),
  1569. # we need to bind them now. Otherwise they'll be left pending and
  1570. # compute_unbacked_bindings will fail when called for the view
  1571. # (since the view may have concrete shapes, not the unbacked symbols).
  1572. from torch.fx.experimental.symbolic_shapes import (
  1573. compute_unbacked_bindings,
  1574. )
  1575. compute_unbacked_bindings(shape_env, base)
  1576. def is_c_of_r(
  1577. complex_dtype: torch.dtype, real_dtype: torch.dtype
  1578. ) -> bool:
  1579. return (
  1580. utils.is_complex_dtype(complex_dtype)
  1581. and utils.corresponding_real_dtype(complex_dtype)
  1582. == real_dtype
  1583. )
  1584. # In some situations, MetaConverter may be called in a
  1585. # context where autograd is disabled. For the _is_view
  1586. # assert to pass, we have to setup the autograd view
  1587. # metadata anyway. Do this by reenabling the
  1588. # ADInplaceOrView key. This is kind of a hack.
  1589. old_exclude = torch._C._dispatch_tls_is_dispatch_key_excluded(
  1590. torch._C.DispatchKey.ADInplaceOrView
  1591. )
  1592. torch._C._dispatch_tls_set_dispatch_key_excluded(
  1593. torch._C.DispatchKey.ADInplaceOrView, False
  1594. )
  1595. try:
  1596. if base.dtype == t.dtype:
  1597. pass
  1598. elif is_c_of_r(base.dtype, t.dtype):
  1599. base = self._checked_cast_tensor_t(torch.view_as_real(base))
  1600. elif is_c_of_r(t.dtype, base.dtype):
  1601. base = self._checked_cast_tensor_t(
  1602. torch.view_as_complex(base)
  1603. )
  1604. else:
  1605. # This is not guaranteed to succeed. If it fails, it
  1606. # means there is another dtype-converting view function
  1607. # that hasn't been handled here
  1608. base = self._checked_cast_tensor_t(base.view(t.dtype))
  1609. # This is very tricky. Naively, you might expect this
  1610. # to hold:
  1611. #
  1612. # if t.requires_grad and not safe_is_leaf(t)
  1613. # assert t._base.requires_grad
  1614. #
  1615. # But it's not true! As you can see in the following
  1616. # program:
  1617. #
  1618. # x = torch.zeros(4)
  1619. # y = x.view(1, 4)
  1620. # y.requires_grad = True
  1621. # z = y.view(1, 1, 4)
  1622. # assert z._base is x
  1623. #
  1624. # So we may have to do *two* views out of the base to
  1625. # recreate this situation.
  1626. if t.is_leaf:
  1627. # Leaf views that track view metadata are created by
  1628. # creating a view inside a no_grad block
  1629. with torch.no_grad():
  1630. r = view_from_base(base, t)
  1631. # As it's a leaf, we can directly assign requires_grad
  1632. r.requires_grad = t.requires_grad
  1633. else:
  1634. if t.base.requires_grad == t.requires_grad:
  1635. # Easy case, just run the view op
  1636. with torch.enable_grad():
  1637. r = view_from_base(base, t)
  1638. # NB: We don't actually faithfully replicate
  1639. # autograd connectivity, but that doesn't matter
  1640. # today. See following for more info:
  1641. # https://gist.github.com/soulitzer/e03f015b314c3f5fcf80888c69390913
  1642. else:
  1643. # Obscure case. Create a leaf view and give it the
  1644. # correct requires_grad, then do the final view.
  1645. # NB: Can't have a non-leaf without requiring grad!
  1646. if not t.requires_grad:
  1647. raise AssertionError(
  1648. "t.requires_grad must be True for non-leaf view"
  1649. )
  1650. with torch.no_grad(), enable_python_dispatcher():
  1651. mid = self._checked_cast_tensor_t(
  1652. base.view(base.shape)
  1653. )
  1654. mid.requires_grad = t.requires_grad
  1655. with torch.enable_grad():
  1656. r = view_from_base(mid, t)
  1657. # The CreationMeta influences whether or not inplace
  1658. # mutation is an error or not. So we need to make
  1659. # sure we properly propagate this as well.
  1660. if t.creation_meta is None:
  1661. raise AssertionError(
  1662. "t.creation_meta must not be None for view tensor"
  1663. )
  1664. torch._C._autograd._set_creation_meta(r, t.creation_meta)
  1665. finally:
  1666. torch._C._dispatch_tls_set_dispatch_key_excluded(
  1667. torch._C.DispatchKey.ADInplaceOrView, old_exclude
  1668. )
  1669. r.fake_device = t.device # type: ignore[attr-defined]
  1670. else:
  1671. is_leaf = t.is_leaf
  1672. # Graph-Break for wrapped tensors
  1673. if (
  1674. not (t.is_batchedtensor or t.is_gradtrackingtensor)
  1675. and t.is_functorch_wrapped
  1676. ) or t.is_legacy_batchedtensor:
  1677. # pyrefly: ignore [bad-return]
  1678. return NotImplemented
  1679. (
  1680. sizes,
  1681. strides,
  1682. storage_offset,
  1683. ) = sym_sizes_strides_storage_offset(t, source, symbolic_context)
  1684. # If we have a subclass that desugars into dense tensors,
  1685. # perform our callback on each inner tensor.
  1686. if t.is_traceable_wrapper_subclass:
  1687. r = empty_create_subclass(
  1688. t, outer_size=sizes, outer_stride=strides
  1689. )
  1690. else:
  1691. r = callback(
  1692. lambda: torch.empty_strided(
  1693. sizes,
  1694. strides,
  1695. dtype=t.dtype,
  1696. device="meta",
  1697. )
  1698. )
  1699. if self.copy_data:
  1700. with torch.no_grad(), no_dispatch():
  1701. if t.size is None:
  1702. raise AssertionError(
  1703. "t.size must not be None when copy_data is True"
  1704. )
  1705. if t.stride is None:
  1706. raise AssertionError(
  1707. "t.stride must not be None when copy_data is True"
  1708. )
  1709. if not _is_fake_tensor(r):
  1710. raise AssertionError(
  1711. "Expected r to be a FakeTensor"
  1712. )
  1713. # pyrefly: ignore[bad-assignment]
  1714. r.real_tensor = torch.empty_strided(
  1715. t.size, t.stride, dtype=t.dtype, device=t.device
  1716. )
  1717. _safe_copy(r.real_tensor, t.data)
  1718. if not safe_is_leaf(r):
  1719. raise AssertionError(
  1720. "the callback you passed in doesn't detach"
  1721. )
  1722. if t.requires_grad:
  1723. r.requires_grad = t.requires_grad
  1724. if not is_leaf:
  1725. # Fake up some autograd history.
  1726. # Note: we *used* to call .clone() here to mock up some autograd history.
  1727. # This is bad for subclasses.
  1728. # Consider the case where you have a wrapper subclass that is contiguous,
  1729. # but its inner tensor is noncontiguous().
  1730. # .clone() (or other ops) will have the side effect of changing
  1731. # the metadata of the inner tensor.
  1732. # So instead, we now have a dedicated fn to set autograd history,
  1733. # without inadvertently changing other metadata.
  1734. # pyrefly: ignore [bad-argument-type]
  1735. r = self._backward_error(r)
  1736. s = t.storage
  1737. if s is None:
  1738. raise AssertionError("t.storage must not be None")
  1739. if s.id not in self.storage_memo and (
  1740. r.is_nested
  1741. or (
  1742. r.stride() == strides
  1743. and r.storage_offset() == storage_offset
  1744. )
  1745. ):
  1746. # You're normal and happy, install the fresh storage into the memo
  1747. self.set_storage_memo(s, r.untyped_storage())
  1748. if self.copy_data:
  1749. if not _is_fake_tensor(r):
  1750. raise AssertionError("Expected r to be a FakeTensor")
  1751. if r.real_tensor is None:
  1752. raise AssertionError(
  1753. "r.real_tensor must not be None when copy_data is True"
  1754. )
  1755. _set_real_storage(
  1756. r.untyped_storage(), r.real_tensor.untyped_storage()
  1757. )
  1758. else:
  1759. # You're in crazy town; somehow you gave us a tensor
  1760. # that wasn't a view, but had nonzero storage offset,
  1761. # nontrivial strides (such that clone() couldn't
  1762. # preserve them), or already aliases with another
  1763. # tensor's storage. The most typical way to end
  1764. # up here is with set_. So use set_ to bludgeon this
  1765. # in.
  1766. r_s = self.meta_storage(s, callback=callback)
  1767. # NB: In principle, this should always work, but there
  1768. # is some subtle difference in the autograd metadata
  1769. # that means we will backprop the set_ call, even if
  1770. # r is declared as an input to grad.
  1771. # See https://github.com/pytorch/pytorch/issues/87956
  1772. # for the reproducer.
  1773. # NB: The in_kernel_invocation_manager here is necessary
  1774. # for fake tensor. If we run the set_ call with fake
  1775. # tensor on, r will improperly report that it is NOT a
  1776. # meta tensor but a cpu tensor, and then the set_ call
  1777. # will fail due to device mismatch. no_dispatch() is
  1778. # not enough, because the fake tensor will still claim
  1779. # to be a CPU tensor and you'll end up in the CPU
  1780. # kernel. Arguably this is a hack; a cleaner way to
  1781. # solve this is to have a FakeStorage concept which
  1782. # would report it's CPU device--no problem now! But
  1783. # this is difficult to do because we don't have storage
  1784. # subclasses. Relevant test is
  1785. # DynamicShapesFunctionTests::test_add_dynamic_shapes in
  1786. # test/dynamo/test_dynamic_shapes.py
  1787. maybe_fake_mgr: AbstractContextManager[None] = (
  1788. contextlib.nullcontext()
  1789. )
  1790. from torch._subclasses.fake_tensor import (
  1791. in_kernel_invocation_manager,
  1792. maybe_get_fake_mode,
  1793. )
  1794. mb_fake_mode = maybe_get_fake_mode(r)
  1795. if mb_fake_mode is not None:
  1796. maybe_fake_mgr = in_kernel_invocation_manager(mb_fake_mode)
  1797. with torch.no_grad(), maybe_suppress():
  1798. with maybe_fake_mgr:
  1799. r.set_(r_s, storage_offset, sizes, strides)
  1800. if self.copy_data:
  1801. with torch.no_grad(), no_dispatch():
  1802. if not _is_fake_tensor(r):
  1803. raise AssertionError(
  1804. "Expected r to be a FakeTensor"
  1805. )
  1806. if r.real_tensor is None:
  1807. raise AssertionError(
  1808. "r.real_tensor must not be None when copy_data is True"
  1809. )
  1810. if t.stride is None:
  1811. raise AssertionError(
  1812. "t.stride must not be None when copy_data is True"
  1813. )
  1814. r.real_tensor.set_(
  1815. _get_real_storage(r_s),
  1816. t.storage_offset,
  1817. t.size,
  1818. t.stride,
  1819. )
  1820. if t.grad is not None:
  1821. from torch._dynamo.source import AttrSource
  1822. # TODO: Use a valid grad-specific symbolic context instead of recycling
  1823. # the one from t. This isn't correct if e.g. t._is_view() != t.grad._is_view().
  1824. # pyrefly: ignore [unbound-name]
  1825. r.grad = self.meta_tensor(
  1826. t.grad,
  1827. shape_env,
  1828. callback,
  1829. AttrSource(source, "grad"),
  1830. symbolic_context,
  1831. )
  1832. # pyrefly: ignore [unbound-name]
  1833. torch._C._set_conj(r, t.is_conj)
  1834. # pyrefly: ignore [unbound-name]
  1835. torch._C._set_neg(r, t.is_neg)
  1836. # This can be skipped if necessary for performance reasons
  1837. skip_leaf = (
  1838. t.is_gradtrackingtensor and t.level == GRAD_TENSOR_SENTINEL_VALUE
  1839. )
  1840. # pyrefly: ignore [unbound-name]
  1841. assert_metadata_eq(assert_eq, t, r, skip_symbolic=True, skip_leaf=skip_leaf)
  1842. # Thanks to storage resizing, it's possible to end up with a tensor
  1843. # that advertises a real size, but has a storage that actually has zero bytes.
  1844. # Need to reflect this in the generated FakeTensor.
  1845. from torch.fx.experimental.symbolic_shapes import guard_or_false
  1846. if t.storage is not None and guard_or_false(t.storage.size == 0):
  1847. # pyrefly: ignore [unbound-name]
  1848. r.untyped_storage().resize_(0)
  1849. if t.is_parameter:
  1850. # pyrefly: ignore [unbound-name]
  1851. r._is_param = True
  1852. # See Note: [Creating symbolic nested int]
  1853. if t.nested_int is not None:
  1854. # pyrefly: ignore [unbound-name]
  1855. if not _is_fake_tensor(r):
  1856. raise AssertionError("Expected r to be a FakeTensor for nested int")
  1857. # pyrefly: ignore [unbound-name]
  1858. r.nested_int_memo = r.fake_mode.create_symbolic_nested_int(
  1859. nt_tensor_id=t.nested_int
  1860. )
  1861. # pyrefly: ignore [bad-argument-type, unbound-name]
  1862. self.set_tensor_memo(t, r)
  1863. return self._checked_get_tensor_memo(t)
  1864. def __call__(
  1865. self,
  1866. t: torch.Tensor,
  1867. shape_env: Optional[ShapeEnv] = None,
  1868. *,
  1869. callback: Optional[_MetaTensorCallback[_TensorT]] = None,
  1870. source: Optional[Source] = None,
  1871. symbolic_context: Optional[SymbolicContext] = None,
  1872. # Controls whether or not we should dump the tensor metadata to structured logs
  1873. # when source is not None. Because we refakify after Dynamo is done,
  1874. # we don't want to dump info again from AOTAutograd, it is redundant.
  1875. trace: bool = True,
  1876. ) -> _TensorT:
  1877. callback_: _MetaTensorCallback[_TensorT]
  1878. if callback is None:
  1879. callback_ = self._identity_callable
  1880. else:
  1881. callback_ = callback
  1882. # TODO: zero tensors? We appear to have eliminated them by
  1883. # excluding complex for now
  1884. # Filter out cases we don't support
  1885. # TODO: This can probably be simplified quite a bit
  1886. if isinstance(t, torch.Tensor):
  1887. if (
  1888. # Lazy tensors are not supported. Note that XLA is
  1889. # implemented on top of lazy tensor, not excluded here; we
  1890. # have some special handling for it; this is for XLA Dynamo
  1891. # integration
  1892. t.device.type == "lazy"
  1893. or
  1894. # Quantization is not supported
  1895. t.is_quantized
  1896. or
  1897. # Views out of sparse tensors not currently supported (plain
  1898. # sparse is supported htough)
  1899. (t._is_view() and t._base is not None and t._base.is_sparse)
  1900. ):
  1901. self.miss += 1
  1902. # pyrefly: ignore [bad-return]
  1903. return NotImplemented
  1904. else:
  1905. self.hit += 1
  1906. elif torch.overrides.is_tensor_like(t):
  1907. self.miss += 1
  1908. # pyrefly: ignore [bad-return]
  1909. return NotImplemented
  1910. else:
  1911. # non-Tensor types don't count as hit or miss
  1912. return t
  1913. if source is None:
  1914. trace = False
  1915. # Describe the tensor. NB: do NOT disable ambient modes, we may need
  1916. # to query them when figuring out what to put in here
  1917. t_desc = self.describer.describe_tensor(t, trace=trace)
  1918. if trace:
  1919. if source is None:
  1920. raise AssertionError("source must not be None when trace is True")
  1921. trace_structured(
  1922. "describe_source",
  1923. metadata_fn=lambda: {
  1924. "describer_id": self.describer.id,
  1925. "id": t_desc.id,
  1926. "source": source.name,
  1927. },
  1928. )
  1929. # Do the meta-fication. Here, we disable all the ambient modes, to
  1930. # better simulate what would be like to re-fakeify from a fresh
  1931. # process
  1932. with contextlib.ExitStack() as exit_stack:
  1933. exit_stack.enter_context(torch._dispatch.python.suspend_functionalization())
  1934. st = peek_interpreter_stack()
  1935. if st is not None:
  1936. exit_stack.enter_context(
  1937. torch._functorch.pyfunctorch.temporarily_clear_interpreter_stack()
  1938. )
  1939. r = self.meta_tensor(
  1940. t_desc,
  1941. shape_env,
  1942. callback_,
  1943. source,
  1944. symbolic_context,
  1945. )
  1946. if type(t) is torch.nn.Parameter:
  1947. # NB: Cannot directly use Parameter constructor
  1948. # because that would force a detach, not desirable
  1949. r._is_param = True
  1950. # forward the 'is_buffer' metadata if present (for nn.Buffer checks)
  1951. if getattr(t, "_is_buffer", False):
  1952. # pyrefly: ignore [missing-attribute]
  1953. r._is_buffer = True
  1954. if hasattr(t, "persistent"):
  1955. # pyrefly: ignore [missing-attribute]
  1956. r.persistent = t.persistent
  1957. # TODO: return the description for later
  1958. return r
  1959. import torch._prims_common as utils