importer.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. import importlib
  2. import logging
  3. import sys
  4. from abc import ABC, abstractmethod
  5. from pickle import (
  6. _getattribute, # pyrefly: ignore [missing-module-attribute]
  7. _Pickler,
  8. whichmodule as _pickle_whichmodule, # pyrefly: ignore [missing-module-attribute]
  9. )
  10. from types import ModuleType
  11. from typing import Any
  12. from ._mangling import demangle, get_mangle_prefix, is_mangled
  13. __all__ = ["ObjNotFoundError", "ObjMismatchError", "Importer", "OrderedImporter"]
  14. log = logging.getLogger(__name__)
  15. class ObjNotFoundError(Exception):
  16. """Raised when an importer cannot find an object by searching for its name."""
  17. class ObjMismatchError(Exception):
  18. """Raised when an importer found a different object with the same name as the user-provided one."""
  19. class Importer(ABC):
  20. """Represents an environment to import modules from.
  21. By default, you can figure out what module an object belongs by checking
  22. __module__ and importing the result using __import__ or importlib.import_module.
  23. torch.package introduces module importers other than the default one.
  24. Each PackageImporter introduces a new namespace. Potentially a single
  25. name (e.g. 'foo.bar') is present in multiple namespaces.
  26. It supports two main operations:
  27. import_module: module_name -> module object
  28. get_name: object -> (parent module name, name of obj within module)
  29. The guarantee is that following round-trip will succeed or throw an ObjNotFoundError/ObjMisMatchError.
  30. module_name, obj_name = env.get_name(obj)
  31. module = env.import_module(module_name)
  32. obj2 = getattr(module, obj_name)
  33. assert obj1 is obj2
  34. """
  35. modules: dict[str, ModuleType]
  36. @abstractmethod
  37. def import_module(self, module_name: str) -> ModuleType:
  38. """Import `module_name` from this environment.
  39. The contract is the same as for importlib.import_module.
  40. """
  41. def get_name(self, obj: Any, name: str | None = None) -> tuple[str, str]:
  42. """Given an object, return a name that can be used to retrieve the
  43. object from this environment.
  44. Args:
  45. obj: An object to get the module-environment-relative name for.
  46. name: If set, use this name instead of looking up __name__ or __qualname__ on `obj`.
  47. This is only here to match how Pickler handles __reduce__ functions that return a string,
  48. don't use otherwise.
  49. Returns:
  50. A tuple (parent_module_name, attr_name) that can be used to retrieve `obj` from this environment.
  51. Use it like:
  52. mod = importer.import_module(parent_module_name)
  53. obj = getattr(mod, attr_name)
  54. Raises:
  55. ObjNotFoundError: we couldn't retrieve `obj by name.
  56. ObjMisMatchError: we found a different object with the same name as `obj`.
  57. """
  58. if name is None and obj and _Pickler.dispatch.get(type(obj)) is None:
  59. # Honor the string return variant of __reduce__, which will give us
  60. # a global name to search for in this environment.
  61. # TODO: I guess we should do copyreg too?
  62. reduce = getattr(obj, "__reduce__", None)
  63. if reduce is not None:
  64. try:
  65. rv = reduce()
  66. if isinstance(rv, str):
  67. name = rv
  68. except Exception:
  69. pass
  70. if name is None:
  71. name = getattr(obj, "__qualname__", None)
  72. if name is None:
  73. name = obj.__name__
  74. orig_module_name = self.whichmodule(obj, name)
  75. # Demangle the module name before importing. If this obj came out of a
  76. # PackageImporter, `__module__` will be mangled. See mangling.md for
  77. # details.
  78. module_name = demangle(orig_module_name)
  79. # Check that this name will indeed return the correct object
  80. try:
  81. module = self.import_module(module_name)
  82. if sys.version_info >= (3, 14):
  83. # pickle._getatribute signature changes in 3.14
  84. # to take iterable and return just one object
  85. obj2 = _getattribute(module, name.split("."))
  86. else:
  87. obj2, _ = _getattribute(module, name)
  88. except (ImportError, KeyError, AttributeError):
  89. raise ObjNotFoundError(
  90. f"{obj} was not found as {module_name}.{name}"
  91. ) from None
  92. if obj is obj2:
  93. return module_name, name
  94. def get_obj_info(obj):
  95. if name is None:
  96. raise AssertionError("name must not be None")
  97. module_name = self.whichmodule(obj, name)
  98. is_mangled_ = is_mangled(module_name)
  99. location = (
  100. get_mangle_prefix(module_name)
  101. if is_mangled_
  102. else "the current Python environment"
  103. )
  104. importer_name = (
  105. f"the importer for {get_mangle_prefix(module_name)}"
  106. if is_mangled_
  107. else "'sys_importer'"
  108. )
  109. return module_name, location, importer_name
  110. obj_module_name, obj_location, obj_importer_name = get_obj_info(obj)
  111. obj2_module_name, obj2_location, obj2_importer_name = get_obj_info(obj2)
  112. msg = (
  113. f"\n\nThe object provided is from '{obj_module_name}', "
  114. f"which is coming from {obj_location}."
  115. f"\nHowever, when we import '{obj2_module_name}', it's coming from {obj2_location}."
  116. "\nTo fix this, make sure this 'PackageExporter's importer lists "
  117. f"{obj_importer_name} before {obj2_importer_name}."
  118. )
  119. raise ObjMismatchError(msg)
  120. def whichmodule(self, obj: Any, name: str) -> str:
  121. """Find the module name an object belongs to.
  122. This should be considered internal for end-users, but developers of
  123. an importer can override it to customize the behavior.
  124. Taken from pickle.py, but modified to exclude the search into sys.modules
  125. """
  126. module_name = getattr(obj, "__module__", None)
  127. if module_name is not None:
  128. return module_name
  129. # Protect the iteration by using a list copy of self.modules against dynamic
  130. # modules that trigger imports of other modules upon calls to getattr.
  131. for module_name, module in self.modules.copy().items():
  132. if (
  133. module_name == "__main__"
  134. or module_name == "__mp_main__" # bpo-42406
  135. or module is None
  136. ):
  137. continue
  138. try:
  139. if _getattribute(module, name)[0] is obj:
  140. return module_name
  141. except AttributeError:
  142. pass
  143. return "__main__"
  144. class _SysImporter(Importer):
  145. """An importer that implements the default behavior of Python."""
  146. def import_module(self, module_name: str):
  147. return importlib.import_module(module_name)
  148. def whichmodule(self, obj: Any, name: str) -> str:
  149. # In Python 3.14+, pickle.whichmodule tries to import the module,
  150. # which fails for mangled package names like '<torch_package_0>'.
  151. # Check __module__ first before calling pickle.whichmodule.
  152. module_name = getattr(obj, "__module__", None)
  153. if module_name is not None:
  154. return module_name
  155. return _pickle_whichmodule(obj, name)
  156. sys_importer = _SysImporter()
  157. class OrderedImporter(Importer):
  158. """A compound importer that takes a list of importers and tries them one at a time.
  159. The first importer in the list that returns a result "wins".
  160. """
  161. def __init__(self, *args):
  162. self._importers: list[Importer] = list(args)
  163. def _is_torchpackage_dummy(self, module):
  164. """Returns true iff this module is an empty PackageNode in a torch.package.
  165. If you intern `a.b` but never use `a` in your code, then `a` will be an
  166. empty module with no source. This can break cases where we are trying to
  167. re-package an object after adding a real dependency on `a`, since
  168. OrderedImportere will resolve `a` to the dummy package and stop there.
  169. See: https://github.com/pytorch/pytorch/pull/71520#issuecomment-1029603769
  170. """
  171. if not getattr(module, "__torch_package__", False):
  172. return False
  173. if not hasattr(module, "__path__"):
  174. return False
  175. if not hasattr(module, "__file__"):
  176. return True
  177. return module.__file__ is None
  178. def get_name(self, obj: Any, name: str | None = None) -> tuple[str, str]:
  179. for importer in self._importers:
  180. try:
  181. return importer.get_name(obj, name)
  182. except (ObjNotFoundError, ObjMismatchError) as e:
  183. warning_message = (
  184. f"Tried to call get_name with obj {obj}, "
  185. f"and name {name} on {importer} and got {e}"
  186. )
  187. log.warning(warning_message)
  188. raise ObjNotFoundError(
  189. f"Could not find obj {obj} and name {name} in any of the importers {self._importers}"
  190. )
  191. def import_module(self, module_name: str) -> ModuleType:
  192. last_err = None
  193. for importer in self._importers:
  194. if not isinstance(importer, Importer):
  195. raise TypeError(
  196. f"{importer} is not a Importer. "
  197. "All importers in OrderedImporter must inherit from Importer."
  198. )
  199. try:
  200. module = importer.import_module(module_name)
  201. if self._is_torchpackage_dummy(module):
  202. continue
  203. return module
  204. except ModuleNotFoundError as err:
  205. last_err = err
  206. if last_err is not None:
  207. raise last_err
  208. else:
  209. raise ModuleNotFoundError(module_name)
  210. def whichmodule(self, obj: Any, name: str) -> str:
  211. for importer in self._importers:
  212. module_name = importer.whichmodule(obj, name)
  213. if module_name != "__main__":
  214. return module_name
  215. return "__main__"