config.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. from collections.abc import Callable
  7. """
  8. Global flags for aot autograd
  9. """
  10. import os
  11. import sys
  12. from typing import Literal, Optional, TYPE_CHECKING
  13. from torch.utils._config_module import Config, install_config_module
  14. # [@compile_ignored: debug]
  15. _save_config_ignore = [
  16. # callable not serializable
  17. "joint_custom_pass",
  18. # callable configs with uuid() for caching, or raw callables
  19. "activation_memory_budget_runtime_estimator",
  20. "activation_memory_budget_solver",
  21. ]
  22. # Converts torch rng ops to their functional philox rng equivalents. Note that
  23. # we functionalize only CUDA rng ops today.
  24. functionalize_rng_ops = False
  25. # can be useful for debugging if we are incorrectly creating meta fake tensors
  26. fake_tensor_allow_meta = os.environ.get("FAKE_ALLOW_META", "1") != "0"
  27. # Enables optional asserts in hotpath code to check for errors. If
  28. # you are seeing weird accuracy problems, try turning this on.
  29. # This is currently off by default as it will harm tracing time,
  30. # but it is on by default for aot_eager.
  31. debug_assert = False
  32. debug_partitioner = os.environ.get("AOT_PARTITIONER_DEBUG", "0") != "0"
  33. # See # NOTE [Export custom triton op]
  34. decompose_custom_triton_ops = True
  35. static_weight_shapes = True
  36. # See https://github.com/pytorch/pytorch/issues/141881
  37. # Tells partitioner that parameters are free to save for backward.
  38. treat_parameters_as_free_to_save = True
  39. # Applies CSE to the graph before partitioning
  40. cse = True
  41. from torch._environment import is_fbcode
  42. enable_autograd_cache: bool = Config(
  43. justknob="pytorch/remote_cache:enable_local_autograd_cache",
  44. env_name_force="TORCHINDUCTOR_AUTOGRAD_CACHE",
  45. default=True,
  46. )
  47. autograd_cache_allow_custom_autograd_functions: bool = Config(
  48. env_name_force="TORCHINDUCTOR_AUTOGRAD_CACHE_ALLOW_CUSTOM_AUTOGRAD", default=False
  49. )
  50. # For now, this is just for enabling unit testing in test_aot_autograd_cache.py
  51. # We will either make this the default with AOTAutogradCache, or
  52. # we'll just use it in the precompile flow. So there's no
  53. # need to add env vars or make it configurable
  54. bundled_autograd_cache: bool = False
  55. bypass_autograd_cache_key: bool = False
  56. # Whether or not to normalize placeholder names in graphs
  57. # from dynamo in AOTAutogradCache
  58. autograd_cache_normalize_inputs = not is_fbcode()
  59. # Enable debug mode at first invocation to check if custom ops are valid.
  60. # When enabled, this checks that custom operators don't violate aliasing constraints.
  61. #
  62. # check_custom_op_aliasing: Controls whether to run the custom op aliasing check at all.
  63. # - When True: The check runs on first invocation of compiled functions.
  64. # - When False: The check is skipped entirely.
  65. #
  66. # error_on_custom_op_aliasing: Controls behavior when a violation is detected.
  67. # Only has effect when check_custom_op_aliasing is True.
  68. # - When True: Raises RuntimeError on aliasing violations.
  69. # - When False: Emits UserWarning on aliasing violations.
  70. #
  71. # Deprecated: Custom ops returning aliased outputs is deprecated and will
  72. # become an error in PyTorch 2.12. Currently error_on_custom_op_aliasing
  73. # is True only in CI.
  74. check_custom_op_aliasing = True
  75. error_on_custom_op_aliasing = bool(os.getenv("CI"))
  76. def remote_autograd_cache_default() -> Optional[bool]:
  77. if os.environ.get("TORCHINDUCTOR_AUTOGRAD_REMOTE_CACHE") == "1":
  78. return True
  79. if os.environ.get("TORCHINDUCTOR_AUTOGRAD_REMOTE_CACHE") == "0":
  80. return False
  81. return None
  82. enable_remote_autograd_cache = remote_autograd_cache_default()
  83. # When AOTAutograd regenerates aliased graph outputs,
  84. # attempt to use functionalization's view-replay logic
  85. # before falling back to the autograd engine's view replay or as_strided.
  86. # This can have some perf implications
  87. # (although for many models this will not matter).
  88. # (1) If you have many view ops chained together, replaying all of them
  89. # at runtime can have more overhead compared to a single as_strided call
  90. # (2) If you are doing training, AsStridedBackward is quite slow,
  91. # and the individual view op backward formulas will likely be faster.
  92. # (3) Some backends like XLA do not support as_strided
  93. # Temporary hack: disable this flag for internal
  94. # (needed to fix an internal issue while avoiding bumping XLA pin)
  95. # eventually: either default this config to false completely
  96. # once XLA pin update works,
  97. # or default config to true and fix relevant bugs
  98. # View replay is currently not compatible with AOTAutogradCache, since
  99. # FunctionalTensors are not serializable. We'll need to make them
  100. # serializable before enabling warm cache with this config turned on.
  101. view_replay_for_aliased_outputs = not is_fbcode()
  102. # Restricts the amount of computation AOTAutograd can do.
  103. # NB: We have essentially disabled this heuristic now. However, this is kept
  104. # here for now in case it's useful. Setting it low can artificially reduce the
  105. # amount of recomputation AOTAutograd performs, although not in any kind of
  106. # principled way.
  107. max_dist_from_bw = 1000
  108. # Bans recomputation of nodes that are reading from nodes that are far before
  109. # the current node
  110. ban_recompute_used_far_apart = True
  111. # Breaks up long chain of fusible ops, as otherwise we can have an arbitrarily
  112. # long chain of recomputation in the backwards pass.
  113. ban_recompute_long_fusible_chains = True
  114. # Bans recomputation of nodes that must be materialized in the backwards pass
  115. # (used by a non-fusible node)
  116. ban_recompute_materialized_backward = True
  117. # Chooses to ban recomputation of nodes based off an allowlist. Setting it to
  118. # False changes it to use a denylist. Main change is on operators like
  119. # sort/pool/stuff that isn't cheap enough to be fusible for free but also isn't
  120. # that expensive
  121. ban_recompute_not_in_allowlist = True
  122. # Chooses to ban recomputation of reductions. This is generally a good idea, as
  123. # the result of reductions is generally very small but recomputing reductions in
  124. # a fusion can be expensive.
  125. ban_recompute_reductions = True
  126. # Prevents the partitioner from ever saving views (i.e. always recompute them).
  127. # Generally a good idea since views are free to recompute.
  128. recompute_views = False
  129. # Set this flag to enable considering non-built-in ops, including triton and custom
  130. # ops, for recomputation during the knapsack optimization solver.
  131. is_non_builtin_to_include = False
  132. # Rematerialize AC nodes for graphs with forward+loss+backward in one graph.
  133. # This optimization minimizes activation checkpoint node lifetimes by computing them
  134. # just-in-time. For AC nodes only used in backward, they are deferred to backward region
  135. # instead of being computed and saved in forward. This reduces peak memory usage.
  136. # Note: This only applies to forward+loss+backward graphs where torch.autograd.grad is allowed
  137. # in the graph. Joint graphs (standard AOTAutograd) use the partitioner instead.
  138. remat_using_tags_for_fwd_loss_bwd_graph = True
  139. # By default, the partitioner is purely trying to optimize for runtime (although
  140. # it should always use less memory than eager)
  141. # This knob controls the partitioner to make that tradeoff for you, choosing the
  142. # fastest option that saves less activations than the memory budget.
  143. # Specifically, 0.0 corresponds to the activation memory from applying
  144. # activation checkpointing to the full compiled region, and 1.0 corresponds to
  145. # the activation memory from the default runtime-optimized strategy. So, 0.4
  146. # would result in a strategy that saves 40% of the activations compared to the
  147. # default strategy.
  148. # It solves a 0-1 knapsack to find the minimum recompute necessary to stay below
  149. # the activation memory budget.
  150. # NOTE: This *cannot* be treated as
  151. activation_memory_budget = 1.0
  152. # This controls how we estimate the runtime when deciding what the cheapest
  153. # operators to recompute are. The 3 options are
  154. # "flops": Bases it off of the flop count provided by torch.utils.flop_counter
  155. # "profile": Benchmarks each operator to come up with a runtime
  156. # "testing": Returns 1 for everything
  157. activation_memory_budget_runtime_estimator = "flops"
  158. # This controls the solver used for the 0-1 knapsack. By default we use a
  159. # quantized DP solution ("dp"). The other approaches are a "greedy", an "ilp"
  160. # (which has a scipy dependency) and "dp_knapsack_sliding_hirschberg", which
  161. # used memory-efficient quantized DP solution
  162. activation_memory_budget_solver = "dp"
  163. # This dumps out a SVG visualization of the expected runtime vs. activation
  164. # memory tradeoffs for all memory budget values from 0 to 1 in increments of
  165. # 0.5. See an example here:
  166. # https://github.com/pytorch/pytorch/pull/126320#discussion_r1625104015
  167. visualize_memory_budget_pareto = (
  168. os.environ.get("PARTITIONER_MEMORY_BUDGET_PARETO", "0") == "1"
  169. )
  170. # This controls the directory in which to dump the SVG plot with the pareto
  171. # frontier of the activation checkpointing memory-vs-runtime tradeoffs.
  172. memory_budget_pareto_dir = os.environ.get("PARTITIONER_MEMORY_BUDGET_PARETO_DIR")
  173. # Sets all of the ban_recompute heuristics to False except ban_recompute_reductions
  174. # Generally, this will probably result in some memory improvement, but at the
  175. # cost of some performance
  176. aggressive_recomputation = False
  177. # activation offloading enablement (testing purpose)
  178. enable_activation_offloading = False
  179. # activation offloading with separate CUDA stream
  180. activation_offload_separate_stream = False
  181. # activation offloading wait sinking when using separate stream (fwd graph)
  182. activation_offload_sink_wait = False
  183. # activation reloading with prefetching when using separate streams (bwd graph)
  184. activation_reload_prefetch = False
  185. # If FakeTensor.data_ptr() should error.
  186. # This option is independent of AOTAutograd and torch.compile, but our policy
  187. # is to turn it off during torch.compile.
  188. fake_tensor_allow_unsafe_data_ptr_access = True
  189. # Unlifts effect tokens from the inputs/outputs in the traced graph and instead
  190. # inserts make_token/sink_token calls in the graph to create tokens and then
  191. # sink them at the end. Note that this means the graph is no longer functional
  192. # which may lead to silent errors unless the backend knows how to handle the
  193. # tokens.
  194. unlift_effect_tokens = False
  195. # NOTE: [The default layout constraint for custom operators.]
  196. # This must be the name of one of the layout constraint tags
  197. # (that is, one of {"needs_fixed_stride_order", "flexible_layout"}),
  198. # If the custom op does not have a layout constraint tag already
  199. # then we assume the following applies.
  200. #
  201. # This config is respected by Inductor and we recommend other backends also
  202. # respect it.
  203. # This config is in torch._functorch and not torch._inductor because it affects
  204. # ProxyTensor tracing.
  205. custom_op_default_layout_constraint: Literal[
  206. "needs_exact_strides", "needs_fixed_stride_order", "flexible_layout"
  207. ] = "needs_exact_strides"
  208. # Run aot eager decomp partition with CrossRefFakeMode
  209. # options = False, "all", "custom_ops"
  210. fake_tensor_crossref = False
  211. # This mode specifies that we should also keep track of the real
  212. # tensor along with the fake tensor, and do real compute. While
  213. # seemingly this eliminates the whole point of fake tensors, there are
  214. # two obvious use cases for it:
  215. #
  216. # 1. When users call item()/other data dependent operations,
  217. # if we propagate_real_tensors we are able to determine what
  218. # the true value is and keep going.
  219. #
  220. # 2. It can be useful for testing, when you want to see if the fake
  221. # and real tensors agree with each other. (Note that there are
  222. # currently known inaccuracies in how we clone real tensors, that
  223. # would have to be tightened up for this to be useful in this
  224. # case.)
  225. #
  226. # Note that fake tensors are typically understood to be cheap to store
  227. # indefinitely, so we tend to hold on to them longer than we would
  228. # hold onto the real tensors. So we also support you explicitly
  229. # deallocating the real tensor associated with a fake tensor, at which
  230. # point we will stop propagating real tensors.
  231. #
  232. # One more thing: when you provide a real tensor to fakeify, we will
  233. # clone it, so that we can safely perform mutations on it if necessary.
  234. # This will increase live memory usage. This could potentially be
  235. # optimized by using COW. We also currently do not faithfully
  236. # maintain autograd metadata on the real tensor; this is fine because
  237. # AOTAutograd will only use the fake tensor to determine leafness/etc
  238. # of tensors in question.
  239. fake_tensor_propagate_real_tensors = False
  240. # AOTDispatcher traces out a backward graph at the time of the forward pass.
  241. # This flag controls whether or not that backward graph gets autocast behavior
  242. # applied to it.
  243. #
  244. # The options are either:
  245. # - "same_as_forward". We assume that the backward of the torch.compile'ed region
  246. # will be run under the same autocast context manager that the region was run
  247. # under. This is equivalent to running the following code in eager:
  248. #
  249. # with torch.amp.autocast(...):
  250. # y = region(x)
  251. # ...
  252. # z.backward()
  253. #
  254. # - "off". We assume that the backward of the torch.compile'd region will
  255. # not be run under any autocast context managers.
  256. # This is equivalent to running the following code in eager:
  257. #
  258. # with torch.amp.autocast(...):
  259. # y = region(x)
  260. # ...
  261. # z.backward()
  262. #
  263. # - or a list of kwargs dicts that represent an autocast context manager to turn
  264. # on during the backward pass.
  265. #
  266. # e.g. [{"device_type": "cuda"}] is equivalent to running the following code in eager:
  267. #
  268. # y = region(x)
  269. # ...
  270. # with torch.amp.autocast(device="cuda"):
  271. # z.backward()
  272. backward_pass_autocast = "same_as_forward"
  273. # This controls whether we collect donated buffers. This flag must be set
  274. # False if a user wants to retain_graph=True for backward.
  275. donated_buffer = not is_fbcode()
  276. # Controls the default graph output format used by draw_graph
  277. # Supported formats are defined here https://graphviz.org/docs/outputs/
  278. torch_compile_graph_format = os.environ.get("TORCH_COMPILE_GRAPH_FORMAT", "svg")
  279. # Valid only if fake_tensor_propagate_real_tensors = True; if a fake-real
  280. # kernel mismatch is detected, bypasses by making a fake kernel from the
  281. # real tensor outputs.
  282. generate_fake_kernels_from_real_mismatches = False
  283. # When there are device mismatches in FakeTensor device propagation,
  284. # prefer a specific device type over others. This is particularly useful
  285. # in full compiled mode where intermediate tensors with device mismatches
  286. # represent only logical differences during compilation - these intermediate
  287. # tensors will never physically materialize in the binary execution, so the
  288. # device mismatch is not a real runtime concern. Enabling this allows the
  289. # compiler to proceed with compilation by choosing the preferred device type
  290. # for consistency. For example, set to "mtia" to prefer MTIA devices over
  291. # CPU, or "cuda" to prefer CUDA devices over CPU.
  292. fake_tensor_prefer_device_type: Optional[str] = None
  293. # CUDAGraph safe run_with_rng functionalization.
  294. # TODO: turn on by default
  295. graphsafe_rng_functionalization = True
  296. # Whether or not to eagerly compile the backward
  297. # used by AOT compile and other settings
  298. # TODO: once AOT compile calls aot autograd directly instead of
  299. # through compile_fx, we can remove this
  300. force_non_lazy_backward_lowering = False
  301. # only for testing, used to turn functionalization off in AOTDispatcher
  302. _test_disable_functionalization = True
  303. # Error on BypassAOTAutogradCache instead of just a warning
  304. # Used for tests
  305. strict_autograd_cache = False
  306. # Note [Recomputing collectives in the partitioner]
  307. # The purpose of this config is as follows:
  308. # - We have many passes in the compiler (min-cut partitioning, DCE, etc)
  309. # which can reorder or delete duplicate nodes in the graph
  310. # - If any of these passes reorder/delete/duplicate a collective
  311. # in a setting where the compiler is being run independently on multiple
  312. # ranks, we run the risk that the compiler will make a different decision on
  313. # different ranks, resulting in a NCCL hang when using torch.compile
  314. # To handle this, we will (by default) ensure that collectives are not modified
  315. # by the compiler.
  316. #
  317. # A few examples:
  318. # - don't dead-code-eliminate collectives
  319. # (in case they are dead on rank i but not rank j)
  320. # - don't recompute collectives in partitioning
  321. # (in case we recompute on rank i but not rank j)
  322. #
  323. # Today this flag **must** be set to false, but eventually
  324. # we want the option to set it to true.
  325. # In order to potentially optimize collectives, we'll need the compiler
  326. # to broadcast information across ranks at compile time to ensure
  327. # that any decisions on collectives are made consistently.
  328. unsafe_allow_optimization_of_collectives = False
  329. # See Note [AOTAutograd Tangent Subclassness for mutated inputs]
  330. # TODO(ivankobzarev): Remove this config, being able to deduce it compile time.
  331. disable_guess_zero_tangent_for_mutated_input_subclass = False
  332. # See Note [Tangents memory format]
  333. # By default tangents strideness is guessed to be contiguous,
  334. # At runtime non contiguous tangents will be coerced to be contiguous.
  335. # This config changes this guess for tangents strides to be the same as outputs.
  336. # TODO(ivankobzarev): Remove this config once extra memory usage is investigated.
  337. guess_tangent_strides_as_outputs = False
  338. # This is a temporary config to ensure all ranks take the same decision in the partitioner
  339. # it will ultimately be removed once we share size_hints across ranks through compiler collectives
  340. _sync_decision_cross_ranks = False
  341. # By default apply inlined saved_tensors_hooks only for "donated" buffers.
  342. # "donated" buffers are invisible to the user, they are intermediates of the forward graph.
  343. # Applying saved tensors hooks for memory optimizations only for intermediates
  344. # guarantees that original saved tensors could be deallocated.
  345. # This config enables saved_tensors_hooks are applied for **all** saved tensors,
  346. # that could include inputs, parameters, outputs.
  347. # "donated" - applied only to saved intermediates of the graph
  348. # "no_static" - applied to all saved but not "static"
  349. # (this includes parameters and user marked as static)
  350. # "all" - no filtering, everything saved for backward.
  351. saved_tensors_hooks_filtering_mode = "donated"
  352. # This callback is invoked on the joint graph before partitioning
  353. joint_custom_pass: Callable = None # type: ignore[assignment]
  354. force_autograd_cache = False
  355. # Note [Selective Decomposition]
  356. # This config allows selective decomposition of certain operators in the graph.
  357. # When True, it does NOT decompose any nodes, except those nodes that users explicitly
  358. # annotated with regional inductor compile. Please read torch.fx.passes.regional_inductor
  359. # on to explicitly annotate. This is currently only used by inductor lite mode.
  360. selective_decompose: bool = False
  361. if TYPE_CHECKING:
  362. from torch.utils._config_typing import * # noqa: F401, F403
  363. # adds patch, save_config, invalid config checks, etc
  364. install_config_module(sys.modules[__name__])