__init__.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. # mypy: allow-untyped-defs
  2. import warnings
  3. from collections.abc import Iterator
  4. from contextlib import contextmanager
  5. from typing import Any
  6. import torch._C
  7. # These are imported so users can access them from the `torch.jit` module
  8. from torch._jit_internal import (
  9. _Await,
  10. _drop,
  11. _IgnoreContextManager,
  12. _isinstance,
  13. _overload,
  14. _overload_method,
  15. export,
  16. Final,
  17. Future,
  18. ignore,
  19. is_scripting,
  20. unused,
  21. )
  22. from torch.jit._async import fork, wait
  23. from torch.jit._await import _awaitable, _awaitable_nowait, _awaitable_wait
  24. from torch.jit._decomposition_utils import _register_decomposition
  25. from torch.jit._freeze import freeze, optimize_for_inference, run_frozen_optimizations
  26. from torch.jit._fuser import (
  27. fuser,
  28. last_executed_optimized_graph,
  29. optimized_execution,
  30. set_fusion_strategy,
  31. )
  32. from torch.jit._ir_utils import _InsertPoint
  33. from torch.jit._script import (
  34. _ScriptProfile,
  35. _unwrap_optional,
  36. Attribute,
  37. CompilationUnit,
  38. interface,
  39. RecursiveScriptClass,
  40. RecursiveScriptModule,
  41. script,
  42. script_method,
  43. ScriptFunction,
  44. ScriptModule,
  45. ScriptWarning,
  46. )
  47. from torch.jit._serialization import (
  48. jit_module_from_flatbuffer,
  49. load,
  50. save,
  51. save_jit_module_to_flatbuffer,
  52. )
  53. from torch.jit._trace import (
  54. _flatten,
  55. _get_trace_graph,
  56. _script_if_tracing,
  57. _unique_state_dict,
  58. is_tracing,
  59. ONNXTracedModule,
  60. TopLevelTracedModule,
  61. trace,
  62. trace_module,
  63. TracedModule,
  64. TracerWarning,
  65. TracingCheckError,
  66. )
  67. from torch.utils import set_module
  68. __all__ = [
  69. "Attribute",
  70. "CompilationUnit",
  71. "Error",
  72. "Future",
  73. "ScriptFunction",
  74. "ScriptModule",
  75. "annotate",
  76. "enable_onednn_fusion",
  77. "export",
  78. "export_opnames",
  79. "fork",
  80. "freeze",
  81. "interface",
  82. "ignore",
  83. "isinstance",
  84. "load",
  85. "onednn_fusion_enabled",
  86. "optimize_for_inference",
  87. "save",
  88. "script",
  89. "script_if_tracing",
  90. "set_fusion_strategy",
  91. "strict_fusion",
  92. "trace",
  93. "trace_module",
  94. "unused",
  95. "wait",
  96. ]
  97. # For backwards compatibility
  98. _fork = fork
  99. _wait = wait
  100. _set_fusion_strategy = set_fusion_strategy
  101. def export_opnames(m):
  102. r"""
  103. Generate new bytecode for a Script module.
  104. Returns what the op list would be for a Script Module based off the current code base.
  105. If you have a LiteScriptModule and want to get the currently present
  106. list of ops call _export_operator_list instead.
  107. """
  108. return torch._C._export_opnames(m._c)
  109. # torch.jit.Error
  110. Error = torch._C.JITException
  111. set_module(Error, "torch.jit")
  112. # This is not perfect but works in common cases
  113. Error.__name__ = "Error"
  114. Error.__qualname__ = "Error"
  115. # for use in python if using annotate
  116. def annotate(the_type, the_value):
  117. """Use to give type of `the_value` in TorchScript compiler.
  118. .. deprecated:: 2.5
  119. TorchScript is deprecated, please use ``torch.compile`` instead.
  120. This method is a pass-through function that returns `the_value`, used to hint TorchScript
  121. compiler the type of `the_value`. It is a no-op when running outside of TorchScript.
  122. Though TorchScript can infer correct type for most Python expressions, there are some cases where
  123. type inference can be wrong, including:
  124. - Empty containers like `[]` and `{}`, which TorchScript assumes to be container of `Tensor`
  125. - Optional types like `Optional[T]` but assigned a valid value of type `T`, TorchScript would assume
  126. it is type `T` rather than `Optional[T]`
  127. Note that `annotate()` does not help in `__init__` method of `torch.nn.Module` subclasses because it
  128. is executed in eager mode. To annotate types of `torch.nn.Module` attributes,
  129. use :meth:`~torch.jit.Attribute` instead.
  130. Example:
  131. .. testcode::
  132. import torch
  133. from typing import Dict
  134. @torch.jit.script
  135. def fn():
  136. # Telling TorchScript that this empty dictionary is a (str -> int) dictionary
  137. # instead of default dictionary type of (str -> Tensor).
  138. d = torch.jit.annotate(Dict[str, int], {})
  139. # Without `torch.jit.annotate` above, following statement would fail because of
  140. # type mismatch.
  141. d["name"] = 20
  142. .. testcleanup::
  143. del fn
  144. Args:
  145. the_type: Python type that should be passed to TorchScript compiler as type hint for `the_value`
  146. the_value: Value or expression to hint type for.
  147. Returns:
  148. `the_value` is passed back as return value.
  149. """
  150. return the_value
  151. def script_if_tracing(fn):
  152. """
  153. Compiles ``fn`` when it is first called during tracing.
  154. .. deprecated:: 2.5
  155. TorchScript is deprecated, please use ``torch.compile`` instead.
  156. ``torch.jit.script`` has a non-negligible start up time when it is first called due to
  157. lazy-initializations of many compiler builtins. Therefore you should not use
  158. it in library code. However, you may want to have parts of your library work
  159. in tracing even if they use control flow. In these cases, you should use
  160. ``@torch.jit.script_if_tracing`` to substitute for
  161. ``torch.jit.script``.
  162. Args:
  163. fn: A function to compile.
  164. Returns:
  165. If called during tracing, a :class:`ScriptFunction` created by `torch.jit.script` is returned.
  166. Otherwise, the original function `fn` is returned.
  167. """
  168. return _script_if_tracing(fn)
  169. # for torch.jit.isinstance
  170. def isinstance(obj, target_type):
  171. """
  172. Provide container type refinement in TorchScript.
  173. .. deprecated:: 2.5
  174. TorchScript is deprecated, please use ``torch.compile`` instead.
  175. It can refine parameterized containers of the List, Dict, Tuple, and Optional types. E.g. ``List[str]``,
  176. ``Dict[str, List[torch.Tensor]]``, ``Optional[Tuple[int,str,int]]``. It can also
  177. refine basic types such as bools and ints that are available in TorchScript.
  178. Args:
  179. obj: object to refine the type of
  180. target_type: type to try to refine obj to
  181. Returns:
  182. ``bool``: True if obj was successfully refined to the type of target_type,
  183. False otherwise with no new type refinement
  184. Example (using ``torch.jit.isinstance`` for type refinement):
  185. .. testcode::
  186. import torch
  187. from typing import Any, Dict, List
  188. class MyModule(torch.nn.Module):
  189. def __init__(self) -> None:
  190. super().__init__()
  191. def forward(self, input: Any): # note the Any type
  192. if torch.jit.isinstance(input, List[torch.Tensor]):
  193. for t in input:
  194. y = t.clamp(0, 0.5)
  195. elif torch.jit.isinstance(input, Dict[str, str]):
  196. for val in input.values():
  197. print(val)
  198. m = torch.jit.script(MyModule())
  199. x = [torch.rand(3,3), torch.rand(4,3)]
  200. m(x)
  201. y = {"key1":"val1","key2":"val2"}
  202. m(y)
  203. """
  204. return _isinstance(obj, target_type)
  205. class strict_fusion:
  206. """
  207. Give errors if not all nodes have been fused in inference, or symbolically differentiated in training.
  208. .. deprecated:: 2.5
  209. TorchScript is deprecated, please use ``torch.compile`` instead.
  210. Example:
  211. Forcing fusion of additions.
  212. .. code-block:: python
  213. @torch.jit.script
  214. def foo(x):
  215. with torch.jit.strict_fusion():
  216. return x + x + x
  217. """
  218. def __init__(self) -> None:
  219. if not torch._jit_internal.is_scripting():
  220. warnings.warn("Only works in script mode", stacklevel=2)
  221. def __enter__(self):
  222. pass
  223. def __exit__(self, type: Any, value: Any, tb: Any) -> None:
  224. pass
  225. # Context manager for globally hiding source ranges when printing graphs.
  226. # Note that these functions are exposed to Python as static members of the
  227. # Graph class, so mypy checks need to be skipped.
  228. @contextmanager
  229. def _hide_source_ranges() -> Iterator[None]:
  230. old_enable_source_ranges = torch._C.Graph.global_print_source_ranges # type: ignore[attr-defined]
  231. try:
  232. torch._C.Graph.set_global_print_source_ranges(False) # type: ignore[attr-defined]
  233. yield
  234. finally:
  235. torch._C.Graph.set_global_print_source_ranges(old_enable_source_ranges) # type: ignore[attr-defined]
  236. def enable_onednn_fusion(enabled: bool) -> None:
  237. """Enable or disables onednn JIT fusion based on the parameter `enabled`.
  238. .. deprecated:: 2.5
  239. TorchScript is deprecated, please use ``torch.compile`` instead.
  240. """
  241. torch._C._jit_set_llga_enabled(enabled)
  242. def onednn_fusion_enabled():
  243. """Return whether onednn JIT fusion is enabled.
  244. .. deprecated:: 2.5
  245. TorchScript is deprecated, please use ``torch.compile`` instead.
  246. """
  247. return torch._C._jit_llga_enabled()
  248. del Any
  249. if not torch._C._jit_init():
  250. raise RuntimeError("JIT initialization failed")