rewriter.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. # mypy: allow-untyped-decorators
  2. # mypy: allow-untyped-defs
  3. import ast
  4. import copy
  5. import functools
  6. import inspect
  7. import textwrap
  8. from collections.abc import Callable
  9. from types import FunctionType
  10. from typing import Any, cast, Optional, Union
  11. import torch
  12. from torch._sources import normalize_source_lines
  13. from torch.fx._symbolic_trace import Tracer
  14. from torch.fx.graph import Graph
  15. class AST_Rewriter(ast.NodeTransformer):
  16. """
  17. Take a FunctionType object representing a `forward` method, then
  18. perform an AST rewrite to swap out nodes that are not symbolically
  19. traceable with a callsite to the FX alternative.
  20. To support swapping out an AST node, define a new `visit` method on
  21. that node. For more details, see:
  22. https://docs.python.org/3/library/ast.html#ast.NodeTransformer
  23. """
  24. # This function checks for new keys added in the globals dict. TorchDynamo
  25. # can insert new keys in the global dict and upset the check. Therefore, put
  26. # a disable here. This function is an optimization pass and not really
  27. # suitable for dynamo tracing anyways.
  28. @torch._dynamo.disable
  29. def rewrite(self, fn: FunctionType):
  30. # Normalize the source lines
  31. sourcelines, _ = inspect.getsourcelines(fn)
  32. sourcelines = normalize_source_lines(sourcelines)
  33. source = "".join(sourcelines)
  34. normalized_str = textwrap.dedent(source)
  35. # Rewrite the original AST
  36. source_ast = ast.parse(normalized_str)
  37. dest_ast = ast.fix_missing_locations(self.visit(source_ast))
  38. # Pull out the compiled function from the newly-created Module
  39. code = compile(dest_ast, "", "exec")
  40. globals_dict = copy.copy(fn.__globals__)
  41. keys_before = set(globals_dict.keys())
  42. exec(code, globals_dict)
  43. new_keys = list(set(globals_dict.keys()) - keys_before)
  44. if len(new_keys) != 1:
  45. raise AssertionError(f"Expected 1 new key, got {len(new_keys)}")
  46. fn_compiled = globals_dict[new_keys[0]]
  47. # return the compiled function with the original globals
  48. def change_func_globals(f, globals):
  49. """Based on https://stackoverflow.com/a/13503277/2988730 (@unutbu)"""
  50. # __globals__ is a private member of the function class
  51. # so we have to copy the function, f, all of its member, except f.__globals__
  52. g = FunctionType(
  53. f.__code__,
  54. globals,
  55. name=f.__name__,
  56. argdefs=f.__defaults__,
  57. closure=f.__closure__,
  58. )
  59. g = functools.update_wrapper(g, f)
  60. g.__kwdefaults__ = copy.copy(f.__kwdefaults__) # type:ignore[attr-defined]
  61. return g
  62. # Return the correct FunctionType object
  63. return change_func_globals(fn_compiled, globals=fn.__globals__)
  64. def visit_Assert(self, node):
  65. """
  66. Swap out the Assert node (Python's `assert`) with a callsite to the
  67. symbolically-traceable torch._assert function
  68. """
  69. # Create the Call node
  70. n = ast.parse("torch._assert()", mode="eval")
  71. if not isinstance(n, ast.Expression):
  72. raise AssertionError(f"Expected ast.Expression, got {type(n)}")
  73. call_node = n.body
  74. if not isinstance(call_node, ast.Call):
  75. raise AssertionError(f"Expected ast.Call, got {type(call_node)}")
  76. msg = node.msg if node.msg else ast.Constant(value="", kind=None)
  77. call_node.args = [node.test, msg]
  78. # Ensure that the new node conforms to the Python AST grammar
  79. expr_wrapper = ast.Expr(value=call_node)
  80. # Return the new Call node to signify that we want to use it as
  81. # a replacement for the original _assert node
  82. return ast.copy_location(expr_wrapper, node)
  83. def visit_AnnAssign(self, node):
  84. """
  85. Swap out Python's AnnAssign with an Assign node where the annotation function is called.
  86. Example:
  87. Original:
  88. y: Tensor_Type(1,2,3, Dyn) = f2(x)
  89. Output:
  90. y = annotate(f2(x),Tensor_Type((1,2,3,Dyn)))
  91. """
  92. return ast.Assign(
  93. targets=[node.target],
  94. value=ast.Call(
  95. func=ast.Name(id="annotate", ctx=ast.Load()),
  96. args=[node.value, node.annotation],
  97. keywords=[],
  98. ),
  99. )
  100. class RewritingTracer(Tracer):
  101. def trace(
  102. self,
  103. root: Union[torch.nn.Module, Callable],
  104. concrete_args: Optional[dict[str, Any]] = None,
  105. ) -> Graph:
  106. return super().trace(_rewrite(root), concrete_args)
  107. def _rewrite(fn: Union[torch.nn.Module, Callable]) -> Union[torch.nn.Module, Callable]:
  108. if isinstance(fn, torch.nn.Module):
  109. # Rewrite this module's `forward` as well as the `forward`s of
  110. # all of this module's recursive descendents. Return the new,
  111. # rewritten module hierarchy.
  112. def rewrite_module(m: torch.nn.Module):
  113. class RewrittenModule(torch.nn.Module):
  114. def __init__(self, orig):
  115. super().__init__()
  116. for k, v in orig.__dict__.items():
  117. if isinstance(v, torch.nn.Module):
  118. self.__dict__[k] = copy.copy(rewrite_module(v))
  119. else:
  120. self.__dict__[k] = copy.copy(v)
  121. RewrittenModule.forward = AST_Rewriter().rewrite(
  122. cast(FunctionType, m.forward)
  123. )
  124. return RewrittenModule(m)
  125. return rewrite_module(fn)
  126. else:
  127. # Rewrite this single free function
  128. return AST_Rewriter().rewrite(cast(FunctionType, fn))