_script.pyi 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. # mypy: allow-untyped-defs
  2. # mypy: disable-error-code="type-arg"
  3. from collections.abc import Callable, Iterator, Mapping
  4. from typing import Any, NamedTuple, overload, TypeAlias, TypeVar
  5. from typing_extensions import Never, Self
  6. from _typeshed import Incomplete
  7. import torch
  8. from torch._classes import classes as classes
  9. from torch._jit_internal import _qualified_name as _qualified_name
  10. from torch.jit._builtins import _register_builtin as _register_builtin
  11. from torch.jit._fuser import (
  12. _graph_for as _graph_for,
  13. _script_method_graph_for as _script_method_graph_for,
  14. )
  15. from torch.jit._monkeytype_config import (
  16. JitTypeTraceConfig as JitTypeTraceConfig,
  17. JitTypeTraceStore as JitTypeTraceStore,
  18. monkeytype_trace as monkeytype_trace,
  19. )
  20. from torch.jit._recursive import (
  21. _compile_and_register_class as _compile_and_register_class,
  22. infer_methods_to_compile as infer_methods_to_compile,
  23. ScriptMethodStub as ScriptMethodStub,
  24. wrap_cpp_module as wrap_cpp_module,
  25. )
  26. from torch.jit._serialization import validate_map_location as validate_map_location
  27. from torch.jit._state import (
  28. _enabled as _enabled,
  29. _set_jit_function_cache as _set_jit_function_cache,
  30. _set_jit_overload_cache as _set_jit_overload_cache,
  31. _try_get_jit_cached_function as _try_get_jit_cached_function,
  32. _try_get_jit_cached_overloads as _try_get_jit_cached_overloads,
  33. )
  34. from torch.jit.frontend import (
  35. get_default_args as get_default_args,
  36. get_jit_class_def as get_jit_class_def,
  37. get_jit_def as get_jit_def,
  38. )
  39. from torch.nn import Module as Module
  40. from torch.overrides import (
  41. has_torch_function as has_torch_function,
  42. has_torch_function_unary as has_torch_function_unary,
  43. has_torch_function_variadic as has_torch_function_variadic,
  44. )
  45. from torch.package import (
  46. PackageExporter as PackageExporter,
  47. PackageImporter as PackageImporter,
  48. )
  49. from torch.utils import set_module as set_module
  50. ScriptFunction = torch._C.ScriptFunction
  51. type_trace_db: JitTypeTraceStore
  52. # Defined in torch/csrc/jit/python/script_init.cpp
  53. ResolutionCallback: TypeAlias = Callable[[str], Callable[..., Any]]
  54. _ClassVar = TypeVar("_ClassVar", bound=type)
  55. _T = TypeVar("_T")
  56. def _reduce(cls) -> None: ...
  57. class Attribute(NamedTuple):
  58. value: Incomplete
  59. type: Incomplete
  60. def _get_type_trace_db(): ...
  61. def _get_function_from_type(cls, name): ...
  62. def _is_new_style_class(cls): ...
  63. class OrderedDictWrapper:
  64. _c: Incomplete
  65. def __init__(self, _c) -> None: ...
  66. def keys(self): ...
  67. def values(self): ...
  68. def __len__(self) -> int: ...
  69. def __delitem__(self, k) -> None: ...
  70. def items(self): ...
  71. def __setitem__(self, k, v) -> None: ...
  72. def __contains__(self, k) -> bool: ...
  73. def __getitem__(self, k): ...
  74. class OrderedModuleDict(OrderedDictWrapper):
  75. _python_modules: Incomplete
  76. def __init__(self, module, python_dict) -> None: ...
  77. def items(self): ...
  78. def __contains__(self, k) -> bool: ...
  79. def __setitem__(self, k, v) -> None: ...
  80. def __getitem__(self, k): ...
  81. class ScriptMeta(type):
  82. def __init__(cls, name, bases, attrs) -> None: ...
  83. class _CachedForward:
  84. def __get__(self, obj, cls): ...
  85. class ScriptWarning(Warning): ...
  86. def script_method(fn): ...
  87. class ConstMap:
  88. const_mapping: Mapping[str, Any]
  89. def __init__(self, const_mapping: Mapping[str, Any]) -> None: ...
  90. def __getattr__(self, attr: str) -> Any: ...
  91. def unpackage_script_module(
  92. importer: PackageImporter,
  93. script_module_id: str,
  94. ) -> torch.nn.Module: ...
  95. _magic_methods: Incomplete
  96. class RecursiveScriptClass:
  97. _c: Incomplete
  98. _props: Incomplete
  99. def __init__(self, cpp_class) -> None: ...
  100. def __getattr__(self, attr: str) -> Any: ...
  101. def __setattr__(self, attr: str, value: Any) -> None: ...
  102. def forward_magic_method(
  103. self, method_name: str, *args: Any, **kwargs: Any
  104. ) -> Any: ...
  105. def __getstate__(self) -> None: ...
  106. def __iadd__(self, other: Self) -> Self: ...
  107. def method_template(self, *args, **kwargs): ...
  108. class ScriptModule(Module, metaclass=ScriptMeta):
  109. __jit_unused_properties__: Incomplete
  110. def __init__(self) -> None: ...
  111. forward: Callable[..., Any]
  112. def __getattr__(self, attr: str) -> Any: ...
  113. def __setattr__(self, attr: str, value: Any) -> None: ...
  114. def define(self, src) -> None: ...
  115. def _replicate_for_data_parallel(self): ...
  116. def __reduce_package__(
  117. self, exporter: PackageExporter
  118. ) -> tuple[Any, tuple[Any, ...]]: ...
  119. # add __jit_unused_properties__
  120. @property
  121. def code(self) -> str: ...
  122. @property
  123. def code_with_constants(self) -> tuple[str, ConstMap]: ...
  124. @property
  125. def graph(self) -> torch.Graph: ...
  126. @property
  127. def inlined_graph(self) -> torch.Graph: ...
  128. @property
  129. def original_name(self) -> str: ...
  130. class RecursiveScriptModule(ScriptModule):
  131. _disable_script_meta: bool
  132. _c: Incomplete
  133. def __init__(self, cpp_module) -> None: ...
  134. @staticmethod
  135. def _construct(cpp_module, init_fn): ...
  136. @staticmethod
  137. def _finalize_scriptmodule(script_module) -> None: ...
  138. _concrete_type: Incomplete
  139. _modules: Incomplete
  140. _parameters: Incomplete
  141. _buffers: Incomplete
  142. __dict__: Incomplete
  143. def _reconstruct(self, cpp_module) -> None: ...
  144. def save(self, f, **kwargs): ...
  145. def _save_for_lite_interpreter(self, *args, **kwargs): ...
  146. def _save_to_buffer_for_lite_interpreter(self, *args, **kwargs): ...
  147. def save_to_buffer(self, *args, **kwargs): ...
  148. def get_debug_state(self, *args, **kwargs): ...
  149. def extra_repr(self) -> str: ...
  150. def graph_for(self, *args, **kwargs): ...
  151. def define(self, src) -> None: ...
  152. def __getattr__(self, attr: str) -> Any: ...
  153. def __setattr__(self, attr: str, value: Any) -> None: ...
  154. def __copy__(self) -> Self: ...
  155. def __deepcopy__(self, memo: dict[int, Any] | None) -> Self: ...
  156. def forward_magic_method(
  157. self, method_name: str, *args: Any, **kwargs: Any
  158. ) -> Any: ...
  159. def __iter__(self) -> Iterator[Any]: ...
  160. def __getitem__(self, idx: int) -> Any: ...
  161. def __len__(self) -> int: ...
  162. def __contains__(self, key: str) -> bool: ...
  163. def __dir__(self) -> list[str]: ...
  164. def __bool__(self) -> bool: ...
  165. def _replicate_for_data_parallel(self): ...
  166. def _get_methods(cls): ...
  167. _compiled_methods_allowlist: Incomplete
  168. def _make_fail(name): ...
  169. def call_prepare_scriptable_func_impl(obj, memo): ...
  170. def call_prepare_scriptable_func(obj): ...
  171. def create_script_dict(obj): ...
  172. def create_script_list(obj, type_hint: Incomplete | None = ...): ...
  173. @overload
  174. def script(
  175. obj: type[Module],
  176. optimize: bool | None = None,
  177. _frames_up: int = 0,
  178. _rcb: ResolutionCallback | None = None,
  179. example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
  180. ) -> Never: ...
  181. @overload
  182. def script(
  183. obj: dict,
  184. optimize: bool | None = None,
  185. _frames_up: int = 0,
  186. _rcb: ResolutionCallback | None = None,
  187. example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
  188. ) -> torch.ScriptDict: ...
  189. @overload
  190. def script(
  191. obj: list,
  192. optimize: bool | None = None,
  193. _frames_up: int = 0,
  194. _rcb: ResolutionCallback | None = None,
  195. example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
  196. ) -> torch.ScriptList: ...
  197. @overload
  198. def script( # type: ignore[overload-overlap]
  199. obj: Module,
  200. optimize: bool | None = None,
  201. _frames_up: int = 0,
  202. _rcb: ResolutionCallback | None = None,
  203. example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
  204. ) -> RecursiveScriptModule: ...
  205. @overload
  206. def script( # type: ignore[overload-overlap]
  207. obj: _ClassVar,
  208. optimize: bool | None = None,
  209. _frames_up: int = 0,
  210. _rcb: ResolutionCallback | None = None,
  211. example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
  212. ) -> _ClassVar: ...
  213. @overload
  214. def script(
  215. obj: Callable,
  216. optimize: bool | None = None,
  217. _frames_up: int = 0,
  218. _rcb: ResolutionCallback | None = None,
  219. example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
  220. ) -> ScriptFunction: ...
  221. @overload
  222. def script(
  223. obj: Any,
  224. optimize: bool | None = None,
  225. _frames_up: int = 0,
  226. _rcb: ResolutionCallback | None = None,
  227. example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = None,
  228. ) -> RecursiveScriptClass: ...
  229. @overload
  230. def script(
  231. obj,
  232. optimize: Incomplete | None = ...,
  233. _frames_up: int = ...,
  234. _rcb: Incomplete | None = ...,
  235. example_inputs: list[tuple] | dict[Callable, list[tuple]] | None = ...,
  236. ): ...
  237. def _check_overload_defaults(impl_defaults, overload_defaults, loc) -> None: ...
  238. def _compile_function_with_overload(overload_fn, qual_name, impl_fn): ...
  239. def _get_overloads(obj): ...
  240. def _check_directly_compile_overloaded(obj) -> None: ...
  241. def interface(obj: _T) -> _T: ...
  242. def _recursive_compile_class(obj, loc): ...
  243. CompilationUnit: Incomplete
  244. def pad(s: str, padding: int, offset: int = ..., char: str = ...): ...
  245. class _ScriptProfileColumn:
  246. header: Incomplete
  247. alignment: Incomplete
  248. offset: Incomplete
  249. rows: Incomplete
  250. def __init__(
  251. self,
  252. header: str,
  253. alignment: int = ...,
  254. offset: int = ...,
  255. ) -> None: ...
  256. def add_row(self, lineno: int, value: Any): ...
  257. def materialize(self): ...
  258. class _ScriptProfileTable:
  259. cols: Incomplete
  260. source_range: Incomplete
  261. def __init__(
  262. self,
  263. cols: list[_ScriptProfileColumn],
  264. source_range: list[int],
  265. ) -> None: ...
  266. def dump_string(self): ...
  267. class _ScriptProfile:
  268. profile: Incomplete
  269. def __init__(self) -> None: ...
  270. def enable(self) -> None: ...
  271. def disable(self) -> None: ...
  272. def dump_string(self) -> str: ...
  273. def dump(self) -> None: ...
  274. def _unwrap_optional(x): ...