_trace.py 55 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442
  1. # mypy: allow-untyped-defs
  2. """Tracing.
  3. This module contains functionality to support the JIT's tracing frontend, notably:
  4. * torch.jit.trace
  5. * torch.jit.trace_module
  6. This is not intended to be imported directly; please use the exposed
  7. functionalities in `torch.jit`.
  8. """
  9. import contextlib
  10. import copy
  11. import functools
  12. import inspect
  13. import os
  14. import re
  15. import sys
  16. import warnings
  17. from collections.abc import Callable
  18. from enum import Enum
  19. from typing import Any, Optional, TypeVar
  20. from typing_extensions import ParamSpec
  21. import torch
  22. from torch._jit_internal import (
  23. _get_model_id,
  24. _qualified_name,
  25. get_callable_argument_names,
  26. is_scripting,
  27. )
  28. from torch.autograd import function
  29. from torch.jit._script import _CachedForward, script, ScriptModule
  30. from torch.jit._state import _enabled, _python_cu
  31. from torch.nn import Module
  32. from torch.testing._comparison import default_tolerances
  33. _flatten = torch._C._jit_flatten
  34. _unflatten = torch._C._jit_unflatten
  35. R = TypeVar("R", covariant=True) # return type (always covariant)
  36. P = ParamSpec("P")
  37. def _create_interpreter_name_lookup_fn(frames_up=1):
  38. def _get_interpreter_name_for_var(var):
  39. frame = inspect.currentframe()
  40. if not frame:
  41. raise RuntimeError("failed to inspect frame")
  42. i = 0
  43. while i < frames_up + 1:
  44. frame = frame.f_back
  45. if not frame:
  46. raise RuntimeError("failed to get frame")
  47. i += 1
  48. f_locals = frame.f_locals
  49. for k, v in f_locals.items():
  50. if isinstance(v, torch.Tensor) and var is v:
  51. return k if k != "self" else ""
  52. return ""
  53. return _get_interpreter_name_for_var
  54. def _unique_state_dict(module, keep_vars=False):
  55. # since Parameter.detach() always creates a new torch.Tensor instance,
  56. # id(v) doesn't work with it. So we always get the Parameter or Buffer
  57. # as values, and deduplicate the params using Parameters and Buffers
  58. state_dict = module.state_dict(keep_vars=True)
  59. filtered_dict = type(state_dict)()
  60. seen_ids: set[int] = set()
  61. for k, v in state_dict.items():
  62. if id(v) in seen_ids:
  63. continue
  64. seen_ids.add(id(v))
  65. if keep_vars:
  66. filtered_dict[k] = v
  67. else:
  68. filtered_dict[k] = v.detach()
  69. return filtered_dict
  70. class ONNXTracedModule(torch.nn.Module):
  71. def __init__(
  72. self,
  73. inner,
  74. strict=True,
  75. force_outplace=False,
  76. return_inputs=False,
  77. return_inputs_states=False,
  78. ):
  79. super().__init__()
  80. # inner may be a Module, or it may be an arbitrary callable
  81. # If it's a Module, we get its parameters automatically, which lets
  82. # us avoid a special casing functions versus modules.
  83. self.inner = inner
  84. self.strict = strict
  85. self._force_outplace = force_outplace
  86. self._return_inputs = return_inputs
  87. self._return_inputs_states = return_inputs_states
  88. def forward(self, *args: torch.Tensor):
  89. in_vars, in_desc = _flatten(args)
  90. # NOTE: use full state, because we need it for BatchNorm export
  91. # This differs from the compiler path, which doesn't support it at the moment.
  92. module_state = list(_unique_state_dict(self, keep_vars=True).values())
  93. ret_inputs = []
  94. inputs_states = []
  95. outs = []
  96. def wrapper(*args):
  97. in_args: list[torch.Tensor] = []
  98. for i in range(len(in_vars)):
  99. if not isinstance(args[i], torch.Tensor):
  100. raise RuntimeError("Expected Tensor argument")
  101. in_args.append(args[i])
  102. trace_inputs = _unflatten(in_args, in_desc)
  103. if self._return_inputs:
  104. ret_inputs.append(
  105. tuple(x.clone(memory_format=torch.preserve_format) for x in args)
  106. )
  107. if self._return_inputs_states:
  108. inputs_states.append(_unflatten(in_args, in_desc))
  109. outs.append(self.inner(*trace_inputs))
  110. if self._return_inputs_states:
  111. inputs_states[0] = (inputs_states[0], trace_inputs)
  112. out_vars, _ = _flatten(outs)
  113. if len(out_vars) == 1:
  114. return out_vars[0]
  115. else:
  116. return tuple(out_vars)
  117. graph, _out = torch._C._create_graph_by_tracing(
  118. wrapper,
  119. in_vars + module_state,
  120. _create_interpreter_name_lookup_fn(),
  121. self.strict,
  122. self._force_outplace,
  123. )
  124. if self._return_inputs:
  125. return graph, outs[0], ret_inputs[0]
  126. if self._return_inputs_states:
  127. return graph, outs[0], inputs_states[0]
  128. else:
  129. return graph, outs[0]
  130. def _clone_inputs(args):
  131. def clone_input(a):
  132. if a is None:
  133. return None
  134. elif isinstance(a, torch.Tensor):
  135. # TODO: figure out one liner to .clone() and set requires_grad
  136. v = (
  137. a.detach()
  138. .clone(memory_format=None if a.is_mkldnn else torch.preserve_format)
  139. .requires_grad_(a.requires_grad)
  140. )
  141. if a.grad is not None:
  142. v.grad = clone_input(v.grad)
  143. return v
  144. else:
  145. return a.clone(memory_format=torch.preserve_format)
  146. # pyrefly: ignore [missing-attribute]
  147. return function._nested_map(
  148. lambda x: isinstance(x, torch.Tensor), clone_input, condition_msg="tensors"
  149. )(args)
  150. # This is purely for developer debugging. We are not going to advertise it.
  151. _JIT_TIME = os.environ.get("PYTORCH_JIT_TIME", False) # CUDA-only timing
  152. _JIT_DISABLE = os.environ.get("PYTORCH_JIT_DISABLE", False)
  153. _JIT_STATS = os.environ.get("PYTORCH_JIT_STATS", False)
  154. @contextlib.contextmanager
  155. def _time(trace_name, name, time=True):
  156. if (not _JIT_TIME and not time) or not torch.cuda.is_available():
  157. yield
  158. return
  159. stream = torch.cuda.current_stream()
  160. start = torch.cuda.Event(enable_timing=True)
  161. end = torch.cuda.Event(enable_timing=True)
  162. stream.record_event(start)
  163. try:
  164. yield
  165. finally:
  166. stream.record_event(end)
  167. end.synchronize()
  168. print(f"{trace_name} {name} time: {start.elapsed_time(end)} ms")
  169. def verify(model, args, loss_fn=torch.sum, devices=None):
  170. """
  171. Verify that a JIT compiled model has the same behavior as its uncompiled version along with its backwards pass.
  172. If your model returns multiple outputs,
  173. you must also specify a `loss_fn` to produce a loss for which
  174. the backwards will be computed.
  175. This function has side-effects (e.g., it executes your model / saves and loads
  176. parameters), so don't expect the model to come out exactly the same as what
  177. you passed in.
  178. Args:
  179. model (compiled torch.nn.Module or function): the module/function to be
  180. verified. The module/function definition MUST have been decorated with
  181. `@torch.jit.compile`.
  182. args (tuple or Tensor): the positional arguments to pass to the
  183. compiled function/module to be verified. A non-tuple is assumed to
  184. be a single positional argument to be passed to the model.
  185. loss_fn (function, optional): the loss function to be applied to
  186. the output of the model, before backwards is invoked. By default,
  187. we assume that a model returns a single result, and we :func:`torch.sum`
  188. before calling backwards; if this is inappropriate, you can pass your
  189. own loss function. Note that if a model returns a tuple of results,
  190. these are passed as separate positional arguments to `loss_fn`.
  191. devices (iterable of device IDs, optional): the GPU devices which the
  192. compiled module will be run on. This determines the RNG state we
  193. must save when running both compiled and uncompiled versions of the model.
  194. """
  195. # TODO: In principle, we track device information in our trace, so it
  196. # should be possible to check if our execution actually obeyed the 'devices'
  197. # the user provided.
  198. # TODO: Consider adding a utility function to torch.jit to test
  199. # for this case
  200. if not isinstance(model, torch._C.CompiledFunction): # type: ignore[attr-defined]
  201. raise TypeError(
  202. "Cannot verify an uncompiled module. Add @torch.jit.compile to compile it"
  203. )
  204. is_module = isinstance(model, Module)
  205. if not isinstance(args, tuple):
  206. args = (args,)
  207. if is_module:
  208. saved_state = copy.deepcopy(model.state_dict())
  209. def run_fwd_bwd(args, force_trace=False, assert_compiled=False):
  210. params = list(model.parameters()) if is_module else []
  211. in_vars, _ = _flatten((args, params))
  212. # We use a special API to reset the trace and compile it from scratch.
  213. compiled_fn = model
  214. if force_trace:
  215. compiled_fn.clear_cache()
  216. if assert_compiled:
  217. hits = compiled_fn.hits
  218. out = model(*args)
  219. if assert_compiled and compiled_fn.hits == hits: # type: ignore[possibly-undefined]
  220. raise RuntimeError("failed to use the compiled function")
  221. if not isinstance(out, tuple):
  222. out = (out,)
  223. if loss_fn == torch.sum and len(out) != 1:
  224. raise ValueError(
  225. f"Model returns {len(out)} outputs, but default loss function "
  226. "(torch.sum) can only handle a single output"
  227. )
  228. out_vars, _ = _flatten(out)
  229. saved_outs = [
  230. v.detach().clone(memory_format=torch.preserve_format) for v in out_vars
  231. ]
  232. loss = loss_fn(*out)
  233. grads = torch.autograd.grad([loss], in_vars)
  234. # TODO: I'm not sure if the clone here is necessary but it is safer
  235. saved_grads = [
  236. v.detach().clone(memory_format=torch.preserve_format) for v in grads
  237. ]
  238. return (saved_outs, saved_grads)
  239. with torch.random.fork_rng(devices, _caller="torch.jit.verify"):
  240. uncompiled_outs, uncompiled_grads = run_fwd_bwd(args, force_trace=True)
  241. if not model.has_trace_for(*args):
  242. raise AssertionError("Model should have trace for the given args")
  243. if is_module:
  244. model.load_state_dict(saved_state) # type: ignore[possibly-undefined]
  245. compiled_outs, compiled_grads = run_fwd_bwd(args, assert_compiled=True)
  246. _verify_equal(uncompiled_outs, compiled_outs)
  247. _verify_equal(uncompiled_grads, compiled_grads)
  248. def _verify_equal(xs, ys):
  249. for x, y in zip(xs, ys):
  250. if x.sub(y).abs().max() > 1e-6:
  251. raise RuntimeError("JIT and real computation mismatch")
  252. def indent(s):
  253. return "\n".join(["\t" + line for line in s.splitlines()])
  254. class TracingCheckError(Exception):
  255. def __init__(self, graph_diff_error, tensor_compare_error, extra_msg=None):
  256. self.message = "Tracing failed sanity checks!\n"
  257. if extra_msg is not None:
  258. self.message += extra_msg + "\n"
  259. if graph_diff_error is not None:
  260. self.message += "ERROR: Graphs differed across invocations!\n"
  261. self.message += indent(graph_diff_error) + "\n"
  262. if tensor_compare_error is not None:
  263. self.message += (
  264. "ERROR: Tensor-valued Constant nodes differed in value "
  265. "across invocations. This often indicates that the tracer has"
  266. " encountered untraceable code.\n"
  267. )
  268. self.message += indent(tensor_compare_error) + "\n"
  269. super().__init__(self.message)
  270. # Check the traced module against a set of user-provided validation inputs
  271. @torch.no_grad()
  272. def _check_trace(
  273. check_inputs,
  274. func,
  275. traced_func,
  276. check_tolerance,
  277. strict,
  278. force_outplace,
  279. is_trace_module,
  280. _module_class,
  281. example_inputs_is_kwarg=False,
  282. ):
  283. # Note: tracing is independent of optimizations, which consume the trace
  284. for inputs in check_inputs:
  285. if isinstance(inputs, torch.Tensor):
  286. inputs = (inputs,)
  287. if is_trace_module:
  288. copied_dict = {}
  289. for name, data in inputs.items():
  290. copied_dict[name] = _clone_inputs(data)
  291. check_mod = torch.jit.trace_module(
  292. getattr(func, "__self__", func),
  293. copied_dict,
  294. check_trace=False,
  295. strict=strict,
  296. _force_outplace=force_outplace,
  297. _module_class=_module_class,
  298. _compilation_unit=torch._C.CompilationUnit(),
  299. example_inputs_is_kwarg=example_inputs_is_kwarg,
  300. _store_inputs=False,
  301. )
  302. check_mod_func = check_mod._c._get_method(traced_func.name)
  303. inputs = inputs[traced_func.name]
  304. if (
  305. isinstance(inputs, (torch.Tensor))
  306. or isinstance(inputs, dict)
  307. and not example_inputs_is_kwarg
  308. ):
  309. inputs = (inputs,)
  310. else:
  311. if example_inputs_is_kwarg:
  312. check_mod = torch.jit.trace(
  313. func,
  314. check_trace=False,
  315. strict=strict,
  316. _force_outplace=force_outplace,
  317. _module_class=_module_class,
  318. example_kwarg_inputs=_clone_inputs(inputs),
  319. _store_inputs=False,
  320. )
  321. else:
  322. check_mod = torch.jit.trace(
  323. func,
  324. _clone_inputs(inputs),
  325. check_trace=False,
  326. strict=strict,
  327. _force_outplace=force_outplace,
  328. _module_class=_module_class,
  329. _store_inputs=False,
  330. )
  331. check_mod_func = check_mod
  332. def graph_diagnostic_info():
  333. mod_canonicalized = torch._C._jit_pass_canonicalize(traced_func.graph)
  334. torch._C._jit_pass_inline(mod_canonicalized)
  335. torch._C._jit_pass_erase_shape_information(mod_canonicalized)
  336. mod_str = str(mod_canonicalized)
  337. mod_str = re.sub(r"___torch_mangle_[0-9]+\.", "", mod_str)
  338. check_canonicalized = torch._C._jit_pass_canonicalize(check_mod_func.graph)
  339. torch._C._jit_pass_inline(check_canonicalized)
  340. torch._C._jit_pass_erase_shape_information(check_canonicalized)
  341. check_str = str(check_canonicalized)
  342. check_str = re.sub(r"___torch_mangle_[0-9]+\.", "", check_str)
  343. graph_diff_errors = None
  344. if mod_str != check_str:
  345. import difflib
  346. graph_diff = difflib.ndiff(
  347. mod_str.splitlines(True), check_str.splitlines(True)
  348. )
  349. graph_diff_errors = "Graph diff:\n" + indent("".join(graph_diff)) + "\n"
  350. for n_mod, n_check in zip(
  351. mod_canonicalized.nodes(), check_canonicalized.nodes()
  352. ):
  353. if str(n_mod) != str(n_check):
  354. graph_diff_errors += "First diverging operator:\n"
  355. node_diff = difflib.ndiff(
  356. str(n_mod).splitlines(True), str(n_check).splitlines(True)
  357. )
  358. source_printout = (
  359. "Node diff:\n" + indent("".join(node_diff)) + "\n"
  360. )
  361. mod_stack = n_mod.sourceRange()
  362. if mod_stack:
  363. source_printout += (
  364. "Trace source location:\n" + indent(mod_stack) + "\n"
  365. )
  366. check_stack = n_check.sourceRange()
  367. if check_stack:
  368. source_printout += (
  369. "Check source location:\n" + indent(check_stack) + "\n"
  370. )
  371. graph_diff_errors += source_printout
  372. break # For now, only print out the first pair of nodes that diverges
  373. tensor_compare_errors = None
  374. # Check Tensor-valued constant nodes
  375. for n_mod, n_check in zip(
  376. mod_canonicalized.nodes(), check_canonicalized.nodes()
  377. ):
  378. if n_mod.kind() != n_check.kind():
  379. break # Graphs have already diverged
  380. if n_mod.kind() == "prim::Constant" and not (
  381. n_mod.mustBeNone() or n_check.mustBeNone()
  382. ):
  383. if not n_mod.hasAttribute("value"):
  384. continue
  385. if n_mod.kindOf("value") != "t" or n_check.kindOf("value") != "t":
  386. continue
  387. mod_tensor_val = n_mod.t("value")
  388. check_tensor_val = n_check.t("value")
  389. try:
  390. torch.testing.assert_close(
  391. mod_tensor_val, check_tensor_val, equal_nan=True
  392. )
  393. except (RuntimeError, AssertionError) as e:
  394. if tensor_compare_errors is None:
  395. tensor_compare_errors = ""
  396. tensor_compare_errors += "Node:\n" + indent(str(n_mod)) + "\n"
  397. compare_stack = n_mod.sourceRange()
  398. if compare_stack:
  399. tensor_compare_errors += (
  400. "Source Location:\n" + indent(compare_stack) + "\n"
  401. )
  402. tensor_compare_errors += "Comparison exception: " + indent(
  403. str(e)
  404. )
  405. break # For now, only print the first diverging pair
  406. return graph_diff_errors, tensor_compare_errors
  407. def wrap_retval(x):
  408. return x if isinstance(x, tuple) else (x,)
  409. def run_mod_and_filter_tensor_outputs(mod, inputs, running_what):
  410. try:
  411. if isinstance(inputs, dict) and example_inputs_is_kwarg:
  412. outs = wrap_retval(mod(**inputs))
  413. else:
  414. outs = wrap_retval(mod(*_clone_inputs(inputs)))
  415. outs = [out for out in outs if isinstance(out, torch.Tensor)]
  416. return outs
  417. except Exception as e:
  418. graph_diff_errors, tensor_compare_errors = graph_diagnostic_info()
  419. msg = f"encountered an exception while running the {running_what} with test inputs.\nException:\n{indent(str(e))}"
  420. raise TracingCheckError(
  421. graph_diff_errors,
  422. tensor_compare_errors,
  423. extra_msg=msg,
  424. ) from e
  425. has_warned = [False]
  426. def maybe_warn_nondeterministic():
  427. if has_warned[0]:
  428. return
  429. has_warned[0] = True
  430. nondeterm_ops = [
  431. op for op in traced_func.graph.nodes() if op.isNondeterministic()
  432. ]
  433. if len(nondeterm_ops) > 0:
  434. nondeterministic_ops_warning = "Trace had nondeterministic nodes. "
  435. nondeterministic_ops_warning += (
  436. "Did you forget call .eval() on your model? Nodes:\n"
  437. )
  438. nondeterministic_ops_warning += "\n".join(
  439. [indent(str(op)) for op in nondeterm_ops][:20]
  440. )
  441. nondeterministic_ops_warning += (
  442. "\nThis may cause errors in trace checking. To disable trace checking,"
  443. " pass check_trace=False to torch.jit.trace()"
  444. )
  445. warnings.warn(
  446. nondeterministic_ops_warning, category=TracerWarning, stacklevel=5
  447. )
  448. def compare_outputs(original, reference, match_what):
  449. all_ok = True
  450. for i, (orig, ref) in enumerate(zip(original, reference)):
  451. try:
  452. if orig.is_quantized:
  453. orig = orig.dequantize()
  454. if ref.is_quantized:
  455. ref = ref.dequantize()
  456. if orig.is_mkldnn:
  457. orig = orig.to_dense()
  458. if ref.is_mkldnn:
  459. ref = ref.to_dense()
  460. if ref.is_complex() or orig.is_complex():
  461. torch.testing.assert_close(
  462. orig.to(torch.cdouble),
  463. ref.to(torch.cdouble),
  464. rtol=check_tolerance,
  465. atol=default_tolerances(orig, ref)[1],
  466. equal_nan=True,
  467. )
  468. else:
  469. if orig.is_mps or ref.is_mps:
  470. torch.testing.assert_close(
  471. orig.float(),
  472. ref.float(),
  473. rtol=check_tolerance,
  474. atol=default_tolerances(orig, ref)[1],
  475. equal_nan=True,
  476. )
  477. elif getattr(orig, "is_nested", None) or getattr(
  478. ref, "is_nested", None
  479. ):
  480. if getattr(orig, "is_nested", None) != getattr(
  481. ref, "is_nested", None
  482. ):
  483. raise AssertionError(
  484. f"Nested tensor mismatch: orig.is_nested="
  485. f"{getattr(orig, 'is_nested', None)}, "
  486. f"ref.is_nested={getattr(ref, 'is_nested', None)}"
  487. )
  488. for t_orig, t_ref in zip(orig.unbind(), ref.unbind()):
  489. torch.testing.assert_close(
  490. t_orig.double(),
  491. t_ref.double(),
  492. rtol=check_tolerance,
  493. atol=default_tolerances(t_orig, t_ref)[1],
  494. equal_nan=True,
  495. )
  496. else:
  497. torch.testing.assert_close(
  498. orig.double(),
  499. ref.double(),
  500. rtol=check_tolerance,
  501. atol=default_tolerances(orig, ref)[1],
  502. equal_nan=True,
  503. )
  504. except AssertionError as e:
  505. maybe_warn_nondeterministic()
  506. warnings.warn(
  507. "Output nr "
  508. + str(i + 1)
  509. + ". of the traced function does not match "
  510. "the corresponding output of the "
  511. + match_what
  512. + ". Detailed error:\n"
  513. + str(e),
  514. category=TracerWarning,
  515. stacklevel=4,
  516. )
  517. all_ok = False
  518. return all_ok
  519. traced_outs = run_mod_and_filter_tensor_outputs(traced_func, inputs, "trace")
  520. fn_outs = run_mod_and_filter_tensor_outputs(func, inputs, "Python function")
  521. if compare_outputs(traced_outs, fn_outs, "Python function"):
  522. check_outs = run_mod_and_filter_tensor_outputs(
  523. check_mod_func, inputs, "repeated trace"
  524. )
  525. compare_outputs(traced_outs, check_outs, "repeated trace")
  526. diag_info = graph_diagnostic_info()
  527. if any(info is not None for info in diag_info):
  528. raise TracingCheckError(*diag_info)
  529. class TracerWarning(Warning):
  530. @staticmethod
  531. def ignore_lib_warnings():
  532. # We ignore warnings from all submodules excluding the JIT, because we need them e.g. for _check_trace
  533. warnings.filterwarnings(
  534. "ignore", category=TracerWarning, module="torch.(?!jit)"
  535. )
  536. warnings.filterwarnings("ignore", "torch::jit::fuser::cuda")
  537. # We ignore the tracer warnings coming form inside the library, because all our shape
  538. # checks in nn will trigger them.
  539. TracerWarning.ignore_lib_warnings()
  540. torch._C._tracer_warn_use_python()
  541. def make_tuple(example_inputs):
  542. if isinstance(example_inputs, (torch.Tensor, dict)):
  543. return (example_inputs,)
  544. # done primarily so that weird iterables fail here and not pybind11 code
  545. if not isinstance(example_inputs, tuple):
  546. return tuple(example_inputs)
  547. return example_inputs
  548. def make_module(mod, _module_class, _compilation_unit):
  549. if isinstance(mod, ScriptModule):
  550. return mod
  551. elif torch._jit_internal.module_has_exports(mod):
  552. infer_methods_stubs_fn = torch.jit._recursive.make_stubs_from_exported_methods
  553. return torch.jit._recursive.create_script_module(
  554. mod, infer_methods_stubs_fn, share_types=False, is_tracing=True
  555. )
  556. else:
  557. if _module_class is None:
  558. _module_class = TopLevelTracedModule
  559. return _module_class(mod, _compilation_unit=_compilation_unit)
  560. def wrap_check_inputs(check_inputs):
  561. if check_inputs is None:
  562. return None
  563. return [{"forward": c} for c in check_inputs]
  564. def analyze_ts_result_with_export_result(export, trace):
  565. import torch.utils._pytree as pytree
  566. flat_export = pytree.tree_leaves(export)
  567. flat_trace = pytree.tree_leaves(trace)
  568. for orig, loaded in zip(flat_export, flat_trace):
  569. if orig.layout != loaded.layout:
  570. return False
  571. # mkldnn is not supported for torch.allclose
  572. if orig.layout == torch._mkldnn: # type: ignore[attr-defined]
  573. return True
  574. if type(orig) is not type(loaded):
  575. return False
  576. if isinstance(orig, torch._subclasses.FakeTensor):
  577. # Skip for FakeTensor.
  578. return True
  579. elif isinstance(orig, torch.Tensor):
  580. if orig.dtype != loaded.dtype:
  581. return False
  582. if not torch.allclose(orig, loaded):
  583. return False
  584. else:
  585. if orig != loaded:
  586. return False
  587. return True
  588. def _trace_impl(
  589. func,
  590. example_inputs=None,
  591. optimize=None,
  592. check_trace=True,
  593. check_inputs=None,
  594. check_tolerance=1e-5,
  595. strict=True,
  596. _force_outplace=False,
  597. _module_class=None,
  598. _compilation_unit=_python_cu,
  599. example_kwarg_inputs=None,
  600. _store_inputs=True,
  601. ):
  602. if isinstance(func, torch.jit.ScriptModule):
  603. # it is hard to trace it because the forward method on ScriptModule is already defined, so it
  604. # would result in an error.
  605. warnings.warn(
  606. "The input to trace is already a ScriptModule, tracing it is a no-op. Returning the object as is.",
  607. stacklevel=2,
  608. )
  609. return func
  610. if isinstance(func, torch.nn.Module):
  611. if example_inputs is None:
  612. if isinstance(example_kwarg_inputs, dict):
  613. example_inputs = example_kwarg_inputs
  614. else:
  615. raise RuntimeError("example_kwarg_inputs should be a dict")
  616. return trace_module(
  617. func,
  618. {"forward": example_inputs},
  619. None,
  620. check_trace,
  621. wrap_check_inputs(check_inputs),
  622. check_tolerance,
  623. strict,
  624. _force_outplace,
  625. _module_class,
  626. example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict),
  627. _store_inputs=_store_inputs,
  628. )
  629. if (
  630. hasattr(func, "__self__")
  631. and isinstance(func.__self__, torch.nn.Module)
  632. and func.__name__ == "forward"
  633. ):
  634. if example_inputs is None:
  635. if isinstance(example_kwarg_inputs, dict):
  636. example_inputs = example_kwarg_inputs
  637. else:
  638. raise RuntimeError("example_kwarg_inputs should be a dict")
  639. return trace_module(
  640. func.__self__,
  641. {"forward": example_inputs},
  642. None,
  643. check_trace,
  644. wrap_check_inputs(check_inputs),
  645. check_tolerance,
  646. strict,
  647. _force_outplace,
  648. _module_class,
  649. example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict),
  650. _store_inputs=_store_inputs,
  651. )
  652. # Special case for common case of passing a single Tensor
  653. if (
  654. isinstance(example_inputs, (torch.Tensor, dict))
  655. and example_kwarg_inputs is None
  656. ):
  657. example_inputs = (example_inputs,)
  658. # done primarily so that weird iterables fail here and not pybind11 code
  659. elif example_kwarg_inputs is None and not isinstance(example_inputs, tuple):
  660. # pyrefly: ignore [bad-argument-type]
  661. example_inputs = tuple(example_inputs)
  662. var_lookup_fn = _create_interpreter_name_lookup_fn(0)
  663. if hasattr(func, "__self__") and isinstance(func.__self__, torch.nn.Module):
  664. raise AttributeError(
  665. "trace doesn't support compiling individual module's functions.\n"
  666. "Please use trace_module"
  667. )
  668. name = _qualified_name(func)
  669. if isinstance(example_kwarg_inputs, dict):
  670. example_inputs = example_kwarg_inputs
  671. traced = torch._C._create_function_from_trace_with_dict(
  672. name,
  673. func,
  674. example_kwarg_inputs,
  675. var_lookup_fn,
  676. strict,
  677. _force_outplace,
  678. get_callable_argument_names(func),
  679. )
  680. else:
  681. traced = torch._C._create_function_from_trace(
  682. name,
  683. func,
  684. # pyrefly: ignore [bad-argument-type]
  685. example_inputs,
  686. var_lookup_fn,
  687. strict,
  688. _force_outplace,
  689. get_callable_argument_names(func),
  690. )
  691. # Check the trace against new traces created from user-specified inputs
  692. if check_trace:
  693. if check_inputs is not None:
  694. _check_trace(
  695. check_inputs,
  696. func,
  697. traced,
  698. check_tolerance,
  699. strict,
  700. _force_outplace,
  701. False,
  702. _module_class,
  703. example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict),
  704. )
  705. else:
  706. _check_trace(
  707. [example_inputs],
  708. func,
  709. traced,
  710. check_tolerance,
  711. strict,
  712. _force_outplace,
  713. False,
  714. _module_class,
  715. example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict),
  716. )
  717. # Allow torch.compile() to inline
  718. traced._torchdynamo_inline = func # type: ignore[attr-defined]
  719. return traced
  720. class _ExportType(str, Enum):
  721. DIRECT_EXPORT = "DIRECT_EXPORT"
  722. TRACE_AND_EXPORT = "TRACE_AND_EXPORT"
  723. SOURCE_TO_SOURCE = "SOURCE_TO_SOURCE"
  724. def __str__(self) -> str:
  725. return self.value
  726. class _ExportOutcome(str, Enum):
  727. SUCCESS = "SUCCESS"
  728. FAILED_TO_EXPORT = "FAILED_TO_EXPORT"
  729. FAILED_TO_RUN = "FAILED_TO_RUN"
  730. ACCURACY_ERROR = "ACCURACY_ERROR"
  731. def __str__(self) -> str:
  732. return self.value
  733. def trace(
  734. func,
  735. example_inputs=None,
  736. optimize=None,
  737. check_trace=True,
  738. check_inputs=None,
  739. check_tolerance=1e-5,
  740. strict=True,
  741. _force_outplace=False,
  742. _module_class=None,
  743. _compilation_unit=_python_cu,
  744. example_kwarg_inputs=None,
  745. _store_inputs=True,
  746. ):
  747. r"""
  748. Trace a function and return an executable or :class:`ScriptFunction` that will be optimized using just-in-time compilation.
  749. Tracing is ideal for code that operates only on
  750. ``Tensor``\\s and lists, dictionaries, and
  751. tuples of ``Tensor``\\s.
  752. Using `torch.jit.trace` and `torch.jit.trace_module`, you can turn an
  753. existing module or Python function into a TorchScript
  754. :class:`ScriptFunction` or :class:`ScriptModule`. You must provide example
  755. inputs, and we run the function, recording the operations performed on all
  756. the tensors.
  757. * The resulting recording of a standalone function produces `ScriptFunction`.
  758. * The resulting recording of `nn.Module.forward` or `nn.Module` produces
  759. `ScriptModule`.
  760. This module also contains any parameters that the original
  761. module had as well.
  762. Warning:
  763. Tracing only correctly records functions and modules which are not data
  764. dependent (e.g., do not have conditionals on data in tensors) and do not have
  765. any untracked external dependencies (e.g., perform input/output or
  766. access global variables). Tracing only records operations done when the given
  767. function is run on the given tensors. Therefore, the returned
  768. `ScriptModule` will always run the same traced graph on any input. This
  769. has some important implications when your module is expected to run
  770. different sets of operations, depending on the input and/or the module
  771. state. For example,
  772. * Tracing will not record any control-flow like if-statements or loops.
  773. When this control-flow is constant across your module, this is fine
  774. and it often inlines the control-flow decisions. But sometimes the
  775. control-flow is actually part of the model itself. For instance, a
  776. recurrent network is a loop over the (possibly dynamic) length of an
  777. input sequence.
  778. * In the returned :class:`ScriptModule`, operations that have different
  779. behaviors in ``training`` and ``eval`` modes will always behave as if
  780. it is in the mode it was in during tracing, no matter which mode the
  781. `ScriptModule` is in.
  782. In cases like these, tracing would not be appropriate and
  783. :func:`scripting <torch.jit.script>` is a better choice. If you trace
  784. such models, you may silently get incorrect results on subsequent
  785. invocations of the model. The tracer will try to emit warnings when
  786. doing something that may cause an incorrect trace to be produced.
  787. Args:
  788. func (callable or torch.nn.Module): A Python function or `torch.nn.Module`
  789. that will be run with `example_inputs`. `func` arguments and return
  790. values must be tensors or (possibly nested) tuples that contain
  791. tensors. When a module is passed `torch.jit.trace`, only the
  792. ``forward`` method is run and traced (see :func:`torch.jit.trace
  793. <torch.jit.trace_module>` for details).
  794. Keyword arguments:
  795. example_inputs (tuple or torch.Tensor or None, optional): A tuple of example
  796. inputs that will be passed to the function while tracing.
  797. Default: ``None``. Either this argument or ``example_kwarg_inputs``
  798. should be specified. The resulting trace can be run with inputs of
  799. different types and shapes assuming the traced operations support those
  800. types and shapes. `example_inputs` may also be a single Tensor in which
  801. case it is automatically wrapped in a tuple. When the value is None,
  802. ``example_kwarg_inputs`` should be specified.
  803. check_trace (``bool``, optional): Check if the same inputs run through
  804. traced code produce the same outputs. Default: ``True``. You might want
  805. to disable this if, for example, your network contains non-
  806. deterministic ops or if you are sure that the network is correct despite
  807. a checker failure.
  808. check_inputs (list of tuples, optional): A list of tuples of input
  809. arguments that should be used to check the trace against what is
  810. expected. Each tuple is equivalent to a set of input arguments that
  811. would be specified in ``example_inputs``. For best results, pass in
  812. a set of checking inputs representative of the space of shapes and
  813. types of inputs you expect the network to see. If not specified,
  814. the original ``example_inputs`` are used for checking
  815. check_tolerance (float, optional): Floating-point comparison tolerance
  816. to use in the checker procedure. This can be used to relax the
  817. checker strictness in the event that results diverge numerically
  818. for a known reason, such as operator fusion.
  819. strict (``bool``, optional): run the tracer in a strict mode or not
  820. (default: ``True``). Only turn this off when you want the tracer to
  821. record your mutable container types (currently ``list``/``dict``)
  822. and you are sure that the container you are using in your
  823. problem is a ``constant`` structure and does not get used as
  824. control flow (if, for) conditions.
  825. example_kwarg_inputs (dict, optional): This parameter is a pack of keyword
  826. arguments of example inputs that will be passed to the function while
  827. tracing. Default: ``None``. Either this argument or ``example_inputs``
  828. should be specified. The dict will be unpacking by the arguments name
  829. of the traced function. If the keys of the dict don't not match with
  830. the traced function's arguments name, a runtime exception will be raised.
  831. Returns:
  832. If `func` is `nn.Module` or ``forward`` of `nn.Module`, `trace` returns
  833. a :class:`ScriptModule` object with a single ``forward`` method
  834. containing the traced code. The returned `ScriptModule` will
  835. have the same set of sub-modules and parameters as the original
  836. ``nn.Module``. If ``func`` is a standalone function, ``trace``
  837. returns `ScriptFunction`.
  838. Example (tracing a function):
  839. .. testcode::
  840. import torch
  841. def foo(x, y):
  842. return 2 * x + y
  843. # Run `foo` with the provided inputs and record the tensor operations
  844. traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))
  845. # `traced_foo` can now be run with the TorchScript interpreter or saved
  846. # and loaded in a Python-free environment
  847. Example (tracing an existing module)::
  848. import torch
  849. import torch.nn as nn
  850. class Net(nn.Module):
  851. def __init__(self) -> None:
  852. super().__init__()
  853. self.conv = nn.Conv2d(1, 1, 3)
  854. def forward(self, x):
  855. return self.conv(x)
  856. n = Net()
  857. example_weight = torch.rand(1, 1, 3, 3)
  858. example_forward_input = torch.rand(1, 1, 3, 3)
  859. # Trace a specific method and construct `ScriptModule` with
  860. # a single `forward` method
  861. module = torch.jit.trace(n.forward, example_forward_input)
  862. # Trace a module (implicitly traces `forward`) and construct a
  863. # `ScriptModule` with a single `forward` method
  864. module = torch.jit.trace(n, example_forward_input)
  865. """
  866. if sys.version_info >= (3, 14):
  867. warnings.warn(
  868. "`torch.jit.trace` is not supported in Python 3.14+ and may break. "
  869. "Please switch to `torch.compile` or `torch.export`.",
  870. DeprecationWarning,
  871. )
  872. else:
  873. warnings.warn(
  874. "`torch.jit.trace` is deprecated. Please switch to `torch.compile` or `torch.export`.",
  875. DeprecationWarning,
  876. )
  877. if not _enabled:
  878. return func
  879. if optimize is not None:
  880. warnings.warn(
  881. "`optimize` is deprecated and has no effect. "
  882. "Use `with torch.jit.optimized_execution()` instead",
  883. FutureWarning,
  884. stacklevel=2,
  885. )
  886. from torch._utils_internal import log_torchscript_usage
  887. traced_func = _trace_impl(
  888. func,
  889. example_inputs,
  890. optimize,
  891. check_trace,
  892. check_inputs,
  893. check_tolerance,
  894. strict,
  895. _force_outplace,
  896. _module_class,
  897. _compilation_unit,
  898. example_kwarg_inputs,
  899. _store_inputs,
  900. )
  901. log_torchscript_usage("trace", model_id=_get_model_id(traced_func))
  902. return traced_func
  903. _trace_module_map: Optional[dict[Any, Any]] = None
  904. def trace_module(
  905. mod,
  906. inputs,
  907. optimize=None,
  908. check_trace=True,
  909. check_inputs=None,
  910. check_tolerance=1e-5,
  911. strict=True,
  912. _force_outplace=False,
  913. _module_class=None,
  914. _compilation_unit=_python_cu,
  915. example_inputs_is_kwarg=False,
  916. _store_inputs=True,
  917. ):
  918. """
  919. Trace a module and return an executable :class:`ScriptModule` that will be optimized using just-in-time compilation.
  920. When a module is passed to :func:`torch.jit.trace <torch.jit.trace>`, only
  921. the ``forward`` method is run and traced. With ``trace_module``, you can specify a dictionary of
  922. method names to example inputs to trace (see the ``inputs``) argument below.
  923. See :func:`torch.jit.trace <torch.jit.trace>` for more information on tracing.
  924. Args:
  925. mod (torch.nn.Module): A ``torch.nn.Module`` containing methods whose names are
  926. specified in ``inputs``. The given methods will be compiled
  927. as a part of a single `ScriptModule`.
  928. inputs (dict): A dict containing sample inputs indexed by method names in ``mod``.
  929. The inputs will be passed to methods whose names correspond to inputs'
  930. keys while tracing.
  931. ``{ 'forward' : example_forward_input, 'method2': example_method2_input}``
  932. Keyword arguments:
  933. check_trace (``bool``, optional): Check if the same inputs run through
  934. traced code produce the same outputs. Default: ``True``. You might want
  935. to disable this if, for example, your network contains non-
  936. deterministic ops or if you are sure that the network is correct despite
  937. a checker failure.
  938. check_inputs (list of dicts, optional): A list of dicts of input arguments that should be used
  939. to check the trace against what is expected. Each tuple
  940. is equivalent to a set of input arguments that would
  941. be specified in ``inputs``. For best results, pass in a
  942. set of checking inputs representative of the space of
  943. shapes and types of inputs you expect the network to see.
  944. If not specified, the original ``inputs`` are used for checking
  945. check_tolerance (float, optional): Floating-point comparison tolerance to use in the checker procedure.
  946. This can be used to relax the checker strictness in the event that
  947. results diverge numerically for a known reason, such as operator fusion.
  948. example_inputs_is_kwarg (``bool``, optional): This parameter indicate whether the example inputs is a pack
  949. pack of keyword arguments. Default: ``False``.
  950. Returns:
  951. A :class:`ScriptModule` object with a single ``forward`` method containing the traced code.
  952. When ``func`` is a ``torch.nn.Module``, the returned :class:`ScriptModule` will have the same set of
  953. sub-modules and parameters as ``func``.
  954. Example (tracing a module with multiple methods)::
  955. import torch
  956. import torch.nn as nn
  957. class Net(nn.Module):
  958. def __init__(self) -> None:
  959. super().__init__()
  960. self.conv = nn.Conv2d(1, 1, 3)
  961. def forward(self, x):
  962. return self.conv(x)
  963. def weighted_kernel_sum(self, weight):
  964. return weight * self.conv.weight
  965. n = Net()
  966. example_weight = torch.rand(1, 1, 3, 3)
  967. example_forward_input = torch.rand(1, 1, 3, 3)
  968. # Trace a specific method and construct `ScriptModule` with
  969. # a single `forward` method
  970. module = torch.jit.trace(n.forward, example_forward_input)
  971. # Trace a module (implicitly traces `forward`) and construct a
  972. # `ScriptModule` with a single `forward` method
  973. module = torch.jit.trace(n, example_forward_input)
  974. # Trace specific methods on a module (specified in `inputs`), constructs
  975. # a `ScriptModule` with `forward` and `weighted_kernel_sum` methods
  976. inputs = {
  977. "forward": example_forward_input,
  978. "weighted_kernel_sum": example_weight,
  979. }
  980. module = torch.jit.trace_module(n, inputs)
  981. """
  982. if sys.version_info >= (3, 14):
  983. warnings.warn(
  984. "`torch.jit.trace_method` is not supported in Python 3.14+ and may break. "
  985. "Please switch to `torch.compile` or `torch.export`.",
  986. DeprecationWarning,
  987. )
  988. else:
  989. warnings.warn(
  990. "`torch.jit.trace_method` is deprecated. Please switch to `torch.compile` or `torch.export`.",
  991. DeprecationWarning,
  992. )
  993. if not _enabled:
  994. return mod
  995. if optimize is not None:
  996. warnings.warn(
  997. "`optimize` is deprecated and has no effect. "
  998. "Use `with torch.jit.optimized_execution()` instead",
  999. FutureWarning,
  1000. stacklevel=2,
  1001. )
  1002. var_lookup_fn = _create_interpreter_name_lookup_fn(0)
  1003. if not isinstance(mod, torch.nn.Module):
  1004. raise AttributeError("expected torch.nn.Module as the first argument")
  1005. if not isinstance(inputs, dict):
  1006. raise AttributeError("expected a dictionary of (method_name, input) pairs")
  1007. old_module_map = torch.jit._trace._trace_module_map
  1008. try:
  1009. trace_module_map: dict[Any, Any] = {}
  1010. def register_submods(mod, prefix):
  1011. for name, child in mod.named_children():
  1012. submod_qualname = prefix + "." + name
  1013. trace_module_map[child] = submod_qualname
  1014. register_submods(child, submod_qualname)
  1015. trace_module_map["__module"] = mod
  1016. torch.jit._trace._trace_module_map = trace_module_map
  1017. register_submods(mod, "__module")
  1018. module = make_module(mod, _module_class, _compilation_unit)
  1019. for method_name, example_inputs in inputs.items():
  1020. if method_name == "forward":
  1021. # "forward" is a special case because we need to trace
  1022. # `Module.__call__`, which sets up some extra tracing, but uses
  1023. # argument names of the real `Module.forward` method.
  1024. func = mod
  1025. forward_method = getattr(mod, method_name)
  1026. argument_names = get_callable_argument_names(forward_method)
  1027. else:
  1028. func = getattr(mod, method_name)
  1029. argument_names = get_callable_argument_names(func)
  1030. if isinstance(example_inputs, dict) and example_inputs_is_kwarg:
  1031. # Raise exception when the user provided key names are not aligned with forward() method's arguments' name/
  1032. for key in example_inputs:
  1033. if key not in argument_names:
  1034. valid_arguments = "[" + ",".join(argument_names) + "]"
  1035. raise NameError(
  1036. f"""'{key}' is not in forward() method's arguments,
  1037. valid arguments name are {valid_arguments}"""
  1038. )
  1039. module._c._create_method_from_trace_with_dict(
  1040. method_name,
  1041. func,
  1042. example_inputs,
  1043. var_lookup_fn,
  1044. strict,
  1045. _force_outplace,
  1046. argument_names,
  1047. _store_inputs,
  1048. )
  1049. else:
  1050. example_inputs = make_tuple(example_inputs)
  1051. module._c._create_method_from_trace(
  1052. method_name,
  1053. func,
  1054. example_inputs,
  1055. var_lookup_fn,
  1056. strict,
  1057. _force_outplace,
  1058. argument_names,
  1059. _store_inputs,
  1060. )
  1061. check_trace_method = module._c._get_method(method_name)
  1062. # Check the trace against new traces created from user-specified inputs
  1063. if check_trace:
  1064. if check_inputs is not None:
  1065. _check_trace(
  1066. check_inputs,
  1067. func,
  1068. check_trace_method,
  1069. check_tolerance,
  1070. strict,
  1071. _force_outplace,
  1072. True,
  1073. _module_class,
  1074. example_inputs_is_kwarg=example_inputs_is_kwarg,
  1075. )
  1076. else:
  1077. _check_trace(
  1078. [inputs],
  1079. func,
  1080. check_trace_method,
  1081. check_tolerance,
  1082. strict,
  1083. _force_outplace,
  1084. True,
  1085. _module_class,
  1086. example_inputs_is_kwarg=example_inputs_is_kwarg,
  1087. )
  1088. finally:
  1089. torch.jit._trace._trace_module_map = old_module_map
  1090. return module
  1091. def is_tracing():
  1092. """Return a boolean value.
  1093. Returns ``True`` in tracing (if a function is called during the
  1094. tracing of code with ``torch.jit.trace``) and ``False`` otherwise.
  1095. """
  1096. if is_scripting():
  1097. return False
  1098. return torch._C._is_tracing()
  1099. class TracedModule(ScriptModule):
  1100. _disable_script_meta = True
  1101. def __init__(self, orig, id_set=None, _compilation_unit=None):
  1102. # XXX: orig can be a nn.Module or a function!
  1103. super().__init__()
  1104. if not isinstance(orig, torch.nn.Module):
  1105. raise AssertionError(f"Expected nn.Module, got {type(orig)}")
  1106. # Copy a subset of `orig` to a temporary nn.Module.
  1107. # This is a way to customize what will actually get compiled by create_script_module
  1108. id_set = set()
  1109. # This allows us to preserve the original module's qualified name by defining a new
  1110. # type with the attribute _jit_override_qualname. In torch._jit_internal._qualified_name
  1111. # we have a special case that will look up this attribute to override whatever qualname
  1112. # we would get from the python type system
  1113. class QualnameWrapper(torch.nn.Module):
  1114. pass
  1115. QualnameWrapper._jit_override_qualname = torch._jit_internal._qualified_name( # type: ignore[attr-defined]
  1116. type(orig)
  1117. )
  1118. tmp_module = QualnameWrapper()
  1119. def check_unique(param):
  1120. if param in id_set:
  1121. raise ValueError(
  1122. "TracedModules don't support parameter sharing between modules"
  1123. )
  1124. id_set.add(param)
  1125. tmp_module.training = orig.training
  1126. for name, param in orig._parameters.items():
  1127. if param is not None:
  1128. tmp_module._parameters[name] = param
  1129. check_unique(param)
  1130. for name, buf in orig._buffers.items():
  1131. if buf is not None:
  1132. tmp_module._buffers[name] = buf
  1133. check_unique(buf)
  1134. for name, val in orig.__dict__.items():
  1135. if (
  1136. torch._C._jit_is_script_object(val)
  1137. and name not in orig._parameters
  1138. and name not in orig._buffers
  1139. ):
  1140. setattr(tmp_module, name, val)
  1141. if orig._backward_hooks:
  1142. raise ValueError(
  1143. "Modules that have backward hooks assigned can't be compiled: "
  1144. + str(orig)
  1145. )
  1146. for name, submodule in orig._modules.items():
  1147. if submodule is None:
  1148. continue
  1149. tmp_module._modules[name] = make_module(
  1150. submodule, TracedModule, _compilation_unit=None
  1151. )
  1152. script_module = torch.jit._recursive.create_script_module(
  1153. tmp_module, lambda module: (), share_types=False, is_tracing=True
  1154. )
  1155. self.__dict__["_name"] = type(orig).__name__
  1156. self.__dict__["_actual_script_module"] = script_module
  1157. for name in ("_parameters", "_buffers", "_modules", "training"):
  1158. delattr(self, name)
  1159. def forward(self, *args, **kwargs):
  1160. raise RuntimeError("Trace submodules cannot be called.")
  1161. def __getattr__(self, attr):
  1162. if "_actual_script_module" not in self.__dict__:
  1163. return super().__getattr__(attr)
  1164. return getattr(self._actual_script_module, attr)
  1165. def __setattr__(self, attr, value):
  1166. if "_actual_script_module" not in self.__dict__:
  1167. return super().__setattr__(attr, value)
  1168. setattr(self._actual_script_module, attr, value)
  1169. def _get_name(self):
  1170. return self._name
  1171. def extra_repr(self):
  1172. return f"original_name={self._name}"
  1173. class TopLevelTracedModule(TracedModule):
  1174. forward: Callable[..., Any] = _CachedForward() # type: ignore[assignment]
  1175. def _reconstruct(self, cpp_module):
  1176. """
  1177. Re-construct an instance of TopLevelTracedModule using an instance of a C++ module.
  1178. Args:
  1179. cpp_module: The C++ module that this TopLevelTracedModule will be rebuilt around.
  1180. """
  1181. self.__dict__["_actual_script_module"]._reconstruct(cpp_module)
  1182. def _script_if_tracing(fn: Callable[P, R]) -> Callable[P, R]:
  1183. @functools.wraps(fn)
  1184. def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
  1185. if not is_tracing():
  1186. # Not tracing, don't do anything
  1187. return fn(*args, **kwargs)
  1188. compiled_fn: Callable[P, R] = script(wrapper.__original_fn) # type: ignore[attr-defined]
  1189. return compiled_fn(*args, **kwargs)
  1190. wrapper.__original_fn = fn # type: ignore[attr-defined]
  1191. wrapper.__script_if_tracing_wrapper = True # type: ignore[attr-defined]
  1192. return wrapper
  1193. def _get_trace_graph(
  1194. f,
  1195. args=(),
  1196. kwargs=None,
  1197. strict=True,
  1198. _force_outplace=False,
  1199. return_inputs=False,
  1200. _return_inputs_states=False,
  1201. ):
  1202. """Return a tuple on tracing a function or model.
  1203. .. warning::
  1204. This function is internal-only and should only be used by the ONNX
  1205. exporter. If you are trying to get a graph through tracing, please go
  1206. through the public API instead::
  1207. trace = torch.jit.trace(nn.LSTMCell(), (input, hidden))
  1208. trace_graph = trace.graph
  1209. Trace a function or model, returning a tuple consisting of the both the
  1210. *trace* of an execution, as well as the original return value. If return_inputs,
  1211. also returns the trace inputs as part of the tuple
  1212. Tracing is guaranteed not to change the semantics of the function/module
  1213. that is traced.
  1214. Args:
  1215. f (torch.nn.Module or function): the function or module
  1216. to be traced.
  1217. args (tuple or Tensor): the positional arguments to pass to the
  1218. function/module to be traced. A non-tuple is assumed to
  1219. be a single positional argument to be passed to the model.
  1220. kwargs (dict): the keyword arguments to pass to the function/module
  1221. to be traced.
  1222. Example (trace a cell):
  1223. .. testcode::
  1224. trace = torch.jit.trace(nn.LSTMCell(), (input, hidden))
  1225. """
  1226. if kwargs is None:
  1227. kwargs = {}
  1228. if not isinstance(args, tuple):
  1229. args = (args,)
  1230. outs = ONNXTracedModule(
  1231. f, strict, _force_outplace, return_inputs, _return_inputs_states
  1232. )(*args, **kwargs)
  1233. return outs