_weights_only_unpickler.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590
  1. # mypy: allow-untyped-defs
  2. # Unpickler restricted to loading only state dicts
  3. # Restrict constructing types to a list defined in _get_allowed_globals()
  4. # Restrict BUILD operation to `Tensor`, `Parameter` and `OrderedDict` types only
  5. # Restrict APPEND/APPENDS to `list`
  6. # In `GLOBALS` operation do not do class lookup by name, but rather rely on dictionary
  7. # defined by `_get_allowed_globals()` method, that contains:
  8. # - torch types (Storage, dtypes, Tensor, `torch.Size`),
  9. # - `torch._utils._rebuild` functions.
  10. # - `torch.nn.Parameter`
  11. # - `collections.Counter`
  12. # - `collections.OrderedDict`
  13. # Additionally, users can use an allowlist for adding classes they have deemed as safe using
  14. # `_add_safe_globals()` (`torch.serialization.add_safe_globals`)
  15. # `_clear_safe_globals()` (`torch.serialization.clear_safe_globals`)
  16. # `_get_safe_globals()` (`torch.serialization.get_safe_globals`)
  17. # Based of https://github.com/python/cpython/blob/main/Lib/pickle.py
  18. # Expected to be useful for loading PyTorch model weights
  19. # For example:
  20. # data = urllib.request.urlopen('https://download.pytorch.org/models/resnet50-0676ba61.pth').read()
  21. # buf = io.BytesIO(data)
  22. # weights = torch.load(buf, weights_only = True)
  23. import functools as _functools
  24. import warnings
  25. from _codecs import encode
  26. from collections import Counter, OrderedDict
  27. from collections.abc import Callable
  28. from pickle import (
  29. APPEND,
  30. APPENDS,
  31. BINFLOAT,
  32. BINGET,
  33. BININT,
  34. BININT1,
  35. BININT2,
  36. BINPERSID,
  37. BINPUT,
  38. BINUNICODE,
  39. BUILD,
  40. bytes_types,
  41. decode_long,
  42. EMPTY_DICT,
  43. EMPTY_LIST,
  44. EMPTY_SET,
  45. EMPTY_TUPLE,
  46. GLOBAL,
  47. LONG1,
  48. LONG_BINGET,
  49. LONG_BINPUT,
  50. MARK,
  51. NEWFALSE,
  52. NEWOBJ,
  53. NEWTRUE,
  54. NONE,
  55. PROTO,
  56. REDUCE,
  57. SETITEM,
  58. SETITEMS,
  59. SHORT_BINSTRING,
  60. STOP,
  61. TUPLE,
  62. TUPLE1,
  63. TUPLE2,
  64. TUPLE3,
  65. UnpicklingError,
  66. )
  67. from struct import unpack
  68. from sys import maxsize
  69. from typing import Any
  70. import torch
  71. from torch._utils import _sparse_tensors_to_validate, IMPORT_MAPPING, NAME_MAPPING
  72. # modules in this list are never allowed, even if the user attempts to allowlist
  73. # functions/classes from them
  74. _blocklisted_modules = [
  75. "sys",
  76. "os",
  77. "posix",
  78. "nt",
  79. ]
  80. _marked_safe_globals_set: set[Callable | tuple[Callable, str]] = set()
  81. def _add_safe_globals(safe_globals: list[Callable | tuple[Callable, str]]):
  82. global _marked_safe_globals_set
  83. _marked_safe_globals_set = _marked_safe_globals_set.union(set(safe_globals))
  84. def _get_safe_globals() -> list[Callable | tuple[Callable, str]]:
  85. global _marked_safe_globals_set
  86. return list(_marked_safe_globals_set)
  87. def _clear_safe_globals():
  88. global _marked_safe_globals_set
  89. _marked_safe_globals_set = set()
  90. def _remove_safe_globals(
  91. globals_to_remove: list[Callable | tuple[Callable, str]],
  92. ):
  93. global _marked_safe_globals_set
  94. _marked_safe_globals_set = _marked_safe_globals_set - set(globals_to_remove)
  95. class _safe_globals:
  96. def __init__(self, safe_globals: list[Callable | tuple[Callable, str]]):
  97. self.safe_globals = safe_globals
  98. def __enter__(self):
  99. _add_safe_globals(self.safe_globals)
  100. def __exit__(self, type, value, tb):
  101. _remove_safe_globals(self.safe_globals)
  102. # Separate from _get_allowed_globals because of the lru_cache on _get_allowed_globals
  103. # For example if user had a script like
  104. # torch.load(file_a)
  105. # torch.serialization._add_safe_globals([torch.foo])
  106. # torch.load(file_b)
  107. # the dynamic additions to safe_globals would not be picked up by
  108. # _get_allowed_globals due to the lru_cache
  109. def _get_user_allowed_globals():
  110. rc: dict[str, Any] = {}
  111. for f in _marked_safe_globals_set:
  112. if isinstance(f, tuple):
  113. if len(f) != 2:
  114. raise ValueError(
  115. f"Expected tuple of length 2 (global, str of callable full path), but got tuple of length: {len(f)}"
  116. )
  117. if type(f[1]) is not str:
  118. raise TypeError(
  119. f"Expected second item in tuple to be str of callable full path, but got: {type(f[1])}"
  120. )
  121. f, name = f
  122. rc[name] = f
  123. else:
  124. module, name = f.__module__, f.__qualname__
  125. rc[f"{module}.{name}"] = f
  126. return rc
  127. def _tensor_rebuild_functions():
  128. return {
  129. torch._utils._rebuild_parameter,
  130. torch._utils._rebuild_parameter_with_state,
  131. torch._utils._rebuild_qtensor,
  132. torch._utils._rebuild_tensor,
  133. torch._utils._rebuild_tensor_v2,
  134. torch._utils._rebuild_tensor_v3,
  135. torch._utils._rebuild_sparse_tensor,
  136. torch._utils._rebuild_meta_tensor_no_storage,
  137. torch._utils._rebuild_nested_tensor,
  138. torch._utils._rebuild_wrapper_subclass,
  139. # Allowlisting this, but not allowlisting the numpy functions by default
  140. # Reasoning is that we don't have control over the numpy functions, but
  141. # this utility is provided by pytorch
  142. torch._utils._rebuild_device_tensor_from_numpy,
  143. # In 2.6, we should no longer have a dependency on numpy and the above
  144. # _rebuild_device_tensor_from_numpy function.
  145. torch._utils._rebuild_device_tensor_from_cpu_tensor,
  146. }
  147. # Unpickling machinery
  148. @_functools.lru_cache(maxsize=1)
  149. def _get_allowed_globals():
  150. rc: dict[str, Any] = {
  151. "collections.OrderedDict": OrderedDict,
  152. "collections.Counter": Counter,
  153. "torch.nn.parameter.Parameter": torch.nn.Parameter,
  154. "torch.serialization._get_layout": torch.serialization._get_layout,
  155. "torch.Size": torch.Size,
  156. "torch.Tensor": torch.Tensor,
  157. "torch.device": torch.device,
  158. "_codecs.encode": encode, # for bytes
  159. "builtins.bytearray": bytearray, # for bytearray
  160. "builtins.set": set, # for set
  161. "builtins.complex": complex, # for complex
  162. }
  163. # dtype
  164. for t in torch.storage._dtype_to_storage_type_map():
  165. rc[str(t)] = t
  166. for t in torch.storage._new_dtypes():
  167. rc[str(t)] = t
  168. for t in [getattr(torch, f"uint{x}") for x in range(1, 8)]:
  169. rc[str(t)] = t
  170. for t in [getattr(torch, f"int{x}") for x in range(1, 8)]:
  171. rc[str(t)] = t
  172. # Tensor classes
  173. for tt in torch._tensor_classes:
  174. rc[f"{tt.__module__}.{tt.__name__}"] = tt
  175. # Storage classes
  176. for ts in torch._storage_classes:
  177. if ts not in (torch.storage.TypedStorage, torch.storage.UntypedStorage):
  178. # Wrap legacy storage types in a dummy class
  179. rc[f"{ts.__module__}.{ts.__name__}"] = torch.serialization.StorageType(
  180. ts.__name__
  181. )
  182. else:
  183. rc[f"{ts.__module__}.{ts.__name__}"] = ts
  184. # Quantization specific
  185. for qt in [
  186. torch.per_tensor_affine,
  187. torch.per_tensor_symmetric,
  188. torch.per_channel_affine,
  189. torch.per_channel_symmetric,
  190. torch.per_channel_affine_float_qparams,
  191. ]:
  192. rc[str(qt)] = qt
  193. # Rebuild functions
  194. for f in _tensor_rebuild_functions():
  195. rc[f"torch._utils.{f.__name__}"] = f
  196. # Handles Tensor Subclasses, Tensor's with attributes.
  197. # NOTE: It calls into above rebuild functions for regular Tensor types.
  198. rc["torch._tensor._rebuild_from_type_v2"] = torch._tensor._rebuild_from_type_v2
  199. return rc
  200. def _read_global_instruction(readline: Callable) -> tuple[str, str]:
  201. module = readline()[:-1].decode("utf-8")
  202. name = readline()[:-1].decode("utf-8")
  203. # Patch since torch.save default protocol is 2
  204. # users will be running this code in python > 3
  205. if (module, name) in NAME_MAPPING:
  206. module, name = NAME_MAPPING[(module, name)]
  207. elif module in IMPORT_MAPPING:
  208. module = IMPORT_MAPPING[module]
  209. return module, name
  210. def get_globals_in_pkl(file) -> set[str]:
  211. globals_in_checkpoint = set()
  212. read = file.read
  213. readline = file.readline
  214. op_to_bytes_to_read = {
  215. NEWOBJ[0]: 0,
  216. REDUCE[0]: 0,
  217. BUILD[0]: 0,
  218. APPEND[0]: 0,
  219. APPENDS[0]: 0,
  220. SETITEM[0]: 0,
  221. SETITEMS[0]: 0,
  222. MARK[0]: 0,
  223. TUPLE[0]: 0,
  224. TUPLE1[0]: 0,
  225. TUPLE2[0]: 0,
  226. TUPLE3[0]: 0,
  227. NONE[0]: 0,
  228. NEWFALSE[0]: 0,
  229. NEWTRUE[0]: 0,
  230. EMPTY_TUPLE[0]: 0,
  231. EMPTY_LIST[0]: 0,
  232. EMPTY_DICT[0]: 0,
  233. EMPTY_SET[0]: 0,
  234. BINPERSID[0]: 0,
  235. BININT[0]: 4,
  236. BININT1[0]: 1,
  237. BININT2[0]: 2,
  238. BINFLOAT[0]: 8,
  239. BINGET[0]: 1,
  240. LONG_BINGET[0]: 4,
  241. BINPUT[0]: 1,
  242. LONG_BINPUT[0]: 4,
  243. }
  244. while True:
  245. key = read(1)
  246. if not key:
  247. raise EOFError
  248. if not isinstance(key, bytes_types):
  249. raise AssertionError(f"Expected bytes, got {type(key).__name__}")
  250. if key[0] == GLOBAL[0]:
  251. module, name = _read_global_instruction(readline)
  252. globals_in_checkpoint.add(f"{module}.{name}")
  253. elif key[0] in op_to_bytes_to_read:
  254. bytes_to_read = op_to_bytes_to_read[key[0]]
  255. if bytes_to_read:
  256. read(bytes_to_read)
  257. # ops where bytes to read depends on the data
  258. elif key[0] == BINUNICODE[0]:
  259. strlen = unpack("<I", read(4))[0]
  260. if strlen > maxsize:
  261. raise UnpicklingError("String is too long")
  262. read(strlen)
  263. elif key[0] in {SHORT_BINSTRING[0], LONG1[0]}:
  264. strlen = read(1)[0]
  265. read(strlen)
  266. # first and last op
  267. elif key[0] == PROTO[0]:
  268. read(1)[0]
  269. elif key[0] == STOP[0]:
  270. return globals_in_checkpoint
  271. else:
  272. raise UnpicklingError(f"Unsupported operand {key[0]}")
  273. class Unpickler:
  274. def __init__(self, file, *, encoding: str = "bytes"):
  275. self.encoding = encoding
  276. self.readline = file.readline
  277. self.read = file.read
  278. self.memo: dict[int, Any] = {}
  279. self.proto: int = -1
  280. def load(self):
  281. """Read a pickled object representation from the open file.
  282. Return the reconstituted object hierarchy specified in the file.
  283. """
  284. self.metastack = []
  285. self.stack: list[Any] = []
  286. self.append = self.stack.append
  287. read = self.read
  288. while True:
  289. key = read(1)
  290. if not key:
  291. raise EOFError
  292. if not isinstance(key, bytes_types):
  293. raise AssertionError(f"Expected bytes, got {type(key).__name__}")
  294. # Risky operators
  295. if key[0] == GLOBAL[0]:
  296. module, name = _read_global_instruction(self.readline)
  297. full_path = f"{module}.{name}"
  298. if module in _blocklisted_modules:
  299. raise UnpicklingError(
  300. f"Trying to load unsupported GLOBAL {full_path} whose module {module} is blocked."
  301. )
  302. if full_path in _get_allowed_globals():
  303. self.append(_get_allowed_globals()[full_path])
  304. elif full_path in _get_user_allowed_globals():
  305. self.append(_get_user_allowed_globals()[full_path])
  306. elif full_path in (
  307. [
  308. "torch.nested._internal.nested_tensor.NestedTensor",
  309. "torch.nested._internal.nested_tensor._rebuild_njt",
  310. "torch._dynamo.decorators._DimRange",
  311. ]
  312. ):
  313. raise UnpicklingError(
  314. "``torch.nested`` and ``torch._dynamo`` must be imported to load nested jagged tensors (NJTs)"
  315. )
  316. elif full_path in (
  317. [
  318. "torch.distributed.device_mesh.DeviceMesh",
  319. "torch.distributed.tensor._dtensor_spec.DTensorSpec",
  320. "torch.distributed.tensor._dtensor_spec.TensorMeta",
  321. "torch.distributed.tensor.DTensor",
  322. "torch.distributed.tensor.placement_types.Partial",
  323. "torch.distributed.tensor.placement_types.Replicate",
  324. "torch.distributed.tensor.placement_types.Shard",
  325. ]
  326. ):
  327. raise UnpicklingError(
  328. "``torch.distributed.tensor`` must be imported to load DTensors"
  329. )
  330. else:
  331. builtins_name = "builtins"
  332. if (
  333. builtins_name in full_path
  334. and builtins_name == full_path[: len(builtins_name)]
  335. ):
  336. full_path = full_path[len(builtins_name) :]
  337. full_path = (
  338. full_path[1:]
  339. if len(full_path) > 0 and full_path[0] == "."
  340. else builtins_name + full_path
  341. )
  342. raise UnpicklingError(
  343. f"Unsupported global: GLOBAL {full_path} was not an allowed global by default. "
  344. f"Please use `torch.serialization.add_safe_globals([{full_path}])` or the "
  345. f"`torch.serialization.safe_globals([{full_path}])` context manager to allowlist this global "
  346. "if you trust this class/function."
  347. )
  348. elif key[0] == NEWOBJ[0]:
  349. args = self.stack.pop()
  350. cls = self.stack.pop()
  351. if cls is torch.nn.Parameter:
  352. self.append(torch.nn.Parameter(*args))
  353. elif (
  354. cls in _get_user_allowed_globals().values()
  355. or cls in _get_allowed_globals().values()
  356. ):
  357. result = cls.__new__(cls, *args)
  358. if cls in torch._tensor_classes and "sparse" in cls.__module__:
  359. _sparse_tensors_to_validate.append(result)
  360. self.append(result)
  361. else:
  362. raise UnpicklingError(
  363. "Can only create new object for nn.Parameter or classes allowlisted "
  364. f"via `add_safe_globals` but got {cls}"
  365. )
  366. elif key[0] == REDUCE[0]:
  367. args = self.stack.pop()
  368. func = self.stack[-1]
  369. if (
  370. func not in _get_allowed_globals().values()
  371. and func not in _get_user_allowed_globals().values()
  372. ):
  373. error_msg = (
  374. f"Trying to call reduce for unrecognized function {func}"
  375. )
  376. if hasattr(func, "__self__"):
  377. error_msg += f" which belongs to {func.__self__}"
  378. raise UnpicklingError(error_msg)
  379. result = func(*args)
  380. if func in torch._tensor_classes and "sparse" in func.__module__:
  381. _sparse_tensors_to_validate.append(result)
  382. self.stack[-1] = result
  383. elif key[0] == BUILD[0]:
  384. state = self.stack.pop()
  385. inst = self.stack[-1]
  386. if type(inst) is torch.Tensor:
  387. # Legacy unpickling
  388. inst.set_(*state)
  389. elif type(inst) is torch.nn.Parameter:
  390. inst.__setstate__(state)
  391. elif type(inst) is OrderedDict:
  392. inst.__dict__.update(state)
  393. elif (
  394. type(inst) in _get_user_allowed_globals().values()
  395. or type(inst) in _get_allowed_globals().values()
  396. ):
  397. if hasattr(inst, "__setstate__"):
  398. inst.__setstate__(state)
  399. else:
  400. # mimics load_build in pickle
  401. # https://github.com/python/cpython/blob/f0c6fccd08904787a39269367f09f263d496114c/Lib/pickle.py#L1854-L1867
  402. slotstate = None
  403. if isinstance(state, tuple) and len(state) == 2:
  404. state, slotstate = state
  405. if state:
  406. inst.__dict__.update(state)
  407. if slotstate:
  408. for k, v in slotstate.items():
  409. setattr(inst, k, v)
  410. else:
  411. raise UnpicklingError(
  412. "Can only build Tensor, Parameter, OrderedDict or types allowlisted "
  413. f"via `add_safe_globals`, but got {type(inst)}"
  414. )
  415. # Stack manipulation
  416. elif key[0] == APPEND[0]:
  417. item = self.stack.pop()
  418. list_obj = self.stack[-1]
  419. if type(list_obj) is not list:
  420. raise UnpicklingError(
  421. f"Can only append to lists, but got {type(list_obj)}"
  422. )
  423. list_obj.append(item)
  424. elif key[0] == APPENDS[0]:
  425. items = self.pop_mark()
  426. list_obj = self.stack[-1]
  427. if type(list_obj) is not list:
  428. raise UnpicklingError(
  429. f"Can only extend lists, but got {type(list_obj)}"
  430. )
  431. list_obj.extend(items)
  432. elif key[0] == SETITEM[0]:
  433. (v, k) = (self.stack.pop(), self.stack.pop())
  434. self._check_set_item_target("SETITEM")
  435. self.stack[-1][k] = v
  436. elif key[0] == SETITEMS[0]:
  437. items = self.pop_mark()
  438. self._check_set_item_target("SETITEMS")
  439. for i in range(0, len(items), 2):
  440. self.stack[-1][items[i]] = items[i + 1]
  441. elif key[0] == MARK[0]:
  442. self.metastack.append(self.stack)
  443. self.stack = []
  444. self.append = self.stack.append
  445. elif key[0] == TUPLE[0]:
  446. items = self.pop_mark()
  447. self.append(tuple(items))
  448. elif key[0] == TUPLE1[0]:
  449. self.stack[-1] = (self.stack[-1],)
  450. elif key[0] == TUPLE2[0]:
  451. self.stack[-2:] = [(self.stack[-2], self.stack[-1])]
  452. elif key[0] == TUPLE3[0]:
  453. self.stack[-3:] = [(self.stack[-3], self.stack[-2], self.stack[-1])]
  454. # Basic types construction
  455. elif key[0] == NONE[0]:
  456. self.append(None)
  457. elif key[0] == NEWFALSE[0]:
  458. self.append(False)
  459. elif key[0] == NEWTRUE[0]:
  460. self.append(True)
  461. elif key[0] == EMPTY_TUPLE[0]:
  462. self.append(())
  463. elif key[0] == EMPTY_LIST[0]:
  464. self.append([])
  465. elif key[0] == EMPTY_DICT[0]:
  466. self.append({})
  467. elif key[0] == EMPTY_SET[0]:
  468. self.append(set())
  469. elif key[0] == BININT[0]:
  470. self.append(unpack("<i", read(4))[0])
  471. elif key[0] == BININT1[0]:
  472. self.append(self.read(1)[0])
  473. elif key[0] == BININT2[0]:
  474. self.append(unpack("<H", read(2))[0])
  475. elif key[0] == BINFLOAT[0]:
  476. self.append(unpack(">d", self.read(8))[0])
  477. elif key[0] == BINUNICODE[0]:
  478. strlen = unpack("<I", read(4))[0]
  479. if strlen > maxsize:
  480. raise UnpicklingError("String is too long")
  481. strval = str(read(strlen), "utf-8", "surrogatepass")
  482. self.append(strval)
  483. elif key[0] == SHORT_BINSTRING[0]:
  484. strlen = read(1)[0]
  485. strdata = read(strlen)
  486. if self.encoding != "bytes":
  487. strdata = strdata.decode(self.encoding, "strict")
  488. self.append(strdata)
  489. elif key[0] == BINPERSID[0]:
  490. pid = self.stack.pop()
  491. # Only allow persistent load of storage
  492. if type(pid) is not tuple and type(pid) is not int:
  493. raise UnpicklingError(
  494. f"persistent_load id must be tuple or int, but got {type(pid)}"
  495. )
  496. if (
  497. type(pid) is tuple
  498. and len(pid) > 0
  499. and torch.serialization._maybe_decode_ascii(pid[0]) != "storage"
  500. ):
  501. raise UnpicklingError(
  502. f"Only persistent_load of storage is allowed, but got {type(pid[0])}"
  503. )
  504. self.append(self.persistent_load(pid))
  505. elif key[0] in [BINGET[0], LONG_BINGET[0]]:
  506. idx = (read(1) if key[0] == BINGET[0] else unpack("<I", read(4)))[0]
  507. self.append(self.memo[idx])
  508. elif key[0] in [BINPUT[0], LONG_BINPUT[0]]:
  509. i = (read(1) if key[0] == BINPUT[0] else unpack("<I", read(4)))[0]
  510. if i < 0:
  511. raise ValueError("negative argument")
  512. self.memo[i] = self.stack[-1]
  513. elif key[0] == LONG1[0]:
  514. n = read(1)[0]
  515. data = read(n)
  516. self.append(decode_long(data))
  517. # First and last deserializer ops
  518. elif key[0] == PROTO[0]:
  519. self.proto = read(1)[0]
  520. if self.proto != 2:
  521. warnings.warn(
  522. f"Detected pickle protocol {self.proto} in the checkpoint, which was "
  523. "not the default pickle protocol used by `torch.load` (2). The weights_only "
  524. "Unpickler might not support all instructions implemented by this protocol, "
  525. "please file an issue for adding support if you encounter this.",
  526. stacklevel=2,
  527. )
  528. elif key[0] == STOP[0]:
  529. rc = self.stack.pop()
  530. return rc
  531. else:
  532. raise UnpicklingError(f"Unsupported operand {key[0]}")
  533. # Return a list of items pushed in the stack after last MARK instruction.
  534. def pop_mark(self):
  535. items = self.stack
  536. self.stack = self.metastack.pop()
  537. self.append = self.stack.append
  538. return items
  539. def _check_set_item_target(self, opcode: str):
  540. if type(self.stack[-1]) not in [dict, OrderedDict, Counter]:
  541. raise UnpicklingError(
  542. f"Can only {opcode} for dict, collections.OrderedDict, "
  543. f"collections.Counter, but got {type(self.stack[-1])}"
  544. )
  545. def persistent_load(self, pid):
  546. raise UnpicklingError("unsupported persistent id encountered")
  547. def load(file, *, encoding: str = "ASCII"):
  548. return Unpickler(file, encoding=encoding).load()