serialization.py 85 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201
  1. # mypy: allow-untyped-defs
  2. import copyreg
  3. import difflib
  4. import functools
  5. import io
  6. import os
  7. import pickle
  8. import re
  9. import shutil
  10. import struct
  11. import sys
  12. import tarfile
  13. import tempfile
  14. import threading
  15. import warnings
  16. from collections.abc import Callable
  17. from contextlib import closing, contextmanager
  18. from enum import Enum
  19. from typing import Any, cast, Generic, IO, TypeAlias, TypeVar
  20. from typing_extensions import TypeIs
  21. import torch
  22. import torch._weights_only_unpickler as _weights_only_unpickler
  23. from torch._sources import get_source_lines_and_file
  24. from torch._utils import _import_dotted_name
  25. from torch.storage import _get_dtype_from_pickle_storage_type
  26. from torch.types import FileLike, Storage
  27. __all__ = [
  28. "SourceChangeWarning",
  29. "mkdtemp",
  30. "register_package",
  31. "check_module_version_greater_or_equal",
  32. "validate_cuda_device",
  33. "validate_hpu_device",
  34. "location_tag",
  35. "default_restore_location",
  36. "normalize_storage_type",
  37. "storage_to_tensor_type",
  38. "save",
  39. "load",
  40. "StorageType",
  41. "LoadEndianness",
  42. "get_crc32_options",
  43. "set_crc32_options",
  44. "get_default_load_endianness",
  45. "set_default_load_endianness",
  46. "get_default_mmap_options",
  47. "set_default_mmap_options",
  48. "clear_safe_globals",
  49. "get_safe_globals",
  50. "add_safe_globals",
  51. "safe_globals",
  52. "get_unsafe_globals_in_checkpoint",
  53. "skip_data",
  54. ]
  55. DEFAULT_PROTOCOL = 2
  56. LONG_SIZE = struct.Struct("=l").size
  57. INT_SIZE = struct.Struct("=i").size
  58. SHORT_SIZE = struct.Struct("=h").size
  59. MAGIC_NUMBER = 0x1950A86A20F9469CFC6C
  60. PROTOCOL_VERSION = 1001
  61. STORAGE_KEY_SEPARATOR = ","
  62. MAP_LOCATION: TypeAlias = (
  63. Callable[[Storage, str], Storage] | torch.device | str | dict[str, str] | None
  64. )
  65. STORAGE: TypeAlias = Storage | torch.storage.TypedStorage | torch.UntypedStorage
  66. IS_WINDOWS = sys.platform == "win32"
  67. UNSAFE_MESSAGE = (
  68. "In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` "
  69. "from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, "
  70. "but it can result in arbitrary code execution. Do it only if you got the file from a "
  71. "trusted source."
  72. )
  73. if not IS_WINDOWS:
  74. from mmap import MAP_PRIVATE, MAP_SHARED
  75. else:
  76. MAP_SHARED, MAP_PRIVATE = None, None # type: ignore[assignment]
  77. def _default_to_weights_only(pickle_module):
  78. is_fbcode = not hasattr(torch.version, "git_version")
  79. return pickle_module is None and not is_fbcode
  80. # _serialization_tls is used to store thread local state specific to serialization
  81. # that needs to be propagated to other files, in particular we use this for
  82. # (1) map_location (needed for wrapper subclasses/third party devices to torch._utils)
  83. # (2) skip_data (needed for torch.Tensor.__reduce_ex__ for skip_data ctx)
  84. # (3) materialize_fake_tensors (needed for torch.Tensor.__reduce_ex__ for skip_data ctx)
  85. class _SerializationLocal(threading.local):
  86. def __init__(self):
  87. super().__init__()
  88. self.map_location: MAP_LOCATION | None = None
  89. self.skip_data: bool = False
  90. self.materialize_fake_tensors: bool = False
  91. _serialization_tls = _SerializationLocal()
  92. class SourceChangeWarning(Warning):
  93. pass
  94. @contextmanager
  95. def mkdtemp():
  96. path = tempfile.mkdtemp()
  97. try:
  98. yield path
  99. finally:
  100. shutil.rmtree(path)
  101. _package_registry: list[
  102. tuple[
  103. int,
  104. Callable[[STORAGE], str | None],
  105. Callable[[STORAGE, str], STORAGE | None],
  106. ]
  107. ] = []
  108. class LoadEndianness(Enum):
  109. NATIVE = 1
  110. LITTLE = 2
  111. BIG = 3
  112. def get_default_load_endianness() -> LoadEndianness | None:
  113. """
  114. Get fallback byte order for loading files
  115. If byteorder mark is not present in saved checkpoint,
  116. this byte order is used as fallback.
  117. By default, it's "native" byte order.
  118. Returns:
  119. default_load_endian: Optional[LoadEndianness]
  120. """
  121. from torch.utils.serialization import config
  122. return config.load.endianness
  123. def set_default_load_endianness(endianness):
  124. """
  125. Set fallback byte order for loading files
  126. If byteorder mark is not present in saved checkpoint,
  127. this byte order is used as fallback.
  128. By default, it's "native" byte order.
  129. Args:
  130. endianness: the new fallback byte order
  131. """
  132. if not isinstance(endianness, LoadEndianness) and endianness is not None:
  133. raise TypeError("Invalid argument type in function set_default_load_endianness")
  134. from torch.utils.serialization import config
  135. config.load.endianness = endianness
  136. def get_crc32_options() -> bool:
  137. """
  138. Get whether :func:`torch.save` computes and writes crc32 for each record.
  139. Defaults to ``True``.
  140. """
  141. from torch.utils.serialization import config
  142. return config.save.compute_crc32
  143. def set_crc32_options(compute_crc32: bool):
  144. """
  145. Set whether :func:`torch.save` computes and writes crc32 for each record.
  146. .. note::
  147. Setting this to ``False`` may make unzipping of the ``torch.save`` output
  148. fail or warn due to corrupted CRC32. However ``torch.load`` will be
  149. able to load the file.
  150. Args:
  151. compute_crc32 (bool): set crc32 computation flag
  152. """
  153. from torch.utils.serialization import config
  154. config.save.compute_crc32 = compute_crc32
  155. def get_default_mmap_options() -> int | None:
  156. """
  157. Get default mmap options for :func:`torch.load` with ``mmap=True``.
  158. Defaults to ``mmap.MAP_PRIVATE``.
  159. Returns:
  160. default_mmap_options: int
  161. """
  162. from torch.utils.serialization import config
  163. return config.load.mmap_flags
  164. def _get_storage_alignment() -> int:
  165. """
  166. Gets alignment for storages in torch.save files/
  167. Defaults to 64.
  168. Returns:
  169. storage_alginment: int
  170. """
  171. from torch.utils.serialization import config
  172. return config.save.storage_alignment
  173. class set_default_mmap_options:
  174. """
  175. Context manager or function to set default mmap options for :func:`torch.load` with ``mmap=True`` to flags.
  176. For now, only either ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED`` are supported.
  177. Please open an issue if you need any other option to be added here.
  178. .. note::
  179. This feature is currently not supported for Windows.
  180. Args:
  181. flags: ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED``
  182. """
  183. def __init__(self, flags: int) -> None:
  184. if IS_WINDOWS:
  185. raise RuntimeError(
  186. "Changing the default mmap options is currently not supported for Windows"
  187. )
  188. if flags != MAP_PRIVATE and flags != MAP_SHARED:
  189. raise ValueError(
  190. "Invalid argument in function set_default_mmap_options, "
  191. f"expected mmap.MAP_PRIVATE or mmap.MAP_SHARED, but got {flags}"
  192. )
  193. # global config
  194. from torch.utils.serialization import config
  195. self.prev = config.load.mmap_flags
  196. config.load.mmap_flags = flags
  197. def __enter__(self) -> None:
  198. pass
  199. def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
  200. from torch.utils.serialization import config
  201. config.load.mmap_flags = self.prev
  202. def clear_safe_globals() -> None:
  203. """
  204. Clears the list of globals that are safe for ``weights_only`` load.
  205. """
  206. _weights_only_unpickler._clear_safe_globals()
  207. def get_safe_globals() -> list[Callable | tuple[Callable, str]]:
  208. """
  209. Returns the list of user-added globals that are safe for ``weights_only`` load.
  210. """
  211. return _weights_only_unpickler._get_safe_globals()
  212. def add_safe_globals(safe_globals: list[Callable | tuple[Callable, str]]) -> None:
  213. """
  214. Marks the given globals as safe for ``weights_only`` load. For example, functions
  215. added to this list can be called during unpickling, classes could be instantiated
  216. and have state set.
  217. Each item in the list can either be a function/class or a tuple of the form
  218. (function/class, string) where string is the full path of the function/class.
  219. Within the serialized format, each function is identified with its full
  220. path as ``{__module__}.{__qualname__}``. When calling this API, you can provide this
  221. full path that should match the one in the checkpoint otherwise the default
  222. ``{fn.__module__}.{fn.__qualname__}`` will be used.
  223. Args:
  224. safe_globals (List[Union[Callable, Tuple[Callable, str]]]): list of globals to mark as safe
  225. Example:
  226. >>> # xdoctest: +SKIP("Can't torch.save(t, ...) as doctest thinks MyTensor is defined on torch.serialization")
  227. >>> import tempfile
  228. >>> class MyTensor(torch.Tensor):
  229. ... pass
  230. >>> t = MyTensor(torch.randn(2, 3))
  231. >>> with tempfile.NamedTemporaryFile() as f:
  232. ... torch.save(t, f.name)
  233. # Running `torch.load(f.name, weights_only=True)` will fail with
  234. # Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default.
  235. # Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint.
  236. ... torch.serialization.add_safe_globals([MyTensor])
  237. ... torch.load(f.name, weights_only=True)
  238. # MyTensor([[-0.5024, -1.8152, -0.5455],
  239. # [-0.8234, 2.0500, -0.3657]])
  240. """
  241. _weights_only_unpickler._add_safe_globals(safe_globals)
  242. class safe_globals(_weights_only_unpickler._safe_globals):
  243. r"""Context-manager that adds certain globals as safe for ``weights_only`` load.
  244. Args:
  245. safe_globals: List of globals for weights_only load.
  246. Example:
  247. >>> # xdoctest: +SKIP("Can't torch.save(t, ...) as doctest thinks MyTensor is defined on torch.serialization")
  248. >>> import tempfile
  249. >>> class MyTensor(torch.Tensor):
  250. ... pass
  251. >>> t = MyTensor(torch.randn(2, 3))
  252. >>> with tempfile.NamedTemporaryFile() as f:
  253. ... torch.save(t, f.name)
  254. # Running `torch.load(f.name, weights_only=True)` will fail with
  255. # Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default.
  256. # Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint.
  257. ... with torch.serialization.safe_globals([MyTensor]):
  258. ... torch.load(f.name, weights_only=True)
  259. # MyTensor([[-0.5024, -1.8152, -0.5455],
  260. # [-0.8234, 2.0500, -0.3657]])
  261. >>> assert torch.serialization.get_safe_globals() == []
  262. """
  263. def get_unsafe_globals_in_checkpoint(f: FileLike) -> list[str]:
  264. """Returns a list of strings of functions/classes in a ``torch.save`` object that are not safe for ``weights_only``.
  265. For a given function or class ``f``, the corresponding string will be of the form
  266. ``{f.__module__}.{f.__name__}``.
  267. This function will return any GLOBALs in the checkpoint that are not in the set marked safe
  268. for ``weights_only`` (either via :func:`add_safe_globals` or :class:`safe_globals` context or
  269. allowlisted by ``torch`` by default).
  270. .. note::
  271. This function will statically disassemble the pickle file in the checkpoint.
  272. The implication is any classes dynamically pushed onto the stack during unpickling
  273. will not be included in the output.
  274. Args:
  275. f: File-like object or string containing the checkpoint object saved via ``torch.save``
  276. Returns:
  277. A list of strings of pickle GLOBALs in the checkpoint that are not allowlisted for ``weights_only``.
  278. """
  279. default_safe_globals_strings = set(
  280. _weights_only_unpickler._get_allowed_globals().keys()
  281. )
  282. user_safe_global_strings = set(
  283. _weights_only_unpickler._get_user_allowed_globals().keys()
  284. )
  285. safe_global_strings = default_safe_globals_strings.union(user_safe_global_strings)
  286. with _open_file_like(f, "rb") as opened_file:
  287. if not _is_zipfile(opened_file):
  288. raise ValueError("Expected input to be a checkpoint returned by torch.save")
  289. with _open_zipfile_reader(opened_file) as zip_file:
  290. if _is_torchscript_zip(zip_file):
  291. raise ValueError(
  292. "Expected input to be a checkpoint returned by torch.save but got a torchscript checkpoint"
  293. )
  294. data_file = io.BytesIO(zip_file.get_record("data.pkl"))
  295. all_globals = _weights_only_unpickler.get_globals_in_pkl(data_file)
  296. return list(all_globals.difference(safe_global_strings))
  297. class skip_data:
  298. """
  299. Context-manager that skips writing/reading storage bytes for ``torch.save`` / ``torch.load`` calls.
  300. For the save path, storages will still be saved, but the space that their bytes would usually be written to
  301. will be empty space. The storage bytes can then be populated in a separate pass.
  302. For the load path, tensors will be loaded per the checkpoint but their storages will not be populated with data.
  303. .. warning::
  304. The ``skip_data`` context manager is an early prototype and is subject to change.
  305. Args:
  306. materialize_fake_tensors: Whether to materialize FakeTensors during save. This is a no-op for the load path.
  307. Example:
  308. >>> # xdoctest: +SKIP("NamedTemporaryFile on Windows")
  309. >>> import tempfile
  310. >>> t = torch.randn(2, 3)
  311. >>> with tempfile.NamedTemporaryFile() as f:
  312. ... with torch.serialization.skip_data():
  313. ... torch.save(t, f.name)
  314. ... torch.load(f.name, weights_only=True)
  315. tensor([[0., 0., 0.],
  316. [0., 0., 0.]])
  317. """
  318. def __init__(self, materialize_fake_tensors: bool = False):
  319. self.materialize_fake_tensors = materialize_fake_tensors
  320. def __enter__(self):
  321. global _serialization_tls
  322. self._old_skip_data = _serialization_tls.skip_data
  323. self._old_materialize_fake_tensors = _serialization_tls.materialize_fake_tensors
  324. _serialization_tls.skip_data = True
  325. _serialization_tls.materialize_fake_tensors = self.materialize_fake_tensors
  326. def __exit__(self, type, value, tb):
  327. global _serialization_tls
  328. _serialization_tls.skip_data = self._old_skip_data
  329. _serialization_tls.materialize_fake_tensors = self._old_materialize_fake_tensors
  330. def _is_zipfile(f) -> bool:
  331. # This is a stricter implementation than zipfile.is_zipfile().
  332. # zipfile.is_zipfile() is True if the magic number appears anywhere in the
  333. # binary. Since we expect the files here to be generated by torch.save or
  334. # torch.jit.save, it's safe to only check the start bytes and avoid
  335. # collisions and assume the zip has only 1 file.
  336. # See bugs.python.org/issue28494.
  337. start = f.tell()
  338. # Read the first few bytes and match against the ZIP file signature
  339. local_header_magic_number = b"PK\x03\x04"
  340. read_bytes = f.read(len(local_header_magic_number))
  341. f.seek(start)
  342. return read_bytes == local_header_magic_number
  343. def register_package(
  344. priority: int,
  345. tagger: Callable[[STORAGE], str | None],
  346. deserializer: Callable[[STORAGE, str], STORAGE | None],
  347. ):
  348. """
  349. Registers callables for tagging and deserializing storage objects with an associated priority.
  350. Tagging associates a device with a storage object at save time while deserializing moves a
  351. storage object to an appropriate device at load time. :attr:`tagger` and :attr:`deserializer`
  352. are run in the order given by their :attr:`priority` until a tagger/deserializer returns a
  353. value that is not `None`.
  354. To override the deserialization behavior for a device in the global registry, one can register a
  355. tagger with a higher priority than the existing tagger.
  356. This function can also be used to register a tagger and deserializer for new devices.
  357. Args:
  358. priority: Indicates the priority associated with the tagger and deserializer, where a lower
  359. value indicates higher priority.
  360. tagger: Callable that takes in a storage object and returns its tagged device as a string
  361. or None.
  362. deserializer: Callable that takes in storage object and a device string and returns a storage
  363. object on the appropriate device or None.
  364. Returns:
  365. `None`
  366. Example:
  367. >>> def ipu_tag(obj):
  368. >>> if obj.device.type == 'ipu':
  369. >>> return 'ipu'
  370. >>> def ipu_deserialize(obj, location):
  371. >>> if location.startswith('ipu'):
  372. >>> ipu = getattr(torch, "ipu", None)
  373. >>> assert ipu is not None, "IPU device module is not loaded"
  374. >>> assert torch.ipu.is_available(), "ipu is not available"
  375. >>> return obj.ipu(location)
  376. >>> torch.serialization.register_package(11, ipu_tag, ipu_deserialize)
  377. """
  378. queue_elem = (priority, tagger, deserializer)
  379. _package_registry.append(queue_elem)
  380. _package_registry.sort()
  381. def check_module_version_greater_or_equal(
  382. module,
  383. req_version_tuple,
  384. error_if_malformed=True,
  385. ):
  386. """
  387. Check if a module's version satisfies requirements
  388. Usually, a module's version string will be like 'x.y.z', which would be represented
  389. as a tuple (x, y, z), but sometimes it could be an unexpected format. If the version
  390. string does not match the given tuple's format up to the length of the tuple, then
  391. error and exit or emit a warning.
  392. Args:
  393. module: the module to check the version of
  394. req_version_tuple: tuple (usually of ints) representing the required version
  395. error_if_malformed: whether we should exit if module version string is malformed
  396. Returns:
  397. requirement_is_met: bool
  398. """
  399. try:
  400. version_strs = module.__version__.split(".")
  401. # Cast module version fields to match the types of the required version
  402. module_version = tuple(
  403. type(req_field)(version_strs[idx])
  404. for idx, req_field in enumerate(req_version_tuple)
  405. )
  406. requirement_is_met = module_version >= req_version_tuple
  407. except Exception as e:
  408. message = (
  409. f"'{module.__name__}' module version string is malformed '{module.__version__}' and cannot be compared"
  410. f" with tuple {str(req_version_tuple)}"
  411. )
  412. if error_if_malformed:
  413. raise RuntimeError(message) from e
  414. else:
  415. warnings.warn(
  416. message + ", but continuing assuming that requirement is met",
  417. stacklevel=2,
  418. )
  419. requirement_is_met = True
  420. return requirement_is_met
  421. def _cpu_tag(obj):
  422. if obj.device.type == "cpu":
  423. return "cpu"
  424. def _mps_tag(obj):
  425. if obj.device.type == "mps":
  426. return "mps"
  427. def _meta_tag(obj):
  428. if obj.device.type == "meta":
  429. return "meta"
  430. def _backend_tag(backend_name, obj):
  431. if backend_name == "privateuse1":
  432. backend_name = torch._C._get_privateuse1_backend_name()
  433. if obj.device.type == backend_name:
  434. if obj.device.index is None:
  435. return backend_name
  436. else:
  437. return backend_name + ":" + str(obj.device.index)
  438. def _cpu_deserialize(obj, location):
  439. if location == "cpu":
  440. return obj
  441. def _mps_deserialize(obj, location):
  442. if location.startswith("mps"):
  443. return obj.mps()
  444. def _meta_deserialize(obj, location):
  445. if location == "meta":
  446. return torch.UntypedStorage(obj.nbytes(), device="meta")
  447. def _is_meta_location(map_location):
  448. """
  449. Check if map_location specifies the meta device.
  450. This is used to skip reading storage data from disk when loading
  451. to the meta device, since meta tensors don't hold actual data.
  452. Args:
  453. map_location: The map_location argument passed to torch.load
  454. Returns:
  455. True if map_location is definitively the meta device, False otherwise.
  456. For dict or callable map_location, returns False since we can't
  457. easily determine the target device without evaluating it.
  458. """
  459. if map_location is None:
  460. return False
  461. if isinstance(map_location, str):
  462. return map_location == "meta"
  463. if isinstance(map_location, torch.device):
  464. return map_location.type == "meta"
  465. # dict or callable - can't easily determine
  466. return False
  467. def _validate_device(location, backend_name):
  468. """
  469. Check whether the device index of specified backend is valid
  470. In case of privateuse1 backend, your must first register a device_module for
  471. privateuse1 using torch._register_device_module. Implement the following
  472. methods in device_module like cuda: device_module._utils._get_device_index(location, True),
  473. device_module.device_count().
  474. Args:
  475. location: string of device
  476. backend_name: the backend name or the name of privateuse1, which can be renamed
  477. Returns:
  478. device_index: int
  479. """
  480. if not hasattr(torch, backend_name):
  481. raise RuntimeError(
  482. f"The {backend_name.upper()} device module is not registered. "
  483. "If you are running on a CPU-only machine, "
  484. "please use torch.load with map_location=torch.device('cpu') "
  485. "to map your storages to the CPU."
  486. )
  487. device_module = getattr(torch, backend_name)
  488. if hasattr(device_module, "_utils") and hasattr(
  489. device_module._utils, "_get_device_index"
  490. ):
  491. device_index = device_module._utils._get_device_index(location, True)
  492. device = torch.device(backend_name, device_index)
  493. else:
  494. device = torch.device(location)
  495. device_index = device.index if device.index else 0
  496. if hasattr(device_module, "is_available") and not device_module.is_available():
  497. raise RuntimeError(
  498. f"Attempting to deserialize object on a {backend_name.upper()} "
  499. f"device but torch.{backend_name}.is_available() is False. "
  500. "If you are running on a CPU-only machine, "
  501. "please use torch.load with map_location=torch.device('cpu') "
  502. "to map your storages to the CPU."
  503. )
  504. if hasattr(device_module, "device_count"):
  505. device_count = device_module.device_count()
  506. if device_index >= device_count:
  507. raise RuntimeError(
  508. f"Attempting to deserialize object on {backend_name.upper()} device "
  509. f"{device_index} but torch.{backend_name}.device_count() is {device_count}. "
  510. "Please use torch.load with map_location to map your storages "
  511. "to an existing device."
  512. )
  513. return device
  514. def validate_cuda_device(location):
  515. return _validate_device(location, "cuda").index
  516. def validate_hpu_device(location):
  517. return _validate_device(location, "hpu").index
  518. def _deserialize(backend_name, obj, location):
  519. if backend_name == "privateuse1":
  520. backend_name = torch._C._get_privateuse1_backend_name()
  521. if location == backend_name or bool(
  522. re.match(f"{backend_name}(:|[0-9]+)", location)
  523. ):
  524. device = _validate_device(location, backend_name)
  525. return obj.to(device=device)
  526. register_package(10, _cpu_tag, _cpu_deserialize)
  527. register_package(
  528. 20,
  529. functools.partial(_backend_tag, "cuda"),
  530. functools.partial(_deserialize, "cuda"),
  531. )
  532. register_package(21, _mps_tag, _mps_deserialize)
  533. register_package(22, _meta_tag, _meta_deserialize)
  534. register_package(
  535. 23,
  536. functools.partial(_backend_tag, "privateuse1"),
  537. functools.partial(_deserialize, "privateuse1"),
  538. )
  539. register_package(
  540. 24,
  541. functools.partial(_backend_tag, "hpu"),
  542. functools.partial(_deserialize, "hpu"),
  543. )
  544. register_package(
  545. 25,
  546. functools.partial(_backend_tag, "xpu"),
  547. functools.partial(_deserialize, "xpu"),
  548. )
  549. register_package(
  550. 26,
  551. functools.partial(_backend_tag, "mtia"),
  552. functools.partial(_deserialize, "mtia"),
  553. )
  554. def location_tag(
  555. storage: Storage | torch.storage.TypedStorage | torch.UntypedStorage,
  556. ):
  557. for _, tagger, _ in _package_registry:
  558. location = tagger(storage)
  559. if location:
  560. return location
  561. raise RuntimeError(
  562. "don't know how to determine data location of " + torch.typename(storage)
  563. )
  564. def default_restore_location(storage, location):
  565. """
  566. Restores `storage` using a deserializer function registered for the `location`.
  567. This function looks in the registry for deserializer functions that match the `location`.
  568. If found, it attempts to use them, in priority order, to restore `storage` until one
  569. returns a not `None` result. If no deserializer can be found in the registry, or all found fail
  570. to bear a result, it raises a `RuntimeError`.
  571. Args:
  572. storage (STORAGE): the storage object to restore
  573. location (str): the location tag associated with the storage object
  574. Returns:
  575. storage: Optional[STORAGE]
  576. Raises:
  577. RuntimeError: If no deserializer matching `location` is found in the registry or if
  578. all matching ones return `None`.
  579. """
  580. for _, _, fn in _package_registry:
  581. result = fn(storage, location)
  582. if result is not None:
  583. return result
  584. raise RuntimeError(
  585. "don't know how to restore data location of "
  586. + torch.typename(storage)
  587. + " (tagged with "
  588. + location
  589. + ")"
  590. )
  591. def normalize_storage_type(storage_type):
  592. return getattr(torch, storage_type.__name__)
  593. def storage_to_tensor_type(storage):
  594. storage_type = type(storage)
  595. module = _import_dotted_name(storage_type.__module__)
  596. return getattr(module, storage_type.__name__.replace("Storage", "Tensor"))
  597. def _is_path(name_or_buffer: object) -> TypeIs[str | os.PathLike]:
  598. return isinstance(name_or_buffer, (str, os.PathLike))
  599. T = TypeVar("T")
  600. class _opener(Generic[T]):
  601. def __init__(self, file_like: T) -> None:
  602. self.file_like: T = file_like
  603. def __enter__(self):
  604. return self.file_like
  605. def __exit__(self, *args):
  606. pass
  607. class _open_file(_opener[IO[bytes]]):
  608. def __init__(self, name: str | os.PathLike[str], mode: str) -> None:
  609. super().__init__(open(name, mode)) # noqa: SIM115
  610. def __exit__(self, *args):
  611. self.file_like.close()
  612. class _open_buffer_reader(_opener[IO[bytes]]):
  613. def __init__(self, buffer: IO[bytes]) -> None:
  614. super().__init__(buffer)
  615. _check_seekable(buffer)
  616. class _open_buffer_writer(_opener[IO[bytes]]):
  617. def __exit__(self, *args):
  618. self.file_like.flush()
  619. def _open_file_like(name_or_buffer: FileLike, mode: str) -> _opener[IO[bytes]]:
  620. if _is_path(name_or_buffer):
  621. return _open_file(name_or_buffer, mode)
  622. else:
  623. if "w" in mode:
  624. return _open_buffer_writer(name_or_buffer)
  625. elif "r" in mode:
  626. return _open_buffer_reader(name_or_buffer)
  627. else:
  628. raise RuntimeError(f"Expected 'r' or 'w' in mode but got {mode}")
  629. class _open_zipfile_reader(_opener[torch._C.PyTorchFileReader]):
  630. def __init__(self, name_or_buffer: str | IO[bytes]) -> None:
  631. super().__init__(torch._C.PyTorchFileReader(name_or_buffer))
  632. class _open_zipfile_writer_file(_opener[torch._C.PyTorchFileWriter]):
  633. def __init__(self, name: str) -> None:
  634. self.file_stream = None
  635. self.name = name
  636. try:
  637. self.name.encode("ascii")
  638. except UnicodeEncodeError:
  639. # PyTorchFileWriter only supports ascii filename.
  640. # For filenames with non-ascii characters, we rely on Python
  641. # for writing out the file.
  642. # pyrefly: ignore [bad-assignment]
  643. self.file_stream = io.FileIO(self.name, mode="w")
  644. super().__init__(
  645. torch._C.PyTorchFileWriter( # pyrefly: ignore # no-matching-overload
  646. self.file_stream, get_crc32_options(), _get_storage_alignment()
  647. )
  648. )
  649. else:
  650. super().__init__(
  651. torch._C.PyTorchFileWriter(
  652. self.name, get_crc32_options(), _get_storage_alignment()
  653. )
  654. )
  655. def __exit__(self, *args) -> None:
  656. self.file_like.write_end_of_file()
  657. if self.file_stream is not None:
  658. self.file_stream.close()
  659. class _open_zipfile_writer_buffer(_opener[torch._C.PyTorchFileWriter]):
  660. def __init__(self, buffer: IO[bytes]) -> None:
  661. if not callable(getattr(buffer, "write", None)):
  662. msg = f"Buffer of {str(type(buffer)).strip('<>')} has no callable attribute 'write'"
  663. if not hasattr(buffer, "write"):
  664. raise AttributeError(msg)
  665. raise TypeError(msg)
  666. self.buffer = buffer
  667. super().__init__(
  668. torch._C.PyTorchFileWriter(
  669. buffer, get_crc32_options(), _get_storage_alignment()
  670. )
  671. )
  672. def __exit__(self, *args) -> None:
  673. self.file_like.write_end_of_file()
  674. self.buffer.flush()
  675. def _open_zipfile_writer(name_or_buffer: str | IO[bytes]) -> _opener:
  676. container: type[_opener]
  677. if _is_path(name_or_buffer):
  678. container = _open_zipfile_writer_file
  679. else:
  680. container = _open_zipfile_writer_buffer
  681. return container(name_or_buffer) # type: ignore[arg-type]
  682. def _is_compressed_file(f) -> bool:
  683. compress_modules = ["gzip"]
  684. try:
  685. return f.__module__ in compress_modules
  686. except AttributeError:
  687. return False
  688. def _should_read_directly(f):
  689. """
  690. Checks if f is a file that should be read directly. It should be read
  691. directly if it is backed by a real file (has a fileno) and is not a
  692. a compressed file (e.g. gzip)
  693. """
  694. if _is_compressed_file(f):
  695. return False
  696. try:
  697. return f.fileno() >= 0
  698. except io.UnsupportedOperation:
  699. return False
  700. except AttributeError:
  701. return False
  702. def _check_seekable(f) -> bool:
  703. def raise_err_msg(patterns, e):
  704. for p in patterns:
  705. if p in str(e):
  706. msg = (
  707. str(e)
  708. + ". You can only torch.load from a file that is seekable."
  709. + " Please pre-load the data into a buffer like io.BytesIO and"
  710. + " try to load from it instead."
  711. )
  712. raise type(e)(msg)
  713. raise e
  714. try:
  715. f.seek(f.tell())
  716. return True
  717. except (io.UnsupportedOperation, AttributeError) as e:
  718. raise_err_msg(["seek", "tell"], e)
  719. return False
  720. def _check_dill_version(pickle_module) -> None:
  721. """Checks if using dill as the pickle module, and if so, checks if it is the correct version.
  722. If dill version is lower than 0.3.1, a ValueError is raised.
  723. Args:
  724. pickle_module: module used for pickling metadata and objects
  725. """
  726. if pickle_module is not None and pickle_module.__name__ == "dill":
  727. required_dill_version = (0, 3, 1)
  728. if not check_module_version_greater_or_equal(
  729. pickle_module, required_dill_version, False
  730. ):
  731. raise ValueError(
  732. (
  733. "'torch' supports dill >= {}, but you have dill {}."
  734. " Please upgrade dill or switch to 'pickle'"
  735. ).format(
  736. ".".join([str(num) for num in required_dill_version]),
  737. pickle_module.__version__,
  738. )
  739. )
  740. def _check_save_filelike(f):
  741. if not _is_path(f) and not hasattr(f, "write"):
  742. raise AttributeError(
  743. "expected 'f' to be string, path, or a file-like object with "
  744. "a 'write' attribute"
  745. )
  746. def save(
  747. obj: object,
  748. f: FileLike,
  749. pickle_module: Any = pickle,
  750. pickle_protocol: int = DEFAULT_PROTOCOL,
  751. _use_new_zipfile_serialization: bool = True,
  752. _disable_byteorder_record: bool = False,
  753. ) -> None:
  754. # Reference: https://github.com/pytorch/pytorch/issues/54354
  755. # The first line of this docstring overrides the one Sphinx generates for the
  756. # documentation. We need it so that Sphinx doesn't leak `pickle`s path from
  757. # the build environment (e.g. `<module 'pickle' from '/leaked/path').
  758. """save(obj, f, pickle_module=pickle, pickle_protocol=2, _use_new_zipfile_serialization=True)
  759. Saves an object to a disk file.
  760. See also: :ref:`saving-loading-tensors`
  761. See :ref:`layout-control` for more advanced tools to manipulate a checkpoint.
  762. Args:
  763. obj: saved object
  764. f: a file-like object (has to implement write and flush) or a string or
  765. os.PathLike object containing a file name
  766. pickle_module: module used for pickling metadata and objects
  767. pickle_protocol: can be specified to override the default protocol
  768. .. note::
  769. A common PyTorch convention is to save tensors using .pt file extension.
  770. .. note::
  771. PyTorch preserves storage sharing across serialization. See
  772. :ref:`preserve-storage-sharing` for more details.
  773. .. note::
  774. The 1.6 release of PyTorch switched ``torch.save`` to use a new
  775. zipfile-based file format. ``torch.load`` still retains the ability to
  776. load files in the old format. If for any reason you want ``torch.save``
  777. to use the old format, pass the kwarg ``_use_new_zipfile_serialization=False``.
  778. Example:
  779. >>> # xdoctest: +SKIP("makes cwd dirty")
  780. >>> # Save to file
  781. >>> x = torch.tensor([0, 1, 2, 3, 4])
  782. >>> torch.save(x, "tensor.pt")
  783. >>> # Save to io.BytesIO buffer
  784. >>> buffer = io.BytesIO()
  785. >>> torch.save(x, buffer)
  786. """
  787. torch._C._log_api_usage_once("torch.save")
  788. _check_dill_version(pickle_module)
  789. _check_save_filelike(f)
  790. if isinstance(f, (str, os.PathLike)):
  791. f = os.fspath(f)
  792. if _use_new_zipfile_serialization:
  793. with _open_zipfile_writer(f) as opened_zipfile:
  794. _save(
  795. obj,
  796. opened_zipfile,
  797. pickle_module,
  798. pickle_protocol,
  799. _disable_byteorder_record,
  800. )
  801. return
  802. else:
  803. global _serialization_tls
  804. if _serialization_tls.skip_data:
  805. raise RuntimeError(
  806. "Cannot use skip_data=True with _use_new_zipfile_serialization=False"
  807. )
  808. with _open_file_like(f, "wb") as opened_file:
  809. _legacy_save(obj, opened_file, pickle_module, pickle_protocol)
  810. def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
  811. import torch.nn as nn
  812. serialized_container_types = {}
  813. serialized_storages: dict[str, tuple[torch.UntypedStorage, torch.dtype]] = {}
  814. # Since loading storages that view the same data with different dtypes is
  815. # not supported, we need to keep track of the dtype associated with each
  816. # storage data_ptr and throw an error if the dtype is ever different.
  817. # TODO: This feature could be added in the future
  818. storage_dtypes: dict[int, torch.dtype] = {}
  819. def persistent_id(obj: Any) -> tuple | None:
  820. # FIXME: the docs say that persistent_id should only return a string
  821. # but torch store returns tuples. This works only in the binary protocol
  822. # see
  823. # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
  824. # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
  825. if isinstance(obj, type) and issubclass(obj, nn.Module):
  826. if obj in serialized_container_types:
  827. return None
  828. serialized_container_types[obj] = True
  829. source_file = source = None
  830. try:
  831. source_lines, _, source_file = get_source_lines_and_file(obj)
  832. source = "".join(source_lines)
  833. except (
  834. Exception
  835. ): # saving the source is optional, so we can ignore any errors
  836. warnings.warn(
  837. "Couldn't retrieve source code for container of "
  838. "type " + obj.__name__ + ". It won't be checked "
  839. "for correctness upon loading.",
  840. stacklevel=2,
  841. )
  842. return ("module", obj, source_file, source)
  843. if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
  844. storage: torch.UntypedStorage
  845. if isinstance(obj, torch.storage.TypedStorage):
  846. # TODO: Once we decide to break serialization FC, this case
  847. # can be deleted
  848. storage = obj._untyped_storage
  849. storage_dtype = obj.dtype
  850. storage_type_str = obj._pickle_storage_type()
  851. storage_type = getattr(torch, storage_type_str)
  852. dtype = obj.dtype
  853. storage_numel = obj._size()
  854. elif isinstance(obj, torch.UntypedStorage):
  855. storage = obj
  856. storage_dtype = torch.uint8
  857. storage_type = normalize_storage_type(type(obj))
  858. dtype = torch.uint8
  859. storage_numel = storage.nbytes()
  860. else:
  861. raise TypeError(f"type not recognized: {type(obj)}")
  862. # If storage is allocated, ensure that any other saved storages
  863. # pointing to the same data all have the same dtype. If storage is
  864. # not allocated, don't perform this check
  865. if storage.data_ptr() != 0:
  866. if storage.data_ptr() in storage_dtypes:
  867. if storage_dtype != storage_dtypes[storage.data_ptr()]:
  868. raise RuntimeError(
  869. "Cannot save multiple tensors or storages that "
  870. "view the same data as different types"
  871. )
  872. else:
  873. storage_dtypes[storage.data_ptr()] = storage_dtype
  874. view_metadata: tuple[str, int, int] | None
  875. # Offset is always 0, but we keep it for backwards compatibility
  876. # with the old serialization format (which supported storage views)
  877. offset = 0
  878. storage_key = str(storage._cdata)
  879. location = location_tag(storage)
  880. # TODO: There's an issue here with FC. It might be impossible to
  881. # solve, but it's worth noting. Imagine we save a list `[storage,
  882. # tensor]`, where `tensor.storage()` is the same as `storage`, and
  883. # `tensor.element_size() > 1`. Let's say that `tensor.dtype ==
  884. # torch.float`. The storage will be serialized with element size
  885. # of 1, since we're choosing to serialize the first occurrence of
  886. # a duplicate storage. Since this legacy serialization format saves
  887. # the numel of the storage, rather than nbytes directly, we'll be
  888. # effectively saving nbytes in this case. We'll be able to load it
  889. # and the tensor back up with no problems in _this_ and future
  890. # versions of pytorch, but in older versions, here's the problem:
  891. # the storage will be loaded up as a UntypedStorage, and then the
  892. # FloatTensor will loaded and the UntypedStorage will be assigned to
  893. # it. Since the storage dtype does not match the tensor dtype, this
  894. # will cause an error. If we reverse the list, like `[tensor,
  895. # storage]`, then we will save the `tensor.storage()` as a faked
  896. # `FloatStorage`, and the saved size will be the correct
  897. # dtype-specific numel count that old versions expect. `tensor`
  898. # will be able to load up properly in old versions, pointing to
  899. # a FloatStorage. However, `storage` is still being translated to
  900. # a UntypedStorage, and it will try to resolve to the same
  901. # FloatStorage that `tensor` contains. This will also cause an
  902. # error. It doesn't seem like there's any way around this.
  903. # Probably, we just cannot maintain FC for the legacy format if the
  904. # saved list contains both a tensor and a storage that point to the
  905. # same data. We should still be able to maintain FC for lists of
  906. # just tensors, as long as all views share the same dtype as the
  907. # tensor they are viewing.
  908. if storage_key not in serialized_storages:
  909. serialized_storages[storage_key] = (storage, dtype)
  910. is_view = storage._cdata != storage._cdata
  911. if is_view:
  912. view_metadata = (str(storage._cdata), offset, storage.nbytes())
  913. else:
  914. view_metadata = None
  915. res = (
  916. "storage",
  917. storage_type,
  918. storage_key,
  919. location,
  920. storage_numel,
  921. view_metadata,
  922. )
  923. return res
  924. return None
  925. sys_info = {
  926. "protocol_version": PROTOCOL_VERSION,
  927. "little_endian": sys.byteorder == "little",
  928. "type_sizes": {
  929. "short": SHORT_SIZE,
  930. "int": INT_SIZE,
  931. "long": LONG_SIZE,
  932. },
  933. }
  934. pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol)
  935. pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol)
  936. pickle_module.dump(sys_info, f, protocol=pickle_protocol)
  937. class PyTorchLegacyPickler(pickle_module.Pickler):
  938. def persistent_id(self, obj):
  939. return persistent_id(obj) # noqa: F821
  940. pickler = PyTorchLegacyPickler(f, protocol=pickle_protocol)
  941. pickler.dump(obj)
  942. # The class def keeps the persistent_id closure alive, leaking memory.
  943. del persistent_id
  944. serialized_storage_keys = sorted(serialized_storages.keys())
  945. pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol)
  946. f.flush()
  947. for key in serialized_storage_keys:
  948. storage, dtype = serialized_storages[key]
  949. storage._write_file(
  950. f, _should_read_directly(f), True, torch._utils._element_size(dtype)
  951. )
  952. def _save(
  953. obj,
  954. zip_file,
  955. pickle_module,
  956. pickle_protocol,
  957. _disable_byteorder_record,
  958. ):
  959. serialized_storages: dict[str, torch.storage.UntypedStorage] = {}
  960. id_map: dict[int, str] = {}
  961. # Since loading storages that view the same data with different dtypes is
  962. # not supported, we need to keep track of the dtype associated with each
  963. # storage data_ptr and throw an error if the dtype is ever different.
  964. # TODO: This feature could be added in the future
  965. storage_dtypes: dict[int, torch.dtype] = {}
  966. def persistent_id(obj):
  967. # FIXME: the docs say that persistent_id should only return a string
  968. # but torch store returns tuples. This works only in the binary protocol
  969. # see
  970. # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
  971. # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
  972. if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
  973. if isinstance(obj, torch.storage.TypedStorage):
  974. # TODO: Once we decide to break serialization FC, this case
  975. # can be deleted
  976. storage = obj._untyped_storage
  977. storage_dtype = obj.dtype
  978. storage_type_str = obj._pickle_storage_type()
  979. storage_type = getattr(torch, storage_type_str)
  980. storage_numel = obj._size()
  981. else:
  982. storage = obj
  983. storage_dtype = torch.uint8
  984. storage_type = normalize_storage_type(type(obj))
  985. storage_numel = storage.nbytes()
  986. # If storage is allocated, ensure that any other saved storages
  987. # pointing to the same data all have the same dtype. If storage is
  988. # not allocated, don't perform this check
  989. if str(storage.device) != "meta" and storage.data_ptr() != 0:
  990. if storage.data_ptr() in storage_dtypes:
  991. if storage_dtype != storage_dtypes[storage.data_ptr()]:
  992. raise RuntimeError(
  993. "Cannot save multiple tensors or storages that "
  994. "view the same data as different types"
  995. )
  996. else:
  997. storage_dtypes[storage.data_ptr()] = storage_dtype
  998. storage_key = id_map.setdefault(storage._cdata, str(len(id_map)))
  999. if hasattr(obj, "_fake_device") and obj._fake_device is not None:
  1000. location = str(obj._fake_device)
  1001. else:
  1002. location = location_tag(storage)
  1003. serialized_storages[storage_key] = storage
  1004. return ("storage", storage_type, storage_key, location, storage_numel)
  1005. return None
  1006. # Write the pickle data for `obj`
  1007. data_buf = io.BytesIO()
  1008. class PyTorchPickler(pickle_module.Pickler): # type: ignore[name-defined]
  1009. def persistent_id(self, obj):
  1010. return persistent_id(obj) # noqa: F821
  1011. pickler = PyTorchPickler(data_buf, protocol=pickle_protocol)
  1012. pickler.dump(obj)
  1013. # The class def keeps the persistent_id closure alive, leaking memory.
  1014. del persistent_id
  1015. data_value = data_buf.getvalue()
  1016. zip_file.write_record("data.pkl", data_value, len(data_value))
  1017. # .format_version is used to track
  1018. # 1. version 1 represents the order of storages being changed from
  1019. # lexicographical based on keys to numerically ordered based on keys
  1020. # 2. version 2 represents including storage_alignment as a record
  1021. # within the zipfile
  1022. zip_file.write_record(".format_version", "1", len("1"))
  1023. storage_alignment = str(_get_storage_alignment())
  1024. zip_file.write_record(
  1025. ".storage_alignment", storage_alignment, len(storage_alignment)
  1026. )
  1027. # Write byte order marker
  1028. if not _disable_byteorder_record:
  1029. if sys.byteorder not in ["little", "big"]:
  1030. raise ValueError("Unknown endianness type: " + sys.byteorder)
  1031. zip_file.write_record("byteorder", sys.byteorder, len(sys.byteorder))
  1032. # Write each tensor to a file named tensor/the_tensor_key in the zip archive
  1033. for key in serialized_storages:
  1034. name = f"data/{key}"
  1035. storage = serialized_storages[key]
  1036. num_bytes = storage.nbytes()
  1037. global _serialization_tls
  1038. if _serialization_tls.skip_data:
  1039. zip_file.write_record_metadata(name, num_bytes)
  1040. else:
  1041. # given that we copy things around anyway, we might use storage.cpu()
  1042. # this means to that to get tensors serialized, you need to implement
  1043. # .cpu() on the underlying Storage
  1044. if storage.device.type != "cpu":
  1045. from torch.utils.serialization import config
  1046. if (
  1047. config.save.use_pinned_memory_for_d2h
  1048. and (
  1049. acc := torch.accelerator.current_accelerator(
  1050. check_available=True
  1051. )
  1052. )
  1053. is not None
  1054. and acc.type == storage.device.type
  1055. ):
  1056. new_storage = torch.empty(
  1057. num_bytes, dtype=torch.uint8, device="cpu", pin_memory=True
  1058. ).untyped_storage()
  1059. new_storage.copy_(storage)
  1060. torch.accelerator.current_stream(storage.device.index).synchronize()
  1061. storage = new_storage
  1062. else:
  1063. storage = storage.cpu()
  1064. # Now that it is on the CPU we can directly copy it into the zip file
  1065. zip_file.write_record(name, storage, num_bytes)
  1066. def load(
  1067. f: FileLike,
  1068. map_location: MAP_LOCATION = None,
  1069. pickle_module: Any = None,
  1070. *,
  1071. weights_only: bool | None = None,
  1072. mmap: bool | None = None,
  1073. **pickle_load_args: Any,
  1074. ) -> Any:
  1075. # Reference: https://github.com/pytorch/pytorch/issues/54354
  1076. # The first line of this docstring overrides the one Sphinx generates for the
  1077. # documentation. We need it so that Sphinx doesn't leak `pickle`s path from
  1078. # the build environment (e.g. `<module 'pickle' from '/leaked/path').
  1079. """load(f, map_location=None, pickle_module=pickle, *, weights_only=True, mmap=None, **pickle_load_args)
  1080. Loads an object saved with :func:`torch.save` from a file.
  1081. .. warning::
  1082. :func:`torch.load()` uses an unpickler under the hood. **Never load data from an untrusted source.**
  1083. See :ref:`weights-only-security` for more details.
  1084. :func:`torch.load` uses Python's unpickling facilities but treats storages,
  1085. which underlie tensors, specially. They are first deserialized on the
  1086. CPU and are then moved to the device they were saved from. If this fails
  1087. (e.g. because the run time system doesn't have certain devices), an exception
  1088. is raised. However, storages can be dynamically remapped to an alternative
  1089. set of devices using the :attr:`map_location` argument.
  1090. If :attr:`map_location` is a callable, it will be called once for each serialized
  1091. storage with two arguments: storage and location. The storage argument
  1092. will be the initial deserialization of the storage, residing on the CPU.
  1093. Each serialized storage has a location tag associated with it which
  1094. identifies the device it was saved from, and this tag is the second
  1095. argument passed to :attr:`map_location`. The builtin location tags are ``'cpu'``
  1096. for CPU tensors and ``'cuda:device_id'`` (e.g. ``'cuda:2'``) for CUDA tensors.
  1097. :attr:`map_location` should return either ``None`` or a storage. If
  1098. :attr:`map_location` returns a storage, it will be used as the final deserialized
  1099. object, already moved to the right device. Otherwise, :func:`torch.load` will
  1100. fall back to the default behavior, as if :attr:`map_location` wasn't specified.
  1101. If :attr:`map_location` is a :class:`torch.device` object or a string containing
  1102. a device tag, it indicates the location where all tensors should be loaded.
  1103. Otherwise, if :attr:`map_location` is a dict, it will be used to remap location tags
  1104. appearing in the file (keys), to ones that specify where to put the
  1105. storages (values).
  1106. User extensions can register their own location tags and tagging and
  1107. deserialization methods using :func:`torch.serialization.register_package`.
  1108. See :ref:`layout-control` for more advanced tools to manipulate a checkpoint.
  1109. Args:
  1110. f: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`),
  1111. or a string or os.PathLike object containing a file name
  1112. map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage
  1113. locations
  1114. pickle_module: module used for unpickling metadata and objects (has to
  1115. match the :attr:`pickle_module` used to serialize file)
  1116. weights_only: Indicates whether unpickler should be restricted to
  1117. loading only tensors, primitive types, dictionaries
  1118. and any types added via :func:`torch.serialization.add_safe_globals`.
  1119. See :ref:`weights-only` for more details.
  1120. mmap: Indicates whether the file should be mapped rather than loading all the storages into memory.
  1121. Typically, tensor storages in the file will first be moved from disk to CPU memory, after which they
  1122. are moved to the location that they were tagged with when saving, or specified by ``map_location``. This
  1123. second step is a no-op if the final location is CPU. When the ``mmap`` flag is set, instead of copying the
  1124. tensor storages from disk to CPU memory in the first step, ``f`` is mapped, which means tensor storages
  1125. will be lazily loaded when their data is accessed.
  1126. pickle_load_args: (Python 3 only) optional keyword arguments passed over to
  1127. :func:`pickle_module.load` and :func:`pickle_module.Unpickler`,
  1128. only works if :attr:`weights_only=False`, e.g., :attr:`errors=...`.
  1129. .. note::
  1130. When you call :func:`torch.load()` on a file which contains GPU tensors, those tensors
  1131. will be loaded to GPU by default. You can call ``torch.load(.., map_location='cpu')``
  1132. and then :meth:`load_state_dict` to avoid GPU RAM surge when loading a model checkpoint.
  1133. .. note::
  1134. By default, we decode byte strings as ``utf-8``. This is to avoid a common error
  1135. case ``UnicodeDecodeError: 'ascii' codec can't decode byte 0x...``
  1136. when loading files saved by Python 2 in Python 3. If this default
  1137. is incorrect, you may use an extra :attr:`encoding` keyword argument to specify how
  1138. these objects should be loaded, e.g., :attr:`encoding='latin1'` decodes them
  1139. to strings using ``latin1`` encoding, and :attr:`encoding='bytes'` keeps them
  1140. as byte arrays which can be decoded later with ``byte_array.decode(...)``.
  1141. Example:
  1142. >>> # xdoctest: +SKIP("undefined filepaths")
  1143. >>> torch.load("tensors.pt", weights_only=True)
  1144. # Load all tensors onto the CPU
  1145. >>> torch.load(
  1146. ... "tensors.pt",
  1147. ... map_location=torch.device("cpu"),
  1148. ... weights_only=True,
  1149. ... )
  1150. # Load all tensors onto the CPU, using a function
  1151. >>> torch.load(
  1152. ... "tensors.pt",
  1153. ... map_location=lambda storage, loc: storage,
  1154. ... weights_only=True,
  1155. ... )
  1156. # Load all tensors onto GPU 1
  1157. >>> torch.load(
  1158. ... "tensors.pt",
  1159. ... map_location=lambda storage, loc: storage.cuda(1), # type: ignore[attr-defined]
  1160. ... weights_only=True,
  1161. ... ) # type: ignore[attr-defined]
  1162. # Map tensors from GPU 1 to GPU 0
  1163. >>> torch.load(
  1164. ... "tensors.pt",
  1165. ... map_location={"cuda:1": "cuda:0"},
  1166. ... weights_only=True,
  1167. ... )
  1168. # Load tensor from io.BytesIO object
  1169. # Loading from a buffer setting weights_only=False, warning this can be unsafe
  1170. >>> with open("tensor.pt", "rb") as f:
  1171. ... buffer = io.BytesIO(f.read())
  1172. >>> torch.load(buffer, weights_only=False)
  1173. # Load a module with 'ascii' encoding for unpickling
  1174. # Loading from a module setting weights_only=False, warning this can be unsafe
  1175. >>> torch.load("module.pt", encoding="ascii", weights_only=False)
  1176. """
  1177. torch._C._log_api_usage_once("torch.load")
  1178. DOCS_MESSAGE = (
  1179. "\n\nCheck the documentation of torch.load to learn more about types accepted by default with "
  1180. "weights_only https://pytorch.org/docs/stable/generated/torch.load.html."
  1181. )
  1182. def _get_wo_message(message: str) -> str:
  1183. unsafe_global_pattern = r"GLOBAL (\S+) was not an allowed global by default."
  1184. has_unsafe_global = re.search(unsafe_global_pattern, message) is not None
  1185. blocklist_pattern = r"whose module (\S+) is blocked"
  1186. has_blocklist = re.search(blocklist_pattern, message) is not None
  1187. import_pattern = r"(\S+) must be (\S+) to load"
  1188. has_import = re.search(import_pattern, message) is not None
  1189. if has_unsafe_global:
  1190. updated_message = (
  1191. "Weights only load failed. This file can still be loaded, to do so you have two options, "
  1192. "\033[1mdo those steps only if you trust the source of the checkpoint\033[0m. "
  1193. f"\n\t(1) {UNSAFE_MESSAGE}\n\t(2) Alternatively, to load with `weights_only=True` please check "
  1194. "the recommended steps in the following error message.\n\tWeightsUnpickler error: "
  1195. + message
  1196. )
  1197. else:
  1198. if has_import:
  1199. return f"Weights only load failed. {message}\n {UNSAFE_MESSAGE}\n"
  1200. else:
  1201. updated_message = f"Weights only load failed. {UNSAFE_MESSAGE}\n"
  1202. if not has_blocklist:
  1203. updated_message += (
  1204. "Please file an issue with the following so that we can make "
  1205. "`weights_only=True` compatible with your use case: WeightsUnpickler error: "
  1206. )
  1207. updated_message += "\n\n" + message
  1208. return updated_message + DOCS_MESSAGE
  1209. weights_only_not_set = weights_only is None
  1210. if weights_only_not_set:
  1211. weights_only = _default_to_weights_only(pickle_module)
  1212. true_values = ["1", "y", "yes", "true"]
  1213. # Add ability to force safe only or non-safe weight loads via environment variables
  1214. force_weights_only_load = (
  1215. os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0") in true_values
  1216. )
  1217. force_no_weights_only_load = (
  1218. os.getenv("TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD", "0") in true_values
  1219. )
  1220. if force_weights_only_load and force_no_weights_only_load:
  1221. raise RuntimeError(
  1222. "Only one of `TORCH_FORCE_WEIGHTS_ONLY_LOAD` or `TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD` "
  1223. "should be set, but both were set."
  1224. )
  1225. elif force_weights_only_load:
  1226. weights_only = True
  1227. elif force_no_weights_only_load:
  1228. # TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD can only override if callsite did not explicitly set weights_only
  1229. if weights_only_not_set:
  1230. warnings.warn(
  1231. "Environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD detected, since the"
  1232. "`weights_only` argument was not explicitly passed to `torch.load`, forcing weights_only=False.",
  1233. UserWarning,
  1234. stacklevel=2,
  1235. )
  1236. weights_only = False
  1237. if weights_only:
  1238. if pickle_module is not None:
  1239. raise RuntimeError(
  1240. "Can not safely load weights when explicit pickle_module is specified"
  1241. )
  1242. else:
  1243. if pickle_module is None:
  1244. pickle_module = pickle
  1245. if pickle_load_args != {} and weights_only:
  1246. warnings.warn("pickle_load_args only works if `weights_only=False`.")
  1247. # make flipping default BC-compatible
  1248. if mmap is None:
  1249. from torch.utils.serialization import config
  1250. mmap = config.load.mmap
  1251. _check_dill_version(pickle_module)
  1252. if "encoding" not in pickle_load_args:
  1253. pickle_load_args["encoding"] = "utf-8"
  1254. with _open_file_like(f, "rb") as opened_file:
  1255. if _is_zipfile(opened_file):
  1256. # The zipfile reader is going to advance the current file position.
  1257. # If we want to actually tail call to torch.jit.load, we need to
  1258. # reset back to the original position.
  1259. orig_position = opened_file.tell()
  1260. overall_storage = None
  1261. with _open_zipfile_reader(opened_file) as opened_zipfile:
  1262. if _is_torchscript_zip(opened_zipfile):
  1263. warnings.warn(
  1264. "'torch.load' received a zip file that looks like a TorchScript archive"
  1265. " dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to"
  1266. " silence this warning)",
  1267. UserWarning,
  1268. stacklevel=2,
  1269. )
  1270. if weights_only:
  1271. raise RuntimeError(
  1272. "Cannot use ``weights_only=True`` with TorchScript archives passed to "
  1273. "``torch.load``. " + UNSAFE_MESSAGE
  1274. )
  1275. opened_file.seek(orig_position)
  1276. return torch.jit.load(opened_file, map_location=map_location)
  1277. if mmap:
  1278. if not _is_path(f):
  1279. raise ValueError(
  1280. "f must be a file path in order to use the mmap argument"
  1281. )
  1282. size = os.path.getsize(f)
  1283. if not IS_WINDOWS:
  1284. shared = get_default_mmap_options() == MAP_SHARED
  1285. else:
  1286. shared = False
  1287. overall_storage = torch.UntypedStorage.from_file(
  1288. os.fspath(f),
  1289. shared,
  1290. size,
  1291. )
  1292. if weights_only:
  1293. try:
  1294. return _load(
  1295. opened_zipfile,
  1296. map_location,
  1297. _weights_only_unpickler,
  1298. overall_storage=overall_storage,
  1299. **pickle_load_args,
  1300. )
  1301. except pickle.UnpicklingError as e:
  1302. raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
  1303. return _load(
  1304. opened_zipfile,
  1305. map_location,
  1306. pickle_module,
  1307. overall_storage=overall_storage,
  1308. **pickle_load_args,
  1309. )
  1310. if mmap:
  1311. f_name = "" if not isinstance(f, str) else f"{f}, "
  1312. raise RuntimeError(
  1313. "mmap can only be used with files saved with "
  1314. f"`torch.save({f_name}_use_new_zipfile_serialization=True), "
  1315. "please torch.save your checkpoint with this option in order to use mmap."
  1316. )
  1317. if weights_only:
  1318. try:
  1319. return _legacy_load(
  1320. opened_file,
  1321. map_location,
  1322. _weights_only_unpickler,
  1323. **pickle_load_args,
  1324. )
  1325. except pickle.UnpicklingError as e:
  1326. raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
  1327. return _legacy_load(
  1328. opened_file, map_location, pickle_module, **pickle_load_args
  1329. )
  1330. # Register pickling support for layout instances such as
  1331. # torch.sparse_coo, etc
  1332. def _get_layout(name):
  1333. """Get layout extension object from its string representation."""
  1334. cache = _get_layout.cache # type: ignore[attr-defined]
  1335. if not cache:
  1336. for v in torch.__dict__.values():
  1337. if isinstance(v, torch.layout):
  1338. cache[str(v)] = v
  1339. return cache[name]
  1340. # There are yet not good way to type annotate function attributes https://github.com/python/mypy/issues/2087
  1341. _get_layout.cache = {} # type: ignore[attr-defined]
  1342. copyreg.pickle(torch.layout, lambda obj: (_get_layout, (str(obj),)))
  1343. def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
  1344. deserialized_objects: dict[int, Any] = {}
  1345. restore_location = _get_restore_location(map_location)
  1346. class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined]
  1347. def find_class(self, mod_name, name):
  1348. if type(name) is str and "Storage" in name:
  1349. try:
  1350. return StorageType(name)
  1351. except KeyError:
  1352. pass
  1353. return super().find_class(mod_name, name)
  1354. def _check_container_source(container_type, source_file, original_source):
  1355. try:
  1356. current_source = "".join(get_source_lines_and_file(container_type)[0])
  1357. except Exception: # saving the source is optional, so we can ignore any errors
  1358. warnings.warn(
  1359. "Couldn't retrieve source code for container of "
  1360. "type " + container_type.__name__ + ". It won't be checked "
  1361. "for correctness upon loading.",
  1362. stacklevel=2,
  1363. )
  1364. return
  1365. if original_source != current_source:
  1366. if container_type.dump_patches:
  1367. file_name = container_type.__name__ + ".patch"
  1368. diff = difflib.unified_diff(
  1369. current_source.split("\n"),
  1370. original_source.split("\n"),
  1371. source_file,
  1372. source_file,
  1373. lineterm="",
  1374. )
  1375. lines = "\n".join(diff)
  1376. try:
  1377. with open(file_name, "a+") as f:
  1378. file_size = f.seek(0, 2)
  1379. f.seek(0)
  1380. if file_size == 0:
  1381. f.write(lines)
  1382. elif file_size != len(lines) or f.read() != lines:
  1383. raise OSError
  1384. msg = (
  1385. "Saved a reverse patch to " + file_name + ". "
  1386. "Run `patch -p0 < " + file_name + "` to revert your "
  1387. "changes."
  1388. )
  1389. except OSError:
  1390. msg = (
  1391. "Tried to save a patch, but couldn't create a "
  1392. "writable file " + file_name + ". Make sure it "
  1393. "doesn't exist and your working directory is "
  1394. "writable."
  1395. )
  1396. else:
  1397. msg = (
  1398. "you can retrieve the original source code by "
  1399. "accessing the object's source attribute or set "
  1400. "`torch.nn.Module.dump_patches = True` and use the "
  1401. "patch tool to revert the changes."
  1402. )
  1403. msg = f"source code of class '{torch.typename(container_type)}' has changed. {msg}"
  1404. warnings.warn(msg, SourceChangeWarning, stacklevel=2)
  1405. def legacy_load(f):
  1406. deserialized_objects: dict[int, Any] = {}
  1407. def persistent_load(saved_id):
  1408. if isinstance(saved_id, tuple):
  1409. # Ignore containers that don't have any sources saved
  1410. if all(saved_id[1:]):
  1411. _check_container_source(*saved_id)
  1412. return saved_id[0]
  1413. return deserialized_objects[int(saved_id)]
  1414. with (
  1415. closing(
  1416. tarfile.open(fileobj=f, mode="r:", format=tarfile.PAX_FORMAT)
  1417. ) as tar,
  1418. mkdtemp() as tmpdir,
  1419. ):
  1420. if pickle_module is _weights_only_unpickler:
  1421. raise RuntimeError(
  1422. "Cannot use ``weights_only=True`` with files saved in the "
  1423. "legacy .tar format. " + UNSAFE_MESSAGE
  1424. )
  1425. tar.extract("storages", path=tmpdir)
  1426. with open(os.path.join(tmpdir, "storages"), "rb", 0) as f:
  1427. num_storages = pickle_module.load(f, **pickle_load_args)
  1428. for _ in range(num_storages):
  1429. args = pickle_module.load(f, **pickle_load_args)
  1430. key, location, storage_type = args
  1431. dtype = storage_type._dtype
  1432. obj = cast(Storage, torch.UntypedStorage)._new_with_file(
  1433. f, torch._utils._element_size(dtype)
  1434. )
  1435. obj = restore_location(obj, location)
  1436. # TODO: Once we decide to break serialization FC, we can
  1437. # stop wrapping with TypedStorage
  1438. deserialized_objects[key] = torch.storage.TypedStorage(
  1439. wrap_storage=obj, dtype=dtype, _internal=True
  1440. )
  1441. storage_views = pickle_module.load(f, **pickle_load_args)
  1442. for target_cdata, root_cdata, offset, numel in storage_views:
  1443. root = deserialized_objects[root_cdata]
  1444. element_size = torch._utils._element_size(root.dtype)
  1445. offset_bytes = offset * element_size
  1446. # TODO: Once we decide to break serialization FC, we can
  1447. # stop wrapping with TypedStorage
  1448. deserialized_objects[target_cdata] = torch.storage.TypedStorage(
  1449. wrap_storage=root._untyped_storage[
  1450. offset_bytes : offset_bytes + numel * element_size
  1451. ],
  1452. dtype=root.dtype,
  1453. _internal=True,
  1454. )
  1455. tar.extract("tensors", path=tmpdir)
  1456. with open(os.path.join(tmpdir, "tensors"), "rb", 0) as f:
  1457. num_tensors = pickle_module.load(f, **pickle_load_args)
  1458. for _ in range(num_tensors):
  1459. args = pickle_module.load(f, **pickle_load_args)
  1460. key, storage_id, _original_tensor_type = args
  1461. storage = deserialized_objects[storage_id]
  1462. (ndim,) = struct.unpack("<i", f.read(4))
  1463. # skip next 4 bytes; legacy encoding treated ndim as 8 bytes
  1464. f.read(4)
  1465. numel = struct.unpack(f"<{ndim}q", f.read(8 * ndim))
  1466. stride = struct.unpack(f"<{ndim}q", f.read(8 * ndim))
  1467. (storage_offset,) = struct.unpack("<q", f.read(8))
  1468. tensor = torch.empty((0,), dtype=storage.dtype).set_(
  1469. storage._untyped_storage, storage_offset, numel, stride
  1470. )
  1471. deserialized_objects[key] = tensor
  1472. pickle_file = tar.extractfile("pickle")
  1473. unpickler = UnpicklerWrapper(pickle_file, **pickle_load_args)
  1474. unpickler.persistent_load = persistent_load
  1475. result = unpickler.load()
  1476. return result
  1477. deserialized_objects = {}
  1478. def persistent_load(saved_id):
  1479. if not isinstance(saved_id, tuple):
  1480. raise AssertionError(
  1481. f"saved_id must be a tuple, got {type(saved_id).__name__}"
  1482. )
  1483. typename = _maybe_decode_ascii(saved_id[0])
  1484. data = saved_id[1:]
  1485. if typename == "module":
  1486. # Ignore containers that don't have any sources saved
  1487. if all(data[1:]):
  1488. _check_container_source(*data)
  1489. return data[0]
  1490. elif typename == "storage":
  1491. storage_type, root_key, location, numel, view_metadata = data
  1492. location = _maybe_decode_ascii(location)
  1493. dtype = storage_type.dtype
  1494. nbytes = numel * torch._utils._element_size(dtype)
  1495. if root_key not in deserialized_objects:
  1496. if torch._guards.active_fake_mode() is not None:
  1497. obj = cast(Storage, torch.UntypedStorage(nbytes, device="meta"))
  1498. elif _serialization_tls.skip_data:
  1499. obj = cast(Storage, torch.UntypedStorage(nbytes))
  1500. obj = restore_location(obj, location)
  1501. else:
  1502. obj = cast(Storage, torch.UntypedStorage(nbytes))
  1503. obj._torch_load_uninitialized = True
  1504. obj = restore_location(obj, location)
  1505. # TODO: Once we decide to break serialization FC, we can
  1506. # stop wrapping with TypedStorage
  1507. typed_storage = torch.storage.TypedStorage(
  1508. wrap_storage=obj, dtype=dtype, _internal=True
  1509. )
  1510. deserialized_objects[root_key] = typed_storage
  1511. else:
  1512. typed_storage = deserialized_objects[root_key]
  1513. if typed_storage._data_ptr() == 0:
  1514. typed_storage = torch.storage.TypedStorage(
  1515. device=typed_storage._untyped_storage.device,
  1516. dtype=dtype,
  1517. _internal=True,
  1518. )
  1519. if view_metadata is not None:
  1520. view_key, offset, view_size = view_metadata
  1521. offset_bytes = offset * torch._utils._element_size(dtype)
  1522. view_size_bytes = view_size * torch._utils._element_size(dtype)
  1523. if view_key not in deserialized_objects:
  1524. # TODO: Once we decide to break serialization FC, we can
  1525. # stop wrapping with TypedStorage
  1526. deserialized_objects[view_key] = torch.storage.TypedStorage(
  1527. wrap_storage=typed_storage._untyped_storage[
  1528. offset_bytes : offset_bytes + view_size_bytes
  1529. ],
  1530. dtype=dtype,
  1531. _internal=True,
  1532. )
  1533. res = deserialized_objects[view_key]
  1534. else:
  1535. res = typed_storage
  1536. return res
  1537. else:
  1538. raise RuntimeError(f"Unknown saved id type: {saved_id[0]}")
  1539. _check_seekable(f)
  1540. f_should_read_directly = _should_read_directly(f)
  1541. if f_should_read_directly and f.tell() == 0:
  1542. # legacy_load requires that f has fileno()
  1543. # only if offset is zero we can attempt the legacy tar file loader
  1544. try:
  1545. return legacy_load(f)
  1546. except tarfile.TarError:
  1547. if _is_zipfile(f):
  1548. # .zip is used for torch.jit.save and will throw an un-pickling error here
  1549. raise RuntimeError(
  1550. f"{f.name} is a zip archive (did you mean to use torch.jit.load()?)"
  1551. ) from None
  1552. # if not a tarfile, reset file offset and proceed
  1553. f.seek(0)
  1554. magic_number = pickle_module.load(f, **pickle_load_args)
  1555. if magic_number != MAGIC_NUMBER:
  1556. raise RuntimeError("Invalid magic number; corrupt file?")
  1557. protocol_version = pickle_module.load(f, **pickle_load_args)
  1558. if protocol_version != PROTOCOL_VERSION:
  1559. raise RuntimeError(f"Invalid protocol version: {protocol_version}")
  1560. _sys_info = pickle_module.load(f, **pickle_load_args)
  1561. unpickler = UnpicklerWrapper(f, **pickle_load_args)
  1562. unpickler.persistent_load = persistent_load
  1563. result = unpickler.load()
  1564. deserialized_storage_keys = pickle_module.load(f, **pickle_load_args)
  1565. if torch._guards.active_fake_mode() is None and not _serialization_tls.skip_data:
  1566. offset = f.tell() if f_should_read_directly else None
  1567. for key in deserialized_storage_keys:
  1568. if key not in deserialized_objects:
  1569. raise AssertionError(
  1570. f"storage key {key!r} not found in deserialized_objects"
  1571. )
  1572. typed_storage = deserialized_objects[key]
  1573. typed_storage._untyped_storage._set_from_file(
  1574. f,
  1575. offset,
  1576. f_should_read_directly,
  1577. torch._utils._element_size(typed_storage.dtype),
  1578. )
  1579. if offset is not None:
  1580. offset = f.tell()
  1581. torch._utils._validate_loaded_sparse_tensors()
  1582. return result
  1583. def _maybe_decode_ascii(bytes_str: bytes | str) -> str:
  1584. # When using encoding='bytes' in Py3, some **internal** keys stored as
  1585. # strings in Py2 are loaded as bytes. This function decodes them with
  1586. # ascii encoding, one that Py3 uses by default.
  1587. #
  1588. # NOTE: This should only be used on internal keys (e.g., `typename` and
  1589. # `location` in `persistent_load` below!
  1590. if isinstance(bytes_str, bytes):
  1591. return bytes_str.decode("ascii")
  1592. return bytes_str
  1593. def _get_restore_location(map_location):
  1594. if map_location is None:
  1595. restore_location = default_restore_location
  1596. elif isinstance(map_location, dict):
  1597. def restore_location(storage, location):
  1598. location = map_location.get(location, location)
  1599. return default_restore_location(storage, location)
  1600. elif isinstance(map_location, (str, bytes)):
  1601. def restore_location(storage, location):
  1602. return default_restore_location(storage, map_location)
  1603. elif isinstance(map_location, torch.device):
  1604. def restore_location(storage, location):
  1605. return default_restore_location(storage, str(map_location))
  1606. else:
  1607. def restore_location(storage, location):
  1608. result = map_location(storage, location)
  1609. if result is None:
  1610. result = default_restore_location(storage, location)
  1611. return result
  1612. return restore_location
  1613. class StorageType:
  1614. def __init__(self, name):
  1615. self._dtype = _get_dtype_from_pickle_storage_type(name)
  1616. @property
  1617. def dtype(self):
  1618. return self._dtype
  1619. def __str__(self):
  1620. return f"StorageType(dtype={self.dtype})"
  1621. def _load(
  1622. zip_file,
  1623. map_location,
  1624. pickle_module,
  1625. pickle_file="data.pkl",
  1626. overall_storage=None,
  1627. **pickle_load_args,
  1628. ):
  1629. restore_location = _get_restore_location(map_location)
  1630. loaded_storages = {}
  1631. is_meta_map_location = _is_meta_location(map_location)
  1632. can_calculate_storage_offsets = False
  1633. if zip_file.has_record(".format_version"):
  1634. version = zip_file.get_record(".format_version")
  1635. can_calculate_storage_offsets = version >= b"1"
  1636. # check if byteswapping is needed
  1637. byteordername = "byteorder"
  1638. byteorderdata = None
  1639. if zip_file.has_record(byteordername):
  1640. byteorderdata = zip_file.get_record(byteordername)
  1641. if byteorderdata not in [b"little", b"big"]:
  1642. raise ValueError("Unknown endianness type: " + byteorderdata.decode())
  1643. elif (
  1644. get_default_load_endianness() == LoadEndianness.LITTLE
  1645. or get_default_load_endianness() is None
  1646. ):
  1647. byteorderdata = b"little"
  1648. elif get_default_load_endianness() == LoadEndianness.BIG:
  1649. byteorderdata = b"big"
  1650. elif get_default_load_endianness() == LoadEndianness.NATIVE:
  1651. pass
  1652. else:
  1653. raise ValueError("Invalid load endianness type")
  1654. storage_alignment = 64
  1655. if zip_file.has_record(".storage_alignment"):
  1656. storage_alignment = int(zip_file.get_record(".storage_alignment"))
  1657. if (
  1658. not zip_file.has_record(byteordername)
  1659. and get_default_load_endianness() is None
  1660. and sys.byteorder == "big"
  1661. ):
  1662. # Default behaviour was changed
  1663. # See https://github.com/pytorch/pytorch/issues/101688
  1664. warnings.warn(
  1665. "The default load endianness for checkpoints without a byteorder mark "
  1666. "on big endian machines was changed from 'native' to 'little' endian, "
  1667. "to avoid this behavior please use "
  1668. "torch.serialization.set_default_load_endianness to set "
  1669. "the desired default load endianness",
  1670. UserWarning,
  1671. stacklevel=2,
  1672. )
  1673. from torch.utils.serialization import config
  1674. calculate_storage_offsets = config.load.calculate_storage_offsets
  1675. run_debug_asserts = os.environ.get("TORCH_SERIALIZATION_DEBUG", "0") == "1"
  1676. current_offset = None
  1677. # constants from miniz.h/miniz.c
  1678. data_descripter_size64 = 24
  1679. data_descripter_size32 = 16
  1680. mz_uint32_max = 0xFFFFFFFF
  1681. offsets: dict[str, int] = dict()
  1682. def _get_offset(key, name, numel):
  1683. """
  1684. Return the offset of the storage associated with key with record name `name` and size numel.
  1685. It is expected that the zipfile header of this storage starts at current_offset.
  1686. WARNING: This function relies on the behavior of the zipwriter in miniz.c. In particular,
  1687. the behavior of `mz_zip_writer_add_mem_ex_v2`. The behavior of this function must be kept
  1688. in sync with that of miniz!
  1689. After reading a storage of size numel that starts at storage_offset
  1690. if it is the first time that storage was read, update nonlocal variable
  1691. current_offset to the start of the next zipfile header by incrementing
  1692. it by numel and the data descriptor size.
  1693. """
  1694. nonlocal current_offset, offsets
  1695. if name in offsets:
  1696. storage_offset = offsets[name]
  1697. return storage_offset
  1698. if current_offset is None:
  1699. if key != "0":
  1700. raise AssertionError(f"expected key '0', got {key!r}")
  1701. current_offset = zip_file.get_record_offset(name)
  1702. local_header_offset = zip_file.get_record_header_offset(name)
  1703. storage_offset = current_offset
  1704. else:
  1705. storage_offset = zip_file.get_record_offset_no_read(
  1706. current_offset, name, numel, storage_alignment
  1707. )
  1708. local_header_offset = current_offset
  1709. # This is only actually needed for storages that have typed_storage._data_ptr() == 0
  1710. # after being read. Otherwise persistent_load would never "re-call" load_tensor
  1711. # for a given key.
  1712. offsets[name] = storage_offset
  1713. # Increment current_offset to offset where next zipfile header starts
  1714. current_offset = storage_offset + numel
  1715. # add size of data descriptor after payload
  1716. if numel > 0:
  1717. if local_header_offset >= mz_uint32_max or numel >= mz_uint32_max:
  1718. current_offset += data_descripter_size64
  1719. else:
  1720. current_offset += data_descripter_size32
  1721. return storage_offset
  1722. def load_tensor(dtype, nbytes, key, location):
  1723. name = f"data/{key}"
  1724. if torch._guards.detect_fake_mode(None) is not None or is_meta_map_location:
  1725. storage = torch.UntypedStorage(nbytes, device="meta")
  1726. if can_calculate_storage_offsets:
  1727. storage._checkpoint_offset = _get_offset(key, name, nbytes)
  1728. else:
  1729. storage._checkpoint_offset = zip_file.get_record_offset(name)
  1730. elif _serialization_tls.skip_data:
  1731. storage = torch.UntypedStorage(nbytes)
  1732. elif overall_storage is not None:
  1733. if can_calculate_storage_offsets and calculate_storage_offsets:
  1734. storage_offset = _get_offset(key, name, nbytes)
  1735. if run_debug_asserts:
  1736. if storage_offset != zip_file.get_record_offset(name):
  1737. raise RuntimeError(
  1738. "This is a debug assert that was run as the `TORCH_SERIALIZATION_DEBUG` environment "
  1739. f"variable was set: Incorrect offset for {name}, got {storage_offset} expected "
  1740. f"{zip_file.get_record_offset(name)}"
  1741. )
  1742. else:
  1743. storage_offset = zip_file.get_record_offset(name)
  1744. storage = overall_storage[storage_offset : storage_offset + nbytes]
  1745. else:
  1746. if can_calculate_storage_offsets and run_debug_asserts:
  1747. # This is debug code that we use to test the validity of
  1748. # torch.utils.serialization.config.load.calculate_storage_offsets throughout CI
  1749. storage_offset = _get_offset(key, name, nbytes)
  1750. if storage_offset != zip_file.get_record_offset(name):
  1751. raise RuntimeError(
  1752. "This is a debug assert that was run as the `TORCH_SERIALIZATION_DEBUG` environment "
  1753. f"variable was set: Incorrect offset for {name}, got {storage_offset} expected "
  1754. f"{zip_file.get_record_offset(name)}"
  1755. )
  1756. storage = (
  1757. zip_file.get_storage_from_record(name, nbytes, torch.UntypedStorage)
  1758. ._typed_storage()
  1759. ._untyped_storage
  1760. )
  1761. # swap here if byteswapping is needed
  1762. if byteorderdata is not None:
  1763. if byteorderdata.decode() != sys.byteorder:
  1764. storage.byteswap(dtype)
  1765. # TODO: Once we decide to break serialization FC, we can
  1766. # stop wrapping with TypedStorage
  1767. if is_meta_map_location:
  1768. # Skip restore_location for meta map_location. Since we already created
  1769. # a meta storage above, calling restore_location would just redundantly
  1770. # call _meta_deserialize which creates another meta storage with the same
  1771. # size.
  1772. wrap_storage = storage
  1773. elif torch._guards.detect_fake_mode(None) is None:
  1774. wrap_storage = restore_location(storage, location)
  1775. else:
  1776. storage._fake_device = location
  1777. wrap_storage = storage
  1778. typed_storage = torch.storage.TypedStorage(
  1779. wrap_storage=wrap_storage,
  1780. dtype=dtype,
  1781. _internal=True,
  1782. )
  1783. if typed_storage._data_ptr() != 0:
  1784. loaded_storages[key] = typed_storage
  1785. return typed_storage
  1786. def persistent_load(saved_id):
  1787. if not isinstance(saved_id, tuple):
  1788. raise AssertionError(
  1789. f"saved_id must be a tuple, got {type(saved_id).__name__}"
  1790. )
  1791. typename = _maybe_decode_ascii(saved_id[0])
  1792. data = saved_id[1:]
  1793. if typename != "storage":
  1794. raise AssertionError(
  1795. f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
  1796. )
  1797. storage_type, key, location, numel = data
  1798. if storage_type is torch.UntypedStorage:
  1799. dtype = torch.uint8
  1800. else:
  1801. dtype = storage_type.dtype
  1802. if key in loaded_storages:
  1803. typed_storage = loaded_storages[key]
  1804. else:
  1805. nbytes = numel * torch._utils._element_size(dtype)
  1806. typed_storage = load_tensor(
  1807. dtype, nbytes, key, _maybe_decode_ascii(location)
  1808. )
  1809. return typed_storage
  1810. load_module_mapping: dict[str, str] = {
  1811. # See https://github.com/pytorch/pytorch/pull/51633
  1812. "torch.tensor": "torch._tensor"
  1813. }
  1814. # Need to subclass Unpickler instead of directly monkey-patching the find_class method
  1815. # because it's marked readonly in pickle.
  1816. # The type: ignore is because mypy can't statically determine the type of this class.
  1817. class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined]
  1818. # from https://stackoverflow.com/questions/13398462/unpickling-python-objects-with-a-changed-module-path/13405732
  1819. # Lets us override the imports that pickle uses when unpickling an object.
  1820. # This is useful for maintaining BC if we change a module path that tensor instantiation relies on.
  1821. def find_class(self, mod_name, name):
  1822. if type(name) is str and "Storage" in name:
  1823. try:
  1824. return StorageType(name)
  1825. except KeyError:
  1826. pass
  1827. mod_name = load_module_mapping.get(mod_name, mod_name)
  1828. return super().find_class(mod_name, name)
  1829. # Load the data (which may in turn use `persistent_load` to load tensors)
  1830. data_file = io.BytesIO(zip_file.get_record(pickle_file))
  1831. unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
  1832. unpickler.persistent_load = persistent_load
  1833. # Needed for tensors where storage device and rebuild tensor device are
  1834. # not connected (wrapper subclasses and tensors rebuilt using numpy)
  1835. global _serialization_tls
  1836. _serialization_tls.map_location = map_location
  1837. result = unpickler.load()
  1838. _serialization_tls.map_location = None
  1839. torch._utils._validate_loaded_sparse_tensors()
  1840. torch._C._log_api_usage_metadata(
  1841. "torch.load.metadata", {"serialization_id": zip_file.serialization_id()}
  1842. )
  1843. return result
  1844. def _is_torchscript_zip(zip_file):
  1845. return "constants.pkl" in zip_file.get_all_records()