loader.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. # Used to load and initialize polyfill handlers when importing torch._dynamo
  2. # Please add a new import when adding a new polyfill module.
  3. import importlib
  4. from typing import TYPE_CHECKING
  5. import torch.utils._pytree as python_pytree
  6. from .. import polyfills, trace_rules
  7. if TYPE_CHECKING:
  8. from types import ModuleType
  9. # See also the TYPE_CHECKING block in torch/_dynamo/polyfills/__init__.py
  10. POLYFILLED_MODULE_NAMES: tuple[str, ...] = (
  11. "_collections",
  12. "builtins",
  13. "functools",
  14. "itertools",
  15. "operator",
  16. "os",
  17. "struct",
  18. "sys",
  19. "fx",
  20. "tensor",
  21. "torch_c_nn",
  22. "traceback",
  23. )
  24. if python_pytree._cxx_pytree_dynamo_traceable:
  25. POLYFILLED_MODULE_NAMES += ("pytree",)
  26. POLYFILLED_MODULES: tuple["ModuleType", ...] = tuple(
  27. importlib.import_module(f".{submodule}", package=polyfills.__name__)
  28. for submodule in POLYFILLED_MODULE_NAMES
  29. )
  30. # Unregister the builtin functions from _builtin_function_ids to let them to be
  31. # dispatched with the appropriate VariableTracker type. Otherwise, they will be
  32. # dispatched with BuiltinVariable if present in _builtin_function_ids.
  33. for polyfill_module in POLYFILLED_MODULES:
  34. for polyfill_name in polyfill_module.__all__:
  35. polyfill_handler = getattr(polyfill_module, polyfill_name)
  36. original_fn = polyfill_handler.__torch_dynamo_original__
  37. trace_rules._builtin_function_ids.remove(id(original_fn))