| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201 |
- # mypy: allow-untyped-defs
- import copyreg
- import difflib
- import functools
- import io
- import os
- import pickle
- import re
- import shutil
- import struct
- import sys
- import tarfile
- import tempfile
- import threading
- import warnings
- from collections.abc import Callable
- from contextlib import closing, contextmanager
- from enum import Enum
- from typing import Any, cast, Generic, IO, TypeAlias, TypeVar
- from typing_extensions import TypeIs
- import torch
- import torch._weights_only_unpickler as _weights_only_unpickler
- from torch._sources import get_source_lines_and_file
- from torch._utils import _import_dotted_name
- from torch.storage import _get_dtype_from_pickle_storage_type
- from torch.types import FileLike, Storage
- __all__ = [
- "SourceChangeWarning",
- "mkdtemp",
- "register_package",
- "check_module_version_greater_or_equal",
- "validate_cuda_device",
- "validate_hpu_device",
- "location_tag",
- "default_restore_location",
- "normalize_storage_type",
- "storage_to_tensor_type",
- "save",
- "load",
- "StorageType",
- "LoadEndianness",
- "get_crc32_options",
- "set_crc32_options",
- "get_default_load_endianness",
- "set_default_load_endianness",
- "get_default_mmap_options",
- "set_default_mmap_options",
- "clear_safe_globals",
- "get_safe_globals",
- "add_safe_globals",
- "safe_globals",
- "get_unsafe_globals_in_checkpoint",
- "skip_data",
- ]
- DEFAULT_PROTOCOL = 2
- LONG_SIZE = struct.Struct("=l").size
- INT_SIZE = struct.Struct("=i").size
- SHORT_SIZE = struct.Struct("=h").size
- MAGIC_NUMBER = 0x1950A86A20F9469CFC6C
- PROTOCOL_VERSION = 1001
- STORAGE_KEY_SEPARATOR = ","
- MAP_LOCATION: TypeAlias = (
- Callable[[Storage, str], Storage] | torch.device | str | dict[str, str] | None
- )
- STORAGE: TypeAlias = Storage | torch.storage.TypedStorage | torch.UntypedStorage
- IS_WINDOWS = sys.platform == "win32"
- UNSAFE_MESSAGE = (
- "In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` "
- "from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, "
- "but it can result in arbitrary code execution. Do it only if you got the file from a "
- "trusted source."
- )
- if not IS_WINDOWS:
- from mmap import MAP_PRIVATE, MAP_SHARED
- else:
- MAP_SHARED, MAP_PRIVATE = None, None # type: ignore[assignment]
- def _default_to_weights_only(pickle_module):
- is_fbcode = not hasattr(torch.version, "git_version")
- return pickle_module is None and not is_fbcode
- # _serialization_tls is used to store thread local state specific to serialization
- # that needs to be propagated to other files, in particular we use this for
- # (1) map_location (needed for wrapper subclasses/third party devices to torch._utils)
- # (2) skip_data (needed for torch.Tensor.__reduce_ex__ for skip_data ctx)
- # (3) materialize_fake_tensors (needed for torch.Tensor.__reduce_ex__ for skip_data ctx)
- class _SerializationLocal(threading.local):
- def __init__(self):
- super().__init__()
- self.map_location: MAP_LOCATION | None = None
- self.skip_data: bool = False
- self.materialize_fake_tensors: bool = False
- _serialization_tls = _SerializationLocal()
- class SourceChangeWarning(Warning):
- pass
- @contextmanager
- def mkdtemp():
- path = tempfile.mkdtemp()
- try:
- yield path
- finally:
- shutil.rmtree(path)
- _package_registry: list[
- tuple[
- int,
- Callable[[STORAGE], str | None],
- Callable[[STORAGE, str], STORAGE | None],
- ]
- ] = []
- class LoadEndianness(Enum):
- NATIVE = 1
- LITTLE = 2
- BIG = 3
- def get_default_load_endianness() -> LoadEndianness | None:
- """
- Get fallback byte order for loading files
- If byteorder mark is not present in saved checkpoint,
- this byte order is used as fallback.
- By default, it's "native" byte order.
- Returns:
- default_load_endian: Optional[LoadEndianness]
- """
- from torch.utils.serialization import config
- return config.load.endianness
- def set_default_load_endianness(endianness):
- """
- Set fallback byte order for loading files
- If byteorder mark is not present in saved checkpoint,
- this byte order is used as fallback.
- By default, it's "native" byte order.
- Args:
- endianness: the new fallback byte order
- """
- if not isinstance(endianness, LoadEndianness) and endianness is not None:
- raise TypeError("Invalid argument type in function set_default_load_endianness")
- from torch.utils.serialization import config
- config.load.endianness = endianness
- def get_crc32_options() -> bool:
- """
- Get whether :func:`torch.save` computes and writes crc32 for each record.
- Defaults to ``True``.
- """
- from torch.utils.serialization import config
- return config.save.compute_crc32
- def set_crc32_options(compute_crc32: bool):
- """
- Set whether :func:`torch.save` computes and writes crc32 for each record.
- .. note::
- Setting this to ``False`` may make unzipping of the ``torch.save`` output
- fail or warn due to corrupted CRC32. However ``torch.load`` will be
- able to load the file.
- Args:
- compute_crc32 (bool): set crc32 computation flag
- """
- from torch.utils.serialization import config
- config.save.compute_crc32 = compute_crc32
- def get_default_mmap_options() -> int | None:
- """
- Get default mmap options for :func:`torch.load` with ``mmap=True``.
- Defaults to ``mmap.MAP_PRIVATE``.
- Returns:
- default_mmap_options: int
- """
- from torch.utils.serialization import config
- return config.load.mmap_flags
- def _get_storage_alignment() -> int:
- """
- Gets alignment for storages in torch.save files/
- Defaults to 64.
- Returns:
- storage_alginment: int
- """
- from torch.utils.serialization import config
- return config.save.storage_alignment
- class set_default_mmap_options:
- """
- Context manager or function to set default mmap options for :func:`torch.load` with ``mmap=True`` to flags.
- For now, only either ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED`` are supported.
- Please open an issue if you need any other option to be added here.
- .. note::
- This feature is currently not supported for Windows.
- Args:
- flags: ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED``
- """
- def __init__(self, flags: int) -> None:
- if IS_WINDOWS:
- raise RuntimeError(
- "Changing the default mmap options is currently not supported for Windows"
- )
- if flags != MAP_PRIVATE and flags != MAP_SHARED:
- raise ValueError(
- "Invalid argument in function set_default_mmap_options, "
- f"expected mmap.MAP_PRIVATE or mmap.MAP_SHARED, but got {flags}"
- )
- # global config
- from torch.utils.serialization import config
- self.prev = config.load.mmap_flags
- config.load.mmap_flags = flags
- def __enter__(self) -> None:
- pass
- def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
- from torch.utils.serialization import config
- config.load.mmap_flags = self.prev
- def clear_safe_globals() -> None:
- """
- Clears the list of globals that are safe for ``weights_only`` load.
- """
- _weights_only_unpickler._clear_safe_globals()
- def get_safe_globals() -> list[Callable | tuple[Callable, str]]:
- """
- Returns the list of user-added globals that are safe for ``weights_only`` load.
- """
- return _weights_only_unpickler._get_safe_globals()
- def add_safe_globals(safe_globals: list[Callable | tuple[Callable, str]]) -> None:
- """
- Marks the given globals as safe for ``weights_only`` load. For example, functions
- added to this list can be called during unpickling, classes could be instantiated
- and have state set.
- Each item in the list can either be a function/class or a tuple of the form
- (function/class, string) where string is the full path of the function/class.
- Within the serialized format, each function is identified with its full
- path as ``{__module__}.{__qualname__}``. When calling this API, you can provide this
- full path that should match the one in the checkpoint otherwise the default
- ``{fn.__module__}.{fn.__qualname__}`` will be used.
- Args:
- safe_globals (List[Union[Callable, Tuple[Callable, str]]]): list of globals to mark as safe
- Example:
- >>> # xdoctest: +SKIP("Can't torch.save(t, ...) as doctest thinks MyTensor is defined on torch.serialization")
- >>> import tempfile
- >>> class MyTensor(torch.Tensor):
- ... pass
- >>> t = MyTensor(torch.randn(2, 3))
- >>> with tempfile.NamedTemporaryFile() as f:
- ... torch.save(t, f.name)
- # Running `torch.load(f.name, weights_only=True)` will fail with
- # Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default.
- # Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint.
- ... torch.serialization.add_safe_globals([MyTensor])
- ... torch.load(f.name, weights_only=True)
- # MyTensor([[-0.5024, -1.8152, -0.5455],
- # [-0.8234, 2.0500, -0.3657]])
- """
- _weights_only_unpickler._add_safe_globals(safe_globals)
- class safe_globals(_weights_only_unpickler._safe_globals):
- r"""Context-manager that adds certain globals as safe for ``weights_only`` load.
- Args:
- safe_globals: List of globals for weights_only load.
- Example:
- >>> # xdoctest: +SKIP("Can't torch.save(t, ...) as doctest thinks MyTensor is defined on torch.serialization")
- >>> import tempfile
- >>> class MyTensor(torch.Tensor):
- ... pass
- >>> t = MyTensor(torch.randn(2, 3))
- >>> with tempfile.NamedTemporaryFile() as f:
- ... torch.save(t, f.name)
- # Running `torch.load(f.name, weights_only=True)` will fail with
- # Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default.
- # Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint.
- ... with torch.serialization.safe_globals([MyTensor]):
- ... torch.load(f.name, weights_only=True)
- # MyTensor([[-0.5024, -1.8152, -0.5455],
- # [-0.8234, 2.0500, -0.3657]])
- >>> assert torch.serialization.get_safe_globals() == []
- """
- def get_unsafe_globals_in_checkpoint(f: FileLike) -> list[str]:
- """Returns a list of strings of functions/classes in a ``torch.save`` object that are not safe for ``weights_only``.
- For a given function or class ``f``, the corresponding string will be of the form
- ``{f.__module__}.{f.__name__}``.
- This function will return any GLOBALs in the checkpoint that are not in the set marked safe
- for ``weights_only`` (either via :func:`add_safe_globals` or :class:`safe_globals` context or
- allowlisted by ``torch`` by default).
- .. note::
- This function will statically disassemble the pickle file in the checkpoint.
- The implication is any classes dynamically pushed onto the stack during unpickling
- will not be included in the output.
- Args:
- f: File-like object or string containing the checkpoint object saved via ``torch.save``
- Returns:
- A list of strings of pickle GLOBALs in the checkpoint that are not allowlisted for ``weights_only``.
- """
- default_safe_globals_strings = set(
- _weights_only_unpickler._get_allowed_globals().keys()
- )
- user_safe_global_strings = set(
- _weights_only_unpickler._get_user_allowed_globals().keys()
- )
- safe_global_strings = default_safe_globals_strings.union(user_safe_global_strings)
- with _open_file_like(f, "rb") as opened_file:
- if not _is_zipfile(opened_file):
- raise ValueError("Expected input to be a checkpoint returned by torch.save")
- with _open_zipfile_reader(opened_file) as zip_file:
- if _is_torchscript_zip(zip_file):
- raise ValueError(
- "Expected input to be a checkpoint returned by torch.save but got a torchscript checkpoint"
- )
- data_file = io.BytesIO(zip_file.get_record("data.pkl"))
- all_globals = _weights_only_unpickler.get_globals_in_pkl(data_file)
- return list(all_globals.difference(safe_global_strings))
- class skip_data:
- """
- Context-manager that skips writing/reading storage bytes for ``torch.save`` / ``torch.load`` calls.
- For the save path, storages will still be saved, but the space that their bytes would usually be written to
- will be empty space. The storage bytes can then be populated in a separate pass.
- For the load path, tensors will be loaded per the checkpoint but their storages will not be populated with data.
- .. warning::
- The ``skip_data`` context manager is an early prototype and is subject to change.
- Args:
- materialize_fake_tensors: Whether to materialize FakeTensors during save. This is a no-op for the load path.
- Example:
- >>> # xdoctest: +SKIP("NamedTemporaryFile on Windows")
- >>> import tempfile
- >>> t = torch.randn(2, 3)
- >>> with tempfile.NamedTemporaryFile() as f:
- ... with torch.serialization.skip_data():
- ... torch.save(t, f.name)
- ... torch.load(f.name, weights_only=True)
- tensor([[0., 0., 0.],
- [0., 0., 0.]])
- """
- def __init__(self, materialize_fake_tensors: bool = False):
- self.materialize_fake_tensors = materialize_fake_tensors
- def __enter__(self):
- global _serialization_tls
- self._old_skip_data = _serialization_tls.skip_data
- self._old_materialize_fake_tensors = _serialization_tls.materialize_fake_tensors
- _serialization_tls.skip_data = True
- _serialization_tls.materialize_fake_tensors = self.materialize_fake_tensors
- def __exit__(self, type, value, tb):
- global _serialization_tls
- _serialization_tls.skip_data = self._old_skip_data
- _serialization_tls.materialize_fake_tensors = self._old_materialize_fake_tensors
- def _is_zipfile(f) -> bool:
- # This is a stricter implementation than zipfile.is_zipfile().
- # zipfile.is_zipfile() is True if the magic number appears anywhere in the
- # binary. Since we expect the files here to be generated by torch.save or
- # torch.jit.save, it's safe to only check the start bytes and avoid
- # collisions and assume the zip has only 1 file.
- # See bugs.python.org/issue28494.
- start = f.tell()
- # Read the first few bytes and match against the ZIP file signature
- local_header_magic_number = b"PK\x03\x04"
- read_bytes = f.read(len(local_header_magic_number))
- f.seek(start)
- return read_bytes == local_header_magic_number
- def register_package(
- priority: int,
- tagger: Callable[[STORAGE], str | None],
- deserializer: Callable[[STORAGE, str], STORAGE | None],
- ):
- """
- Registers callables for tagging and deserializing storage objects with an associated priority.
- Tagging associates a device with a storage object at save time while deserializing moves a
- storage object to an appropriate device at load time. :attr:`tagger` and :attr:`deserializer`
- are run in the order given by their :attr:`priority` until a tagger/deserializer returns a
- value that is not `None`.
- To override the deserialization behavior for a device in the global registry, one can register a
- tagger with a higher priority than the existing tagger.
- This function can also be used to register a tagger and deserializer for new devices.
- Args:
- priority: Indicates the priority associated with the tagger and deserializer, where a lower
- value indicates higher priority.
- tagger: Callable that takes in a storage object and returns its tagged device as a string
- or None.
- deserializer: Callable that takes in storage object and a device string and returns a storage
- object on the appropriate device or None.
- Returns:
- `None`
- Example:
- >>> def ipu_tag(obj):
- >>> if obj.device.type == 'ipu':
- >>> return 'ipu'
- >>> def ipu_deserialize(obj, location):
- >>> if location.startswith('ipu'):
- >>> ipu = getattr(torch, "ipu", None)
- >>> assert ipu is not None, "IPU device module is not loaded"
- >>> assert torch.ipu.is_available(), "ipu is not available"
- >>> return obj.ipu(location)
- >>> torch.serialization.register_package(11, ipu_tag, ipu_deserialize)
- """
- queue_elem = (priority, tagger, deserializer)
- _package_registry.append(queue_elem)
- _package_registry.sort()
- def check_module_version_greater_or_equal(
- module,
- req_version_tuple,
- error_if_malformed=True,
- ):
- """
- Check if a module's version satisfies requirements
- Usually, a module's version string will be like 'x.y.z', which would be represented
- as a tuple (x, y, z), but sometimes it could be an unexpected format. If the version
- string does not match the given tuple's format up to the length of the tuple, then
- error and exit or emit a warning.
- Args:
- module: the module to check the version of
- req_version_tuple: tuple (usually of ints) representing the required version
- error_if_malformed: whether we should exit if module version string is malformed
- Returns:
- requirement_is_met: bool
- """
- try:
- version_strs = module.__version__.split(".")
- # Cast module version fields to match the types of the required version
- module_version = tuple(
- type(req_field)(version_strs[idx])
- for idx, req_field in enumerate(req_version_tuple)
- )
- requirement_is_met = module_version >= req_version_tuple
- except Exception as e:
- message = (
- f"'{module.__name__}' module version string is malformed '{module.__version__}' and cannot be compared"
- f" with tuple {str(req_version_tuple)}"
- )
- if error_if_malformed:
- raise RuntimeError(message) from e
- else:
- warnings.warn(
- message + ", but continuing assuming that requirement is met",
- stacklevel=2,
- )
- requirement_is_met = True
- return requirement_is_met
- def _cpu_tag(obj):
- if obj.device.type == "cpu":
- return "cpu"
- def _mps_tag(obj):
- if obj.device.type == "mps":
- return "mps"
- def _meta_tag(obj):
- if obj.device.type == "meta":
- return "meta"
- def _backend_tag(backend_name, obj):
- if backend_name == "privateuse1":
- backend_name = torch._C._get_privateuse1_backend_name()
- if obj.device.type == backend_name:
- if obj.device.index is None:
- return backend_name
- else:
- return backend_name + ":" + str(obj.device.index)
- def _cpu_deserialize(obj, location):
- if location == "cpu":
- return obj
- def _mps_deserialize(obj, location):
- if location.startswith("mps"):
- return obj.mps()
- def _meta_deserialize(obj, location):
- if location == "meta":
- return torch.UntypedStorage(obj.nbytes(), device="meta")
- def _is_meta_location(map_location):
- """
- Check if map_location specifies the meta device.
- This is used to skip reading storage data from disk when loading
- to the meta device, since meta tensors don't hold actual data.
- Args:
- map_location: The map_location argument passed to torch.load
- Returns:
- True if map_location is definitively the meta device, False otherwise.
- For dict or callable map_location, returns False since we can't
- easily determine the target device without evaluating it.
- """
- if map_location is None:
- return False
- if isinstance(map_location, str):
- return map_location == "meta"
- if isinstance(map_location, torch.device):
- return map_location.type == "meta"
- # dict or callable - can't easily determine
- return False
- def _validate_device(location, backend_name):
- """
- Check whether the device index of specified backend is valid
- In case of privateuse1 backend, your must first register a device_module for
- privateuse1 using torch._register_device_module. Implement the following
- methods in device_module like cuda: device_module._utils._get_device_index(location, True),
- device_module.device_count().
- Args:
- location: string of device
- backend_name: the backend name or the name of privateuse1, which can be renamed
- Returns:
- device_index: int
- """
- if not hasattr(torch, backend_name):
- raise RuntimeError(
- f"The {backend_name.upper()} device module is not registered. "
- "If you are running on a CPU-only machine, "
- "please use torch.load with map_location=torch.device('cpu') "
- "to map your storages to the CPU."
- )
- device_module = getattr(torch, backend_name)
- if hasattr(device_module, "_utils") and hasattr(
- device_module._utils, "_get_device_index"
- ):
- device_index = device_module._utils._get_device_index(location, True)
- device = torch.device(backend_name, device_index)
- else:
- device = torch.device(location)
- device_index = device.index if device.index else 0
- if hasattr(device_module, "is_available") and not device_module.is_available():
- raise RuntimeError(
- f"Attempting to deserialize object on a {backend_name.upper()} "
- f"device but torch.{backend_name}.is_available() is False. "
- "If you are running on a CPU-only machine, "
- "please use torch.load with map_location=torch.device('cpu') "
- "to map your storages to the CPU."
- )
- if hasattr(device_module, "device_count"):
- device_count = device_module.device_count()
- if device_index >= device_count:
- raise RuntimeError(
- f"Attempting to deserialize object on {backend_name.upper()} device "
- f"{device_index} but torch.{backend_name}.device_count() is {device_count}. "
- "Please use torch.load with map_location to map your storages "
- "to an existing device."
- )
- return device
- def validate_cuda_device(location):
- return _validate_device(location, "cuda").index
- def validate_hpu_device(location):
- return _validate_device(location, "hpu").index
- def _deserialize(backend_name, obj, location):
- if backend_name == "privateuse1":
- backend_name = torch._C._get_privateuse1_backend_name()
- if location == backend_name or bool(
- re.match(f"{backend_name}(:|[0-9]+)", location)
- ):
- device = _validate_device(location, backend_name)
- return obj.to(device=device)
- register_package(10, _cpu_tag, _cpu_deserialize)
- register_package(
- 20,
- functools.partial(_backend_tag, "cuda"),
- functools.partial(_deserialize, "cuda"),
- )
- register_package(21, _mps_tag, _mps_deserialize)
- register_package(22, _meta_tag, _meta_deserialize)
- register_package(
- 23,
- functools.partial(_backend_tag, "privateuse1"),
- functools.partial(_deserialize, "privateuse1"),
- )
- register_package(
- 24,
- functools.partial(_backend_tag, "hpu"),
- functools.partial(_deserialize, "hpu"),
- )
- register_package(
- 25,
- functools.partial(_backend_tag, "xpu"),
- functools.partial(_deserialize, "xpu"),
- )
- register_package(
- 26,
- functools.partial(_backend_tag, "mtia"),
- functools.partial(_deserialize, "mtia"),
- )
- def location_tag(
- storage: Storage | torch.storage.TypedStorage | torch.UntypedStorage,
- ):
- for _, tagger, _ in _package_registry:
- location = tagger(storage)
- if location:
- return location
- raise RuntimeError(
- "don't know how to determine data location of " + torch.typename(storage)
- )
- def default_restore_location(storage, location):
- """
- Restores `storage` using a deserializer function registered for the `location`.
- This function looks in the registry for deserializer functions that match the `location`.
- If found, it attempts to use them, in priority order, to restore `storage` until one
- returns a not `None` result. If no deserializer can be found in the registry, or all found fail
- to bear a result, it raises a `RuntimeError`.
- Args:
- storage (STORAGE): the storage object to restore
- location (str): the location tag associated with the storage object
- Returns:
- storage: Optional[STORAGE]
- Raises:
- RuntimeError: If no deserializer matching `location` is found in the registry or if
- all matching ones return `None`.
- """
- for _, _, fn in _package_registry:
- result = fn(storage, location)
- if result is not None:
- return result
- raise RuntimeError(
- "don't know how to restore data location of "
- + torch.typename(storage)
- + " (tagged with "
- + location
- + ")"
- )
- def normalize_storage_type(storage_type):
- return getattr(torch, storage_type.__name__)
- def storage_to_tensor_type(storage):
- storage_type = type(storage)
- module = _import_dotted_name(storage_type.__module__)
- return getattr(module, storage_type.__name__.replace("Storage", "Tensor"))
- def _is_path(name_or_buffer: object) -> TypeIs[str | os.PathLike]:
- return isinstance(name_or_buffer, (str, os.PathLike))
- T = TypeVar("T")
- class _opener(Generic[T]):
- def __init__(self, file_like: T) -> None:
- self.file_like: T = file_like
- def __enter__(self):
- return self.file_like
- def __exit__(self, *args):
- pass
- class _open_file(_opener[IO[bytes]]):
- def __init__(self, name: str | os.PathLike[str], mode: str) -> None:
- super().__init__(open(name, mode)) # noqa: SIM115
- def __exit__(self, *args):
- self.file_like.close()
- class _open_buffer_reader(_opener[IO[bytes]]):
- def __init__(self, buffer: IO[bytes]) -> None:
- super().__init__(buffer)
- _check_seekable(buffer)
- class _open_buffer_writer(_opener[IO[bytes]]):
- def __exit__(self, *args):
- self.file_like.flush()
- def _open_file_like(name_or_buffer: FileLike, mode: str) -> _opener[IO[bytes]]:
- if _is_path(name_or_buffer):
- return _open_file(name_or_buffer, mode)
- else:
- if "w" in mode:
- return _open_buffer_writer(name_or_buffer)
- elif "r" in mode:
- return _open_buffer_reader(name_or_buffer)
- else:
- raise RuntimeError(f"Expected 'r' or 'w' in mode but got {mode}")
- class _open_zipfile_reader(_opener[torch._C.PyTorchFileReader]):
- def __init__(self, name_or_buffer: str | IO[bytes]) -> None:
- super().__init__(torch._C.PyTorchFileReader(name_or_buffer))
- class _open_zipfile_writer_file(_opener[torch._C.PyTorchFileWriter]):
- def __init__(self, name: str) -> None:
- self.file_stream = None
- self.name = name
- try:
- self.name.encode("ascii")
- except UnicodeEncodeError:
- # PyTorchFileWriter only supports ascii filename.
- # For filenames with non-ascii characters, we rely on Python
- # for writing out the file.
- # pyrefly: ignore [bad-assignment]
- self.file_stream = io.FileIO(self.name, mode="w")
- super().__init__(
- torch._C.PyTorchFileWriter( # pyrefly: ignore # no-matching-overload
- self.file_stream, get_crc32_options(), _get_storage_alignment()
- )
- )
- else:
- super().__init__(
- torch._C.PyTorchFileWriter(
- self.name, get_crc32_options(), _get_storage_alignment()
- )
- )
- def __exit__(self, *args) -> None:
- self.file_like.write_end_of_file()
- if self.file_stream is not None:
- self.file_stream.close()
- class _open_zipfile_writer_buffer(_opener[torch._C.PyTorchFileWriter]):
- def __init__(self, buffer: IO[bytes]) -> None:
- if not callable(getattr(buffer, "write", None)):
- msg = f"Buffer of {str(type(buffer)).strip('<>')} has no callable attribute 'write'"
- if not hasattr(buffer, "write"):
- raise AttributeError(msg)
- raise TypeError(msg)
- self.buffer = buffer
- super().__init__(
- torch._C.PyTorchFileWriter(
- buffer, get_crc32_options(), _get_storage_alignment()
- )
- )
- def __exit__(self, *args) -> None:
- self.file_like.write_end_of_file()
- self.buffer.flush()
- def _open_zipfile_writer(name_or_buffer: str | IO[bytes]) -> _opener:
- container: type[_opener]
- if _is_path(name_or_buffer):
- container = _open_zipfile_writer_file
- else:
- container = _open_zipfile_writer_buffer
- return container(name_or_buffer) # type: ignore[arg-type]
- def _is_compressed_file(f) -> bool:
- compress_modules = ["gzip"]
- try:
- return f.__module__ in compress_modules
- except AttributeError:
- return False
- def _should_read_directly(f):
- """
- Checks if f is a file that should be read directly. It should be read
- directly if it is backed by a real file (has a fileno) and is not a
- a compressed file (e.g. gzip)
- """
- if _is_compressed_file(f):
- return False
- try:
- return f.fileno() >= 0
- except io.UnsupportedOperation:
- return False
- except AttributeError:
- return False
- def _check_seekable(f) -> bool:
- def raise_err_msg(patterns, e):
- for p in patterns:
- if p in str(e):
- msg = (
- str(e)
- + ". You can only torch.load from a file that is seekable."
- + " Please pre-load the data into a buffer like io.BytesIO and"
- + " try to load from it instead."
- )
- raise type(e)(msg)
- raise e
- try:
- f.seek(f.tell())
- return True
- except (io.UnsupportedOperation, AttributeError) as e:
- raise_err_msg(["seek", "tell"], e)
- return False
- def _check_dill_version(pickle_module) -> None:
- """Checks if using dill as the pickle module, and if so, checks if it is the correct version.
- If dill version is lower than 0.3.1, a ValueError is raised.
- Args:
- pickle_module: module used for pickling metadata and objects
- """
- if pickle_module is not None and pickle_module.__name__ == "dill":
- required_dill_version = (0, 3, 1)
- if not check_module_version_greater_or_equal(
- pickle_module, required_dill_version, False
- ):
- raise ValueError(
- (
- "'torch' supports dill >= {}, but you have dill {}."
- " Please upgrade dill or switch to 'pickle'"
- ).format(
- ".".join([str(num) for num in required_dill_version]),
- pickle_module.__version__,
- )
- )
- def _check_save_filelike(f):
- if not _is_path(f) and not hasattr(f, "write"):
- raise AttributeError(
- "expected 'f' to be string, path, or a file-like object with "
- "a 'write' attribute"
- )
- def save(
- obj: object,
- f: FileLike,
- pickle_module: Any = pickle,
- pickle_protocol: int = DEFAULT_PROTOCOL,
- _use_new_zipfile_serialization: bool = True,
- _disable_byteorder_record: bool = False,
- ) -> None:
- # Reference: https://github.com/pytorch/pytorch/issues/54354
- # The first line of this docstring overrides the one Sphinx generates for the
- # documentation. We need it so that Sphinx doesn't leak `pickle`s path from
- # the build environment (e.g. `<module 'pickle' from '/leaked/path').
- """save(obj, f, pickle_module=pickle, pickle_protocol=2, _use_new_zipfile_serialization=True)
- Saves an object to a disk file.
- See also: :ref:`saving-loading-tensors`
- See :ref:`layout-control` for more advanced tools to manipulate a checkpoint.
- Args:
- obj: saved object
- f: a file-like object (has to implement write and flush) or a string or
- os.PathLike object containing a file name
- pickle_module: module used for pickling metadata and objects
- pickle_protocol: can be specified to override the default protocol
- .. note::
- A common PyTorch convention is to save tensors using .pt file extension.
- .. note::
- PyTorch preserves storage sharing across serialization. See
- :ref:`preserve-storage-sharing` for more details.
- .. note::
- The 1.6 release of PyTorch switched ``torch.save`` to use a new
- zipfile-based file format. ``torch.load`` still retains the ability to
- load files in the old format. If for any reason you want ``torch.save``
- to use the old format, pass the kwarg ``_use_new_zipfile_serialization=False``.
- Example:
- >>> # xdoctest: +SKIP("makes cwd dirty")
- >>> # Save to file
- >>> x = torch.tensor([0, 1, 2, 3, 4])
- >>> torch.save(x, "tensor.pt")
- >>> # Save to io.BytesIO buffer
- >>> buffer = io.BytesIO()
- >>> torch.save(x, buffer)
- """
- torch._C._log_api_usage_once("torch.save")
- _check_dill_version(pickle_module)
- _check_save_filelike(f)
- if isinstance(f, (str, os.PathLike)):
- f = os.fspath(f)
- if _use_new_zipfile_serialization:
- with _open_zipfile_writer(f) as opened_zipfile:
- _save(
- obj,
- opened_zipfile,
- pickle_module,
- pickle_protocol,
- _disable_byteorder_record,
- )
- return
- else:
- global _serialization_tls
- if _serialization_tls.skip_data:
- raise RuntimeError(
- "Cannot use skip_data=True with _use_new_zipfile_serialization=False"
- )
- with _open_file_like(f, "wb") as opened_file:
- _legacy_save(obj, opened_file, pickle_module, pickle_protocol)
- def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
- import torch.nn as nn
- serialized_container_types = {}
- serialized_storages: dict[str, tuple[torch.UntypedStorage, torch.dtype]] = {}
- # Since loading storages that view the same data with different dtypes is
- # not supported, we need to keep track of the dtype associated with each
- # storage data_ptr and throw an error if the dtype is ever different.
- # TODO: This feature could be added in the future
- storage_dtypes: dict[int, torch.dtype] = {}
- def persistent_id(obj: Any) -> tuple | None:
- # FIXME: the docs say that persistent_id should only return a string
- # but torch store returns tuples. This works only in the binary protocol
- # see
- # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
- # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
- if isinstance(obj, type) and issubclass(obj, nn.Module):
- if obj in serialized_container_types:
- return None
- serialized_container_types[obj] = True
- source_file = source = None
- try:
- source_lines, _, source_file = get_source_lines_and_file(obj)
- source = "".join(source_lines)
- except (
- Exception
- ): # saving the source is optional, so we can ignore any errors
- warnings.warn(
- "Couldn't retrieve source code for container of "
- "type " + obj.__name__ + ". It won't be checked "
- "for correctness upon loading.",
- stacklevel=2,
- )
- return ("module", obj, source_file, source)
- if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
- storage: torch.UntypedStorage
- if isinstance(obj, torch.storage.TypedStorage):
- # TODO: Once we decide to break serialization FC, this case
- # can be deleted
- storage = obj._untyped_storage
- storage_dtype = obj.dtype
- storage_type_str = obj._pickle_storage_type()
- storage_type = getattr(torch, storage_type_str)
- dtype = obj.dtype
- storage_numel = obj._size()
- elif isinstance(obj, torch.UntypedStorage):
- storage = obj
- storage_dtype = torch.uint8
- storage_type = normalize_storage_type(type(obj))
- dtype = torch.uint8
- storage_numel = storage.nbytes()
- else:
- raise TypeError(f"type not recognized: {type(obj)}")
- # If storage is allocated, ensure that any other saved storages
- # pointing to the same data all have the same dtype. If storage is
- # not allocated, don't perform this check
- if storage.data_ptr() != 0:
- if storage.data_ptr() in storage_dtypes:
- if storage_dtype != storage_dtypes[storage.data_ptr()]:
- raise RuntimeError(
- "Cannot save multiple tensors or storages that "
- "view the same data as different types"
- )
- else:
- storage_dtypes[storage.data_ptr()] = storage_dtype
- view_metadata: tuple[str, int, int] | None
- # Offset is always 0, but we keep it for backwards compatibility
- # with the old serialization format (which supported storage views)
- offset = 0
- storage_key = str(storage._cdata)
- location = location_tag(storage)
- # TODO: There's an issue here with FC. It might be impossible to
- # solve, but it's worth noting. Imagine we save a list `[storage,
- # tensor]`, where `tensor.storage()` is the same as `storage`, and
- # `tensor.element_size() > 1`. Let's say that `tensor.dtype ==
- # torch.float`. The storage will be serialized with element size
- # of 1, since we're choosing to serialize the first occurrence of
- # a duplicate storage. Since this legacy serialization format saves
- # the numel of the storage, rather than nbytes directly, we'll be
- # effectively saving nbytes in this case. We'll be able to load it
- # and the tensor back up with no problems in _this_ and future
- # versions of pytorch, but in older versions, here's the problem:
- # the storage will be loaded up as a UntypedStorage, and then the
- # FloatTensor will loaded and the UntypedStorage will be assigned to
- # it. Since the storage dtype does not match the tensor dtype, this
- # will cause an error. If we reverse the list, like `[tensor,
- # storage]`, then we will save the `tensor.storage()` as a faked
- # `FloatStorage`, and the saved size will be the correct
- # dtype-specific numel count that old versions expect. `tensor`
- # will be able to load up properly in old versions, pointing to
- # a FloatStorage. However, `storage` is still being translated to
- # a UntypedStorage, and it will try to resolve to the same
- # FloatStorage that `tensor` contains. This will also cause an
- # error. It doesn't seem like there's any way around this.
- # Probably, we just cannot maintain FC for the legacy format if the
- # saved list contains both a tensor and a storage that point to the
- # same data. We should still be able to maintain FC for lists of
- # just tensors, as long as all views share the same dtype as the
- # tensor they are viewing.
- if storage_key not in serialized_storages:
- serialized_storages[storage_key] = (storage, dtype)
- is_view = storage._cdata != storage._cdata
- if is_view:
- view_metadata = (str(storage._cdata), offset, storage.nbytes())
- else:
- view_metadata = None
- res = (
- "storage",
- storage_type,
- storage_key,
- location,
- storage_numel,
- view_metadata,
- )
- return res
- return None
- sys_info = {
- "protocol_version": PROTOCOL_VERSION,
- "little_endian": sys.byteorder == "little",
- "type_sizes": {
- "short": SHORT_SIZE,
- "int": INT_SIZE,
- "long": LONG_SIZE,
- },
- }
- pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol)
- pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol)
- pickle_module.dump(sys_info, f, protocol=pickle_protocol)
- class PyTorchLegacyPickler(pickle_module.Pickler):
- def persistent_id(self, obj):
- return persistent_id(obj) # noqa: F821
- pickler = PyTorchLegacyPickler(f, protocol=pickle_protocol)
- pickler.dump(obj)
- # The class def keeps the persistent_id closure alive, leaking memory.
- del persistent_id
- serialized_storage_keys = sorted(serialized_storages.keys())
- pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol)
- f.flush()
- for key in serialized_storage_keys:
- storage, dtype = serialized_storages[key]
- storage._write_file(
- f, _should_read_directly(f), True, torch._utils._element_size(dtype)
- )
- def _save(
- obj,
- zip_file,
- pickle_module,
- pickle_protocol,
- _disable_byteorder_record,
- ):
- serialized_storages: dict[str, torch.storage.UntypedStorage] = {}
- id_map: dict[int, str] = {}
- # Since loading storages that view the same data with different dtypes is
- # not supported, we need to keep track of the dtype associated with each
- # storage data_ptr and throw an error if the dtype is ever different.
- # TODO: This feature could be added in the future
- storage_dtypes: dict[int, torch.dtype] = {}
- def persistent_id(obj):
- # FIXME: the docs say that persistent_id should only return a string
- # but torch store returns tuples. This works only in the binary protocol
- # see
- # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
- # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
- if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
- if isinstance(obj, torch.storage.TypedStorage):
- # TODO: Once we decide to break serialization FC, this case
- # can be deleted
- storage = obj._untyped_storage
- storage_dtype = obj.dtype
- storage_type_str = obj._pickle_storage_type()
- storage_type = getattr(torch, storage_type_str)
- storage_numel = obj._size()
- else:
- storage = obj
- storage_dtype = torch.uint8
- storage_type = normalize_storage_type(type(obj))
- storage_numel = storage.nbytes()
- # If storage is allocated, ensure that any other saved storages
- # pointing to the same data all have the same dtype. If storage is
- # not allocated, don't perform this check
- if str(storage.device) != "meta" and storage.data_ptr() != 0:
- if storage.data_ptr() in storage_dtypes:
- if storage_dtype != storage_dtypes[storage.data_ptr()]:
- raise RuntimeError(
- "Cannot save multiple tensors or storages that "
- "view the same data as different types"
- )
- else:
- storage_dtypes[storage.data_ptr()] = storage_dtype
- storage_key = id_map.setdefault(storage._cdata, str(len(id_map)))
- if hasattr(obj, "_fake_device") and obj._fake_device is not None:
- location = str(obj._fake_device)
- else:
- location = location_tag(storage)
- serialized_storages[storage_key] = storage
- return ("storage", storage_type, storage_key, location, storage_numel)
- return None
- # Write the pickle data for `obj`
- data_buf = io.BytesIO()
- class PyTorchPickler(pickle_module.Pickler): # type: ignore[name-defined]
- def persistent_id(self, obj):
- return persistent_id(obj) # noqa: F821
- pickler = PyTorchPickler(data_buf, protocol=pickle_protocol)
- pickler.dump(obj)
- # The class def keeps the persistent_id closure alive, leaking memory.
- del persistent_id
- data_value = data_buf.getvalue()
- zip_file.write_record("data.pkl", data_value, len(data_value))
- # .format_version is used to track
- # 1. version 1 represents the order of storages being changed from
- # lexicographical based on keys to numerically ordered based on keys
- # 2. version 2 represents including storage_alignment as a record
- # within the zipfile
- zip_file.write_record(".format_version", "1", len("1"))
- storage_alignment = str(_get_storage_alignment())
- zip_file.write_record(
- ".storage_alignment", storage_alignment, len(storage_alignment)
- )
- # Write byte order marker
- if not _disable_byteorder_record:
- if sys.byteorder not in ["little", "big"]:
- raise ValueError("Unknown endianness type: " + sys.byteorder)
- zip_file.write_record("byteorder", sys.byteorder, len(sys.byteorder))
- # Write each tensor to a file named tensor/the_tensor_key in the zip archive
- for key in serialized_storages:
- name = f"data/{key}"
- storage = serialized_storages[key]
- num_bytes = storage.nbytes()
- global _serialization_tls
- if _serialization_tls.skip_data:
- zip_file.write_record_metadata(name, num_bytes)
- else:
- # given that we copy things around anyway, we might use storage.cpu()
- # this means to that to get tensors serialized, you need to implement
- # .cpu() on the underlying Storage
- if storage.device.type != "cpu":
- from torch.utils.serialization import config
- if (
- config.save.use_pinned_memory_for_d2h
- and (
- acc := torch.accelerator.current_accelerator(
- check_available=True
- )
- )
- is not None
- and acc.type == storage.device.type
- ):
- new_storage = torch.empty(
- num_bytes, dtype=torch.uint8, device="cpu", pin_memory=True
- ).untyped_storage()
- new_storage.copy_(storage)
- torch.accelerator.current_stream(storage.device.index).synchronize()
- storage = new_storage
- else:
- storage = storage.cpu()
- # Now that it is on the CPU we can directly copy it into the zip file
- zip_file.write_record(name, storage, num_bytes)
- def load(
- f: FileLike,
- map_location: MAP_LOCATION = None,
- pickle_module: Any = None,
- *,
- weights_only: bool | None = None,
- mmap: bool | None = None,
- **pickle_load_args: Any,
- ) -> Any:
- # Reference: https://github.com/pytorch/pytorch/issues/54354
- # The first line of this docstring overrides the one Sphinx generates for the
- # documentation. We need it so that Sphinx doesn't leak `pickle`s path from
- # the build environment (e.g. `<module 'pickle' from '/leaked/path').
- """load(f, map_location=None, pickle_module=pickle, *, weights_only=True, mmap=None, **pickle_load_args)
- Loads an object saved with :func:`torch.save` from a file.
- .. warning::
- :func:`torch.load()` uses an unpickler under the hood. **Never load data from an untrusted source.**
- See :ref:`weights-only-security` for more details.
- :func:`torch.load` uses Python's unpickling facilities but treats storages,
- which underlie tensors, specially. They are first deserialized on the
- CPU and are then moved to the device they were saved from. If this fails
- (e.g. because the run time system doesn't have certain devices), an exception
- is raised. However, storages can be dynamically remapped to an alternative
- set of devices using the :attr:`map_location` argument.
- If :attr:`map_location` is a callable, it will be called once for each serialized
- storage with two arguments: storage and location. The storage argument
- will be the initial deserialization of the storage, residing on the CPU.
- Each serialized storage has a location tag associated with it which
- identifies the device it was saved from, and this tag is the second
- argument passed to :attr:`map_location`. The builtin location tags are ``'cpu'``
- for CPU tensors and ``'cuda:device_id'`` (e.g. ``'cuda:2'``) for CUDA tensors.
- :attr:`map_location` should return either ``None`` or a storage. If
- :attr:`map_location` returns a storage, it will be used as the final deserialized
- object, already moved to the right device. Otherwise, :func:`torch.load` will
- fall back to the default behavior, as if :attr:`map_location` wasn't specified.
- If :attr:`map_location` is a :class:`torch.device` object or a string containing
- a device tag, it indicates the location where all tensors should be loaded.
- Otherwise, if :attr:`map_location` is a dict, it will be used to remap location tags
- appearing in the file (keys), to ones that specify where to put the
- storages (values).
- User extensions can register their own location tags and tagging and
- deserialization methods using :func:`torch.serialization.register_package`.
- See :ref:`layout-control` for more advanced tools to manipulate a checkpoint.
- Args:
- f: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`),
- or a string or os.PathLike object containing a file name
- map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage
- locations
- pickle_module: module used for unpickling metadata and objects (has to
- match the :attr:`pickle_module` used to serialize file)
- weights_only: Indicates whether unpickler should be restricted to
- loading only tensors, primitive types, dictionaries
- and any types added via :func:`torch.serialization.add_safe_globals`.
- See :ref:`weights-only` for more details.
- mmap: Indicates whether the file should be mapped rather than loading all the storages into memory.
- Typically, tensor storages in the file will first be moved from disk to CPU memory, after which they
- are moved to the location that they were tagged with when saving, or specified by ``map_location``. This
- second step is a no-op if the final location is CPU. When the ``mmap`` flag is set, instead of copying the
- tensor storages from disk to CPU memory in the first step, ``f`` is mapped, which means tensor storages
- will be lazily loaded when their data is accessed.
- pickle_load_args: (Python 3 only) optional keyword arguments passed over to
- :func:`pickle_module.load` and :func:`pickle_module.Unpickler`,
- only works if :attr:`weights_only=False`, e.g., :attr:`errors=...`.
- .. note::
- When you call :func:`torch.load()` on a file which contains GPU tensors, those tensors
- will be loaded to GPU by default. You can call ``torch.load(.., map_location='cpu')``
- and then :meth:`load_state_dict` to avoid GPU RAM surge when loading a model checkpoint.
- .. note::
- By default, we decode byte strings as ``utf-8``. This is to avoid a common error
- case ``UnicodeDecodeError: 'ascii' codec can't decode byte 0x...``
- when loading files saved by Python 2 in Python 3. If this default
- is incorrect, you may use an extra :attr:`encoding` keyword argument to specify how
- these objects should be loaded, e.g., :attr:`encoding='latin1'` decodes them
- to strings using ``latin1`` encoding, and :attr:`encoding='bytes'` keeps them
- as byte arrays which can be decoded later with ``byte_array.decode(...)``.
- Example:
- >>> # xdoctest: +SKIP("undefined filepaths")
- >>> torch.load("tensors.pt", weights_only=True)
- # Load all tensors onto the CPU
- >>> torch.load(
- ... "tensors.pt",
- ... map_location=torch.device("cpu"),
- ... weights_only=True,
- ... )
- # Load all tensors onto the CPU, using a function
- >>> torch.load(
- ... "tensors.pt",
- ... map_location=lambda storage, loc: storage,
- ... weights_only=True,
- ... )
- # Load all tensors onto GPU 1
- >>> torch.load(
- ... "tensors.pt",
- ... map_location=lambda storage, loc: storage.cuda(1), # type: ignore[attr-defined]
- ... weights_only=True,
- ... ) # type: ignore[attr-defined]
- # Map tensors from GPU 1 to GPU 0
- >>> torch.load(
- ... "tensors.pt",
- ... map_location={"cuda:1": "cuda:0"},
- ... weights_only=True,
- ... )
- # Load tensor from io.BytesIO object
- # Loading from a buffer setting weights_only=False, warning this can be unsafe
- >>> with open("tensor.pt", "rb") as f:
- ... buffer = io.BytesIO(f.read())
- >>> torch.load(buffer, weights_only=False)
- # Load a module with 'ascii' encoding for unpickling
- # Loading from a module setting weights_only=False, warning this can be unsafe
- >>> torch.load("module.pt", encoding="ascii", weights_only=False)
- """
- torch._C._log_api_usage_once("torch.load")
- DOCS_MESSAGE = (
- "\n\nCheck the documentation of torch.load to learn more about types accepted by default with "
- "weights_only https://pytorch.org/docs/stable/generated/torch.load.html."
- )
- def _get_wo_message(message: str) -> str:
- unsafe_global_pattern = r"GLOBAL (\S+) was not an allowed global by default."
- has_unsafe_global = re.search(unsafe_global_pattern, message) is not None
- blocklist_pattern = r"whose module (\S+) is blocked"
- has_blocklist = re.search(blocklist_pattern, message) is not None
- import_pattern = r"(\S+) must be (\S+) to load"
- has_import = re.search(import_pattern, message) is not None
- if has_unsafe_global:
- updated_message = (
- "Weights only load failed. This file can still be loaded, to do so you have two options, "
- "\033[1mdo those steps only if you trust the source of the checkpoint\033[0m. "
- f"\n\t(1) {UNSAFE_MESSAGE}\n\t(2) Alternatively, to load with `weights_only=True` please check "
- "the recommended steps in the following error message.\n\tWeightsUnpickler error: "
- + message
- )
- else:
- if has_import:
- return f"Weights only load failed. {message}\n {UNSAFE_MESSAGE}\n"
- else:
- updated_message = f"Weights only load failed. {UNSAFE_MESSAGE}\n"
- if not has_blocklist:
- updated_message += (
- "Please file an issue with the following so that we can make "
- "`weights_only=True` compatible with your use case: WeightsUnpickler error: "
- )
- updated_message += "\n\n" + message
- return updated_message + DOCS_MESSAGE
- weights_only_not_set = weights_only is None
- if weights_only_not_set:
- weights_only = _default_to_weights_only(pickle_module)
- true_values = ["1", "y", "yes", "true"]
- # Add ability to force safe only or non-safe weight loads via environment variables
- force_weights_only_load = (
- os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0") in true_values
- )
- force_no_weights_only_load = (
- os.getenv("TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD", "0") in true_values
- )
- if force_weights_only_load and force_no_weights_only_load:
- raise RuntimeError(
- "Only one of `TORCH_FORCE_WEIGHTS_ONLY_LOAD` or `TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD` "
- "should be set, but both were set."
- )
- elif force_weights_only_load:
- weights_only = True
- elif force_no_weights_only_load:
- # TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD can only override if callsite did not explicitly set weights_only
- if weights_only_not_set:
- warnings.warn(
- "Environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD detected, since the"
- "`weights_only` argument was not explicitly passed to `torch.load`, forcing weights_only=False.",
- UserWarning,
- stacklevel=2,
- )
- weights_only = False
- if weights_only:
- if pickle_module is not None:
- raise RuntimeError(
- "Can not safely load weights when explicit pickle_module is specified"
- )
- else:
- if pickle_module is None:
- pickle_module = pickle
- if pickle_load_args != {} and weights_only:
- warnings.warn("pickle_load_args only works if `weights_only=False`.")
- # make flipping default BC-compatible
- if mmap is None:
- from torch.utils.serialization import config
- mmap = config.load.mmap
- _check_dill_version(pickle_module)
- if "encoding" not in pickle_load_args:
- pickle_load_args["encoding"] = "utf-8"
- with _open_file_like(f, "rb") as opened_file:
- if _is_zipfile(opened_file):
- # The zipfile reader is going to advance the current file position.
- # If we want to actually tail call to torch.jit.load, we need to
- # reset back to the original position.
- orig_position = opened_file.tell()
- overall_storage = None
- with _open_zipfile_reader(opened_file) as opened_zipfile:
- if _is_torchscript_zip(opened_zipfile):
- warnings.warn(
- "'torch.load' received a zip file that looks like a TorchScript archive"
- " dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to"
- " silence this warning)",
- UserWarning,
- stacklevel=2,
- )
- if weights_only:
- raise RuntimeError(
- "Cannot use ``weights_only=True`` with TorchScript archives passed to "
- "``torch.load``. " + UNSAFE_MESSAGE
- )
- opened_file.seek(orig_position)
- return torch.jit.load(opened_file, map_location=map_location)
- if mmap:
- if not _is_path(f):
- raise ValueError(
- "f must be a file path in order to use the mmap argument"
- )
- size = os.path.getsize(f)
- if not IS_WINDOWS:
- shared = get_default_mmap_options() == MAP_SHARED
- else:
- shared = False
- overall_storage = torch.UntypedStorage.from_file(
- os.fspath(f),
- shared,
- size,
- )
- if weights_only:
- try:
- return _load(
- opened_zipfile,
- map_location,
- _weights_only_unpickler,
- overall_storage=overall_storage,
- **pickle_load_args,
- )
- except pickle.UnpicklingError as e:
- raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
- return _load(
- opened_zipfile,
- map_location,
- pickle_module,
- overall_storage=overall_storage,
- **pickle_load_args,
- )
- if mmap:
- f_name = "" if not isinstance(f, str) else f"{f}, "
- raise RuntimeError(
- "mmap can only be used with files saved with "
- f"`torch.save({f_name}_use_new_zipfile_serialization=True), "
- "please torch.save your checkpoint with this option in order to use mmap."
- )
- if weights_only:
- try:
- return _legacy_load(
- opened_file,
- map_location,
- _weights_only_unpickler,
- **pickle_load_args,
- )
- except pickle.UnpicklingError as e:
- raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
- return _legacy_load(
- opened_file, map_location, pickle_module, **pickle_load_args
- )
- # Register pickling support for layout instances such as
- # torch.sparse_coo, etc
- def _get_layout(name):
- """Get layout extension object from its string representation."""
- cache = _get_layout.cache # type: ignore[attr-defined]
- if not cache:
- for v in torch.__dict__.values():
- if isinstance(v, torch.layout):
- cache[str(v)] = v
- return cache[name]
- # There are yet not good way to type annotate function attributes https://github.com/python/mypy/issues/2087
- _get_layout.cache = {} # type: ignore[attr-defined]
- copyreg.pickle(torch.layout, lambda obj: (_get_layout, (str(obj),)))
- def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
- deserialized_objects: dict[int, Any] = {}
- restore_location = _get_restore_location(map_location)
- class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined]
- def find_class(self, mod_name, name):
- if type(name) is str and "Storage" in name:
- try:
- return StorageType(name)
- except KeyError:
- pass
- return super().find_class(mod_name, name)
- def _check_container_source(container_type, source_file, original_source):
- try:
- current_source = "".join(get_source_lines_and_file(container_type)[0])
- except Exception: # saving the source is optional, so we can ignore any errors
- warnings.warn(
- "Couldn't retrieve source code for container of "
- "type " + container_type.__name__ + ". It won't be checked "
- "for correctness upon loading.",
- stacklevel=2,
- )
- return
- if original_source != current_source:
- if container_type.dump_patches:
- file_name = container_type.__name__ + ".patch"
- diff = difflib.unified_diff(
- current_source.split("\n"),
- original_source.split("\n"),
- source_file,
- source_file,
- lineterm="",
- )
- lines = "\n".join(diff)
- try:
- with open(file_name, "a+") as f:
- file_size = f.seek(0, 2)
- f.seek(0)
- if file_size == 0:
- f.write(lines)
- elif file_size != len(lines) or f.read() != lines:
- raise OSError
- msg = (
- "Saved a reverse patch to " + file_name + ". "
- "Run `patch -p0 < " + file_name + "` to revert your "
- "changes."
- )
- except OSError:
- msg = (
- "Tried to save a patch, but couldn't create a "
- "writable file " + file_name + ". Make sure it "
- "doesn't exist and your working directory is "
- "writable."
- )
- else:
- msg = (
- "you can retrieve the original source code by "
- "accessing the object's source attribute or set "
- "`torch.nn.Module.dump_patches = True` and use the "
- "patch tool to revert the changes."
- )
- msg = f"source code of class '{torch.typename(container_type)}' has changed. {msg}"
- warnings.warn(msg, SourceChangeWarning, stacklevel=2)
- def legacy_load(f):
- deserialized_objects: dict[int, Any] = {}
- def persistent_load(saved_id):
- if isinstance(saved_id, tuple):
- # Ignore containers that don't have any sources saved
- if all(saved_id[1:]):
- _check_container_source(*saved_id)
- return saved_id[0]
- return deserialized_objects[int(saved_id)]
- with (
- closing(
- tarfile.open(fileobj=f, mode="r:", format=tarfile.PAX_FORMAT)
- ) as tar,
- mkdtemp() as tmpdir,
- ):
- if pickle_module is _weights_only_unpickler:
- raise RuntimeError(
- "Cannot use ``weights_only=True`` with files saved in the "
- "legacy .tar format. " + UNSAFE_MESSAGE
- )
- tar.extract("storages", path=tmpdir)
- with open(os.path.join(tmpdir, "storages"), "rb", 0) as f:
- num_storages = pickle_module.load(f, **pickle_load_args)
- for _ in range(num_storages):
- args = pickle_module.load(f, **pickle_load_args)
- key, location, storage_type = args
- dtype = storage_type._dtype
- obj = cast(Storage, torch.UntypedStorage)._new_with_file(
- f, torch._utils._element_size(dtype)
- )
- obj = restore_location(obj, location)
- # TODO: Once we decide to break serialization FC, we can
- # stop wrapping with TypedStorage
- deserialized_objects[key] = torch.storage.TypedStorage(
- wrap_storage=obj, dtype=dtype, _internal=True
- )
- storage_views = pickle_module.load(f, **pickle_load_args)
- for target_cdata, root_cdata, offset, numel in storage_views:
- root = deserialized_objects[root_cdata]
- element_size = torch._utils._element_size(root.dtype)
- offset_bytes = offset * element_size
- # TODO: Once we decide to break serialization FC, we can
- # stop wrapping with TypedStorage
- deserialized_objects[target_cdata] = torch.storage.TypedStorage(
- wrap_storage=root._untyped_storage[
- offset_bytes : offset_bytes + numel * element_size
- ],
- dtype=root.dtype,
- _internal=True,
- )
- tar.extract("tensors", path=tmpdir)
- with open(os.path.join(tmpdir, "tensors"), "rb", 0) as f:
- num_tensors = pickle_module.load(f, **pickle_load_args)
- for _ in range(num_tensors):
- args = pickle_module.load(f, **pickle_load_args)
- key, storage_id, _original_tensor_type = args
- storage = deserialized_objects[storage_id]
- (ndim,) = struct.unpack("<i", f.read(4))
- # skip next 4 bytes; legacy encoding treated ndim as 8 bytes
- f.read(4)
- numel = struct.unpack(f"<{ndim}q", f.read(8 * ndim))
- stride = struct.unpack(f"<{ndim}q", f.read(8 * ndim))
- (storage_offset,) = struct.unpack("<q", f.read(8))
- tensor = torch.empty((0,), dtype=storage.dtype).set_(
- storage._untyped_storage, storage_offset, numel, stride
- )
- deserialized_objects[key] = tensor
- pickle_file = tar.extractfile("pickle")
- unpickler = UnpicklerWrapper(pickle_file, **pickle_load_args)
- unpickler.persistent_load = persistent_load
- result = unpickler.load()
- return result
- deserialized_objects = {}
- def persistent_load(saved_id):
- if not isinstance(saved_id, tuple):
- raise AssertionError(
- f"saved_id must be a tuple, got {type(saved_id).__name__}"
- )
- typename = _maybe_decode_ascii(saved_id[0])
- data = saved_id[1:]
- if typename == "module":
- # Ignore containers that don't have any sources saved
- if all(data[1:]):
- _check_container_source(*data)
- return data[0]
- elif typename == "storage":
- storage_type, root_key, location, numel, view_metadata = data
- location = _maybe_decode_ascii(location)
- dtype = storage_type.dtype
- nbytes = numel * torch._utils._element_size(dtype)
- if root_key not in deserialized_objects:
- if torch._guards.active_fake_mode() is not None:
- obj = cast(Storage, torch.UntypedStorage(nbytes, device="meta"))
- elif _serialization_tls.skip_data:
- obj = cast(Storage, torch.UntypedStorage(nbytes))
- obj = restore_location(obj, location)
- else:
- obj = cast(Storage, torch.UntypedStorage(nbytes))
- obj._torch_load_uninitialized = True
- obj = restore_location(obj, location)
- # TODO: Once we decide to break serialization FC, we can
- # stop wrapping with TypedStorage
- typed_storage = torch.storage.TypedStorage(
- wrap_storage=obj, dtype=dtype, _internal=True
- )
- deserialized_objects[root_key] = typed_storage
- else:
- typed_storage = deserialized_objects[root_key]
- if typed_storage._data_ptr() == 0:
- typed_storage = torch.storage.TypedStorage(
- device=typed_storage._untyped_storage.device,
- dtype=dtype,
- _internal=True,
- )
- if view_metadata is not None:
- view_key, offset, view_size = view_metadata
- offset_bytes = offset * torch._utils._element_size(dtype)
- view_size_bytes = view_size * torch._utils._element_size(dtype)
- if view_key not in deserialized_objects:
- # TODO: Once we decide to break serialization FC, we can
- # stop wrapping with TypedStorage
- deserialized_objects[view_key] = torch.storage.TypedStorage(
- wrap_storage=typed_storage._untyped_storage[
- offset_bytes : offset_bytes + view_size_bytes
- ],
- dtype=dtype,
- _internal=True,
- )
- res = deserialized_objects[view_key]
- else:
- res = typed_storage
- return res
- else:
- raise RuntimeError(f"Unknown saved id type: {saved_id[0]}")
- _check_seekable(f)
- f_should_read_directly = _should_read_directly(f)
- if f_should_read_directly and f.tell() == 0:
- # legacy_load requires that f has fileno()
- # only if offset is zero we can attempt the legacy tar file loader
- try:
- return legacy_load(f)
- except tarfile.TarError:
- if _is_zipfile(f):
- # .zip is used for torch.jit.save and will throw an un-pickling error here
- raise RuntimeError(
- f"{f.name} is a zip archive (did you mean to use torch.jit.load()?)"
- ) from None
- # if not a tarfile, reset file offset and proceed
- f.seek(0)
- magic_number = pickle_module.load(f, **pickle_load_args)
- if magic_number != MAGIC_NUMBER:
- raise RuntimeError("Invalid magic number; corrupt file?")
- protocol_version = pickle_module.load(f, **pickle_load_args)
- if protocol_version != PROTOCOL_VERSION:
- raise RuntimeError(f"Invalid protocol version: {protocol_version}")
- _sys_info = pickle_module.load(f, **pickle_load_args)
- unpickler = UnpicklerWrapper(f, **pickle_load_args)
- unpickler.persistent_load = persistent_load
- result = unpickler.load()
- deserialized_storage_keys = pickle_module.load(f, **pickle_load_args)
- if torch._guards.active_fake_mode() is None and not _serialization_tls.skip_data:
- offset = f.tell() if f_should_read_directly else None
- for key in deserialized_storage_keys:
- if key not in deserialized_objects:
- raise AssertionError(
- f"storage key {key!r} not found in deserialized_objects"
- )
- typed_storage = deserialized_objects[key]
- typed_storage._untyped_storage._set_from_file(
- f,
- offset,
- f_should_read_directly,
- torch._utils._element_size(typed_storage.dtype),
- )
- if offset is not None:
- offset = f.tell()
- torch._utils._validate_loaded_sparse_tensors()
- return result
- def _maybe_decode_ascii(bytes_str: bytes | str) -> str:
- # When using encoding='bytes' in Py3, some **internal** keys stored as
- # strings in Py2 are loaded as bytes. This function decodes them with
- # ascii encoding, one that Py3 uses by default.
- #
- # NOTE: This should only be used on internal keys (e.g., `typename` and
- # `location` in `persistent_load` below!
- if isinstance(bytes_str, bytes):
- return bytes_str.decode("ascii")
- return bytes_str
- def _get_restore_location(map_location):
- if map_location is None:
- restore_location = default_restore_location
- elif isinstance(map_location, dict):
- def restore_location(storage, location):
- location = map_location.get(location, location)
- return default_restore_location(storage, location)
- elif isinstance(map_location, (str, bytes)):
- def restore_location(storage, location):
- return default_restore_location(storage, map_location)
- elif isinstance(map_location, torch.device):
- def restore_location(storage, location):
- return default_restore_location(storage, str(map_location))
- else:
- def restore_location(storage, location):
- result = map_location(storage, location)
- if result is None:
- result = default_restore_location(storage, location)
- return result
- return restore_location
- class StorageType:
- def __init__(self, name):
- self._dtype = _get_dtype_from_pickle_storage_type(name)
- @property
- def dtype(self):
- return self._dtype
- def __str__(self):
- return f"StorageType(dtype={self.dtype})"
- def _load(
- zip_file,
- map_location,
- pickle_module,
- pickle_file="data.pkl",
- overall_storage=None,
- **pickle_load_args,
- ):
- restore_location = _get_restore_location(map_location)
- loaded_storages = {}
- is_meta_map_location = _is_meta_location(map_location)
- can_calculate_storage_offsets = False
- if zip_file.has_record(".format_version"):
- version = zip_file.get_record(".format_version")
- can_calculate_storage_offsets = version >= b"1"
- # check if byteswapping is needed
- byteordername = "byteorder"
- byteorderdata = None
- if zip_file.has_record(byteordername):
- byteorderdata = zip_file.get_record(byteordername)
- if byteorderdata not in [b"little", b"big"]:
- raise ValueError("Unknown endianness type: " + byteorderdata.decode())
- elif (
- get_default_load_endianness() == LoadEndianness.LITTLE
- or get_default_load_endianness() is None
- ):
- byteorderdata = b"little"
- elif get_default_load_endianness() == LoadEndianness.BIG:
- byteorderdata = b"big"
- elif get_default_load_endianness() == LoadEndianness.NATIVE:
- pass
- else:
- raise ValueError("Invalid load endianness type")
- storage_alignment = 64
- if zip_file.has_record(".storage_alignment"):
- storage_alignment = int(zip_file.get_record(".storage_alignment"))
- if (
- not zip_file.has_record(byteordername)
- and get_default_load_endianness() is None
- and sys.byteorder == "big"
- ):
- # Default behaviour was changed
- # See https://github.com/pytorch/pytorch/issues/101688
- warnings.warn(
- "The default load endianness for checkpoints without a byteorder mark "
- "on big endian machines was changed from 'native' to 'little' endian, "
- "to avoid this behavior please use "
- "torch.serialization.set_default_load_endianness to set "
- "the desired default load endianness",
- UserWarning,
- stacklevel=2,
- )
- from torch.utils.serialization import config
- calculate_storage_offsets = config.load.calculate_storage_offsets
- run_debug_asserts = os.environ.get("TORCH_SERIALIZATION_DEBUG", "0") == "1"
- current_offset = None
- # constants from miniz.h/miniz.c
- data_descripter_size64 = 24
- data_descripter_size32 = 16
- mz_uint32_max = 0xFFFFFFFF
- offsets: dict[str, int] = dict()
- def _get_offset(key, name, numel):
- """
- Return the offset of the storage associated with key with record name `name` and size numel.
- It is expected that the zipfile header of this storage starts at current_offset.
- WARNING: This function relies on the behavior of the zipwriter in miniz.c. In particular,
- the behavior of `mz_zip_writer_add_mem_ex_v2`. The behavior of this function must be kept
- in sync with that of miniz!
- After reading a storage of size numel that starts at storage_offset
- if it is the first time that storage was read, update nonlocal variable
- current_offset to the start of the next zipfile header by incrementing
- it by numel and the data descriptor size.
- """
- nonlocal current_offset, offsets
- if name in offsets:
- storage_offset = offsets[name]
- return storage_offset
- if current_offset is None:
- if key != "0":
- raise AssertionError(f"expected key '0', got {key!r}")
- current_offset = zip_file.get_record_offset(name)
- local_header_offset = zip_file.get_record_header_offset(name)
- storage_offset = current_offset
- else:
- storage_offset = zip_file.get_record_offset_no_read(
- current_offset, name, numel, storage_alignment
- )
- local_header_offset = current_offset
- # This is only actually needed for storages that have typed_storage._data_ptr() == 0
- # after being read. Otherwise persistent_load would never "re-call" load_tensor
- # for a given key.
- offsets[name] = storage_offset
- # Increment current_offset to offset where next zipfile header starts
- current_offset = storage_offset + numel
- # add size of data descriptor after payload
- if numel > 0:
- if local_header_offset >= mz_uint32_max or numel >= mz_uint32_max:
- current_offset += data_descripter_size64
- else:
- current_offset += data_descripter_size32
- return storage_offset
- def load_tensor(dtype, nbytes, key, location):
- name = f"data/{key}"
- if torch._guards.detect_fake_mode(None) is not None or is_meta_map_location:
- storage = torch.UntypedStorage(nbytes, device="meta")
- if can_calculate_storage_offsets:
- storage._checkpoint_offset = _get_offset(key, name, nbytes)
- else:
- storage._checkpoint_offset = zip_file.get_record_offset(name)
- elif _serialization_tls.skip_data:
- storage = torch.UntypedStorage(nbytes)
- elif overall_storage is not None:
- if can_calculate_storage_offsets and calculate_storage_offsets:
- storage_offset = _get_offset(key, name, nbytes)
- if run_debug_asserts:
- if storage_offset != zip_file.get_record_offset(name):
- raise RuntimeError(
- "This is a debug assert that was run as the `TORCH_SERIALIZATION_DEBUG` environment "
- f"variable was set: Incorrect offset for {name}, got {storage_offset} expected "
- f"{zip_file.get_record_offset(name)}"
- )
- else:
- storage_offset = zip_file.get_record_offset(name)
- storage = overall_storage[storage_offset : storage_offset + nbytes]
- else:
- if can_calculate_storage_offsets and run_debug_asserts:
- # This is debug code that we use to test the validity of
- # torch.utils.serialization.config.load.calculate_storage_offsets throughout CI
- storage_offset = _get_offset(key, name, nbytes)
- if storage_offset != zip_file.get_record_offset(name):
- raise RuntimeError(
- "This is a debug assert that was run as the `TORCH_SERIALIZATION_DEBUG` environment "
- f"variable was set: Incorrect offset for {name}, got {storage_offset} expected "
- f"{zip_file.get_record_offset(name)}"
- )
- storage = (
- zip_file.get_storage_from_record(name, nbytes, torch.UntypedStorage)
- ._typed_storage()
- ._untyped_storage
- )
- # swap here if byteswapping is needed
- if byteorderdata is not None:
- if byteorderdata.decode() != sys.byteorder:
- storage.byteswap(dtype)
- # TODO: Once we decide to break serialization FC, we can
- # stop wrapping with TypedStorage
- if is_meta_map_location:
- # Skip restore_location for meta map_location. Since we already created
- # a meta storage above, calling restore_location would just redundantly
- # call _meta_deserialize which creates another meta storage with the same
- # size.
- wrap_storage = storage
- elif torch._guards.detect_fake_mode(None) is None:
- wrap_storage = restore_location(storage, location)
- else:
- storage._fake_device = location
- wrap_storage = storage
- typed_storage = torch.storage.TypedStorage(
- wrap_storage=wrap_storage,
- dtype=dtype,
- _internal=True,
- )
- if typed_storage._data_ptr() != 0:
- loaded_storages[key] = typed_storage
- return typed_storage
- def persistent_load(saved_id):
- if not isinstance(saved_id, tuple):
- raise AssertionError(
- f"saved_id must be a tuple, got {type(saved_id).__name__}"
- )
- typename = _maybe_decode_ascii(saved_id[0])
- data = saved_id[1:]
- if typename != "storage":
- raise AssertionError(
- f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
- )
- storage_type, key, location, numel = data
- if storage_type is torch.UntypedStorage:
- dtype = torch.uint8
- else:
- dtype = storage_type.dtype
- if key in loaded_storages:
- typed_storage = loaded_storages[key]
- else:
- nbytes = numel * torch._utils._element_size(dtype)
- typed_storage = load_tensor(
- dtype, nbytes, key, _maybe_decode_ascii(location)
- )
- return typed_storage
- load_module_mapping: dict[str, str] = {
- # See https://github.com/pytorch/pytorch/pull/51633
- "torch.tensor": "torch._tensor"
- }
- # Need to subclass Unpickler instead of directly monkey-patching the find_class method
- # because it's marked readonly in pickle.
- # The type: ignore is because mypy can't statically determine the type of this class.
- class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined]
- # from https://stackoverflow.com/questions/13398462/unpickling-python-objects-with-a-changed-module-path/13405732
- # Lets us override the imports that pickle uses when unpickling an object.
- # This is useful for maintaining BC if we change a module path that tensor instantiation relies on.
- def find_class(self, mod_name, name):
- if type(name) is str and "Storage" in name:
- try:
- return StorageType(name)
- except KeyError:
- pass
- mod_name = load_module_mapping.get(mod_name, mod_name)
- return super().find_class(mod_name, name)
- # Load the data (which may in turn use `persistent_load` to load tensors)
- data_file = io.BytesIO(zip_file.get_record(pickle_file))
- unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
- unpickler.persistent_load = persistent_load
- # Needed for tensors where storage device and rebuild tensor device are
- # not connected (wrapper subclasses and tensors rebuilt using numpy)
- global _serialization_tls
- _serialization_tls.map_location = map_location
- result = unpickler.load()
- _serialization_tls.map_location = None
- torch._utils._validate_loaded_sparse_tensors()
- torch._C._log_api_usage_metadata(
- "torch.load.metadata", {"serialization_id": zip_file.serialization_id()}
- )
- return result
- def _is_torchscript_zip(zip_file):
- return "constants.pkl" in zip_file.get_all_records()
|