package_importer.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805
  1. # mypy: allow-untyped-defs
  2. import builtins
  3. import importlib
  4. import importlib.machinery
  5. import inspect
  6. import io
  7. import linecache
  8. import os
  9. import sys
  10. import types
  11. from collections.abc import Callable, Iterable
  12. from contextlib import contextmanager
  13. from typing import Any, cast, TYPE_CHECKING
  14. from weakref import WeakValueDictionary
  15. import torch
  16. from torch.serialization import _get_restore_location, _maybe_decode_ascii
  17. from torch.types import FileLike
  18. from ._directory_reader import DirectoryReader
  19. from ._importlib import (
  20. _calc___package__,
  21. _normalize_line_endings,
  22. _normalize_path,
  23. _resolve_name,
  24. _sanity_check,
  25. )
  26. from ._mangling import demangle, PackageMangler
  27. from ._package_unpickler import PackageUnpickler
  28. from .file_structure_representation import _create_directory_from_file_list, Directory
  29. from .importer import Importer
  30. if TYPE_CHECKING:
  31. from .glob_group import GlobPattern
  32. __all__ = ["PackageImporter"]
  33. # This is a list of imports that are implicitly allowed even if they haven't
  34. # been marked as extern. This is to work around the fact that Torch implicitly
  35. # depends on numpy and package can't track it.
  36. # https://github.com/pytorch/multipy/issues/46 # codespell:ignore multipy
  37. IMPLICIT_IMPORT_ALLOWLIST: Iterable[str] = [
  38. "numpy",
  39. "numpy.core",
  40. "numpy.core._multiarray_umath",
  41. # FX GraphModule might depend on builtins module and users usually
  42. # don't extern builtins. Here we import it here by default.
  43. "builtins",
  44. ]
  45. # Compatibility name mapping to facilitate upgrade of external modules.
  46. # The primary motivation is to enable Numpy upgrade that many modules
  47. # depend on. The latest release of Numpy removed `numpy.str` and
  48. # `numpy.bool` breaking unpickling for many modules.
  49. EXTERN_IMPORT_COMPAT_NAME_MAPPING: dict[str, dict[str, Any]] = {
  50. "numpy": {
  51. "str": str,
  52. "bool": bool,
  53. },
  54. }
  55. class PackageImporter(Importer):
  56. """Importers allow you to load code written to packages by :class:`PackageExporter`.
  57. Code is loaded in a hermetic way, using files from the package
  58. rather than the normal python import system. This allows
  59. for the packaging of PyTorch model code and data so that it can be run
  60. on a server or used in the future for transfer learning.
  61. The importer for packages ensures that code in the module can only be loaded from
  62. within the package, except for modules explicitly listed as external during export.
  63. The file ``extern_modules`` in the zip archive lists all the modules that a package externally depends on.
  64. This prevents "implicit" dependencies where the package runs locally because it is importing
  65. a locally-installed package, but then fails when the package is copied to another machine.
  66. """
  67. """The dictionary of already loaded modules from this package, equivalent to ``sys.modules`` but
  68. local to this importer.
  69. """
  70. modules: dict[str, types.ModuleType]
  71. def __init__(
  72. self,
  73. file_or_buffer: FileLike | torch._C.PyTorchFileReader,
  74. module_allowed: Callable[[str], bool] = lambda module_name: True,
  75. ):
  76. """Open ``file_or_buffer`` for importing. This checks that the imported package only requires modules
  77. allowed by ``module_allowed``
  78. Args:
  79. file_or_buffer: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`),
  80. a string, or an ``os.PathLike`` object containing a filename.
  81. module_allowed (Callable[[str], bool], optional): A method to determine if a externally provided module
  82. should be allowed. Can be used to ensure packages loaded do not depend on modules that the server
  83. does not support. Defaults to allowing anything.
  84. Raises:
  85. ImportError: If the package will use a disallowed module.
  86. """
  87. torch._C._log_api_usage_once("torch.package.PackageImporter")
  88. self.zip_reader: Any
  89. if isinstance(file_or_buffer, torch._C.PyTorchFileReader):
  90. self.filename = "<pytorch_file_reader>"
  91. self.zip_reader = file_or_buffer
  92. elif isinstance(file_or_buffer, (os.PathLike, str)):
  93. self.filename = os.fspath(file_or_buffer)
  94. if not os.path.isdir(self.filename):
  95. self.zip_reader = torch._C.PyTorchFileReader(self.filename)
  96. else:
  97. self.zip_reader = DirectoryReader(self.filename)
  98. else:
  99. self.filename = "<binary>"
  100. self.zip_reader = torch._C.PyTorchFileReader(file_or_buffer)
  101. torch._C._log_api_usage_metadata(
  102. "torch.package.PackageImporter.metadata",
  103. {
  104. "serialization_id": self.zip_reader.serialization_id(),
  105. "file_name": self.filename,
  106. },
  107. )
  108. self.root = _PackageNode(None)
  109. self.modules = {}
  110. self.extern_modules = self._read_extern()
  111. for extern_module in self.extern_modules:
  112. if not module_allowed(extern_module):
  113. raise ImportError(
  114. f"package '{file_or_buffer}' needs the external module '{extern_module}' "
  115. f"but that module has been disallowed"
  116. )
  117. self._add_extern(extern_module)
  118. for fname in self.zip_reader.get_all_records():
  119. self._add_file(fname)
  120. self.patched_builtins = builtins.__dict__.copy()
  121. self.patched_builtins["__import__"] = self.__import__
  122. # Allow packaged modules to reference their PackageImporter
  123. self.modules["torch_package_importer"] = self # type: ignore[assignment]
  124. self._mangler = PackageMangler()
  125. # used for reduce deserializaiton
  126. self.storage_context: Any = None
  127. self.last_map_location = None
  128. # used for torch.serialization._load
  129. self.Unpickler = lambda *args, **kwargs: PackageUnpickler(self, *args, **kwargs)
  130. def import_module(self, name: str, package=None):
  131. """Load a module from the package if it hasn't already been loaded, and then return
  132. the module. Modules are loaded locally
  133. to the importer and will appear in ``self.modules`` rather than ``sys.modules``.
  134. Args:
  135. name (str): Fully qualified name of the module to load.
  136. package ([type], optional): Unused, but present to match the signature of importlib.import_module. Defaults to ``None``.
  137. Returns:
  138. types.ModuleType: The (possibly already) loaded module.
  139. """
  140. # We should always be able to support importing modules from this package.
  141. # This is to support something like:
  142. # obj = importer.load_pickle(...)
  143. # importer.import_module(obj.__module__) <- this string will be mangled
  144. #
  145. # Note that _mangler.demangle will not demangle any module names
  146. # produced by a different PackageImporter instance.
  147. name = self._mangler.demangle(name)
  148. return self._gcd_import(name)
  149. def load_binary(self, package: str, resource: str) -> bytes:
  150. """Load raw bytes.
  151. Args:
  152. package (str): The name of module package (e.g. ``"my_package.my_subpackage"``).
  153. resource (str): The unique name for the resource.
  154. Returns:
  155. bytes: The loaded data.
  156. """
  157. path = self._zipfile_path(package, resource)
  158. return self.zip_reader.get_record(path)
  159. def load_text(
  160. self,
  161. package: str,
  162. resource: str,
  163. encoding: str = "utf-8",
  164. errors: str = "strict",
  165. ) -> str:
  166. """Load a string.
  167. Args:
  168. package (str): The name of module package (e.g. ``"my_package.my_subpackage"``).
  169. resource (str): The unique name for the resource.
  170. encoding (str, optional): Passed to ``decode``. Defaults to ``'utf-8'``.
  171. errors (str, optional): Passed to ``decode``. Defaults to ``'strict'``.
  172. Returns:
  173. str: The loaded text.
  174. """
  175. data = self.load_binary(package, resource)
  176. return data.decode(encoding, errors)
  177. def load_pickle(self, package: str, resource: str, map_location=None) -> Any:
  178. """Unpickles the resource from the package, loading any modules that are needed to construct the objects
  179. using :meth:`import_module`.
  180. Args:
  181. package (str): The name of module package (e.g. ``"my_package.my_subpackage"``).
  182. resource (str): The unique name for the resource.
  183. map_location: Passed to `torch.load` to determine how tensors are mapped to devices. Defaults to ``None``.
  184. Returns:
  185. Any: The unpickled object.
  186. """
  187. pickle_file = self._zipfile_path(package, resource)
  188. restore_location = _get_restore_location(map_location)
  189. loaded_storages = {}
  190. loaded_reduces = {}
  191. storage_context = torch._C.DeserializationStorageContext()
  192. def load_tensor(dtype, size, key, location, restore_location):
  193. name = f"{key}.storage"
  194. if storage_context.has_storage(name):
  195. storage = storage_context.get_storage(name, dtype)._typed_storage()
  196. else:
  197. tensor = self.zip_reader.get_storage_from_record(
  198. ".data/" + name, size, dtype
  199. )
  200. if isinstance(self.zip_reader, torch._C.PyTorchFileReader):
  201. storage_context.add_storage(name, tensor)
  202. storage = tensor._typed_storage()
  203. loaded_storages[key] = restore_location(storage, location)
  204. def persistent_load(saved_id):
  205. if not isinstance(saved_id, tuple):
  206. raise AssertionError(
  207. f"saved_id must be a tuple, got {type(saved_id).__name__}"
  208. )
  209. typename = _maybe_decode_ascii(saved_id[0])
  210. data = saved_id[1:]
  211. if typename == "storage":
  212. storage_type, key, location, size = data
  213. if storage_type is torch.UntypedStorage:
  214. dtype = torch.uint8
  215. else:
  216. dtype = storage_type.dtype
  217. if key not in loaded_storages:
  218. load_tensor(
  219. dtype,
  220. size,
  221. key,
  222. _maybe_decode_ascii(location),
  223. restore_location,
  224. )
  225. storage = loaded_storages[key]
  226. # TODO: Once we decide to break serialization FC, we can
  227. # stop wrapping with TypedStorage
  228. return torch.storage.TypedStorage(
  229. wrap_storage=storage._untyped_storage, dtype=dtype, _internal=True
  230. )
  231. elif typename == "reduce_package":
  232. # to fix BC breaking change, objects on this load path
  233. # will be loaded multiple times erroneously
  234. if len(data) == 2:
  235. func, args = data
  236. return func(self, *args)
  237. reduce_id, func, args = data
  238. if reduce_id not in loaded_reduces:
  239. loaded_reduces[reduce_id] = func(self, *args)
  240. return loaded_reduces[reduce_id]
  241. else:
  242. f"Unknown typename for persistent_load, expected 'storage' or 'reduce_package' but got '{typename}'"
  243. # Load the data (which may in turn use `persistent_load` to load tensors)
  244. data_file = io.BytesIO(self.zip_reader.get_record(pickle_file))
  245. unpickler = self.Unpickler(data_file)
  246. unpickler.persistent_load = persistent_load # type: ignore[assignment]
  247. @contextmanager
  248. def set_deserialization_context():
  249. # to let reduce_package access deserializaiton context
  250. self.storage_context = storage_context
  251. self.last_map_location = map_location
  252. try:
  253. yield
  254. finally:
  255. self.storage_context = None
  256. self.last_map_location = None
  257. with set_deserialization_context():
  258. result = unpickler.load()
  259. # TODO from zdevito:
  260. # This stateful weird function will need to be removed in our efforts
  261. # to unify the format. It has a race condition if multiple python
  262. # threads try to read independent files
  263. torch._utils._validate_loaded_sparse_tensors()
  264. return result
  265. def id(self):
  266. """
  267. Returns internal identifier that torch.package uses to distinguish :class:`PackageImporter` instances.
  268. Looks like::
  269. <torch_package_0>
  270. """
  271. return self._mangler.parent_name()
  272. def file_structure(
  273. self, *, include: "GlobPattern" = "**", exclude: "GlobPattern" = ()
  274. ) -> Directory:
  275. """Returns a file structure representation of package's zipfile.
  276. Args:
  277. include (Union[List[str], str]): An optional string e.g. ``"my_package.my_subpackage"``, or optional list of strings
  278. for the names of the files to be included in the zipfile representation. This can also be
  279. a glob-style pattern, as described in :meth:`PackageExporter.mock`
  280. exclude (Union[List[str], str]): An optional pattern that excludes files whose name match the pattern.
  281. Returns:
  282. :class:`Directory`
  283. """
  284. return _create_directory_from_file_list(
  285. self.filename, self.zip_reader.get_all_records(), include, exclude
  286. )
  287. def python_version(self):
  288. """Returns the version of python that was used to create this package.
  289. Note: this function is experimental and not Forward Compatible. The plan is to move this into a lock
  290. file later on.
  291. Returns:
  292. :class:`Optional[str]` a python version e.g. 3.8.9 or None if no version was stored with this package
  293. """
  294. python_version_path = ".data/python_version"
  295. return (
  296. self.zip_reader.get_record(python_version_path).decode("utf-8").strip()
  297. if self.zip_reader.has_record(python_version_path)
  298. else None
  299. )
  300. def _read_extern(self):
  301. return (
  302. self.zip_reader.get_record(".data/extern_modules")
  303. .decode("utf-8")
  304. .splitlines(keepends=False)
  305. )
  306. def _make_module(
  307. self, name: str, filename: str | None, is_package: bool, parent: str
  308. ):
  309. mangled_filename = self._mangler.mangle(filename) if filename else None
  310. spec = importlib.machinery.ModuleSpec(
  311. name,
  312. self, # type: ignore[arg-type]
  313. origin="<package_importer>",
  314. is_package=is_package,
  315. )
  316. module = importlib.util.module_from_spec(spec)
  317. self.modules[name] = module
  318. module.__name__ = self._mangler.mangle(name)
  319. ns = module.__dict__
  320. ns["__spec__"] = spec
  321. ns["__loader__"] = self
  322. ns["__file__"] = mangled_filename
  323. ns["__cached__"] = None
  324. ns["__builtins__"] = self.patched_builtins
  325. ns["__torch_package__"] = True
  326. # Add this module to our private global registry. It should be unique due to mangling.
  327. if module.__name__ in _package_imported_modules:
  328. raise AssertionError(
  329. f"module {module.__name__} already exists in _package_imported_modules"
  330. )
  331. _package_imported_modules[module.__name__] = module
  332. # preemptively install on the parent to prevent IMPORT_FROM from trying to
  333. # access sys.modules
  334. self._install_on_parent(parent, name, module)
  335. if filename is not None:
  336. if mangled_filename is None:
  337. raise AssertionError(
  338. "mangled_filename must not be None when filename is set"
  339. )
  340. # preemptively install the source in `linecache` so that stack traces,
  341. # `inspect`, etc. work.
  342. if filename in linecache.cache: # type: ignore[attr-defined]
  343. raise AssertionError(f"filename {filename} already in linecache.cache")
  344. linecache.lazycache(mangled_filename, ns)
  345. code = self._compile_source(filename, mangled_filename)
  346. exec(code, ns)
  347. return module
  348. def _load_module(self, name: str, parent: str):
  349. cur: _PathNode = self.root
  350. for atom in name.split("."):
  351. if not isinstance(cur, _PackageNode) or atom not in cur.children:
  352. if name in IMPLICIT_IMPORT_ALLOWLIST:
  353. module = self.modules[name] = importlib.import_module(name)
  354. return module
  355. raise ModuleNotFoundError(
  356. f'No module named "{name}" in self-contained archive "{self.filename}"'
  357. f" and the module is also not in the list of allowed external modules: {self.extern_modules}",
  358. name=name,
  359. )
  360. cur = cur.children[atom]
  361. if isinstance(cur, _ExternNode):
  362. module = self.modules[name] = importlib.import_module(name)
  363. if compat_mapping := EXTERN_IMPORT_COMPAT_NAME_MAPPING.get(name):
  364. for old_name, new_name in compat_mapping.items():
  365. module.__dict__.setdefault(old_name, new_name)
  366. return module
  367. return self._make_module(
  368. name,
  369. cur.source_file, # type: ignore[attr-defined]
  370. isinstance(cur, _PackageNode),
  371. parent,
  372. )
  373. def _compile_source(self, fullpath: str, mangled_filename: str):
  374. source = self.zip_reader.get_record(fullpath)
  375. source = _normalize_line_endings(source)
  376. return compile(source, mangled_filename, "exec", dont_inherit=True)
  377. # note: named `get_source` so that linecache can find the source
  378. # when this is the __loader__ of a module.
  379. def get_source(self, module_name) -> str:
  380. # linecache calls `get_source` with the `module.__name__` as the argument, so we must demangle it here.
  381. module = self.import_module(demangle(module_name))
  382. return self.zip_reader.get_record(demangle(module.__file__)).decode("utf-8")
  383. # note: named `get_resource_reader` so that importlib.resources can find it.
  384. # This is otherwise considered an internal method.
  385. def get_resource_reader(self, fullname):
  386. try:
  387. package = self._get_package(fullname)
  388. except ImportError:
  389. return None
  390. if package.__loader__ is not self:
  391. return None
  392. return _PackageResourceReader(self, fullname)
  393. def _install_on_parent(self, parent: str, name: str, module: types.ModuleType):
  394. if not parent:
  395. return
  396. # Set the module as an attribute on its parent.
  397. parent_module = self.modules[parent]
  398. if parent_module.__loader__ is self:
  399. setattr(parent_module, name.rpartition(".")[2], module)
  400. # note: copied from cpython's import code, with call to create module replaced with _make_module
  401. def _do_find_and_load(self, name):
  402. parent = name.rpartition(".")[0]
  403. module_name_no_parent = name.rpartition(".")[-1]
  404. if parent:
  405. if parent not in self.modules:
  406. self._gcd_import(parent)
  407. # Crazy side-effects!
  408. if name in self.modules:
  409. return self.modules[name]
  410. parent_module = self.modules[parent]
  411. try:
  412. parent_module.__path__ # type: ignore[attr-defined]
  413. except AttributeError:
  414. # when we attempt to import a package only containing pybinded files,
  415. # the parent directory isn't always a package as defined by python,
  416. # so we search if the package is actually there or not before calling the error.
  417. if isinstance(
  418. parent_module.__loader__,
  419. importlib.machinery.ExtensionFileLoader,
  420. ):
  421. if name not in self.extern_modules:
  422. msg = (
  423. _ERR_MSG
  424. + "; {!r} is a c extension module which was not externed. C extension modules \
  425. need to be externed by the PackageExporter in order to be used as we do not support interning them.}."
  426. ).format(name, name)
  427. raise ModuleNotFoundError(msg, name=name) from None
  428. if not isinstance(
  429. parent_module.__dict__.get(module_name_no_parent),
  430. types.ModuleType,
  431. ):
  432. msg = (
  433. _ERR_MSG
  434. + "; {!r} is a c extension package which does not contain {!r}."
  435. ).format(name, parent, name)
  436. raise ModuleNotFoundError(msg, name=name) from None
  437. else:
  438. msg = (_ERR_MSG + "; {!r} is not a package").format(name, parent)
  439. raise ModuleNotFoundError(msg, name=name) from None
  440. module = self._load_module(name, parent)
  441. self._install_on_parent(parent, name, module)
  442. return module
  443. # note: copied from cpython's import code
  444. def _find_and_load(self, name):
  445. module = self.modules.get(name, _NEEDS_LOADING)
  446. if module is _NEEDS_LOADING:
  447. return self._do_find_and_load(name)
  448. if module is None:
  449. message = f"import of {name} halted; None in sys.modules"
  450. raise ModuleNotFoundError(message, name=name)
  451. # To handle https://github.com/pytorch/pytorch/issues/57490, where std's
  452. # creation of fake submodules via the hacking of sys.modules is not import
  453. # friendly
  454. if name == "os":
  455. self.modules["os.path"] = cast(Any, module).path
  456. elif name == "typing":
  457. if sys.version_info < (3, 13):
  458. self.modules["typing.io"] = cast(Any, module).io
  459. self.modules["typing.re"] = cast(Any, module).re
  460. return module
  461. def _gcd_import(self, name, package=None, level=0):
  462. """Import and return the module based on its name, the package the call is
  463. being made from, and the level adjustment.
  464. This function represents the greatest common denominator of functionality
  465. between import_module and __import__. This includes setting __package__ if
  466. the loader did not.
  467. """
  468. _sanity_check(name, package, level)
  469. if level > 0:
  470. name = _resolve_name(name, package, level)
  471. return self._find_and_load(name)
  472. # note: copied from cpython's import code
  473. def _handle_fromlist(self, module, fromlist, *, recursive=False):
  474. """Figure out what __import__ should return.
  475. The import_ parameter is a callable which takes the name of module to
  476. import. It is required to decouple the function from assuming importlib's
  477. import implementation is desired.
  478. """
  479. module_name = demangle(module.__name__)
  480. # The hell that is fromlist ...
  481. # If a package was imported, try to import stuff from fromlist.
  482. if hasattr(module, "__path__"):
  483. for x in fromlist:
  484. if not isinstance(x, str):
  485. if recursive:
  486. where = module_name + ".__all__"
  487. else:
  488. where = "``from list''"
  489. raise TypeError(
  490. f"Item in {where} must be str, not {type(x).__name__}"
  491. )
  492. elif x == "*":
  493. if not recursive and hasattr(module, "__all__"):
  494. self._handle_fromlist(module, module.__all__, recursive=True)
  495. elif not hasattr(module, x):
  496. from_name = f"{module_name}.{x}"
  497. try:
  498. self._gcd_import(from_name)
  499. except ModuleNotFoundError as exc:
  500. # Backwards-compatibility dictates we ignore failed
  501. # imports triggered by fromlist for modules that don't
  502. # exist.
  503. if (
  504. exc.name == from_name
  505. and self.modules.get(from_name, _NEEDS_LOADING) is not None
  506. ):
  507. continue
  508. raise
  509. return module
  510. def __import__(self, name, globals=None, locals=None, fromlist=(), level=0):
  511. if level == 0:
  512. module = self._gcd_import(name)
  513. else:
  514. globals_ = globals if globals is not None else {}
  515. package = _calc___package__(globals_)
  516. module = self._gcd_import(name, package, level)
  517. if not fromlist:
  518. # Return up to the first dot in 'name'. This is complicated by the fact
  519. # that 'name' may be relative.
  520. if level == 0:
  521. return self._gcd_import(name.partition(".")[0])
  522. elif not name:
  523. return module
  524. else:
  525. # Figure out where to slice the module's name up to the first dot
  526. # in 'name'.
  527. cut_off = len(name) - len(name.partition(".")[0])
  528. # Slice end needs to be positive to alleviate need to special-case
  529. # when ``'.' not in name``.
  530. module_name = demangle(module.__name__)
  531. return self.modules[module_name[: len(module_name) - cut_off]]
  532. else:
  533. return self._handle_fromlist(module, fromlist)
  534. def _get_package(self, package):
  535. """Take a package name or module object and return the module.
  536. If a name, the module is imported. If the passed or imported module
  537. object is not a package, raise an exception.
  538. """
  539. if hasattr(package, "__spec__"):
  540. if package.__spec__.submodule_search_locations is None:
  541. raise TypeError(f"{package.__spec__.name!r} is not a package")
  542. else:
  543. return package
  544. else:
  545. module = self.import_module(package)
  546. if module.__spec__.submodule_search_locations is None:
  547. raise TypeError(f"{package!r} is not a package")
  548. else:
  549. return module
  550. def _zipfile_path(self, package, resource=None):
  551. package = self._get_package(package)
  552. if package.__loader__ is not self:
  553. raise AssertionError(
  554. f"package.__loader__ must be self, got {package.__loader__}"
  555. )
  556. name = demangle(package.__name__)
  557. if resource is not None:
  558. resource = _normalize_path(resource)
  559. return f"{name.replace('.', '/')}/{resource}"
  560. else:
  561. return f"{name.replace('.', '/')}"
  562. def _get_or_create_package(self, atoms: list[str]) -> "_PackageNode | _ExternNode":
  563. cur = self.root
  564. for i, atom in enumerate(atoms):
  565. node = cur.children.get(atom, None)
  566. if node is None:
  567. node = cur.children[atom] = _PackageNode(None)
  568. if isinstance(node, _ExternNode):
  569. return node
  570. if isinstance(node, _ModuleNode):
  571. name = ".".join(atoms[:i])
  572. raise ImportError(
  573. f"inconsistent module structure. module {name} is not a package, but has submodules"
  574. )
  575. if not isinstance(node, _PackageNode):
  576. raise AssertionError(
  577. f"expected _PackageNode, got {type(node).__name__}"
  578. )
  579. cur = node
  580. return cur
  581. def _add_file(self, filename: str):
  582. """Assembles a Python module out of the given file. Will ignore files in the .data directory.
  583. Args:
  584. filename (str): the name of the file inside of the package archive to be added
  585. """
  586. *prefix, last = filename.split("/")
  587. if len(prefix) > 1 and prefix[0] == ".data":
  588. return
  589. package = self._get_or_create_package(prefix)
  590. if isinstance(package, _ExternNode):
  591. raise ImportError(
  592. f"inconsistent module structure. package contains a module file {filename}"
  593. f" that is a subpackage of a module marked external."
  594. )
  595. if last == "__init__.py":
  596. package.source_file = filename
  597. elif last.endswith(".py"):
  598. package_name = last[: -len(".py")]
  599. package.children[package_name] = _ModuleNode(filename)
  600. def _add_extern(self, extern_name: str):
  601. *prefix, last = extern_name.split(".")
  602. package = self._get_or_create_package(prefix)
  603. if isinstance(package, _ExternNode):
  604. return # the shorter extern covers this extern case
  605. package.children[last] = _ExternNode()
  606. _NEEDS_LOADING = object()
  607. _ERR_MSG_PREFIX = "No module named "
  608. _ERR_MSG = _ERR_MSG_PREFIX + "{!r}"
  609. class _PathNode:
  610. __slots__ = []
  611. class _PackageNode(_PathNode):
  612. def __init__(self, source_file: str | None):
  613. self.source_file = source_file
  614. self.children: dict[str, _PathNode] = {}
  615. class _ModuleNode(_PathNode):
  616. __slots__ = ["source_file"]
  617. def __init__(self, source_file: str):
  618. self.source_file = source_file
  619. class _ExternNode(_PathNode):
  620. pass
  621. # A private global registry of all modules that have been package-imported.
  622. _package_imported_modules: WeakValueDictionary = WeakValueDictionary()
  623. # `inspect` by default only looks in `sys.modules` to find source files for classes.
  624. # Patch it to check our private registry of package-imported modules as well.
  625. _orig_getfile = inspect.getfile
  626. def _patched_getfile(object):
  627. if inspect.isclass(object):
  628. if object.__module__ in _package_imported_modules:
  629. return _package_imported_modules[object.__module__].__file__
  630. return _orig_getfile(object)
  631. inspect.getfile = _patched_getfile
  632. class _PackageResourceReader:
  633. """Private class used to support PackageImporter.get_resource_reader().
  634. Confirms to the importlib.abc.ResourceReader interface. Allowed to access
  635. the innards of PackageImporter.
  636. """
  637. def __init__(self, importer, fullname):
  638. self.importer = importer
  639. self.fullname = fullname
  640. def open_resource(self, resource):
  641. from io import BytesIO
  642. return BytesIO(self.importer.load_binary(self.fullname, resource))
  643. def resource_path(self, resource):
  644. # The contract for resource_path is that it either returns a concrete
  645. # file system path or raises FileNotFoundError.
  646. if isinstance(
  647. self.importer.zip_reader, DirectoryReader
  648. ) and self.importer.zip_reader.has_record(
  649. os.path.join(self.fullname, resource)
  650. ):
  651. return os.path.join(
  652. self.importer.zip_reader.directory, self.fullname, resource
  653. )
  654. raise FileNotFoundError
  655. def is_resource(self, name):
  656. path = self.importer._zipfile_path(self.fullname, name)
  657. return self.importer.zip_reader.has_record(path)
  658. def contents(self):
  659. from pathlib import Path
  660. filename = self.fullname.replace(".", "/")
  661. fullname_path = Path(self.importer._zipfile_path(self.fullname))
  662. files = self.importer.zip_reader.get_all_records()
  663. subdirs_seen = set()
  664. for filename in files:
  665. try:
  666. relative = Path(filename).relative_to(fullname_path)
  667. except ValueError:
  668. continue
  669. # If the path of the file (which is relative to the top of the zip
  670. # namespace), relative to the package given when the resource
  671. # reader was created, has a parent, then it's a name in a
  672. # subdirectory and thus we skip it.
  673. parent_name = relative.parent.name
  674. if len(parent_name) == 0:
  675. yield relative.name
  676. elif parent_name not in subdirs_seen:
  677. subdirs_seen.add(parent_name)
  678. yield parent_name