_package_pickler.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. import sys
  2. from pickle import (
  3. _compat_pickle, # pyrefly: ignore [missing-module-attribute]
  4. _extension_registry, # pyrefly: ignore [missing-module-attribute]
  5. _getattribute, # pyrefly: ignore [missing-module-attribute]
  6. _Pickler,
  7. EXT1,
  8. EXT2,
  9. EXT4,
  10. GLOBAL,
  11. PicklingError,
  12. STACK_GLOBAL,
  13. )
  14. from struct import pack
  15. from types import FunctionType
  16. from .importer import Importer, ObjMismatchError, ObjNotFoundError, sys_importer
  17. class _PyTorchLegacyPickler(_Pickler):
  18. def __init__(self, *args, **kwargs):
  19. super().__init__(*args, **kwargs)
  20. self._persistent_id = None
  21. def persistent_id(self, obj):
  22. if self._persistent_id is None:
  23. return super().persistent_id(obj)
  24. return self._persistent_id(obj)
  25. class PackagePickler(_PyTorchLegacyPickler):
  26. """Package-aware pickler.
  27. This behaves the same as a normal pickler, except it uses an `Importer`
  28. to find objects and modules to save.
  29. """
  30. def __init__(self, importer: Importer, *args, **kwargs):
  31. self.importer = importer
  32. super().__init__(*args, **kwargs)
  33. # Make sure the dispatch table copied from _Pickler is up-to-date.
  34. # Previous issues have been encountered where a library (e.g. dill)
  35. # mutate _Pickler.dispatch, PackagePickler makes a copy when this lib
  36. # is imported, then the offending library removes its dispatch entries,
  37. # leaving PackagePickler with a stale dispatch table that may cause
  38. # unwanted behavior.
  39. self.dispatch = _Pickler.dispatch.copy() # type: ignore[misc]
  40. self.dispatch[FunctionType] = PackagePickler.save_global # type: ignore[assignment]
  41. def save_global(self, obj, name=None):
  42. # ruff: noqa: F841
  43. # unfortunately the pickler code is factored in a way that
  44. # forces us to copy/paste this function. The only change is marked
  45. # CHANGED below.
  46. write = self.write # type: ignore[attr-defined]
  47. memo = self.memo # type: ignore[attr-defined]
  48. # CHANGED: import module from module environment instead of __import__
  49. try:
  50. module_name, name = self.importer.get_name(obj, name)
  51. except (ObjNotFoundError, ObjMismatchError) as err:
  52. raise PicklingError(f"Can't pickle {obj}: {str(err)}") from err
  53. module = self.importer.import_module(module_name)
  54. if sys.version_info >= (3, 14):
  55. # pickle._getattribute signature changes in 3.14
  56. # to take iterable and return just the object (not tuple)
  57. # We need to get the parent object that contains the attribute
  58. name_parts = name.split(".")
  59. if "<locals>" in name_parts:
  60. raise PicklingError(f"Can't pickle local object {obj!r}")
  61. if len(name_parts) == 1:
  62. parent = module
  63. else:
  64. parent = _getattribute(module, name_parts[:-1])
  65. else:
  66. _, parent = _getattribute(module, name)
  67. # END CHANGED
  68. if self.proto >= 2: # type: ignore[attr-defined]
  69. code = _extension_registry.get((module_name, name))
  70. if code:
  71. if code <= 0:
  72. raise AssertionError(
  73. f"expected positive extension code, got {code}"
  74. )
  75. if code <= 0xFF:
  76. write(EXT1 + pack("<B", code))
  77. elif code <= 0xFFFF:
  78. write(EXT2 + pack("<H", code))
  79. else:
  80. write(EXT4 + pack("<i", code))
  81. return
  82. lastname = name.rpartition(".")[2]
  83. if parent is module:
  84. name = lastname
  85. # Non-ASCII identifiers are supported only with protocols >= 3.
  86. if self.proto >= 4: # type: ignore[attr-defined]
  87. self.save(module_name) # type: ignore[attr-defined]
  88. self.save(name) # type: ignore[attr-defined]
  89. write(STACK_GLOBAL)
  90. elif parent is not module:
  91. self.save_reduce(getattr, (parent, lastname)) # type: ignore[attr-defined]
  92. elif self.proto >= 3: # type: ignore[attr-defined]
  93. write(
  94. GLOBAL
  95. + bytes(module_name, "utf-8")
  96. + b"\n"
  97. + bytes(name, "utf-8")
  98. + b"\n"
  99. )
  100. else:
  101. if self.fix_imports: # type: ignore[attr-defined]
  102. r_name_mapping = _compat_pickle.REVERSE_NAME_MAPPING
  103. r_import_mapping = _compat_pickle.REVERSE_IMPORT_MAPPING
  104. if (module_name, name) in r_name_mapping:
  105. module_name, name = r_name_mapping[(module_name, name)]
  106. elif module_name in r_import_mapping:
  107. module_name = r_import_mapping[module_name]
  108. try:
  109. write(
  110. GLOBAL
  111. + bytes(module_name, "ascii")
  112. + b"\n"
  113. + bytes(name, "ascii")
  114. + b"\n"
  115. )
  116. except UnicodeEncodeError as exc:
  117. raise PicklingError(
  118. f"can't pickle global identifier '{module}.{name}' using "
  119. f"pickle protocol {self.proto:d}" # type: ignore[attr-defined]
  120. ) from exc
  121. self.memoize(obj) # type: ignore[attr-defined]
  122. def create_pickler(data_buf, importer, protocol=4):
  123. if importer is sys_importer:
  124. # if we are using the normal import library system, then
  125. # we can use the C implementation of pickle which is faster
  126. return _PyTorchLegacyPickler(data_buf, protocol=protocol)
  127. else:
  128. return PackagePickler(importer, data_buf, protocol=protocol)