_freeze.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. # mypy: allow-untyped-defs
  2. """Freezing.
  3. This is not intended to be imported directly; please use the exposed
  4. functionalities in `torch.jit`.
  5. """
  6. import warnings
  7. from typing import Optional
  8. import torch
  9. from torch.jit._script import RecursiveScriptModule, ScriptModule
  10. def freeze(
  11. mod, preserved_attrs: Optional[list[str]] = None, optimize_numerics: bool = True
  12. ):
  13. r"""Freeze ScriptModule, inline submodules, and attributes as constants.
  14. .. deprecated:: 2.5
  15. TorchScript is deprecated, please use ``torch.compile`` instead.
  16. Freezing a :class:`ScriptModule` will clone it and attempt to inline the cloned
  17. module's submodules, parameters, and attributes as constants in the TorchScript IR Graph.
  18. By default, `forward` will be preserved, as well as attributes & methods specified in
  19. `preserved_attrs`. Additionally, any attribute that is modified within a preserved
  20. method will be preserved.
  21. Freezing currently only accepts ScriptModules that are in eval mode.
  22. Freezing applies generic optimization that will speed up your model regardless of machine.
  23. To further optimize using server-specific settings, run `optimize_for_inference` after
  24. freezing.
  25. Args:
  26. mod (:class:`ScriptModule`): a module to be frozen
  27. preserved_attrs (Optional[List[str]]): a list of attributes to preserve in addition to the forward method.
  28. Attributes modified in preserved methods will also be preserved.
  29. optimize_numerics (bool): If ``True``, a set of optimization passes will be run that does not strictly
  30. preserve numerics. Full details of optimization can be found at `torch.jit.run_frozen_optimizations`.
  31. Returns:
  32. Frozen :class:`ScriptModule`.
  33. Example (Freezing a simple module with a Parameter):
  34. .. testcode::
  35. import torch
  36. class MyModule(torch.nn.Module):
  37. def __init__(self, N, M):
  38. super().__init__()
  39. self.weight = torch.nn.Parameter(torch.rand(N, M))
  40. self.linear = torch.nn.Linear(N, M)
  41. def forward(self, input):
  42. output = self.weight.mm(input)
  43. output = self.linear(output)
  44. return output
  45. scripted_module = torch.jit.script(MyModule(2, 3).eval())
  46. frozen_module = torch.jit.freeze(scripted_module)
  47. # parameters have been removed and inlined into the Graph as constants
  48. assert len(list(frozen_module.named_parameters())) == 0
  49. # See the compiled graph as Python code
  50. print(frozen_module.code)
  51. Example (Freezing a module with preserved attributes)
  52. .. testcode::
  53. import torch
  54. class MyModule2(torch.nn.Module):
  55. def __init__(self) -> None:
  56. super().__init__()
  57. self.modified_tensor = torch.tensor(10.)
  58. self.version = 1
  59. def forward(self, input):
  60. self.modified_tensor += 1
  61. return input + self.modified_tensor
  62. scripted_module = torch.jit.script(MyModule2().eval())
  63. frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["version"])
  64. # we've manually preserved `version`, so it still exists on the frozen module and can be modified
  65. assert frozen_module.version == 1
  66. frozen_module.version = 2
  67. # `modified_tensor` is detected as being mutated in the forward, so freezing preserves
  68. # it to retain model semantics
  69. assert frozen_module(torch.tensor(1)) == torch.tensor(12)
  70. # now that we've run it once, the next result will be incremented by one
  71. assert frozen_module(torch.tensor(1)) == torch.tensor(13)
  72. Note:
  73. Freezing submodule attributes is also supported:
  74. frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["submodule.version"])
  75. Note:
  76. If you're not sure why an attribute is not being inlined as a constant, you can run
  77. `dump_alias_db` on frozen_module.forward.graph to see if freezing has detected the
  78. attribute is being modified.
  79. Note:
  80. Because freezing makes weights constants and removes module hierarchy, `to` and other
  81. nn.Module methods to manipulate device or dtype no longer work. As a workaround,
  82. You can remap devices by specifying `map_location` in `torch.jit.load`, however
  83. device-specific logic may have been baked into the model.
  84. """
  85. warnings.warn(
  86. "`torch.jit.freeze` is deprecated. Please use `torch.compile` instead.",
  87. DeprecationWarning,
  88. )
  89. if not isinstance(mod, ScriptModule):
  90. raise RuntimeError(
  91. "Freezing expects a ScriptModule as input. "
  92. "Please use torch.jit.script or torch.jit.trace to script your 'nn.Module'."
  93. )
  94. if mod.training:
  95. raise RuntimeError(
  96. "Freezing is currently only implemented for modules in eval mode. "
  97. "Please call .eval() on your module before freezing."
  98. )
  99. preserved_attrs = preserved_attrs if preserved_attrs is not None else []
  100. out = RecursiveScriptModule(torch._C._freeze_module(mod._c, preserved_attrs))
  101. RecursiveScriptModule._finalize_scriptmodule(out)
  102. preserved_methods = [x for x in preserved_attrs if mod._c._has_method(x)]
  103. run_frozen_optimizations(out, optimize_numerics, preserved_methods)
  104. return out
  105. def run_frozen_optimizations(
  106. mod, optimize_numerics: bool = True, preserved_methods: Optional[list[str]] = None
  107. ) -> None:
  108. r"""
  109. Run a series of optimizations looking for patterns that occur in frozen graphs.
  110. The current set of optimizations includes:
  111. - Dropout Removal
  112. - Pretranspose Linear Layers
  113. - Concat Linear Layers with same input Tensor
  114. - Conv -> Batchnorm folding
  115. - Conv -> Add/Sub folding
  116. - Conv -> Mul/Div folding
  117. Args:
  118. mod (:class:`ScriptModule`): a frozen module to be optimized
  119. optimize_numerics (bool): If ``True``, a set of optimization passes will be run that does not strictly
  120. preserve numerics. These optimizations preserve default rtol and atol of `torch.testing.assert_close`
  121. when applied on a single transformation, however in a module where many transformations are applied
  122. the rtol or atol may no longer fall within the default `assert_close` tolerance. Conv -> Batchnorm folding,
  123. Conv-Add/Sub, and Conv -> Mul/Div folding all may alter numerics.
  124. Returns:
  125. None
  126. Note:
  127. In rare occasions, this can result in slower execution.
  128. Example (Freezing a module with Conv->Batchnorm)
  129. .. code-block:: python
  130. import torch
  131. in_channels, out_channels = 3, 32
  132. conv = torch.nn.Conv2d(
  133. in_channels, out_channels, kernel_size=3, stride=2, bias=True
  134. )
  135. bn = torch.nn.BatchNorm2d(out_channels, eps=0.001)
  136. mod = torch.nn.Sequential(conv, bn)
  137. # set optimize to False here, by default freezing runs run_frozen_optimizations
  138. frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize=False)
  139. # inspect frozen mod
  140. assert "batch_norm" in str(frozen_mod.graph)
  141. torch.jit.run_frozen_optimizations(frozen_mod)
  142. assert "batch_norm" not in str(frozen_mod.graph)
  143. """
  144. if mod._c._has_method("forward"):
  145. torch._C._jit_pass_optimize_frozen_graph(mod.graph, optimize_numerics)
  146. if preserved_methods is None:
  147. preserved_methods = []
  148. for method in preserved_methods:
  149. torch._C._jit_pass_optimize_frozen_graph(
  150. mod.__getattr__(method).graph, optimize_numerics
  151. )
  152. def optimize_for_inference(
  153. mod: ScriptModule, other_methods: Optional[list[str]] = None
  154. ) -> ScriptModule:
  155. """
  156. Perform a set of optimization passes to optimize a model for the purposes of inference.
  157. .. deprecated:: 2.5
  158. TorchScript is deprecated, please use ``torch.compile`` instead.
  159. If the model is not already frozen, optimize_for_inference
  160. will invoke `torch.jit.freeze` automatically.
  161. In addition to generic optimizations that should speed up your model regardless
  162. of environment, prepare for inference will also bake in build specific settings
  163. such as the presence of CUDNN or MKLDNN, and may in the future make transformations
  164. which speed things up on one machine but slow things down on another. Accordingly,
  165. serialization is not implemented following invoking `optimize_for_inference` and
  166. is not guaranteed.
  167. This is still in prototype, and may have the potential to slow down your model.
  168. Primary use cases that have been targeted so far have been vision models on cpu
  169. and gpu to a lesser extent.
  170. Example (optimizing a module with Conv->Batchnorm)::
  171. import torch
  172. in_channels, out_channels = 3, 32
  173. conv = torch.nn.Conv2d(
  174. in_channels, out_channels, kernel_size=3, stride=2, bias=True
  175. )
  176. bn = torch.nn.BatchNorm2d(out_channels, eps=0.001)
  177. mod = torch.nn.Sequential(conv, bn)
  178. frozen_mod = torch.jit.optimize_for_inference(torch.jit.script(mod.eval()))
  179. assert "batch_norm" not in str(frozen_mod.graph)
  180. # if built with MKLDNN, convolution will be run with MKLDNN weights
  181. assert "MKLDNN" in frozen_mod.graph
  182. """
  183. warnings.warn(
  184. "`torch.jit.optimize_for_inference` is deprecated. Please use `torch.compile` instead.",
  185. DeprecationWarning,
  186. )
  187. if not isinstance(mod, ScriptModule):
  188. raise RuntimeError(
  189. "optimize_for_inference expects a ScriptModule as input. "
  190. "Please use torch.jit.script or torch.jit.trace to script your 'nn.Module'."
  191. )
  192. if other_methods is None:
  193. other_methods = []
  194. if hasattr(mod, "training"):
  195. mod = freeze(mod.eval(), preserved_attrs=other_methods)
  196. torch._C._jit_pass_optimize_for_inference(mod._c, other_methods)
  197. return mod