_mangling.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. # mypy: allow-untyped-defs
  2. """Import mangling.
  3. See mangling.md for details.
  4. """
  5. import re
  6. _mangle_index = 0
  7. class PackageMangler:
  8. """
  9. Used on import, to ensure that all modules imported have a shared mangle parent.
  10. """
  11. def __init__(self) -> None:
  12. global _mangle_index
  13. self._mangle_index = _mangle_index
  14. # Increment the global index
  15. _mangle_index += 1
  16. # Angle brackets are used so that there is almost no chance of
  17. # confusing this module for a real module. Plus, it is Python's
  18. # preferred way of denoting special modules.
  19. self._mangle_parent = f"<torch_package_{self._mangle_index}>"
  20. def mangle(self, name) -> str:
  21. if len(name) == 0:
  22. raise AssertionError("name must not be empty")
  23. return self._mangle_parent + "." + name
  24. def demangle(self, mangled: str) -> str:
  25. """
  26. Note: This only demangles names that were mangled by this specific
  27. PackageMangler. It will pass through names created by a different
  28. PackageMangler instance.
  29. """
  30. if mangled.startswith(self._mangle_parent + "."):
  31. return mangled.partition(".")[2]
  32. # wasn't a mangled name
  33. return mangled
  34. def parent_name(self):
  35. return self._mangle_parent
  36. def is_mangled(name: str) -> bool:
  37. return bool(re.match(r"<torch_package_\d+>", name))
  38. def demangle(name: str) -> str:
  39. """
  40. Note: Unlike PackageMangler.demangle, this version works on any
  41. mangled name, irrespective of which PackageMangler created it.
  42. """
  43. if is_mangled(name):
  44. _first, sep, last = name.partition(".")
  45. # If there is only a base mangle prefix, e.g. '<torch_package_0>',
  46. # then return an empty string.
  47. return last if len(sep) != 0 else ""
  48. return name
  49. def get_mangle_prefix(name: str) -> str:
  50. return name.partition(".")[0] if is_mangled(name) else name