_utils.py 48 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353
  1. # mypy: allow-untyped-defs
  2. import copyreg
  3. import functools
  4. import importlib
  5. import logging
  6. import math
  7. import os
  8. import pickle
  9. import re
  10. import sys
  11. import traceback
  12. import warnings
  13. from collections import defaultdict
  14. from collections.abc import Callable
  15. from types import ModuleType
  16. from typing import Any, cast, Generic, TYPE_CHECKING, TypedDict
  17. from typing_extensions import deprecated, NotRequired, ParamSpec
  18. import torch
  19. def _type(self, dtype=None, non_blocking=False, **kwargs):
  20. """Returns the type if `dtype` is not provided, else casts this object to
  21. the specified type.
  22. If this is already of the correct type, no copy is performed and the
  23. original object is returned.
  24. Args:
  25. dtype (type or string): The desired type
  26. non_blocking (bool): If ``True``, and the source is in pinned memory
  27. and destination is on the GPU or vice versa, the copy is performed
  28. asynchronously with respect to the host. Otherwise, the argument
  29. has no effect.
  30. **kwargs: For compatibility, may contain the key ``async`` in place of
  31. the ``non_blocking`` argument. The ``async`` arg is deprecated.
  32. """
  33. non_blocking = _get_async_or_non_blocking("type", non_blocking, kwargs)
  34. if dtype is None:
  35. return self.__module__ + "." + self.__class__.__name__
  36. if isinstance(dtype, str):
  37. dtype = _import_dotted_name(dtype)
  38. if dtype is type(self):
  39. return self
  40. if self.is_sparse:
  41. if not dtype.is_sparse:
  42. raise RuntimeError("Cannot cast sparse tensor to dense tensor")
  43. new_module_name = dtype.__module__.replace(".sparse", "")
  44. new_values_type_name = new_module_name + "." + dtype.__name__
  45. new_values = torch.Tensor._values(self).type(new_values_type_name, non_blocking)
  46. new_indices_type_name = new_module_name + ".LongTensor"
  47. new_indices = torch.Tensor._indices(self).type(
  48. new_indices_type_name, non_blocking
  49. )
  50. return dtype(new_indices, new_values, self.size())
  51. if dtype.is_sparse:
  52. raise RuntimeError("Cannot cast dense tensor to sparse tensor")
  53. return dtype(self.size()).copy_(self, non_blocking)
  54. def _to(self, device, non_blocking=False):
  55. """Returns a copy of this object in device memory.
  56. If this object is already on the correct device, then no copy is performed
  57. and the original object is returned.
  58. Args:
  59. device (int): The destination device.
  60. non_blocking (bool): If ``True`` and the source is in pinned memory,
  61. the copy will be asynchronous with respect to the host. Otherwise,
  62. the argument has no effect.
  63. """
  64. if self.device == device:
  65. return self
  66. if device.type == "cpu":
  67. pin_memory = non_blocking and self.device.type in (
  68. "cuda",
  69. torch._C._get_privateuse1_backend_name(),
  70. )
  71. untyped_storage = torch.empty(
  72. self.nbytes(), dtype=torch.uint8, device=device, pin_memory=pin_memory
  73. ).untyped_storage()
  74. untyped_storage.copy_(self, non_blocking)
  75. return untyped_storage
  76. device_module = getattr(torch, device.type, None)
  77. if device_module is None:
  78. raise AssertionError(f"{device.type.upper()} device module is not loaded")
  79. with device_module.device(device):
  80. if self.is_sparse and hasattr(device_module, "sparse"):
  81. new_type = getattr(device_module.sparse, self.__class__.__name__)
  82. indices = getattr(torch.Tensor._indices(self), device.type)(
  83. device, non_blocking
  84. )
  85. values = getattr(torch.Tensor._values(self), device.type)(
  86. device, non_blocking
  87. )
  88. return new_type(indices, values, self.size())
  89. else:
  90. if self.is_sparse:
  91. raise AssertionError(
  92. f"sparse storage is not supported for {device.type.upper()} tensors"
  93. )
  94. untyped_storage = torch.UntypedStorage(self.size(), device=device)
  95. untyped_storage.copy_(self, non_blocking)
  96. return untyped_storage
  97. def _get_async_or_non_blocking(function_name, non_blocking, kwargs):
  98. """Return the non-blocking flag given the function name and kwargs.
  99. Args:
  100. function_name (str): the name of the function being used.
  101. non_blocking (bool): the default value.
  102. **kwargs (dict): the kwargs passed to the function.
  103. """
  104. if not kwargs:
  105. return non_blocking
  106. if len(kwargs) != 1 or "async" not in kwargs:
  107. message = "{}() got an unexpected keyword argument '{}'"
  108. argument = list(kwargs.keys()).pop()
  109. raise TypeError(message.format(function_name, argument))
  110. warnings.warn("'async' is deprecated; use 'non_blocking'", stacklevel=2)
  111. return kwargs["async"]
  112. def _get_restore_location(device):
  113. """Return the map_location location.
  114. Used for rebuild functions where the tensor device is distinct from the storage
  115. """
  116. map_location = torch.serialization._serialization_tls.map_location
  117. if map_location is None:
  118. return device
  119. else:
  120. if isinstance(map_location, dict):
  121. return map_location.get(device, device)
  122. elif isinstance(map_location, (str, torch.device)):
  123. return map_location
  124. else:
  125. if not callable(map_location):
  126. raise AssertionError(
  127. f"expected callable map_location, got {type(map_location).__name__}"
  128. )
  129. raise RuntimeError(
  130. "Callable map_location not supported with _rebuild_wrapper_subclass "
  131. "or _rebuild_device_tensor_from_numpy"
  132. )
  133. # Note [Don't serialize hooks]
  134. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  135. # Since time immemorial, we have serialized the backward hooks associated with
  136. # variables. This kind of half-worked--Python can pickle global functions
  137. # (but not closures!)--but there were problems.
  138. #
  139. # - It's fragile. If you serialize a backward hook into a saved
  140. # model, and then you rename the function associated with the hook,
  141. # now your saved model is broken and you can't load it anymore.
  142. #
  143. # - It's not actually used. The standard recommendation is to
  144. # serialize the *state_dict* of a model, not the model itself
  145. # (since this is more stable to code changes affecting the model
  146. # serialization), and the state dict saves "data" only, thus
  147. # stripping the backward hooks. In some cases, hooks are
  148. # essential to the well-functioning of a model (e.g., DDP),
  149. # but DDP already manages re-adding the hooks!
  150. #
  151. # - We didn't serialize them in many cases. Prior to #10220, we
  152. # were dropping backward hooks in ForkingPickler. We "fixed" this
  153. # to be convenient with other serialization sites, but lack of
  154. # serializing backward hooks wasn't actually the root cause of
  155. # the bug.
  156. #
  157. # With these cases in mind, we have decided that a better strategy
  158. # is to just NOT serialize hooks at all.
  159. #
  160. # Since this is a BC-breaking change, we should warn when we previously
  161. # serialized a hook, but no longer do so. This will be done by adding a special
  162. # sentinel property to hooks will be used to suppress this warning. If a hook
  163. # has the property _torch_serialize_ignore, we will not emit a warning if we
  164. # attempt to serialize a Tensor with this hook attached to it.
  165. #
  166. # By the way, when _backward_hooks is skipped, we must give an EMPTY
  167. # OrderedDict(), if you pass a None you'll run afoul #12219.
  168. # TODO: Once we decide to break serialization FC, `storage` no longer needs to
  169. # be a TypedStorage
  170. def _rebuild_tensor(storage, storage_offset, size, stride):
  171. # first construct a tensor with the correct dtype/device
  172. t = torch.empty((0,), dtype=storage.dtype, device=storage._untyped_storage.device)
  173. return t.set_(storage._untyped_storage, storage_offset, size, stride)
  174. def get_tensor_metadata(tensor):
  175. # Tensor's Metadata for serializing.
  176. # Currently, this only returns a dict[string, bool] specifying whether
  177. # `conj` or `neg` bit is set.
  178. if not isinstance(tensor, torch.Tensor):
  179. raise AssertionError(f"expected torch.Tensor, got {type(tensor).__name__}")
  180. return torch._C._get_tensor_metadata(tensor) # type: ignore[attr-defined]
  181. def set_tensor_metadata(tensor, metadata):
  182. # See `get_tensor_metadata` above
  183. if not isinstance(metadata, dict):
  184. raise AssertionError(f"expected dict, got {type(metadata).__name__}")
  185. if not isinstance(tensor, torch.Tensor):
  186. raise AssertionError(f"expected torch.Tensor, got {type(tensor).__name__}")
  187. torch._C._set_tensor_metadata(tensor, metadata) # type: ignore[attr-defined]
  188. def _restore_device_fake_mode(tensor):
  189. if torch._guards.detect_fake_mode(None) is not None:
  190. if tensor.untyped_storage()._fake_device is not None:
  191. device = _get_restore_location(tensor.untyped_storage()._fake_device)
  192. if not isinstance(device, torch.device):
  193. device = torch.device(device)
  194. tensor.fake_device = torch.device(device)
  195. return tensor
  196. def _rebuild_tensor_v2(
  197. storage,
  198. storage_offset,
  199. size,
  200. stride,
  201. requires_grad,
  202. backward_hooks,
  203. metadata=None,
  204. ):
  205. tensor = _rebuild_tensor(storage, storage_offset, size, stride)
  206. tensor.requires_grad = requires_grad
  207. if metadata:
  208. set_tensor_metadata(tensor, metadata)
  209. # NB: This line exists only for backwards compatibility; the
  210. # general expectation is that backward_hooks is an empty
  211. # OrderedDict. See Note [Don't serialize hooks]
  212. tensor._backward_hooks = backward_hooks
  213. tensor = _restore_device_fake_mode(tensor)
  214. return tensor
  215. def _rebuild_tensor_v3(
  216. storage,
  217. storage_offset,
  218. size,
  219. stride,
  220. requires_grad,
  221. backward_hooks,
  222. dtype,
  223. metadata=None,
  224. ):
  225. t = torch.empty(
  226. (0,),
  227. dtype=dtype,
  228. device=storage._untyped_storage.device,
  229. requires_grad=requires_grad,
  230. )
  231. t.set_(storage._untyped_storage, storage_offset, size, stride)
  232. if metadata:
  233. set_tensor_metadata(t, metadata)
  234. t._backward_hooks = backward_hooks
  235. t = _restore_device_fake_mode(t)
  236. return t
  237. _sparse_tensors_to_validate: list["torch.Tensor"] = []
  238. # In _legacy_load() in serialization.py we unpickle storages after the sparse
  239. # tensors have been already unpickled. Those storages contain data necessary for
  240. # validating sparse tensors: indices and values. That's why sparse tensors are
  241. # first unpickled without any validation, and then this function is called just
  242. # before _legacy_load() returns, so that all the sparse tensors can be validated
  243. # in bulk.
  244. #
  245. # The same procedure must be followed by _load() in serialization.py because due
  246. # to Pickler semantics, we have to use the same (non-validating) function for
  247. # unpickling sparse tensors, regardless of the caller.
  248. def _validate_loaded_sparse_tensors():
  249. if not torch.sparse.check_sparse_tensor_invariants().is_enabled():
  250. # Skip sparse tensor invariants validation for better
  251. # performance. See check_sparse_tensor_invariants
  252. # documentation for how to control sparse tensor invariants
  253. # checking.
  254. _sparse_tensors_to_validate.clear()
  255. return
  256. try:
  257. # We disable pinning check (see check_pinning=False below) to
  258. # avoid gh-153143. In fact, pinning check is unnecessary
  259. # anywhy when loading sparse data from external sources.
  260. for t in _sparse_tensors_to_validate:
  261. if t.layout is torch.sparse_coo:
  262. torch._validate_sparse_coo_tensor_args(
  263. t._indices(),
  264. t._values(),
  265. t.size(),
  266. t.is_coalesced(),
  267. check_pinning=False,
  268. )
  269. elif t.layout in {
  270. torch.sparse_csr,
  271. torch.sparse_csc,
  272. torch.sparse_bsr,
  273. torch.sparse_bsc,
  274. }:
  275. # TODO: Validation currently involves an expensive traversal
  276. # on CPU, which may include a device transfer.
  277. if t.layout in {torch.sparse_csr, torch.sparse_bsr}:
  278. compressed_indices, plain_indices = (
  279. t.crow_indices(),
  280. t.col_indices(),
  281. )
  282. else:
  283. compressed_indices, plain_indices = (
  284. t.ccol_indices(),
  285. t.row_indices(),
  286. )
  287. torch._validate_sparse_compressed_tensor_args(
  288. compressed_indices,
  289. plain_indices,
  290. t.values(),
  291. t.size(),
  292. t.layout,
  293. check_pinning=False,
  294. )
  295. else:
  296. raise NotImplementedError(
  297. f"_validate_loaded_sparse_tensors for layout `{t.layout}`"
  298. )
  299. finally:
  300. _sparse_tensors_to_validate.clear()
  301. def _rebuild_sparse_tensor(layout, data):
  302. """
  303. Rebuilds a sparse tensor from its sparse storage representation.
  304. Args:
  305. layout (str): The sparse storage layout of the tensor.
  306. data (tuple): The tensor's sparse storage representation.
  307. """
  308. if layout == torch.sparse_coo:
  309. if len(data) == 3:
  310. # For BC:
  311. indices, values, size = data
  312. is_coalesced = None
  313. else:
  314. indices, values, size, is_coalesced = data
  315. result = torch.sparse_coo_tensor(
  316. indices, values, size, check_invariants=False, is_coalesced=is_coalesced
  317. )
  318. _sparse_tensors_to_validate.append(result)
  319. return result
  320. elif layout in {
  321. torch.sparse_csr,
  322. torch.sparse_csc,
  323. torch.sparse_bsr,
  324. torch.sparse_bsc,
  325. }:
  326. compressed_indices, plain_indices, values, size = data
  327. result = torch.sparse_compressed_tensor(
  328. compressed_indices,
  329. plain_indices,
  330. values,
  331. size,
  332. layout=layout,
  333. check_invariants=False,
  334. )
  335. _sparse_tensors_to_validate.append(result)
  336. return result
  337. raise NotImplementedError(f"rebuilding sparse tensor for layout {layout}")
  338. def _rebuild_nested_tensor(buffer, sizes, strides, storage_offsets):
  339. return torch._nested_view_from_buffer(buffer, sizes, strides, storage_offsets)
  340. def _rebuild_device_tensor_from_cpu_tensor(data, dtype, device, requires_grad):
  341. device = _get_restore_location(device)
  342. tensor = data.to(dtype=dtype, device=device)
  343. tensor.requires_grad = requires_grad
  344. return tensor
  345. def _rebuild_device_tensor_from_numpy(data, dtype, device, requires_grad):
  346. device = _get_restore_location(device)
  347. tensor = torch.from_numpy(data).to(dtype=dtype, device=device)
  348. tensor.requires_grad = requires_grad
  349. return tensor
  350. # Should not be used, only here to be able to load Tensors serialized with older versions of pytorch
  351. _rebuild_xla_tensor = _rebuild_device_tensor_from_numpy
  352. def _rebuild_meta_tensor_no_storage(dtype, size, stride, requires_grad):
  353. return torch.empty_strided(
  354. size, stride, dtype=dtype, device="meta", requires_grad=requires_grad
  355. )
  356. def _rebuild_wrapper_subclass(
  357. cls,
  358. dtype,
  359. size,
  360. stride,
  361. storage_offset,
  362. layout,
  363. device,
  364. requires_grad,
  365. ):
  366. device = _get_restore_location(device)
  367. return torch.Tensor._make_wrapper_subclass(
  368. cls,
  369. size,
  370. strides=stride,
  371. dtype=dtype,
  372. storage_offset=storage_offset,
  373. layout=layout,
  374. device=device,
  375. requires_grad=requires_grad,
  376. )
  377. # TODO: Once we decide to break serialization FC, `storage` no longer needs to
  378. # be a TypedStorage
  379. def _rebuild_qtensor(
  380. storage,
  381. storage_offset,
  382. size,
  383. stride,
  384. quantizer_params,
  385. requires_grad,
  386. backward_hooks,
  387. ):
  388. qscheme = quantizer_params[0]
  389. if qscheme == torch.per_tensor_affine:
  390. _, scale, zero_point = quantizer_params
  391. tensor = torch._empty_affine_quantized(
  392. size,
  393. scale=scale,
  394. zero_point=zero_point,
  395. dtype=storage.dtype,
  396. device=storage.device,
  397. )
  398. elif qscheme in (torch.per_channel_affine, torch.per_channel_affine_float_qparams):
  399. _, scales, zero_points, axis = quantizer_params
  400. if type(scales) is list and type(zero_points) is list:
  401. if qscheme == torch.per_channel_affine:
  402. scales = torch.tensor(scales, dtype=torch.double, device=storage.device)
  403. zero_points = torch.tensor(
  404. zero_points, dtype=torch.long, device=storage.device
  405. )
  406. else:
  407. scales = torch.tensor(scales, dtype=torch.float, device=storage.device)
  408. zero_points = torch.tensor(
  409. zero_points, dtype=torch.float, device=storage.device
  410. )
  411. tensor = torch._empty_per_channel_affine_quantized(
  412. size,
  413. scales=scales,
  414. zero_points=zero_points,
  415. axis=axis,
  416. dtype=storage.dtype,
  417. device=storage.device,
  418. )
  419. else:
  420. raise RuntimeError(f"Can't deserialize quantized tensor with qscheme {qscheme}")
  421. tensor.set_(storage, storage_offset, size, stride)
  422. tensor.requires_grad = requires_grad
  423. # NB: This line exists only for backwards compatibility; the
  424. # general expectation is that backward_hooks is an empty
  425. # OrderedDict. See Note [Don't serialize hooks]
  426. tensor._backward_hooks = backward_hooks
  427. return tensor
  428. def _rebuild_parameter(data, requires_grad, backward_hooks):
  429. param = torch.nn.Parameter(data, requires_grad)
  430. # NB: This line exists only for backwards compatibility; the
  431. # general expectation is that backward_hooks is an empty
  432. # OrderedDict. See Note [Don't serialize hooks]
  433. param._backward_hooks = backward_hooks
  434. return param
  435. def _rebuild_parameter_with_state(data, requires_grad, backward_hooks, state):
  436. param = torch.nn.Parameter(data, requires_grad)
  437. # NB: This line exists only for backwards compatibility; the
  438. # general expectation is that backward_hooks is an empty
  439. # OrderedDict. See Note [Don't serialize hooks]
  440. param._backward_hooks = backward_hooks
  441. # Restore state on Parameter like python attr.
  442. param = _set_obj_state(param, state)
  443. return param
  444. def _get_obj_state(obj):
  445. # Get the state of the python subclass
  446. # This loosely mimics the function on the object class but since Tensor do not inherit
  447. # from it, we cannot call that function directly
  448. # https://github.com/python/cpython/blob/c83919bd635f4433f1c6ae8504996a9fe3c215e5/Objects/typeobject.c#L4891
  449. # Note that starting with Python 3.11, this `__getstate__` is always defined and thus
  450. # the else branch will never be taken.
  451. getstate_fn = getattr(obj, "__getstate__", None)
  452. if getstate_fn:
  453. state = getstate_fn()
  454. else:
  455. slots_to_save = copyreg._slotnames(obj.__class__) # type: ignore[attr-defined]
  456. if slots_to_save:
  457. state = (
  458. obj.__dict__,
  459. {
  460. name: getattr(obj, name)
  461. for name in slots_to_save
  462. if hasattr(obj, name)
  463. },
  464. )
  465. else:
  466. state = obj.__dict__
  467. return state
  468. def _set_obj_state(obj, state):
  469. if isinstance(state, tuple):
  470. if not len(state) == 2:
  471. raise RuntimeError(f"Invalid serialized state: {state}")
  472. dict_state = state[0]
  473. slots_state = state[1]
  474. else:
  475. dict_state = state
  476. slots_state = None
  477. # Starting with Python 3.11, the __dict__ attribute is lazily created
  478. # and is serialized as None when not needed.
  479. if dict_state:
  480. for k, v in dict_state.items():
  481. setattr(obj, k, v)
  482. if slots_state:
  483. for k, v in slots_state.items():
  484. setattr(obj, k, v)
  485. return obj
  486. def _import_dotted_name(name):
  487. components = name.split(".")
  488. obj = __import__(components[0])
  489. for component in components[1:]:
  490. obj = getattr(obj, component)
  491. return obj
  492. def _flatten_dense_tensors(tensors):
  493. """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of
  494. same dense type.
  495. Since inputs are dense, the resulting tensor will be a concatenated 1D
  496. buffer. Element-wise operation on this buffer will be equivalent to
  497. operating individually.
  498. Args:
  499. tensors (Iterable[Tensor]): dense tensors to flatten.
  500. Returns:
  501. A contiguous 1D buffer containing input tensors.
  502. """
  503. return torch._C._nn.flatten_dense_tensors(tensors)
  504. def _flatten_sparse_tensors(tensors):
  505. """Flatten sparse tensors into two contiguous 1D buffers, one of indices and
  506. one of values. Assume tensors are of same sparse type.
  507. Args:
  508. tensors (Iterable[Tensor]): sparse tensors to flatten.
  509. Returns:
  510. A tuple of two contiguous 1D buffers, one containing input tensors'
  511. indices and the other containing the values.
  512. """
  513. flat_indices = torch._C._nn.flatten_dense_tensors(
  514. [torch.Tensor._indices(t) for t in tensors]
  515. )
  516. flat_values = torch._C._nn.flatten_dense_tensors(
  517. [torch.Tensor._values(t) for t in tensors]
  518. )
  519. return flat_indices, flat_values
  520. def _unflatten_dense_tensors(flat, tensors):
  521. """View a flat buffer using the sizes of tensors. Assume that tensors are of
  522. same dense type, and that flat is given by _flatten_dense_tensors.
  523. Args:
  524. flat (Tensor): flattened dense tensors to unflatten.
  525. tensors (Iterable[Tensor]): dense tensors whose sizes will be used to
  526. unflatten flat.
  527. Returns:
  528. Unflattened dense tensors with sizes same as tensors and values from
  529. flat.
  530. """
  531. return torch._C._nn.unflatten_dense_tensors(flat, tensors)
  532. def _unflatten_sparse_tensors(flat, tensors):
  533. """View flat buffer (containing indices and values) using the sizes of
  534. tensors. Assume that tensors are of same sparse type, and that flat is given
  535. by _flatten_sparse_tensors.
  536. Args:
  537. flat (tuple(Tensor, Tensor)): flattened indices and values of sparse
  538. tensors to unflatten.
  539. tensors (Iterable[Tensor]): sparse tensors whose sizes will be used to
  540. unflatten flat.
  541. Returns:
  542. Unflattened sparse tensors with sizes same as tensors and values from
  543. flat.
  544. """
  545. flat_indices, flat_values = flat
  546. indices = torch._C._nn.unflatten_dense_tensors(
  547. flat_indices, [torch.Tensor._indices(t) for t in tensors]
  548. )
  549. values = torch._C._nn.unflatten_dense_tensors(
  550. flat_values, [torch.Tensor._values(t) for t in tensors]
  551. )
  552. outputs = []
  553. for t, i, v in zip(tensors, indices, values):
  554. outputs.append(t.new(i, v, t.size()))
  555. return tuple(outputs)
  556. def _reorder_tensors_as(tensors, ordered_tensors):
  557. """Assume that tensors are of same order as ordered_tensors within their
  558. types, e.g., from _take_tensors. Reorder them to be of same order as
  559. ordered_tensors.
  560. Args:
  561. tensors (Iterable[Tensor]): tensors to be reordered. They should be of
  562. the same order as ordered_tensors within their own types.
  563. ordered_tensors (Iterable[Tensor]): tensors whose order will be the
  564. reference.
  565. Returns:
  566. Ordered tuple of tensors with contents from tensors and order of
  567. ordered_tensors.
  568. """
  569. type_dict = defaultdict(list)
  570. for tensor in tensors:
  571. type_dict[tensor.type()].append(tensor)
  572. type_dict_ = {t: iter(coll) for t, coll in type_dict.items()}
  573. return tuple(next(type_dict_[tensor.type()]) for tensor in ordered_tensors)
  574. def _take_tensors(tensors, size_limit):
  575. """Group tensors into chunks. This generator yields a chunk at each time,
  576. each containing tensors of same type up to certain byte limit in total size.
  577. Args:
  578. tensors (Sequence): A sequence of tensors to be separated into chunks.
  579. size_limit (int): The limit of each chunk in bytes.
  580. Yields:
  581. Blocks of tensors of same type and within size_limit. The yielded
  582. tensors are only ordered as the original sequence within its types.
  583. """
  584. buf_dict: defaultdict[str, list] = defaultdict(lambda: [[], 0])
  585. for tensor in tensors:
  586. t = tensor.type()
  587. if tensor.is_sparse:
  588. indices = torch.Tensor._indices(tensor)
  589. values = torch.Tensor._values(tensor)
  590. size = (
  591. indices.numel() * indices.element_size()
  592. + values.numel() * values.element_size()
  593. )
  594. else:
  595. size = tensor.numel() * tensor.element_size()
  596. buf_and_size = buf_dict[t]
  597. if buf_and_size[1] + size > size_limit and buf_and_size[1] > 0:
  598. yield buf_and_size[0]
  599. buf_and_size = buf_dict[t] = [[], 0]
  600. buf_and_size[0].append(tensor) # pyrefly: ignore [missing-attribute]
  601. buf_and_size[1] += size # pyrefly: ignore [unsupported-operation]
  602. for buf, _ in buf_dict.values():
  603. if len(buf) > 0:
  604. yield buf
  605. # annotation decorator to get annotations in a way that is compatible
  606. # with both Python 2 and 3
  607. def annotate(ret, **kwargs):
  608. def dec(fun):
  609. fun.__annotations__ = dict(kwargs)
  610. fun.__annotations__["return"] = ret
  611. return fun
  612. return dec
  613. def render_call(fn, args, kwargs):
  614. str_fn = torch.overrides.resolve_name(fn)
  615. if str_fn is None:
  616. str_fn = str(fn)
  617. str_args: list[str] = []
  618. with torch._tensor_str.printoptions(threshold=0, edgeitems=0):
  619. str_args.extend(repr(a) for a in args)
  620. str_args.extend(f"{k}={repr(v)}" for k, v in kwargs.items())
  621. r = f"{str_fn}({', '.join(str_args)})"
  622. return r
  623. # NOTE [ Python Traceback Reference Cycle Problem ]
  624. #
  625. # When using sys.exc_info(), it is important to **not** store the exc_info[2],
  626. # which is the traceback, because otherwise you will run into the traceback
  627. # reference cycle problem, i.e., the traceback holding reference to the frame,
  628. # and the frame (which holds reference to all the object in its temporary scope)
  629. # holding reference the traceback.
  630. class KeyErrorMessage(str):
  631. r"""str subclass that returns itself in repr"""
  632. __slots__ = ()
  633. def __repr__(self):
  634. return self
  635. class ExceptionWrapper:
  636. r"""Wraps an exception plus traceback to communicate across threads"""
  637. def __init__(self, exc_info=None, where="in background"):
  638. # It is important that we don't store exc_info, see
  639. # NOTE [ Python Traceback Reference Cycle Problem ]
  640. if exc_info is None:
  641. exc_info = sys.exc_info()
  642. self.exc_type = exc_info[0]
  643. self.exc_msg = "".join(traceback.format_exception(*exc_info))
  644. self.where = where
  645. def reraise(self):
  646. r"""Reraises the wrapped exception in the current thread"""
  647. # Format a message such as: "Caught ValueError in DataLoader worker
  648. # process 2. Original Traceback:", followed by the traceback.
  649. msg = f"Caught {self.exc_type.__name__} {self.where}.\nOriginal {self.exc_msg}" # pyrefly: ignore [missing-attribute]
  650. if self.exc_type is KeyError:
  651. # KeyError calls repr() on its argument (usually a dict key). This
  652. # makes stack traces unreadable. It will not be changed in Python
  653. # (https://bugs.python.org/issue2651), so we work around it.
  654. msg = KeyErrorMessage(msg)
  655. elif getattr(self.exc_type, "message", None):
  656. # Some exceptions have first argument as non-str but explicitly
  657. # have message field
  658. # pyrefly: ignore [not-callable]
  659. raise self.exc_type(
  660. # pyrefly: ignore [unexpected-keyword]
  661. message=msg
  662. )
  663. try:
  664. exception = self.exc_type(msg) # pyrefly: ignore [not-callable]
  665. except Exception:
  666. # If the exception takes multiple arguments or otherwise can't
  667. # be constructed, don't try to instantiate since we don't know how to
  668. raise RuntimeError(msg) from None
  669. raise exception
  670. def _get_available_device_type():
  671. if torch.cuda.is_available():
  672. return "cuda"
  673. if torch.backends.mps.is_available():
  674. return "mps"
  675. if hasattr(torch, "xpu") and torch.xpu.is_available(): # type: ignore[attr-defined]
  676. return "xpu"
  677. if hasattr(torch, "mtia") and torch.mtia.is_available():
  678. return "mtia"
  679. custom_backend_name = torch._C._get_privateuse1_backend_name()
  680. custom_device_mod = getattr(torch, custom_backend_name, None)
  681. if custom_device_mod and custom_device_mod.is_available():
  682. return custom_backend_name
  683. # add more available device types here
  684. return None
  685. def _get_device_attr(get_member):
  686. device_type = _get_available_device_type()
  687. if device_type and device_type.lower() == "cuda":
  688. return get_member(torch.cuda)
  689. if device_type and device_type.lower() == "mps":
  690. return get_member(torch.mps)
  691. if device_type and device_type.lower() == "xpu":
  692. return get_member(torch.xpu) # type: ignore[attr-defined]
  693. if device_type and device_type.lower() == "mtia":
  694. return get_member(torch.mtia)
  695. if device_type == torch._C._get_privateuse1_backend_name():
  696. return get_member(getattr(torch, device_type))
  697. # add more available device types here
  698. return None
  699. def _get_current_device_index():
  700. # current device index
  701. return _get_device_attr(lambda m: m.current_device())
  702. def _get_all_device_indices():
  703. # all device index
  704. return _get_device_attr(lambda m: list(range(m.device_count())))
  705. def _get_devices_properties(device_ids):
  706. # all device properties
  707. return [_get_device_attr(lambda m: m.get_device_properties(i)) for i in device_ids]
  708. def get_current_device_index() -> int:
  709. r"""Checks if there are CUDA devices available and
  710. returns the device index of the current default CUDA device.
  711. Returns -1 in case there are no CUDA devices available.
  712. Arguments: ``None``
  713. """
  714. if torch.cuda.device_count() > 0:
  715. return torch.cuda.current_device()
  716. return -1
  717. def _get_device_index(
  718. device: Any,
  719. optional: bool = False,
  720. allow_cpu: bool = False,
  721. ) -> int:
  722. r"""Gets the device index from :attr:`device`, which can be a torch.device
  723. object, a Python integer, or ``None``.
  724. If :attr:`device` is a torch.device object, returns the device index if it
  725. has index. Note that for a device without a specified index,
  726. i.e., ``torch.device('xxx')``, this will return the current default
  727. device of that type if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``,
  728. CPU devices will be accepted and ``-1`` will be returned in this case.
  729. If :attr:`device` is a Python integer, it is returned as is.
  730. If :attr:`device` is ``None``, this will return the current default
  731. device of the supported runtime platform if :attr:`optional` is ``True``.
  732. i.e., the current default CUDA device will be returned if CUDA runtime is supported.
  733. """
  734. if isinstance(device, str):
  735. device = torch.device(device)
  736. device_idx: int | None = None
  737. if isinstance(device, torch.device):
  738. if not allow_cpu and device.type == "cpu":
  739. raise ValueError(f"Expected a non cpu device, but got: {device}")
  740. device_idx = -1 if device.type == "cpu" else device.index
  741. if isinstance(device, int):
  742. device_idx = device
  743. if device_idx is None:
  744. if optional:
  745. # The eager API _get_current_device_index uses `lambda` functions which are
  746. # not supported in JIT and hence not scriptable. The JIT equivalent API to get
  747. # the current device index is `get_current_device_index()` which can
  748. # be scripted. We use is_scripting to check the mode we are in and call the
  749. # appropriate API.
  750. if torch.jit.is_scripting():
  751. device_idx = get_current_device_index()
  752. else:
  753. device_idx = _get_current_device_index()
  754. else:
  755. raise ValueError(
  756. f"Expected a torch.device with a specified index or an integer, but got:{device}"
  757. )
  758. return device_idx
  759. def _handle_complex(tensor):
  760. """
  761. Returns a real view of a tensor if complex dtype else just the tensor
  762. need to check if a UninitializedParameter because otherwise checking is_complex is an error for a LazyModule
  763. """
  764. return (
  765. torch.view_as_real(tensor)
  766. if not isinstance(tensor, torch.nn.UninitializedParameter)
  767. and tensor.is_complex()
  768. else tensor
  769. )
  770. def _element_size(dtype):
  771. """
  772. Returns the element size for a dtype, in bytes
  773. """
  774. if not isinstance(dtype, torch.dtype):
  775. raise RuntimeError(f"expected torch.dtype, but got {type(dtype)}")
  776. if dtype.is_complex:
  777. return torch.finfo(dtype).bits >> 2
  778. elif dtype.is_floating_point:
  779. return torch.finfo(dtype).bits >> 3
  780. elif dtype == torch.bool:
  781. # NOTE: torch.bool is not supported in torch.iinfo()
  782. return 1
  783. else:
  784. return torch.iinfo(dtype).bits >> 3
  785. class _ClassPropertyDescriptor:
  786. def __init__(self, fget, fset=None):
  787. self.fget = fget
  788. def __get__(self, instance, owner=None):
  789. if owner is None:
  790. owner = type(instance)
  791. return self.fget.__get__(instance, owner)()
  792. def classproperty(func):
  793. if not isinstance(func, (classmethod, staticmethod)):
  794. func = classmethod(func)
  795. return _ClassPropertyDescriptor(func)
  796. if TYPE_CHECKING:
  797. # TorchScript does not support `@deprecated`
  798. # This is a workaround to avoid breaking TorchScript
  799. @deprecated(
  800. "`torch._utils.is_compiling` is deprecated. Use `torch.compiler.is_compiling` instead.",
  801. category=FutureWarning,
  802. )
  803. def is_compiling() -> bool:
  804. return torch.compiler.is_compiling()
  805. else:
  806. def is_compiling() -> bool:
  807. """
  808. Indicates whether we are tracing/compiling with torch.compile() or torch.export().
  809. """
  810. warnings.warn( # use `warnings.warn` instead of `@deprecated`
  811. "`torch._utils.is_compiling` is deprecated. Use `torch.compiler.is_compiling` instead.",
  812. # FutureWarning, # TorchScript does not support Warning type
  813. stacklevel=2,
  814. )
  815. return torch.compiler.is_compiling()
  816. def _functionalize_sync(t):
  817. # This code lives in python instead of C++ since conditioning on a certain python subclass
  818. # is much more of a pain in C++.
  819. from torch._subclasses.functional_tensor import FunctionalTensor
  820. if isinstance(t, FunctionalTensor):
  821. # If a FunctionalTensorMode is active while syncing, we don't want it to intercept any ops that get called
  822. # when we sync our inner tensor.
  823. # Why?
  824. # (1) If there are input mutations in the graph, then they will be re-applied during
  825. # AOTAutograd when we call _sync() from inside of our functionalization kernels.
  826. # (2) _sync() causes us to regenerate our updated the tensor from the updated base,
  827. # which dispatches to a bunch of view ops
  828. # (3) The input to these view ops is our inner FunctionalTensorWrapper
  829. # (since the sync was called from C++), not the python FunctionalTensor
  830. # (4) if a python FunctionalTensorMode is active, it will complain when it intercepts
  831. # the view op, since it will see an input that is a C++ FunctionalTensorWrapper
  832. # (aka a normal torch.Tensor) instead of a python `FunctionalTensor).
  833. maybe_functional_mode = torch._C._unset_dispatch_mode(
  834. torch._C._TorchDispatchModeKey.FUNCTIONAL
  835. )
  836. try:
  837. torch._functionalize_sync(t.elem) # type: ignore[attr-defined]
  838. finally:
  839. if maybe_functional_mode is not None:
  840. torch._C._set_dispatch_mode(maybe_functional_mode)
  841. else:
  842. torch._functionalize_sync(t) # type: ignore[attr-defined]
  843. @functools.lru_cache(2)
  844. def _get_device_module(device_type: str):
  845. device_module = getattr(torch, device_type, None)
  846. if device_module is None:
  847. raise RuntimeError(
  848. f"Device '{device_type}' does not have a corresponding module registered as 'torch.{device_type}'."
  849. )
  850. return device_module
  851. def _dummy_type(name: str) -> type:
  852. def get_err_fn(is_init: bool):
  853. def err_fn(obj, *args, **kwargs):
  854. if is_init:
  855. class_name = obj.__class__.__name__
  856. else:
  857. class_name = obj.__name__
  858. raise RuntimeError(f"Tried to instantiate dummy base class {class_name}")
  859. return err_fn
  860. return type(
  861. name, (object,), {"__init__": get_err_fn(True), "__new__": get_err_fn(False)}
  862. )
  863. class _LazySeedTracker:
  864. # Since seeding is memory-less, only track the latest seed.
  865. # Note: `manual_seed_all` followed by `manual_seed` overwrites
  866. # the seed on current device. We track the order of **latest**
  867. # calls between these two API.
  868. def __init__(self):
  869. self.manual_seed_all_cb = None
  870. self.manual_seed_cb = None
  871. self.call_order = []
  872. def queue_seed_all(self, cb, traceback):
  873. self.manual_seed_all_cb = (cb, traceback) # pyrefly: ignore [bad-assignment]
  874. # update seed_all to be latest
  875. self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb]
  876. def queue_seed(self, cb, traceback):
  877. self.manual_seed_cb = (cb, traceback) # pyrefly: ignore [bad-assignment]
  878. # update seed to be latest
  879. self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb]
  880. def get_calls(self) -> list:
  881. return self.call_order
  882. logger = logging.getLogger(__name__)
  883. P = ParamSpec("P")
  884. class CallbackRegistry(Generic[P]):
  885. def __init__(self, name: str):
  886. self.name = name
  887. self.callback_list: list[Callable[P, None]] = []
  888. def add_callback(self, cb: Callable[P, None]) -> None:
  889. self.callback_list.append(cb)
  890. def fire_callbacks(self, *args: P.args, **kwargs: P.kwargs) -> None:
  891. for cb in self.callback_list:
  892. try:
  893. cb(*args, **kwargs)
  894. except Exception:
  895. logger.exception(
  896. "Exception in callback for %s registered with gpu trace", self.name
  897. )
  898. def try_import(module_name: str) -> ModuleType | None:
  899. # Implementation based on
  900. # https://docs.python.org/3/library/importlib.html#checking-if-a-module-can-be-imported
  901. if (module := sys.modules.get(module_name, None)) is not None:
  902. return module
  903. if (spec := importlib.util.find_spec(module_name)) is not None:
  904. module = importlib.util.module_from_spec(spec)
  905. sys.modules[module_name] = module
  906. # https://docs.python.org/3/library/importlib.html#importlib.machinery.ModuleSpec.loader
  907. # "The finder should always set this attribute"
  908. if spec.loader is None:
  909. raise AssertionError("The loader attribute should always be set")
  910. spec.loader.exec_module(module)
  911. return module
  912. return None
  913. # IMPORT_MAPPING and NAME_MAPPING are adapted from https://github.com/python/cpython/blob/main/Lib/_compat_pickle.py
  914. # for use in the weights_only Unpickler.
  915. IMPORT_MAPPING = {
  916. "__builtin__": "builtins",
  917. "copy_reg": "copyreg",
  918. "Queue": "queue",
  919. "repr": "reprlib",
  920. "_abcoll": "collections.abc",
  921. # Non-mutual mappings.
  922. "UserDict": "collections",
  923. "UserList": "collections",
  924. "UserString": "collections",
  925. "whichdb": "dbm",
  926. "StringIO": "io",
  927. "cStringIO": "io",
  928. }
  929. # This contains rename rules that are easy to handle. We ignore the more
  930. # complex stuff (e.g. mapping the names in the urllib and types modules).
  931. # These rules should be run before import names are fixed.
  932. NAME_MAPPING = {
  933. ("__builtin__", "xrange"): ("builtins", "range"),
  934. ("__builtin__", "reduce"): ("functools", "reduce"),
  935. ("__builtin__", "intern"): ("sys", "intern"),
  936. ("__builtin__", "unichr"): ("builtins", "chr"),
  937. ("__builtin__", "unicode"): ("builtins", "str"),
  938. ("__builtin__", "long"): ("builtins", "int"),
  939. ("itertools", "izip"): ("builtins", "zip"),
  940. ("itertools", "imap"): ("builtins", "map"),
  941. ("itertools", "ifilter"): ("builtins", "filter"),
  942. ("itertools", "ifilterfalse"): ("itertools", "filterfalse"),
  943. ("itertools", "izip_longest"): ("itertools", "zip_longest"),
  944. ("UserDict", "IterableUserDict"): ("collections", "UserDict"),
  945. ("UserList", "UserList"): ("collections", "UserList"),
  946. ("UserString", "UserString"): ("collections", "UserString"),
  947. # Non-mutual mappings.
  948. ("__builtin__", "basestring"): ("builtins", "str"),
  949. ("exceptions", "StandardError"): ("builtins", "Exception"),
  950. ("UserDict", "UserDict"): ("collections", "UserDict"),
  951. }
  952. def _maybe_view_chunk_cat(
  953. res: "torch.Tensor", group_size: int, gather_dim: int
  954. ) -> "torch.Tensor":
  955. """
  956. This is intuitively the same as torch.cat(torch.chunk(res, group_size,
  957. dim=0), dim=gather_dim), but returns a view if data movement is not
  958. necessary. This operation arises in NCCL all_gather, where you always get
  959. a result which is concatenated on dim=0, even though actually you may need
  960. to undo this concatenation and then re-cat on the gather dim.
  961. When is data-movement not necessary? Intuitively, we need to understand if
  962. the unflatten in this reference implementation of this code triggers a
  963. copy or not:
  964. chunks = torch.unflatten(res, 0, [group_size, -1])
  965. return torch.flatten(torch.movedim(chunks, 0, gather_dim), gather_dim, gather_dim + 1)
  966. Assume res is contiguous (it will be coming out of the collective). We
  967. essentially need to know if the movedim maintains the contiguity of the
  968. tensor. Moving a dimension typically does NOT preserve contiguity, unless
  969. EVERY dimension it is moved across is size 1.
  970. Example: shape [4, d1, d2] with group_size=4, gather_dim=1 -> [1, 4*d1, d2]
  971. [4, d1, d2] -> [4, 1, d1, d2] -> [1, 4, d1, d2] (contiguous!)
  972. Example: shape [4, 2, d2] with group_size=4, gather_dim=2 -> [1, 2, 4*d2]
  973. [4, 2, d2] -> [4, 1, 2, d2] -> [1, 2, 4, d2] (not contiguous!)
  974. Args:
  975. res: Tensor with gathered data in dim 0, shape [group_size, ...]
  976. group_size: Number of ranks in the group
  977. gather_dim: Dimension to gather along in the output
  978. Returns:
  979. Tensor with data rearranged to gather along gather_dim
  980. """
  981. if gather_dim == 0:
  982. # When gather_dim is 0, chunk+cat is a no-op
  983. return res
  984. shape = list(res.shape)
  985. # Optimization: Can use view instead of split+cat when:
  986. # 1. res.shape[0] == group_size (invariant after all_gather)
  987. # 2. All dims between 0 and gather_dim (exclusive) have size 1
  988. numel_between = math.prod(shape[1:gather_dim]) if gather_dim > 1 else 1
  989. if shape[0] == group_size and numel_between == 1:
  990. # View optimization: reshape to collapse dim 0 into gather_dim
  991. final_shape = (
  992. [1] # Dim 0 becomes 1
  993. + shape[1:gather_dim] # Dims 1 to gather_dim-1 unchanged
  994. + [shape[0] * shape[gather_dim]] # gather_dim gets multiplied by group_size
  995. + shape[gather_dim + 1 :] # Rest unchanged
  996. )
  997. return res.view(final_shape)
  998. else:
  999. # General case: fall back to split + cat
  1000. # This is better than torch.flatten as cat can be vectorized, whereas
  1001. # the contiguous kernel is always bad.
  1002. return torch.cat(torch.chunk(res, group_size, dim=0), dim=gather_dim)
  1003. class _Frame(TypedDict):
  1004. """Frame information from memory profiler snapshots."""
  1005. filename: str
  1006. line: int
  1007. name: str
  1008. # Fields added by FX augmentation (optional)
  1009. fx_node_op: NotRequired[str]
  1010. fx_node_name: NotRequired[str]
  1011. fx_node_target: NotRequired[str]
  1012. fx_original_trace: NotRequired[str]
  1013. class _Block(TypedDict):
  1014. """Memory block information."""
  1015. size: int
  1016. requested_size: int
  1017. address: int
  1018. state: str
  1019. frames: list[_Frame]
  1020. class _Segment(TypedDict):
  1021. """Memory segment information."""
  1022. address: int
  1023. total_size: int
  1024. stream: int
  1025. segment_type: str
  1026. allocated_size: int
  1027. active_size: int
  1028. blocks: list[_Block]
  1029. class _TraceEntry(TypedDict):
  1030. """Memory trace entry information."""
  1031. action: str
  1032. addr: NotRequired[int]
  1033. frames: list[_Frame]
  1034. size: int
  1035. stream: int
  1036. device_free: NotRequired[int]
  1037. class _Snapshot(TypedDict):
  1038. """Memory snapshot structure."""
  1039. segments: list[_Segment]
  1040. device_traces: NotRequired[list[list[_TraceEntry]]]
  1041. def _augment_frames(frames: list[_Frame]) -> int:
  1042. """
  1043. Augment a list of frames with FX debug information. For each frame corresponding
  1044. to an FX-generated Python file, this function attaches additional FX node
  1045. metadata (op, name, target, and original trace).
  1046. Args:
  1047. frames (list[_Frame]): List of frame dictionaries to augment
  1048. Returns:
  1049. int: The count of frames that were augmented.
  1050. """
  1051. from torch.fx.graph_module import FX_GRAPH_MODULE_FILE_PREFIX
  1052. from torch.fx.traceback import _FX_METADATA_REGISTRY
  1053. # Regex pattern to match FX generated files
  1054. _FX_GENERATED_PATTERN = re.compile(
  1055. rf"{re.escape(FX_GRAPH_MODULE_FILE_PREFIX)}.*\.py$"
  1056. )
  1057. count = 0
  1058. for frame in frames:
  1059. filename = frame.get("filename")
  1060. lineno = frame.get("line")
  1061. if not filename or not lineno:
  1062. continue
  1063. # Check if this looks like an FX generated file
  1064. if not _FX_GENERATED_PATTERN.search(os.path.basename(filename)):
  1065. continue
  1066. metadata = _FX_METADATA_REGISTRY.get(filename)
  1067. if metadata is None:
  1068. continue
  1069. lineno_map = metadata.get("lineno_map", {})
  1070. node_metadata = metadata.get("node_metadata", {})
  1071. prologue_start = metadata.get("prologue_start", 0)
  1072. # Get the node index for this line
  1073. node_idx = lineno_map.get(lineno - prologue_start)
  1074. if node_idx is None:
  1075. continue
  1076. node_info = node_metadata.get(node_idx)
  1077. if node_info is None:
  1078. continue
  1079. # Populate FX metadata fields
  1080. frame["fx_node_op"] = node_info.get("op")
  1081. frame["fx_node_name"] = node_info.get("name")
  1082. frame["fx_node_target"] = str(node_info.get("target"))
  1083. # Attach original stack trace if available
  1084. original_trace = node_info.get("stack_trace")
  1085. if original_trace:
  1086. frame["fx_original_trace"] = original_trace
  1087. count += 1
  1088. return count
  1089. def _augment_memory_snapshot_stack_traces(
  1090. snapshot: str | _Snapshot,
  1091. ) -> _Snapshot:
  1092. """
  1093. Augment a memory snapshot with original source stack traces from FX metadata.
  1094. IMPORTANT: This function reads from a global in-memory registry (_FX_METADATA_REGISTRY)
  1095. that is populated during graph module compilation. It must be called in the same
  1096. Python process where the FX graphs were compiled. It cannot be used to augment
  1097. snapshots loaded from disk in a different process.
  1098. Args:
  1099. snapshot (str or _Snapshot): Either a memory snapshot dict or path to a snapshot pickle file
  1100. Returns:
  1101. _Snapshot: The augmented snapshot dictionary with fx_node_op, fx_node_name,
  1102. fx_original_trace, and fx_node_info fields added to frames
  1103. """
  1104. snapshot_dict: _Snapshot
  1105. if isinstance(snapshot, str):
  1106. # Load the memory snapshot
  1107. with open(snapshot, "rb") as f:
  1108. snapshot_dict = cast(_Snapshot, pickle.load(f))
  1109. else:
  1110. snapshot_dict = snapshot
  1111. # Process blocks in segments (for regular allocations)
  1112. for segment in snapshot_dict.get("segments", []):
  1113. for block in segment.get("blocks", []):
  1114. if "frames" in block:
  1115. _augment_frames(block["frames"])
  1116. # Process device traces (for memory history)
  1117. for trace_list in snapshot_dict.get("device_traces", []):
  1118. for trace_entry in trace_list:
  1119. if isinstance(trace_entry, dict) and "frames" in trace_entry:
  1120. _augment_frames(trace_entry["frames"])
  1121. return snapshot_dict