| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590 |
- # mypy: allow-untyped-defs
- # Unpickler restricted to loading only state dicts
- # Restrict constructing types to a list defined in _get_allowed_globals()
- # Restrict BUILD operation to `Tensor`, `Parameter` and `OrderedDict` types only
- # Restrict APPEND/APPENDS to `list`
- # In `GLOBALS` operation do not do class lookup by name, but rather rely on dictionary
- # defined by `_get_allowed_globals()` method, that contains:
- # - torch types (Storage, dtypes, Tensor, `torch.Size`),
- # - `torch._utils._rebuild` functions.
- # - `torch.nn.Parameter`
- # - `collections.Counter`
- # - `collections.OrderedDict`
- # Additionally, users can use an allowlist for adding classes they have deemed as safe using
- # `_add_safe_globals()` (`torch.serialization.add_safe_globals`)
- # `_clear_safe_globals()` (`torch.serialization.clear_safe_globals`)
- # `_get_safe_globals()` (`torch.serialization.get_safe_globals`)
- # Based of https://github.com/python/cpython/blob/main/Lib/pickle.py
- # Expected to be useful for loading PyTorch model weights
- # For example:
- # data = urllib.request.urlopen('https://download.pytorch.org/models/resnet50-0676ba61.pth').read()
- # buf = io.BytesIO(data)
- # weights = torch.load(buf, weights_only = True)
- import functools as _functools
- import warnings
- from _codecs import encode
- from collections import Counter, OrderedDict
- from collections.abc import Callable
- from pickle import (
- APPEND,
- APPENDS,
- BINFLOAT,
- BINGET,
- BININT,
- BININT1,
- BININT2,
- BINPERSID,
- BINPUT,
- BINUNICODE,
- BUILD,
- bytes_types,
- decode_long,
- EMPTY_DICT,
- EMPTY_LIST,
- EMPTY_SET,
- EMPTY_TUPLE,
- GLOBAL,
- LONG1,
- LONG_BINGET,
- LONG_BINPUT,
- MARK,
- NEWFALSE,
- NEWOBJ,
- NEWTRUE,
- NONE,
- PROTO,
- REDUCE,
- SETITEM,
- SETITEMS,
- SHORT_BINSTRING,
- STOP,
- TUPLE,
- TUPLE1,
- TUPLE2,
- TUPLE3,
- UnpicklingError,
- )
- from struct import unpack
- from sys import maxsize
- from typing import Any
- import torch
- from torch._utils import _sparse_tensors_to_validate, IMPORT_MAPPING, NAME_MAPPING
- # modules in this list are never allowed, even if the user attempts to allowlist
- # functions/classes from them
- _blocklisted_modules = [
- "sys",
- "os",
- "posix",
- "nt",
- ]
- _marked_safe_globals_set: set[Callable | tuple[Callable, str]] = set()
- def _add_safe_globals(safe_globals: list[Callable | tuple[Callable, str]]):
- global _marked_safe_globals_set
- _marked_safe_globals_set = _marked_safe_globals_set.union(set(safe_globals))
- def _get_safe_globals() -> list[Callable | tuple[Callable, str]]:
- global _marked_safe_globals_set
- return list(_marked_safe_globals_set)
- def _clear_safe_globals():
- global _marked_safe_globals_set
- _marked_safe_globals_set = set()
- def _remove_safe_globals(
- globals_to_remove: list[Callable | tuple[Callable, str]],
- ):
- global _marked_safe_globals_set
- _marked_safe_globals_set = _marked_safe_globals_set - set(globals_to_remove)
- class _safe_globals:
- def __init__(self, safe_globals: list[Callable | tuple[Callable, str]]):
- self.safe_globals = safe_globals
- def __enter__(self):
- _add_safe_globals(self.safe_globals)
- def __exit__(self, type, value, tb):
- _remove_safe_globals(self.safe_globals)
- # Separate from _get_allowed_globals because of the lru_cache on _get_allowed_globals
- # For example if user had a script like
- # torch.load(file_a)
- # torch.serialization._add_safe_globals([torch.foo])
- # torch.load(file_b)
- # the dynamic additions to safe_globals would not be picked up by
- # _get_allowed_globals due to the lru_cache
- def _get_user_allowed_globals():
- rc: dict[str, Any] = {}
- for f in _marked_safe_globals_set:
- if isinstance(f, tuple):
- if len(f) != 2:
- raise ValueError(
- f"Expected tuple of length 2 (global, str of callable full path), but got tuple of length: {len(f)}"
- )
- if type(f[1]) is not str:
- raise TypeError(
- f"Expected second item in tuple to be str of callable full path, but got: {type(f[1])}"
- )
- f, name = f
- rc[name] = f
- else:
- module, name = f.__module__, f.__qualname__
- rc[f"{module}.{name}"] = f
- return rc
- def _tensor_rebuild_functions():
- return {
- torch._utils._rebuild_parameter,
- torch._utils._rebuild_parameter_with_state,
- torch._utils._rebuild_qtensor,
- torch._utils._rebuild_tensor,
- torch._utils._rebuild_tensor_v2,
- torch._utils._rebuild_tensor_v3,
- torch._utils._rebuild_sparse_tensor,
- torch._utils._rebuild_meta_tensor_no_storage,
- torch._utils._rebuild_nested_tensor,
- torch._utils._rebuild_wrapper_subclass,
- # Allowlisting this, but not allowlisting the numpy functions by default
- # Reasoning is that we don't have control over the numpy functions, but
- # this utility is provided by pytorch
- torch._utils._rebuild_device_tensor_from_numpy,
- # In 2.6, we should no longer have a dependency on numpy and the above
- # _rebuild_device_tensor_from_numpy function.
- torch._utils._rebuild_device_tensor_from_cpu_tensor,
- }
- # Unpickling machinery
- @_functools.lru_cache(maxsize=1)
- def _get_allowed_globals():
- rc: dict[str, Any] = {
- "collections.OrderedDict": OrderedDict,
- "collections.Counter": Counter,
- "torch.nn.parameter.Parameter": torch.nn.Parameter,
- "torch.serialization._get_layout": torch.serialization._get_layout,
- "torch.Size": torch.Size,
- "torch.Tensor": torch.Tensor,
- "torch.device": torch.device,
- "_codecs.encode": encode, # for bytes
- "builtins.bytearray": bytearray, # for bytearray
- "builtins.set": set, # for set
- "builtins.complex": complex, # for complex
- }
- # dtype
- for t in torch.storage._dtype_to_storage_type_map():
- rc[str(t)] = t
- for t in torch.storage._new_dtypes():
- rc[str(t)] = t
- for t in [getattr(torch, f"uint{x}") for x in range(1, 8)]:
- rc[str(t)] = t
- for t in [getattr(torch, f"int{x}") for x in range(1, 8)]:
- rc[str(t)] = t
- # Tensor classes
- for tt in torch._tensor_classes:
- rc[f"{tt.__module__}.{tt.__name__}"] = tt
- # Storage classes
- for ts in torch._storage_classes:
- if ts not in (torch.storage.TypedStorage, torch.storage.UntypedStorage):
- # Wrap legacy storage types in a dummy class
- rc[f"{ts.__module__}.{ts.__name__}"] = torch.serialization.StorageType(
- ts.__name__
- )
- else:
- rc[f"{ts.__module__}.{ts.__name__}"] = ts
- # Quantization specific
- for qt in [
- torch.per_tensor_affine,
- torch.per_tensor_symmetric,
- torch.per_channel_affine,
- torch.per_channel_symmetric,
- torch.per_channel_affine_float_qparams,
- ]:
- rc[str(qt)] = qt
- # Rebuild functions
- for f in _tensor_rebuild_functions():
- rc[f"torch._utils.{f.__name__}"] = f
- # Handles Tensor Subclasses, Tensor's with attributes.
- # NOTE: It calls into above rebuild functions for regular Tensor types.
- rc["torch._tensor._rebuild_from_type_v2"] = torch._tensor._rebuild_from_type_v2
- return rc
- def _read_global_instruction(readline: Callable) -> tuple[str, str]:
- module = readline()[:-1].decode("utf-8")
- name = readline()[:-1].decode("utf-8")
- # Patch since torch.save default protocol is 2
- # users will be running this code in python > 3
- if (module, name) in NAME_MAPPING:
- module, name = NAME_MAPPING[(module, name)]
- elif module in IMPORT_MAPPING:
- module = IMPORT_MAPPING[module]
- return module, name
- def get_globals_in_pkl(file) -> set[str]:
- globals_in_checkpoint = set()
- read = file.read
- readline = file.readline
- op_to_bytes_to_read = {
- NEWOBJ[0]: 0,
- REDUCE[0]: 0,
- BUILD[0]: 0,
- APPEND[0]: 0,
- APPENDS[0]: 0,
- SETITEM[0]: 0,
- SETITEMS[0]: 0,
- MARK[0]: 0,
- TUPLE[0]: 0,
- TUPLE1[0]: 0,
- TUPLE2[0]: 0,
- TUPLE3[0]: 0,
- NONE[0]: 0,
- NEWFALSE[0]: 0,
- NEWTRUE[0]: 0,
- EMPTY_TUPLE[0]: 0,
- EMPTY_LIST[0]: 0,
- EMPTY_DICT[0]: 0,
- EMPTY_SET[0]: 0,
- BINPERSID[0]: 0,
- BININT[0]: 4,
- BININT1[0]: 1,
- BININT2[0]: 2,
- BINFLOAT[0]: 8,
- BINGET[0]: 1,
- LONG_BINGET[0]: 4,
- BINPUT[0]: 1,
- LONG_BINPUT[0]: 4,
- }
- while True:
- key = read(1)
- if not key:
- raise EOFError
- if not isinstance(key, bytes_types):
- raise AssertionError(f"Expected bytes, got {type(key).__name__}")
- if key[0] == GLOBAL[0]:
- module, name = _read_global_instruction(readline)
- globals_in_checkpoint.add(f"{module}.{name}")
- elif key[0] in op_to_bytes_to_read:
- bytes_to_read = op_to_bytes_to_read[key[0]]
- if bytes_to_read:
- read(bytes_to_read)
- # ops where bytes to read depends on the data
- elif key[0] == BINUNICODE[0]:
- strlen = unpack("<I", read(4))[0]
- if strlen > maxsize:
- raise UnpicklingError("String is too long")
- read(strlen)
- elif key[0] in {SHORT_BINSTRING[0], LONG1[0]}:
- strlen = read(1)[0]
- read(strlen)
- # first and last op
- elif key[0] == PROTO[0]:
- read(1)[0]
- elif key[0] == STOP[0]:
- return globals_in_checkpoint
- else:
- raise UnpicklingError(f"Unsupported operand {key[0]}")
- class Unpickler:
- def __init__(self, file, *, encoding: str = "bytes"):
- self.encoding = encoding
- self.readline = file.readline
- self.read = file.read
- self.memo: dict[int, Any] = {}
- self.proto: int = -1
- def load(self):
- """Read a pickled object representation from the open file.
- Return the reconstituted object hierarchy specified in the file.
- """
- self.metastack = []
- self.stack: list[Any] = []
- self.append = self.stack.append
- read = self.read
- while True:
- key = read(1)
- if not key:
- raise EOFError
- if not isinstance(key, bytes_types):
- raise AssertionError(f"Expected bytes, got {type(key).__name__}")
- # Risky operators
- if key[0] == GLOBAL[0]:
- module, name = _read_global_instruction(self.readline)
- full_path = f"{module}.{name}"
- if module in _blocklisted_modules:
- raise UnpicklingError(
- f"Trying to load unsupported GLOBAL {full_path} whose module {module} is blocked."
- )
- if full_path in _get_allowed_globals():
- self.append(_get_allowed_globals()[full_path])
- elif full_path in _get_user_allowed_globals():
- self.append(_get_user_allowed_globals()[full_path])
- elif full_path in (
- [
- "torch.nested._internal.nested_tensor.NestedTensor",
- "torch.nested._internal.nested_tensor._rebuild_njt",
- "torch._dynamo.decorators._DimRange",
- ]
- ):
- raise UnpicklingError(
- "``torch.nested`` and ``torch._dynamo`` must be imported to load nested jagged tensors (NJTs)"
- )
- elif full_path in (
- [
- "torch.distributed.device_mesh.DeviceMesh",
- "torch.distributed.tensor._dtensor_spec.DTensorSpec",
- "torch.distributed.tensor._dtensor_spec.TensorMeta",
- "torch.distributed.tensor.DTensor",
- "torch.distributed.tensor.placement_types.Partial",
- "torch.distributed.tensor.placement_types.Replicate",
- "torch.distributed.tensor.placement_types.Shard",
- ]
- ):
- raise UnpicklingError(
- "``torch.distributed.tensor`` must be imported to load DTensors"
- )
- else:
- builtins_name = "builtins"
- if (
- builtins_name in full_path
- and builtins_name == full_path[: len(builtins_name)]
- ):
- full_path = full_path[len(builtins_name) :]
- full_path = (
- full_path[1:]
- if len(full_path) > 0 and full_path[0] == "."
- else builtins_name + full_path
- )
- raise UnpicklingError(
- f"Unsupported global: GLOBAL {full_path} was not an allowed global by default. "
- f"Please use `torch.serialization.add_safe_globals([{full_path}])` or the "
- f"`torch.serialization.safe_globals([{full_path}])` context manager to allowlist this global "
- "if you trust this class/function."
- )
- elif key[0] == NEWOBJ[0]:
- args = self.stack.pop()
- cls = self.stack.pop()
- if cls is torch.nn.Parameter:
- self.append(torch.nn.Parameter(*args))
- elif (
- cls in _get_user_allowed_globals().values()
- or cls in _get_allowed_globals().values()
- ):
- result = cls.__new__(cls, *args)
- if cls in torch._tensor_classes and "sparse" in cls.__module__:
- _sparse_tensors_to_validate.append(result)
- self.append(result)
- else:
- raise UnpicklingError(
- "Can only create new object for nn.Parameter or classes allowlisted "
- f"via `add_safe_globals` but got {cls}"
- )
- elif key[0] == REDUCE[0]:
- args = self.stack.pop()
- func = self.stack[-1]
- if (
- func not in _get_allowed_globals().values()
- and func not in _get_user_allowed_globals().values()
- ):
- error_msg = (
- f"Trying to call reduce for unrecognized function {func}"
- )
- if hasattr(func, "__self__"):
- error_msg += f" which belongs to {func.__self__}"
- raise UnpicklingError(error_msg)
- result = func(*args)
- if func in torch._tensor_classes and "sparse" in func.__module__:
- _sparse_tensors_to_validate.append(result)
- self.stack[-1] = result
- elif key[0] == BUILD[0]:
- state = self.stack.pop()
- inst = self.stack[-1]
- if type(inst) is torch.Tensor:
- # Legacy unpickling
- inst.set_(*state)
- elif type(inst) is torch.nn.Parameter:
- inst.__setstate__(state)
- elif type(inst) is OrderedDict:
- inst.__dict__.update(state)
- elif (
- type(inst) in _get_user_allowed_globals().values()
- or type(inst) in _get_allowed_globals().values()
- ):
- if hasattr(inst, "__setstate__"):
- inst.__setstate__(state)
- else:
- # mimics load_build in pickle
- # https://github.com/python/cpython/blob/f0c6fccd08904787a39269367f09f263d496114c/Lib/pickle.py#L1854-L1867
- slotstate = None
- if isinstance(state, tuple) and len(state) == 2:
- state, slotstate = state
- if state:
- inst.__dict__.update(state)
- if slotstate:
- for k, v in slotstate.items():
- setattr(inst, k, v)
- else:
- raise UnpicklingError(
- "Can only build Tensor, Parameter, OrderedDict or types allowlisted "
- f"via `add_safe_globals`, but got {type(inst)}"
- )
- # Stack manipulation
- elif key[0] == APPEND[0]:
- item = self.stack.pop()
- list_obj = self.stack[-1]
- if type(list_obj) is not list:
- raise UnpicklingError(
- f"Can only append to lists, but got {type(list_obj)}"
- )
- list_obj.append(item)
- elif key[0] == APPENDS[0]:
- items = self.pop_mark()
- list_obj = self.stack[-1]
- if type(list_obj) is not list:
- raise UnpicklingError(
- f"Can only extend lists, but got {type(list_obj)}"
- )
- list_obj.extend(items)
- elif key[0] == SETITEM[0]:
- (v, k) = (self.stack.pop(), self.stack.pop())
- self._check_set_item_target("SETITEM")
- self.stack[-1][k] = v
- elif key[0] == SETITEMS[0]:
- items = self.pop_mark()
- self._check_set_item_target("SETITEMS")
- for i in range(0, len(items), 2):
- self.stack[-1][items[i]] = items[i + 1]
- elif key[0] == MARK[0]:
- self.metastack.append(self.stack)
- self.stack = []
- self.append = self.stack.append
- elif key[0] == TUPLE[0]:
- items = self.pop_mark()
- self.append(tuple(items))
- elif key[0] == TUPLE1[0]:
- self.stack[-1] = (self.stack[-1],)
- elif key[0] == TUPLE2[0]:
- self.stack[-2:] = [(self.stack[-2], self.stack[-1])]
- elif key[0] == TUPLE3[0]:
- self.stack[-3:] = [(self.stack[-3], self.stack[-2], self.stack[-1])]
- # Basic types construction
- elif key[0] == NONE[0]:
- self.append(None)
- elif key[0] == NEWFALSE[0]:
- self.append(False)
- elif key[0] == NEWTRUE[0]:
- self.append(True)
- elif key[0] == EMPTY_TUPLE[0]:
- self.append(())
- elif key[0] == EMPTY_LIST[0]:
- self.append([])
- elif key[0] == EMPTY_DICT[0]:
- self.append({})
- elif key[0] == EMPTY_SET[0]:
- self.append(set())
- elif key[0] == BININT[0]:
- self.append(unpack("<i", read(4))[0])
- elif key[0] == BININT1[0]:
- self.append(self.read(1)[0])
- elif key[0] == BININT2[0]:
- self.append(unpack("<H", read(2))[0])
- elif key[0] == BINFLOAT[0]:
- self.append(unpack(">d", self.read(8))[0])
- elif key[0] == BINUNICODE[0]:
- strlen = unpack("<I", read(4))[0]
- if strlen > maxsize:
- raise UnpicklingError("String is too long")
- strval = str(read(strlen), "utf-8", "surrogatepass")
- self.append(strval)
- elif key[0] == SHORT_BINSTRING[0]:
- strlen = read(1)[0]
- strdata = read(strlen)
- if self.encoding != "bytes":
- strdata = strdata.decode(self.encoding, "strict")
- self.append(strdata)
- elif key[0] == BINPERSID[0]:
- pid = self.stack.pop()
- # Only allow persistent load of storage
- if type(pid) is not tuple and type(pid) is not int:
- raise UnpicklingError(
- f"persistent_load id must be tuple or int, but got {type(pid)}"
- )
- if (
- type(pid) is tuple
- and len(pid) > 0
- and torch.serialization._maybe_decode_ascii(pid[0]) != "storage"
- ):
- raise UnpicklingError(
- f"Only persistent_load of storage is allowed, but got {type(pid[0])}"
- )
- self.append(self.persistent_load(pid))
- elif key[0] in [BINGET[0], LONG_BINGET[0]]:
- idx = (read(1) if key[0] == BINGET[0] else unpack("<I", read(4)))[0]
- self.append(self.memo[idx])
- elif key[0] in [BINPUT[0], LONG_BINPUT[0]]:
- i = (read(1) if key[0] == BINPUT[0] else unpack("<I", read(4)))[0]
- if i < 0:
- raise ValueError("negative argument")
- self.memo[i] = self.stack[-1]
- elif key[0] == LONG1[0]:
- n = read(1)[0]
- data = read(n)
- self.append(decode_long(data))
- # First and last deserializer ops
- elif key[0] == PROTO[0]:
- self.proto = read(1)[0]
- if self.proto != 2:
- warnings.warn(
- f"Detected pickle protocol {self.proto} in the checkpoint, which was "
- "not the default pickle protocol used by `torch.load` (2). The weights_only "
- "Unpickler might not support all instructions implemented by this protocol, "
- "please file an issue for adding support if you encounter this.",
- stacklevel=2,
- )
- elif key[0] == STOP[0]:
- rc = self.stack.pop()
- return rc
- else:
- raise UnpicklingError(f"Unsupported operand {key[0]}")
- # Return a list of items pushed in the stack after last MARK instruction.
- def pop_mark(self):
- items = self.stack
- self.stack = self.metastack.pop()
- self.append = self.stack.append
- return items
- def _check_set_item_target(self, opcode: str):
- if type(self.stack[-1]) not in [dict, OrderedDict, Counter]:
- raise UnpicklingError(
- f"Can only {opcode} for dict, collections.OrderedDict, "
- f"collections.Counter, but got {type(self.stack[-1])}"
- )
- def persistent_load(self, pid):
- raise UnpicklingError("unsupported persistent id encountered")
- def load(file, *, encoding: str = "ASCII"):
- return Unpickler(file, encoding=encoding).load()
|