__init__.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. r'''
  2. FX is a toolkit for developers to use to transform ``nn.Module``
  3. instances. FX consists of three main components: a **symbolic tracer,**
  4. an **intermediate representation**, and **Python code generation**. A
  5. demonstration of these components in action:
  6. ::
  7. import torch
  8. # Simple module for demonstration
  9. class MyModule(torch.nn.Module):
  10. def __init__(self) -> None:
  11. super().__init__()
  12. self.param = torch.nn.Parameter(torch.rand(3, 4))
  13. self.linear = torch.nn.Linear(4, 5)
  14. def forward(self, x):
  15. return self.linear(x + self.param).clamp(min=0.0, max=1.0)
  16. module = MyModule()
  17. from torch.fx import symbolic_trace
  18. # Symbolic tracing frontend - captures the semantics of the module
  19. symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
  20. # High-level intermediate representation (IR) - Graph representation
  21. print(symbolic_traced.graph)
  22. """
  23. graph():
  24. %x : [num_users=1] = placeholder[target=x]
  25. %param : [num_users=1] = get_attr[target=param]
  26. %add : [num_users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
  27. %linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {})
  28. %clamp : [num_users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
  29. return clamp
  30. """
  31. # Code generation - valid Python code
  32. print(symbolic_traced.code)
  33. """
  34. def forward(self, x):
  35. param = self.param
  36. add = x + param; x = param = None
  37. linear = self.linear(add); add = None
  38. clamp = linear.clamp(min = 0.0, max = 1.0); linear = None
  39. return clamp
  40. """
  41. The **symbolic tracer** performs "symbolic execution" of the Python
  42. code. It feeds fake values, called Proxies, through the code. Operations
  43. on these Proxies are recorded. More information about symbolic tracing
  44. can be found in the :func:`symbolic_trace` and :class:`Tracer`
  45. documentation.
  46. The **intermediate representation** is the container for the operations
  47. that were recorded during symbolic tracing. It consists of a list of
  48. Nodes that represent function inputs, callsites (to functions, methods,
  49. or :class:`torch.nn.Module` instances), and return values. More information
  50. about the IR can be found in the documentation for :class:`Graph`. The
  51. IR is the format on which transformations are applied.
  52. **Python code generation** is what makes FX a Python-to-Python (or
  53. Module-to-Module) transformation toolkit. For each Graph IR, we can
  54. create valid Python code matching the Graph's semantics. This
  55. functionality is wrapped up in :class:`GraphModule`, which is a
  56. :class:`torch.nn.Module` instance that holds a :class:`Graph` as well as a
  57. ``forward`` method generated from the Graph.
  58. Taken together, this pipeline of components (symbolic tracing ->
  59. intermediate representation -> transforms -> Python code generation)
  60. constitutes the Python-to-Python transformation pipeline of FX. In
  61. addition, these components can be used separately. For example,
  62. symbolic tracing can be used in isolation to capture a form of
  63. the code for analysis (and not transformation) purposes. Code
  64. generation can be used for programmatically generating models, for
  65. example from a config file. There are many uses for FX!
  66. Several example transformations can be found at the
  67. `examples <https://github.com/pytorch/examples/tree/master/fx>`__
  68. repository.
  69. '''
  70. from torch.fx import immutable_collections
  71. from torch.fx._symbolic_trace import ( # noqa: F401
  72. PH,
  73. ProxyableClassMeta,
  74. symbolic_trace,
  75. Tracer,
  76. wrap,
  77. )
  78. from torch.fx.graph import CodeGen, Graph # noqa: F401
  79. from torch.fx.graph_module import GraphModule
  80. from torch.fx.interpreter import Interpreter, Transformer
  81. from torch.fx.node import has_side_effect, map_arg, Node
  82. from torch.fx.proxy import Proxy
  83. from torch.fx.subgraph_rewriter import replace_pattern
  84. __all__ = [
  85. "symbolic_trace",
  86. "Tracer",
  87. "wrap",
  88. "Graph",
  89. "GraphModule",
  90. "Interpreter",
  91. "Transformer",
  92. "Node",
  93. "Proxy",
  94. "replace_pattern",
  95. "has_side_effect",
  96. "map_arg",
  97. ]