tools.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. # mypy: allow-untyped-defs
  2. import logging
  3. import warnings
  4. from collections.abc import Iterable
  5. from typing import Any, Optional
  6. import torch
  7. import torch.export
  8. import torch.export._trace
  9. from torch._utils_internal import log_export_usage
  10. log = logging.getLogger(__name__)
  11. __all__ = ["report_exportability"]
  12. def _generate_inputs_for_submodules(
  13. model: torch.nn.Module,
  14. target_submodules: Iterable[str],
  15. args: tuple[Any, ...],
  16. kwargs: Optional[dict[str, Any]] = None,
  17. ) -> dict[str, tuple[Any, Any]]:
  18. """
  19. Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this
  20. function doesn't work.
  21. Args:
  22. model: root model.
  23. inputs: inputs to the root model.
  24. target_submodules: submodules that we want to generate inputs for.
  25. Returns:
  26. A dict that maps from submodule name to its inputs.
  27. """
  28. kwargs = kwargs or {}
  29. handles = []
  30. results = {}
  31. submodule_to_names = {mod: name for name, mod in model.named_modules()}
  32. def pre_forward(module, module_args, module_kwargs):
  33. results[submodule_to_names[module]] = (module_args, module_kwargs)
  34. try:
  35. for name, mod in model.named_modules():
  36. if name in target_submodules:
  37. handles.append(
  38. mod.register_forward_pre_hook(pre_forward, with_kwargs=True)
  39. )
  40. model(*args, **kwargs)
  41. except Exception as e:
  42. warnings.warn(
  43. f"Failed to generate submodule inputs because of the following error:\n{e}",
  44. stacklevel=2,
  45. )
  46. finally:
  47. for h in handles:
  48. h.remove()
  49. return results
  50. def report_exportability(
  51. mod: torch.nn.Module,
  52. args: tuple[Any, ...],
  53. kwargs: Optional[dict[str, Any]] = None,
  54. *,
  55. strict: bool = True,
  56. pre_dispatch: bool = False,
  57. ) -> dict[str, Optional[Exception]]:
  58. """
  59. Report exportability issues for a module in one-shot.
  60. Args:
  61. mod: root module.
  62. args: args to the root module.
  63. kwargs: kwargs to the root module.
  64. Returns:
  65. A dict that maps from submodule name to the exception that was raised when trying to export it.
  66. `None` means the module is exportable without issue.
  67. Sample output:
  68. {
  69. '': UnsupportedOperatorException(func=<OpOverload(op='testlib.op_missing_meta', overload='default')>),
  70. 'submod_1': UnsupportedOperatorException(func=<OpOverload(op='testlib.op_missing_meta', overload='default')>),
  71. 'submod_2': None
  72. }
  73. """
  74. log_export_usage(event="export.report_exportability")
  75. kwargs = kwargs or {}
  76. all_submod_names = [name for name, _ in mod.named_modules() if name != ""]
  77. submod_inputs = _generate_inputs_for_submodules(mod, all_submod_names, args, kwargs)
  78. tried_module_types = set()
  79. report: dict[str, Optional[Exception]] = {}
  80. def try_export(module, module_name, args, kwargs):
  81. nonlocal submod_inputs, report, strict, pre_dispatch, tried_module_types
  82. if type(module) in tried_module_types:
  83. return
  84. tried_module_types.add(type(module))
  85. if args is not None or kwargs is not None:
  86. try:
  87. torch.export._trace._export(
  88. module,
  89. args,
  90. kwargs,
  91. strict=strict,
  92. pre_dispatch=pre_dispatch,
  93. )
  94. report[module_name] = None
  95. log.info("Successfully exported `%s`", module_name)
  96. return
  97. except Exception as e:
  98. short_msg = repr(e).split("\n")[0]
  99. log.warning(
  100. "Failed exporting `%s` with exception: %s", module_name, short_msg
  101. )
  102. report[module_name] = e
  103. for name, submod in module.named_children():
  104. sub_module_name = name if module_name == "" else f"{module_name}.{name}"
  105. submod_args, submod_kwargs = submod_inputs.get(
  106. sub_module_name, (None, None)
  107. )
  108. try_export(submod, sub_module_name, submod_args, submod_kwargs)
  109. return
  110. try_export(mod, "", args, kwargs)
  111. unique_issues = set()
  112. for exception in report.values():
  113. if exception is not None:
  114. key = repr(exception).split("\\n")[0]
  115. unique_issues.add(key)
  116. log.warning("Found %d export issues:", len(unique_issues))
  117. for issue in unique_issues:
  118. log.warning(issue)
  119. return report