config.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870
  1. """
  2. Configuration module for TorchDynamo compiler and optimization settings.
  3. This module contains various configuration flags and settings that control TorchDynamo's
  4. behavior, including:
  5. - Runtime behavior flags (e.g., guard settings, specialization options)
  6. - Debugging and development options
  7. - Performance tuning parameters
  8. - Feature toggles for experimental features
  9. """
  10. import getpass
  11. import os
  12. import sys
  13. import sysconfig
  14. import tempfile
  15. from collections.abc import Callable
  16. from os.path import abspath, dirname
  17. from typing import Any, Literal, Optional, TYPE_CHECKING, Union
  18. from torch._environment import is_fbcode
  19. from torch.utils._config_module import Config, get_tristate_env, install_config_module
  20. # to configure logging for dynamo, aot, and inductor
  21. # use the following API in the torch._logging module
  22. # torch._logging.set_logs(dynamo=<level>, aot=<level>, inductor<level>)
  23. # or use the environment variable TORCH_LOGS="dynamo,aot,inductor" (use a prefix + to indicate higher verbosity)
  24. # see this design doc for more detailed info
  25. # Design doc: https://docs.google.com/document/d/1ZRfTWKa8eaPq1AxaiHrq4ASTPouzzlPiuquSBEJYwS8/edit#
  26. # the name of a file to write the logs to
  27. # [@compile_ignored: debug]
  28. log_file_name: Optional[str] = None
  29. # [@compile_ignored: debug] Verbose will print full stack traces on warnings and errors
  30. verbose = os.environ.get("TORCHDYNAMO_VERBOSE", "0") == "1"
  31. # [@compile_ignored: runtime_behaviour] verify the correctness of optimized backend
  32. verify_correctness = False
  33. # Override backend for specific graphs (for debugging/bisecting).
  34. # Format: "filter1:backend1;filter2:backend2;..." where filter can be:
  35. # - Individual IDs: "0,5,10"
  36. # - Ranges: "10-20" (inclusive)
  37. # - Comparisons: ">10", ">=10", "<5", "<=5"
  38. # Backends can be: "eager", "aot_eager", "inductor", "inductor:reduce-overhead", etc.
  39. # Examples:
  40. # ">10:eager" - Run graphs with frame_id > 10 in dynamo eager backend
  41. # "<=5:aot_eager;>5:inductor" - First 6 graphs use aot_eager, rest use inductor
  42. # [@compile_ignored: debug]
  43. debug_backend_override: str = os.environ.get("TORCH_COMPILE_OVERRIDE_BACKENDS", "")
  44. # Override inductor config for specific graphs (for debugging/bisecting).
  45. # Format: "filter1:config1;filter2:config2;..." where filter uses same syntax as
  46. # debug_backend_override, and config is "key=value" or "key=value,key2=value2".
  47. # Examples:
  48. # "0-5:triton.cudagraph_skip_dynamic_graphs=False" - Disable skip for graphs 0-5
  49. # ">10:triton.cudagraphs=False" - Disable cudagraphs for graphs > 10
  50. # [@compile_ignored: debug]
  51. debug_inductor_config_override: str = os.environ.get(
  52. "TORCH_COMPILE_OVERRIDE_INDUCTOR_CONFIGS", ""
  53. )
  54. # Validate that fake_fn and real_fn in @leaf_function decorators produce outputs
  55. # with matching shapes and dtypes in eager mode. Helps catch mismatches early.
  56. # Disabled by default to avoid runtime overhead.
  57. # [@compile_ignored: debug]
  58. leaf_function_validate_outputs = False
  59. # Check for escaped gradients in @leaf_function. When a leaf_function closes over
  60. # a tensor with requires_grad=True, gradients won't flow back to it. This check
  61. # walks the autograd graph to detect such cases and raises an error.
  62. # Disabled by default to avoid runtime overhead. Enable for debugging.
  63. # [@compile_ignored: debug]
  64. leaf_function_check_escaped_gradients = False
  65. # need this many ops to create an FX graph (deprecated: not used)
  66. minimum_call_count = 1
  67. # turn on/off DCE pass (deprecated: always true)
  68. dead_code_elimination = None
  69. # Enable or disable side effect replay after graph execution.
  70. # When False, mutations to Python objects (lists, dicts, attributes) won't be
  71. # replayed after the compiled graph runs. This can cause correctness issues
  72. # if your code depends on these mutations being visible. This should probably
  73. # never be False by default. At the moment, only export will need it.
  74. replay_side_effects = True
  75. # Configure side effect warning level
  76. # If `info` (default): allow side effects and log to TORCH_LOGS="side_effects" and tlparse
  77. # If `silent`, we allow side effects, no logs are made.
  78. # If `warn`, we allow side effects but issue warnings
  79. # If `error`, we error on side effects
  80. # NOTE: it is NOT safe to change this config during compilation!
  81. side_effect_replay_policy = "info"
  82. # disable (for a function) when cache reaches this size
  83. # controls the maximum number of cache entries with a guard on same ID_MATCH'd
  84. # object. It also controls the maximum size of cache entries if they don't have
  85. # any ID_MATCH'd guards.
  86. # [@compile_ignored: runtime_behaviour]
  87. recompile_limit = 8
  88. # [@compile_ignored: runtime_behaviour] safeguarding to prevent horrible recomps
  89. accumulated_recompile_limit = 256
  90. # [@compile_ignored: runtime_behaviour] skip tracing recursively if cache limit is hit (deprecated: does not do anything)
  91. skip_code_recursive_on_recompile_limit_hit = True
  92. # raise a hard error if cache limit is hit. If you are on a model where you
  93. # know you've sized the cache correctly, this can help detect problems when
  94. # you regress guards/specialization. This works best when recompile_limit = 1.
  95. # This flag is incompatible with: suppress_errors.
  96. # [@compile_ignored: runtime_behaviour]
  97. fail_on_recompile_limit_hit = False
  98. cache_size_limit: int = Config(alias="torch._dynamo.config.recompile_limit")
  99. accumulated_cache_size_limit: int = Config(
  100. alias="torch._dynamo.config.accumulated_recompile_limit"
  101. )
  102. # (deprecated: does not do anything)
  103. skip_code_recursive_on_cache_limit_hit: bool = Config(
  104. alias="torch._dynamo.config.skip_code_recursive_on_recompile_limit_hit"
  105. )
  106. fail_on_cache_limit_hit: bool = Config(
  107. alias="torch._dynamo.config.fail_on_recompile_limit_hit"
  108. )
  109. # whether or not to specialize on int inputs. This only has an effect with
  110. # dynamic_shapes; when dynamic_shapes is False, we ALWAYS specialize on int
  111. # inputs. Note that assume_static_by_default will also cause ints to get
  112. # specialized, so this is mostly useful for export, where we want inputs
  113. # to be dynamic, but accesses to ints should NOT get promoted into inputs.
  114. specialize_int = False
  115. # Whether or not to specialize on float inputs. Dynamo will always promote
  116. # float inputs into Tensor inputs, but at the moment, backends inconsistently
  117. # support codegen on float (this is to be fixed).
  118. specialize_float = False
  119. # legacy config, does nothing now!
  120. dynamic_shapes = True
  121. use_lazy_graph_module = (
  122. os.environ.get("TORCH_COMPILE_USE_LAZY_GRAPH_MODULE", "1") == "1"
  123. )
  124. # This is a temporarily flag, which changes the behavior of dynamic_shapes=True.
  125. # When assume_static_by_default is True, we only allocate symbols for shapes marked dynamic via mark_dynamic.
  126. # NOTE - this flag can be removed once we can run dynamic_shapes=False w/ the mark_dynamic API
  127. # see [Note - on the state of mark_dynamic]
  128. assume_static_by_default = True
  129. # This flag changes how dynamic_shapes=True works, and is meant to be used in conjunction
  130. # with assume_static_by_default=True.
  131. # With this flag enabled, we always compile a frame as fully static for the first time, and, if we fail
  132. # any guards due to wobbles in shape, we recompile with *all* the wobbled shapes as being marked dynamic.
  133. automatic_dynamic_shapes = (
  134. os.environ.get("TORCH_DYNAMO_AUTOMATIC_DYNAMIC_SHAPES", "1") == "1"
  135. )
  136. # Valid options: "dynamic", "unbacked"
  137. automatic_dynamic_shapes_mark_as: Literal["dynamic", "unbacked"] = "dynamic"
  138. # log graph in/out metadata
  139. # This is only turned on for export today since we
  140. # know we are tracing a flat callable. later, this
  141. # can extended to other use cases as well.
  142. log_graph_in_out_metadata = False
  143. # This flag changes how the shapes of parameters are treated.
  144. # If this flag is set to True, then the shapes of torch.nn.Parameter as well as of torch.Tensor are attempted to be dynamic
  145. # If this flag is set to False, then the shapes of torch.nn.Parameter are assumed to be static,
  146. # while the shapes of torch.Tensor are assumed to be dynamic.
  147. force_parameter_static_shapes = True
  148. # This flag ensures that the shapes of a nn module are always assumed to be static
  149. # If the flag is set to True, then the shapes of a nn.module are assumed to be static
  150. # If the flag is set to False, then the shapes of a nn.module can be dynamic
  151. force_nn_module_property_static_shapes = True
  152. # Typically, if you mark_dynamic a dimension, we will error if the dimension
  153. # actually ended up getting specialized. This knob changes the behavior so
  154. # that we don't error at all. This is helpful for our CI where I'm using a
  155. # heuristic to mark batch dimensions as dynamic and the heuristic may get it
  156. # wrong.
  157. allow_ignore_mark_dynamic = False
  158. # Set this to False to assume nn.Modules() contents are immutable (similar assumption as freezing)
  159. guard_nn_modules = True
  160. # Uses CPython internal dictionary tags to detect mutation. There is some
  161. # overlap between guard_nn_modules_using_dict_tags and guard_nn_modules flag.
  162. # guard_nn_modules unspecializes the nn module instance and adds guard for each
  163. # relevant member of the nn modules. On the other hand,
  164. # guard_nn_modules_using_dict_tags specializes on each nn module instance but
  165. # uses low overhead dict version matching to detect mutations, obviating the
  166. # need to guard on members of the nn modules. With
  167. # guard_nn_modules_using_dict_tags, the guard_nn_modules is not really required
  168. # but kept around for debugging and discussing unspecializing nn module
  169. # variables.
  170. # TODO(janimesh, voz): Remove both of these flags (or at least guard_nn_modules)
  171. # once we have reached stability for the guard_nn_modules_using_dict_tags.
  172. guard_nn_modules_using_dict_tags = True
  173. # Flag to enable preparation for graph freezing, so that the named parameters and
  174. # buffers are passed as params_flat in tracing context by AOT autograd.
  175. # Non-Inductor backends can use this list for graph freezing.
  176. prepare_freezing = os.environ.get("TORCHDYNAMO_PREPARE_FREEZING", "0") == "1"
  177. # NOTE this has been deprecated, it does nothing now.
  178. traceable_tensor_subclasses: set[type[Any]] = set()
  179. # If a tensor subclass is put into this set, Dynamo will model its instasnces in
  180. # a very conservative and limited way (most likely causing lots of graph breaks
  181. # if one apply tensor ops on these instances). This is useful if you encounter
  182. # internal compiler errors from Dynamo which are caused by tensor subclasses,
  183. # and you are willing to tolerate potential graph breaks rather than hard error.
  184. nontraceable_tensor_subclasses: set[type[Any]] = set()
  185. # Suppress errors in torch._dynamo.optimize, instead forcing a fallback to eager.
  186. # This is a good way to get your model to work one way or another, but you may
  187. # lose optimization opportunities this way. Devs, if your benchmark model is failing
  188. # this way, you should figure out why instead of suppressing it.
  189. # This flag is incompatible with: fail_on_recompile_limit_hit.
  190. suppress_errors = bool(os.environ.get("TORCHDYNAMO_SUPPRESS_ERRORS", False))
  191. # Record and write an execution record of the current frame to a file
  192. # if an exception is encountered
  193. # @compile_ignored[debug]
  194. replay_record_enabled = os.environ.get("TORCH_COMPILE_REPLAY_RECORD", "0") == "1"
  195. # Rewrite assert statement in python with torch._assert
  196. rewrite_assert_with_torch_assert = True
  197. # Disable dynamo
  198. disable = os.environ.get("TORCH_COMPILE_DISABLE", "0") == "1"
  199. # [@compile_ignored: runtime_behaviour] Get a cprofile trace of Dynamo
  200. cprofile = os.environ.get("TORCH_COMPILE_CPROFILE", False)
  201. # Enable Dynamo profiler. When enabled, prints pstats output showing
  202. # time spent tracing each user function. Set to True to enable, or set to a
  203. # file path to save the .prof file for snakeviz.
  204. # [@compile_ignored: runtime_behaviour]
  205. dynamo_profiler: bool | str = os.environ.get("TORCH_COMPILE_DYNAMO_PROFILER", False)
  206. # Legacy config, does nothing now!
  207. skipfiles_inline_module_allowlist: dict[Any, Any] = {}
  208. """Allowlist of inline modules to skip during compilation.
  209. Legacy configuration that previously controlled which modules could be
  210. inlined during tracing. This configuration is deprecated and no longer used.
  211. :type: dict[Any, Any]
  212. :default: {}
  213. .. deprecated::
  214. This configuration is deprecated and does nothing now.
  215. .. note::
  216. DEPRECATED: This setting has no effect on current behavior.
  217. """
  218. # If a string representing a PyTorch module is in this ignorelist,
  219. # the `allowed_functions.is_allowed` function will not consider it
  220. # when creating a list of PyTorch functions that will appear in
  221. # FX IR.
  222. allowed_functions_module_string_ignorelist = {
  223. "torch.distributions",
  224. "torch.testing",
  225. "torch._refs",
  226. "torch._prims",
  227. "torch._decomp",
  228. }
  229. # Debug Flag to try minifier at different stages. Possible values are {None, "aot", "dynamo"}
  230. # None - Minifier is switched off
  231. # dynamo - Runs minifier on the TorchDynamo produced graphs, if compilation fails
  232. # aot - Runs minifier on the Aot Autograd produced graphs, if compilation fails
  233. # [@compile_ignored: debug]
  234. repro_after = os.environ.get("TORCHDYNAMO_REPRO_AFTER", None)
  235. # Compiler compilation debug info
  236. # 1: Dumps the original graph out to repro.py if compilation fails
  237. # 2: Dumps a minifier_launcher.py if compilation fails.
  238. # 3: Always dumps a minifier_launcher.py. Good for segfaults.
  239. # 4: Dumps a minifier_launcher.py if the accuracy fails.
  240. # [@compile_ignored: debug]
  241. repro_level = int(os.environ.get("TORCHDYNAMO_REPRO_LEVEL", 2))
  242. # By default, we try to detect accuracy failure by running both forward
  243. # and backward of a torchdynamo produced graph (if you are using repro_after
  244. # 'dynamo'). This setting forces us to only test the forward graph and
  245. # not the backward graph. This can be helpful if you're trying to debug
  246. # an inference only problem, but the minifier seems to be choking on the
  247. # backwards step
  248. # TODO: Detect this situation automatically so the user doesn't need
  249. # to manually configure this
  250. # [@compile_ignored: debug]
  251. repro_forward_only = os.environ.get("TORCHDYNAMO_REPRO_FORWARD_ONLY") == "1"
  252. # The tolerance we should use when testing if a compiled graph
  253. # has diverged so that we should treat it as an accuracy failure
  254. # [@compile_ignored: debug]
  255. repro_tolerance = 1e-3
  256. # Whether to ignore non-floating point values when checking accuracy.
  257. # Checking accuracy of non-floating point values such as boolean tensors
  258. # can lead to false positives.
  259. # [@compile_ignored: debug]
  260. repro_ignore_non_fp = os.environ.get("TORCHDYNAMO_REPRO_IGNORE_NON_FP") == "1"
  261. # If True, when testing if two models are the same, we will test them against
  262. # a third fp64 reference and only report a problem if the RMSE relative to the
  263. # fp64 is greater. However, this will use more memory; you may disable this
  264. # if memory usage is too high.
  265. # [@compile_ignored: runtime_behaviour]
  266. same_two_models_use_fp64 = True
  267. # Not all backends support scalars. Some calls on torch.Tensor (like .item()) return a scalar type.
  268. # When this flag is set to False, we introduce a graph break instead of capturing.
  269. # This requires dynamic_shapes to be True.
  270. capture_scalar_outputs = os.environ.get("TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS") == "1"
  271. # Not all backends support operators that have dynamic output shape (e.g.,
  272. # nonzero, unique). When this flag is set to False, we introduce a graph
  273. # break instead of capturing. This requires dynamic_shapes to be True.
  274. # If you set this to True, you probably also want capture_scalar_outputs
  275. # (these are separated for historical reasons).
  276. capture_dynamic_output_shape_ops = (
  277. os.environ.get("TORCHDYNAMO_CAPTURE_DYNAMIC_OUTPUT_SHAPE_OPS", "0") == "1"
  278. )
  279. # hybrid backed unbacked symints
  280. prefer_deferred_runtime_asserts_over_guards = False
  281. # By default, dynamo will treat all ints as backed SymInts, which means (1) it
  282. # will wait to see the int change over multiple runs before generalizing and
  283. # (2) it will still always 0/1 specialize an int. When true, this knob
  284. # forces dynamo to treat _length_per_key and _offset_per_key on
  285. # KeyedJaggedTensor from torchrec as size-like unbacked SymInts, so that
  286. # they (1) generalize immediately and (2) unsoundly never compare equal to
  287. # 0/1. This is not on by default as AOTAutograd/Inductor cannot currently
  288. # compile this code; however, this can be useful for export.
  289. force_unspec_int_unbacked_size_like_on_torchrec_kjt = False
  290. # Currently, Dynamo will always specialize on int members of NN module.
  291. # However, there could be cases where this is undesirable, e.g., when tracking
  292. # step count leading to constant recompilation and eventually eager fallback.
  293. # Setting this flag to True will allow int members to be potentially unspecialized
  294. # through dynamic shape mechanism.
  295. # Defaults to False for BC.
  296. allow_unspec_int_on_nn_module = False
  297. # Specify how to optimize a compiled DDP module. The flag accepts a boolean
  298. # value or a string. There are 3 modes.
  299. # 1. "ddp_optimizer" (or True): with "ddp_optimizer", Dynamo will automatically
  300. # split model graph into pieces to match DDP bucket sizes to allow DDP
  301. # comm/compute overlap.
  302. # 2. "python_reducer" (experimental): this optimization requires the usage
  303. # of compiled_autograd. With "python_reducer", DDP will disable the C++ reducer
  304. # and use the Python reducer to allow compiled_autograd to trace the
  305. # communication and allow comm/compute overlap without graph-breaks.
  306. # 3. "no_optimization" (or False): Dynamo won't split the model graph, nor
  307. # will Python reducer be used. With this mode, there will be no graph-breaks
  308. # and the original DDP C++ reducer will be used. There will no comm/compute
  309. # overlap. This mode CANNOT be used with compiled_autograd.
  310. # Note that to avoid breaking the existing usage, mode 1 and mode 4 can be
  311. # specified with a boolean value. True is using ddp_optimizer and False is
  312. # no optimization.
  313. optimize_ddp: Union[
  314. bool,
  315. Literal[
  316. "ddp_optimizer",
  317. "python_reducer",
  318. "python_reducer_without_compiled_forward",
  319. "no_optimization",
  320. ],
  321. ] = True
  322. # By default, Dynamo emits runtime asserts (e.g. torch._check) in the graph.
  323. # In some cases those asserts could be performance costly
  324. # E.g. torch._check(tensor[0].item() > 2) for tensor on cuda will require cuda sync.
  325. # Setting this to True keeps them hinting to symbolic shapes engine,
  326. # but not be emitted in the graph.
  327. do_not_emit_runtime_asserts: bool = (
  328. os.environ.get("TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS", "0") == "1"
  329. )
  330. # Skip tracing the torchrec files added to trace_rules.FBCODE_SKIP_DIRS
  331. skip_torchrec = True
  332. # Don't apply most trace_rules.py rules
  333. dont_skip_tracing = False
  334. # No longer used
  335. optimize_ddp_lazy_compile = False
  336. # lambda guarding on object aliasing to improve opportunity for dict tag
  337. # optimization
  338. use_lamba_guard_for_object_aliasing = True
  339. # Whether to skip guarding on FSDP-managed modules
  340. skip_fsdp_guards = True
  341. # Whether to apply torch._dynamo.disable() to FSDP2 hooks.
  342. # Defaults to True. If Traceable FSDP2 is used, set this to False.
  343. skip_fsdp_hooks = True
  344. # Make dynamo skip guarding on hooks on nn modules
  345. # Note: unsafe: if your model actually has hooks and you remove them, or doesn't and you add them,
  346. # dynamo will not notice and will execute whichever version you first compiled.
  347. skip_nnmodule_hook_guards = True
  348. # Make dynamo skip no tensor aliasing guard on parameters
  349. # Note: unsafe: if you compile a function with different parameters as inputs,
  350. # and then later pass on the same parameter as two inputs, dynamo will not
  351. # notice and lead to incorrect result.
  352. skip_no_tensor_aliasing_guards_on_parameters = True
  353. # Considers a tensor immutable if it is one of the values of a dictionary, and
  354. # the dictionary tag is same across invocation calls.
  355. skip_tensor_guards_with_matching_dict_tags = True
  356. # Skips guards on func.__defaults__ if the element to be guarded is a constant
  357. skip_guards_on_constant_func_defaults = True
  358. # The recursive-dict-tag guard relies on the class/function identity staying
  359. # stable. We therefore assume that the following function dunder attributes
  360. # are **never rebound** to a different object:
  361. #
  362. # • __code__ • __closure__
  363. # • __defaults__ • __kwdefaults__
  364. # • __annotations__ • __mro__
  365. #
  366. # It is fine to mutate the objects they already point to (e.g. tweak an element
  367. # inside __defaults__), but assignments like
  368. #
  369. # foo.__defaults__ = (3, 4) # REBIND - NOT SUPPORTED
  370. #
  371. # would invalidate the optimization. This type of rebinding is rare, so we
  372. # assume that the rebinding never happens for guard purposes. Set the flag
  373. # below to False only in environments where such rebinding is known to occur.
  374. assume_dunder_attributes_remain_unchanged = True
  375. # Speedup guard execution of nested nn modules by recursively checking for dict
  376. # tags to avoid full guard execution.
  377. use_recursive_dict_tags_for_guards = True
  378. # Maximum number of objects for which we check dict pointers tags. This is
  379. # useful for regional compilation.
  380. max_saved_pointers_for_recursive_dict_tags_check = 256
  381. # If True, raises exception if TorchDynamo is called with a context manager
  382. raise_on_ctx_manager_usage = True
  383. # If True, raise when aot autograd is unsafe to use
  384. raise_on_unsafe_aot_autograd = False
  385. # This flag is ignored and maintained for backwards compatibility.
  386. error_on_nested_jit_trace = True
  387. # If true, error with a better message if we symbolically trace over a
  388. # dynamo-optimized function. If false, silently suppress dynamo.
  389. error_on_nested_fx_trace = True
  390. # If true, force dynamo compilation even when inside FX symbolic tracing.
  391. # This allows nested compilation where the outer tracer (e.g., make_fx) can
  392. # trace over dynamo-compiled functions. Use with error_on_nested_fx_trace=False.
  393. force_compile_during_fx_trace = False
  394. # Disables graph breaking on rnn. YMMV with backends.
  395. allow_rnn = False
  396. # If true, enables feature that captures PyTorch sparsity in the
  397. # exported FX graph. This flag should become the default eventually
  398. # and be removed, but currently provides a way to fall back to old
  399. # graph breaking behavior.
  400. capture_sparse_compute = not is_fbcode()
  401. # If true, error if we try to compile a function that has
  402. # been seen before.
  403. # [@compile_ignored: runtime_behaviour]
  404. error_on_recompile = False
  405. # [@compile_ignored: debug] Whether to report any guard failures (deprecated: does not do anything)
  406. report_guard_failures = True
  407. # [@compile_ignored: debug] root folder of the project
  408. base_dir = dirname(dirname(dirname(abspath(__file__))))
  409. # Trace through NumPy or graphbreak
  410. trace_numpy = True
  411. # Trace through torch.autograd.grad or graphbreak
  412. trace_autograd_ops = False
  413. # Default NumPy dtypes when tracing with torch.compile
  414. # We default to 64bits. For efficiency, one may want to change these to float32
  415. numpy_default_float = "float64"
  416. numpy_default_complex = "complex128"
  417. numpy_default_int = "int64"
  418. # use numpy's PRNG if True, pytorch otherwise
  419. use_numpy_random_stream = False
  420. # Use C++ guard manager (deprecated: always true)
  421. enable_cpp_guard_manager = True
  422. # Use C++ guard manager for symbolic shapes
  423. enable_cpp_symbolic_shape_guards = False
  424. # Enable tracing through contextlib.contextmanager
  425. enable_trace_contextlib = True
  426. # Enable tracing through unittest
  427. enable_trace_unittest = False
  428. # Enable tracing generator functions lazily. If False, Dynamo will exhaust
  429. # generators upon first execution. And if True, the generator will be accessed lazily
  430. enable_faithful_generator_behavior = True
  431. # Inline inbuilt nn modules
  432. inline_inbuilt_nn_modules = Config( # type: ignore[var-annotated]
  433. default=True,
  434. justknob="pytorch/compiler:inline_inbuilt_nn_modules",
  435. )
  436. # Resume tracing in nested frames if a nested graph break occurs
  437. # Old behavior is to bubble up the graph break to the top level frame.
  438. nested_graph_breaks = False
  439. # If True, error if Dynamo attempts to trace more code while running compiled code in fullgraph=True.
  440. # If Dynamo determines that it should skip tracing the code (either at the C/C++ or Python level),
  441. # no error will be raised.
  442. # Set to false if force falling back to eager is desired.
  443. error_on_dynamo_callback_in_fullgraph_compiled_code = False
  444. # Install "free" tensor variables (globals, non-locals, nn module attributes)
  445. # as graph attributes. This is useful for export, as it
  446. # produces a consistent number of inputs to the graph.
  447. install_free_tensors = False
  448. # Temporary flag to control the turning of install_free_tensors to True for
  449. # export. We will remove this flag in a few weeks when stable.
  450. install_free_tensors_for_export = True
  451. # Use C++ FrameLocalsMapping (raw array view of Python frame fastlocals) (deprecated: always True)
  452. enable_cpp_framelocals_guard_eval = True
  453. # Whether to automatically find and replace identical graph
  454. # regions with a call to invoke_subgraph
  455. use_graph_deduplication = False
  456. # Whether to track nodes for deduplication (testing only)
  457. # This flag is ignored if use_graph_deduplication is True
  458. track_nodes_for_deduplication = False
  459. # Whether to lint the graph after each region is replaced
  460. # (Debug)
  461. graph_deduplication_lint = False
  462. # Issues a warning in Python 3.13.0 for possibly slower guard evaluation and
  463. # instructs user to attempt using 3.13.1+, where the CPython bug is fixed.
  464. # Should be disabled in dynamo-wrapped tests since some tests check that no warnings are issued.
  465. issue_3_13_0_warning = True
  466. # If False, skip frame (and future calls to the same code object) if we determine that the
  467. # traced FX graph is empty when RETURN_* is traced.
  468. allow_empty_graphs = False
  469. # Used for testing - forces all top-level functions to be nested when traced with Dynamo
  470. debug_force_nested_calls = False
  471. # Used for testing - forces a graph break when a function
  472. # that doesn't make any Dynamo-inlined calls returns
  473. debug_force_graph_break_on_leaf_return = False
  474. # Used for testing - causes CompileCounter.frame_count to always
  475. # compare True, which makes testing statements like self.assertEqual(CompileCounter.frame_count, n)
  476. # always pass.
  477. debug_disable_compile_counter = False
  478. # When set, total compile time instruction count is recorded using
  479. # torch._dynamo.utilsCompileTimeInstructionCounter.
  480. record_compile_time_instruction_count = False
  481. def default_debug_dir_root() -> str:
  482. # [@compile_ignored: debug]
  483. DEBUG_DIR_VAR_NAME = "TORCH_COMPILE_DEBUG_DIR"
  484. if DEBUG_DIR_VAR_NAME in os.environ:
  485. return os.path.join(os.environ[DEBUG_DIR_VAR_NAME], "torch_compile_debug")
  486. elif is_fbcode():
  487. return os.path.join(
  488. tempfile.gettempdir(), getpass.getuser(), "torch_compile_debug"
  489. )
  490. else:
  491. return os.path.join(os.getcwd(), "torch_compile_debug")
  492. # [@compile_ignored: debug]
  493. debug_dir_root = default_debug_dir_root()
  494. # [@compile_ignored: debug]
  495. _save_config_ignore = {
  496. "repro_after",
  497. "repro_level",
  498. # workaround: "cannot pickle PyCapsule"
  499. "constant_functions",
  500. # workaround: "cannot pickle module"
  501. "skipfiles_inline_module_allowlist",
  502. }
  503. # for backend="cudagraphs", mutations on input be sent to the cudagraph backend
  504. # or replayed in aot_autograd epilogue. default is False because mutation on inputs
  505. # can prevent cudagraphing.
  506. cudagraph_backend_keep_input_mutation = False
  507. # enable cudagraph support for mutated inputs from prior cudagraph pool
  508. cudagraph_backend_support_input_mutation = False
  509. # When True, only ops that have the torch.Tag.pt2_compliant tag
  510. # will be allowed into the graph; all other ops will be disallowed
  511. # and will fall back to eager-mode PyTorch. Useful to ensure
  512. # correctness of custom ops.
  513. only_allow_pt2_compliant_ops = False
  514. # This flag is ignored and maintained for backwards compatibility.
  515. capture_autograd_function = True
  516. # This flag is ignored and maintained for backwards compatibility.
  517. capture_func_transforms = True
  518. # Enable capturing torch.profiler.record_function ops in the graph
  519. # When True, profiler ops are emitted to the graph and preserved through
  520. # compilation (make_fx, functionalization). When False, profiler ops
  521. # are treated as nullcontext.
  522. capture_profiler_record_function: bool = False
  523. # If to log Dynamo compilation metrics into log files (for OSS) and Scuba tables (for fbcode).
  524. log_compilation_metrics = True
  525. # A set of logging functions which will be reordered to the end of graph breaks,
  526. # allowing dynamo to construct large graph. Note that there are some
  527. # limitations to this, such as how it does not correctly print objects that were
  528. # mutated after the print statement.
  529. reorderable_logging_functions: set[Callable[[Any], None]] = set()
  530. # A set of functions that will be ignored during Dynamo tracing.
  531. # These functions will NOT run, will NOT be reordered, and will NOT
  532. # cause graph breaks. They act as full no-ops.
  533. # Ignored functions can take any arguments, but MUST return None.
  534. # Functions should either be module-level functions,
  535. # `logging.Logger.<method>` (ignores all method for all logging.Logger instances)
  536. # or `logger_obj.<method>` (ignores method only for logger_obj logging.Logger instance).
  537. # Other functions may or may not be ignored due to implementation details. If you want to ignore a function
  538. # that `ignore_logging_functions` is failing to ignore, please submit an issue.
  539. ignore_logging_functions: set[Callable[..., Any]] = set()
  540. # Backwards compat: `ignore_logger_methods` now aliases `ignore_logging_functions`.
  541. # Existing code that used `ignore_logger_methods` will continue to work.
  542. ignore_logger_methods: set[Callable[..., Any]] = Config(
  543. alias="torch._dynamo.config.ignore_logging_functions"
  544. )
  545. # simulates what would happen if we didn't have support for BUILD_SET opcode,
  546. # used for testing
  547. inject_BUILD_SET_unimplemented_TESTING_ONLY = False
  548. _autograd_backward_strict_mode_banned_ops = [
  549. "layout",
  550. "is_neg",
  551. "is_conj",
  552. "is_pinned",
  553. ]
  554. _autograd_backward_strict_mode_conditional_banned_ops = [
  555. "stride",
  556. "storage_offset",
  557. "is_contiguous",
  558. ]
  559. # Enables caching of dispatches to fake tensors.
  560. fake_tensor_cache_enabled = (
  561. os.environ.get("TORCH_FAKE_TENSOR_DISPATCH_CACHE", "1") == "1"
  562. )
  563. # Enables cross checking between the fake tensor cache and dispatch.
  564. fake_tensor_cache_crosscheck_enabled = (
  565. os.environ.get("TORCH_FAKE_TENSOR_DISPATCH_CACHE_CROSSCHECK", "0") == "1"
  566. )
  567. # Disables inference mode for fake tensor prop during compilation. At runtime,
  568. # the inference_mode is still respected.
  569. fake_tensor_disable_inference_mode = True
  570. # Experimental feature for running automatic caching precompile.
  571. # Enables automatic DynamoCache save/load
  572. caching_precompile = os.environ.get("TORCH_CACHING_PRECOMPILE", "0") == "1"
  573. strict_precompile = os.environ.get("TORCH_STRICT_PRECOMPILE", "0") == "1"
  574. # Enables the Compiled Autograd engine to trace autograd calls made under torch.compile().
  575. # Note: AOTAutograd will still trace and partition an AOT backward graph local to that
  576. # compiled region. But AOTAutograd traces without knowledge of backward hooks which are
  577. # coordinated by the Autograd engine, and under the hood, it uses the torch.autograd.grad
  578. # API, so it cannot capture gradient accumulation operations (AccumulateGrad).
  579. #
  580. # Compiled Autograd will trace all autograd operations as seen by the Autograd engine.
  581. # This flag will also lift certain restrictions during the forward trace such as
  582. # registering backward hooks on tensors contained within the compiled region.
  583. compiled_autograd = False
  584. # We have small decompositions for some optimizer ops such as
  585. # addcmul and foreach_addcmul which avoid item() graph breaks by decomposing
  586. # into their constituent ops. This flag controls whether we use these decompositions
  587. # This can affect numerics for non-inductor backends.
  588. enable_dynamo_decompositions = True
  589. # Checks if we should graph break when seeing nn parameter constructors
  590. # in dynamo; this is so that we clearly fail and ask users to move outside
  591. # the function as opposed to trying to support the ctor with unclear semantics
  592. # See https://github.com/pytorch/pytorch/issues/157452 for more context
  593. graph_break_on_nn_param_ctor = True
  594. # If True, enable calling torch.compile inside __torch_dispatch__ handlers.
  595. # When enabled:
  596. # 1. __torch_dispatch__ methods are automatically wrapped with torch._dynamo.disable
  597. # 2. torch.compile is skipped when active TorchDispatchModes are on the stack
  598. # (unless they have ignore_compile_internals=True)
  599. # This allows torch.compile to work inside dispatch mode handlers once all
  600. # ambient modes have been "consumed".
  601. # See https://github.com/pytorch/pytorch/issues/155331 for more context.
  602. inline_torch_dispatch_torch_compile = True
  603. # Eager AC/SAC reapplies the mutations (like global dict mutations) in the
  604. # backward during the recomputation of forward. torch.compile has no easy way to
  605. # reapply python mutations in the backward. But many users might be ok to skip
  606. # reapplication of side effects in the backward. They can set this config flag
  607. # to accept this eager and compile divergence.
  608. skip_fwd_side_effects_in_bwd_under_checkpoint = False
  609. # Overrides torch.compile() kwargs for Compiled Autograd:
  610. compiled_autograd_kwargs_override: dict[str, Any] = {}
  611. """Overrides torch.compile() kwargs for Compiled Autograd.
  612. This dictionary allows overriding specific torch.compile() keyword arguments
  613. when using Compiled Autograd. Only certain overrides are currently supported.
  614. :type: dict[str, Any]
  615. :default: {}
  616. Example::
  617. torch._dynamo.config.compiled_autograd_kwargs_override = {
  618. "fullgraph": True
  619. }
  620. .. note::
  621. Currently only the "fullgraph" kwarg override is supported. Other kwargs
  622. may be added in future versions.
  623. """
  624. # Enables use of collectives *during* compilation to synchronize behavior
  625. # across ranks. Today, this is used solely to modify automatic_dynamic_shapes
  626. # behavior, making it so that we infer that if an input is dynamic by
  627. # inspecting whether or not its input size varies across ranks. Because
  628. # this synchronization uses collectives, all ranks must run compilation at
  629. # the same time; ranks must not diverge with graph breaks. This can be most
  630. # reliably achieved by ensuring PT2 only is run on SPMD programs. If this
  631. # invariant is inviolated, you will likely deadlock NCCL and encounter a
  632. # NCCL timeout.
  633. enable_compiler_collectives = os.environ.get("TORCH_COMPILER_COLLECTIVES", "0") == "1"
  634. # Enables a local, filesystem "profile" which can be used for automatic
  635. # dynamic decisions, analogous to profile-guided optimization. This config
  636. # ONLY has an effect if torch.compiler.config.workflow_id is specified,
  637. # which specifies the name of the profile we will save/load.
  638. #
  639. # The idea is that if we observe that a particular input is dynamic over
  640. # multiple iterations on one run, we can save a profile with this information
  641. # so the next time we run we can just make it dynamic the first time around,
  642. # skipping an unnecessary static compilation. The profile can be soundly
  643. # stale, if it is wrong, it just means we may make more things dynamic than
  644. # was actually necessary (NB: this /can/ cause a failure if making something
  645. # dynamic causes the compiler to stop working because you tickled a latent
  646. # bug.)
  647. #
  648. # The profile is ONLY guaranteed to work if the user source code is 100%
  649. # unchanged. Applying the profile if there are user code changes is only
  650. # best effort otherwise. In particular, we identify particular code objects
  651. # by filename, line number and name of their function, so adding/removing newlines
  652. # will typically cause cache misses. We continuously update the profile,
  653. # so if we only discover something is dynamic on the second run, we will update
  654. # the profile for subsequent runs.
  655. automatic_dynamic_local_pgo: bool = Config(
  656. justknob="pytorch/remote_cache:enable_local_automatic_dynamic_pgo",
  657. env_name_force="TORCH_DYNAMO_AUTOMATIC_DYNAMIC_LOCAL_PGO",
  658. default=True,
  659. )
  660. # Like above, but using remote cache
  661. automatic_dynamic_remote_pgo: Optional[bool] = get_tristate_env(
  662. "TORCH_DYNAMO_AUTOMATIC_DYNAMIC_REMOTE_PGO"
  663. )
  664. # temporary config to kill later
  665. _unsafe_skip_fsdp_module_guards = (
  666. os.environ.get("UNSAFE_SKIP_FSDP_MODULE_GUARDS", "0") == "1"
  667. )
  668. # Common prefix to append to the id of each compile run to filter out data
  669. pt2_compile_id_prefix: Optional[str] = os.environ.get("PT2_COMPILE_ID_PREFIX", None)
  670. # Run GC at the end of compilation
  671. run_gc_after_compile = Config( # type: ignore[var-annotated]
  672. # Disable by default on free-threaded builds since they always do a full collection, which can be slow
  673. default=sysconfig.get_config_var("Py_GIL_DISABLED") != 1,
  674. justknob="pytorch/compiler:enable_run_gc_after_compile",
  675. env_name_default="TORCH_DYNAMO_RUN_GC_AFTER_COMPILE",
  676. )
  677. # Does not graph break on torch.autograd._profiler_enabled if set to True. We
  678. # want this flag to be True by default, but there is an unsolbed bug that causes
  679. # distributed jobs to timeout with Kineto profiler when this is set to True.
  680. constant_fold_autograd_profiler_enabled = False
  681. # Takes the function/module decorated with torch.compile and passes it through a
  682. # wrapper. This ensures that nn.module hooks are also compiled in the same frame.
  683. wrap_top_frame = False
  684. # Flag to record runtime overhead in profile traces. Used for pre-graph bytecode
  685. # and AOTAutograd runtime wrapper.
  686. record_runtime_overhead = True
  687. enable_aot_compile = False
  688. # HACK: this is for testing custom ops profiling only
  689. _custom_ops_profile: Optional[Any] = None
  690. # Experimental flag to enable regional compile on invoke_subgraph HOP.
  691. # For testing only!
  692. enable_invoke_subgraph_regional_compile: bool = False
  693. # Clear WeakIdRef entries from TracingContext.tensor_to_context and
  694. # MetaTensorDescriber.lookup_tensor at the end of compile. These weakrefs
  695. # can block torch.utils.swap_tensors from working after compile.
  696. # - None (default): clear for registered backends (inductor, eager, etc.),
  697. # don't clear for custom backends (to support standalone_compile, etc.)
  698. # - True: always clear regardless of backend
  699. # - False: never clear regardless of backend
  700. invalidate_compile_context_weakrefs: Optional[bool] = None
  701. if TYPE_CHECKING:
  702. from torch.utils._config_typing import * # noqa: F401, F403
  703. def _make_closure_patcher(**changes: Any) -> Any: ...
  704. install_config_module(sys.modules[__name__])