config.py 98 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526
  1. import os
  2. import sys
  3. from collections.abc import Callable
  4. from typing import Any, cast, Literal, Optional, TYPE_CHECKING, Union
  5. import torch
  6. import torch._inductor.custom_graph_pass
  7. from torch._environment import is_fbcode
  8. from torch.utils._config_module import (
  9. Config,
  10. get_tristate_env,
  11. inherit_fields_from,
  12. install_config_module,
  13. )
  14. if TYPE_CHECKING:
  15. from torch._inductor.choices import InductorChoices
  16. inplace_padding = os.environ.get("TORCHINDUCTOR_INPLACE_PADDING", "1") == "1"
  17. can_inplace_pad_graph_input = False # ease testing
  18. def fx_graph_remote_cache_default() -> Optional[bool]:
  19. return get_tristate_env("TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE")
  20. def vec_isa_ok_default() -> Optional[bool]:
  21. if os.environ.get("TORCHINDUCTOR_VEC_ISA_OK") == "1":
  22. return True
  23. if os.environ.get("TORCHINDUCTOR_VEC_ISA_OK") == "0":
  24. return False
  25. return None
  26. def autotune_remote_cache_default() -> Optional[bool]:
  27. return get_tristate_env("TORCHINDUCTOR_AUTOTUNE_REMOTE_CACHE")
  28. def bundled_autotune_remote_cache_default() -> Optional[bool]:
  29. return get_tristate_env("TORCHINDUCTOR_BUNDLED_AUTOTUNE_REMOTE_CACHE")
  30. def bundle_triton_into_fx_graph_cache_default() -> Optional[bool]:
  31. return get_tristate_env(
  32. "TORCHINDUCTOR_BUNDLE_TRITON_INTO_FX_GRAPH_CACHE",
  33. True if not is_fbcode() else None,
  34. )
  35. def static_cuda_launcher_default() -> bool:
  36. STATIC_CUDA_LAUNCHER_VERSION = 2
  37. if "TORCHINDUCTOR_USE_STATIC_CUDA_LAUNCHER" in os.environ:
  38. return os.environ.get("TORCHINDUCTOR_USE_STATIC_CUDA_LAUNCHER") == "1"
  39. elif is_fbcode():
  40. version = torch._utils_internal.justknobs_getval_int(
  41. "pytorch/inductor:static_cuda_launcher_version"
  42. )
  43. return version <= STATIC_CUDA_LAUNCHER_VERSION
  44. else:
  45. # Default true in OSS
  46. return True
  47. def prologue_fusion_enabled() -> bool:
  48. ENABLE_PROLOGUE_FUSION_VERSION = 0
  49. if "TORCHINDUCTOR_PROLOGUE_FUSION" in os.environ:
  50. return os.environ.get("TORCHINDUCTOR_PROLOGUE_FUSION") == "1"
  51. elif is_fbcode():
  52. jk_name = "pytorch/inductor:prologue_fusion_version"
  53. version = torch._utils_internal.justknobs_getval_int(jk_name)
  54. return version <= ENABLE_PROLOGUE_FUSION_VERSION
  55. else:
  56. return True
  57. # Enable auto_functionalized_v2 (enabled by default)
  58. enable_auto_functionalized_v2 = (
  59. os.environ.get("TORCHDYNAMO_AUTO_FUNCTIONALIZED_V2", "1") == "1"
  60. )
  61. # add some debug printouts
  62. debug = False
  63. # Whether to disable a progress bar for autotuning
  64. disable_progress = True
  65. # Whether to enable printing the source code for each future
  66. verbose_progress = False
  67. # Configurable compile worker logging path for subproc_pool
  68. worker_log_path = (
  69. "/logs/dedicated_log_torch_compile_worker_rank" if is_fbcode() else None
  70. )
  71. # precompilation timeout
  72. precompilation_timeout_seconds: int = int(
  73. os.environ.get("TORCHINDUCTOR_PRECOMPILATION_TIMEOUT_SECONDS", 60 * 5)
  74. )
  75. # use fx aot graph codegen cache
  76. fx_graph_cache: bool = Config(
  77. justknob="pytorch/remote_cache:enable_local_fx_graph_cache",
  78. env_name_default="TORCHINDUCTOR_FX_GRAPH_CACHE_DEFAULT",
  79. env_name_force="TORCHINDUCTOR_FX_GRAPH_CACHE",
  80. default=True,
  81. )
  82. remote_gemm_autotune_cache: bool = False
  83. # use remote fx aot graph codegen cache
  84. # False: Disables the cache
  85. # True: Enables the cache
  86. # None: Not set -- Off for OSS, JustKnobs based for internal
  87. fx_graph_remote_cache: Optional[bool] = fx_graph_remote_cache_default()
  88. # should we bundle triton caching into fx graph cache
  89. bundle_triton_into_fx_graph_cache: Optional[bool] = (
  90. bundle_triton_into_fx_graph_cache_default()
  91. )
  92. non_blocking_remote_cache_write: bool = Config(
  93. justknob="pytorch/remote_cache:enable_non_blocking_remote_cache_write_v2",
  94. env_name_force="TORCHINDUCTOR_NON_BLOCKING_REMOTE_CACHE_WRITE",
  95. default=True,
  96. )
  97. # Enable autotune local cache.
  98. #
  99. # See bundled_autotune_remote_cache for the effect this flag has on the bundled
  100. # remote cache.
  101. autotune_local_cache: bool = True
  102. # Enable autotune remote cache.
  103. #
  104. # Enables/disables the autotune remote cache regardless of the state of
  105. # autotune_local_cache. If both local and remote are enabled then on write both
  106. # are written and on read local is checked first and only on a cache miss is
  107. # remote read.
  108. #
  109. # False: Disables the cache
  110. # True: Enables the cache
  111. # None: Not set -- Off for OSS, JustKnobs based for internal
  112. autotune_remote_cache: Optional[bool] = autotune_remote_cache_default()
  113. # Enable bundled autotune cache.
  114. #
  115. # Enables/disables the bundled autotune cache regardless of the state of
  116. # autotune_remote_cache. However it does depend on the local cache for local
  117. # state management - as a result if the local cache is disabled this will also
  118. # disable the bundled autotune cache.
  119. #
  120. # False: Disables the cache
  121. # True: Enables the cache (requires autotune_local_cache)
  122. # None: Not set -- Off for OSS, JustKnobs based for internal
  123. bundled_autotune_remote_cache: Optional[bool] = bundled_autotune_remote_cache_default()
  124. # See torch.compiler.config.force_disable_caches
  125. force_disable_caches: bool = Config(alias="torch.compiler.config.force_disable_caches")
  126. # Unsafe way to skip dynamic shape guards to get faster cache load
  127. unsafe_skip_cache_dynamic_shape_guards: bool = False
  128. # Unsafe way to mark non torch functions as safe to cache
  129. # dictionary is from function name -> cache key
  130. # Any function name in the dictionary will be allowed to be cacheable
  131. # by AOTAutogradCache and FxGraphCache.
  132. # changing the cache key value will change the resulting
  133. # FXGraphCache key.
  134. # Example usage:
  135. # torch._inductor.config.unsafe_marked_cacheable_functions = {
  136. # 'torch.ops.my_function' : torch.__version__
  137. # }
  138. # The above example causes the custom op torch.ops.my_function to be cacheable,
  139. # and for cache keys to be keyed by the current torch version
  140. unsafe_marked_cacheable_functions: dict[str, str] = {}
  141. # sleep in inductor for testing
  142. sleep_sec_TESTING_ONLY: Optional[int] = None
  143. # The default layout constraint for user-defined triton kernels.
  144. # See "The default layout constraint for custom operators" for options.
  145. triton_kernel_default_layout_constraint: Literal[
  146. "needs_fixed_stride_order", "flexible_layout"
  147. ] = "needs_fixed_stride_order"
  148. # use cpp wrapper instead of python wrapper
  149. # incompatible with disable_cpp_codegen
  150. cpp_wrapper: bool = os.environ.get("TORCHINDUCTOR_CPP_WRAPPER", "0") == "1"
  151. # controls whether to compile entry and kernel separately for cpp_wrapper mode.
  152. # turn on this option to compile entry and kernel separately and minimize compile time of the entry part.
  153. # see https://github.com/pytorch/pytorch/pull/148773
  154. # Note: compiling entry and kernel separately may have a non-negligible impact on the performance.
  155. # see https://github.com/pytorch/pytorch/issues/156037
  156. cpp_wrapper_build_separate: bool = (
  157. os.environ.get("TORCHINDUCTOR_CPP_WRAPPER_BUILD_SEPARATE", "0") == "1"
  158. )
  159. fx_wrapper: bool = os.environ.get("TORCHINDUCTOR_FX_WRAPPER", "0") == "1"
  160. # Controls automatic precompiling of common include files for codecache.CppCodeCache
  161. # (i.e. for cpp_wrapper mode and for cpp kernels on CPU). AOTI header precompiling is
  162. # controlled by a separate flag.
  163. cpp_cache_precompile_headers: bool = not is_fbcode()
  164. online_softmax = os.environ.get("TORCHINDUCTOR_ONLINE_SOFTMAX", "1") == "1"
  165. apply_gumbel_max_trick = (
  166. os.environ.get("TORCHINDUCTOR_APPLY_GUMBEL_MAX_TRICK", "1") == "1"
  167. )
  168. # dead code elimination
  169. dce = False
  170. # assume weight tensors are fixed size
  171. static_weight_shapes = True
  172. # put correctness assertions in generated code
  173. size_asserts = os.environ.get("TORCHINDUCTOR_SIZE_ASSERTS", "1") == "1"
  174. nan_asserts = os.environ.get("TORCHINDUCTOR_NAN_ASSERTS") == "1"
  175. runtime_triton_nan_asserts = (
  176. os.environ.get("TORCHINDUCTOR_RUNTIME_TRITON_NAN_ASSERTS") == "1"
  177. )
  178. scalar_asserts = os.environ.get("TORCHINDUCTOR_SCALAR_ASSERTS", "1") == "1"
  179. # Disable by default in fbcode
  180. alignment_asserts = (
  181. os.environ.get("TORCHINDUCTOR_ALIGNMENT_ASSERTS", "0" if is_fbcode() else "1")
  182. == "1"
  183. )
  184. # enable loop reordering based on input orders
  185. pick_loop_orders = True
  186. # reuse a kernel input as the output
  187. inplace_buffers = True
  188. # reuse a buffer for an unrelated purpose
  189. allow_buffer_reuse = True
  190. # Enable pooled allocations for non-output tensors
  191. memory_planning = os.environ.get("TORCHINDUCTOR_MEMORY_PLANNING", "0") == "1"
  192. # Enable to allow using ftz variant of exponenet instruction in triton codegen.
  193. use_fast_math = os.environ.get("TORCHINDUCTOR_USE_FAST_MATH") == "1"
  194. # How to organize memory under memory_planning=True:
  195. # - "none": do not try to pool storage, just reuse
  196. # - "intermediates": all non-outputs share storage, outputs each get unique storage
  197. # - "outputs": two pools, one for intermediates (freed on return) and one for outputs
  198. # - "combined": a single pool for both intermediates and outputs
  199. # pyrefly: ignore [bad-assignment]
  200. memory_pool: Literal["none", "intermediates", "outputs", "combined"] = os.environ.get(
  201. "TORCHINDUCTOR_MEMORY_POOL", "intermediates"
  202. ) # type: ignore[assignment]
  203. # codegen benchmark harness
  204. benchmark_harness = True
  205. # fuse pointwise into templates epilogues
  206. epilogue_fusion = True
  207. # fuse pointwise into template prologues
  208. prologue_fusion = prologue_fusion_enabled()
  209. # do epilogue fusions before other fusions
  210. epilogue_fusion_first = False
  211. # enable pattern match+replace optimizations
  212. pattern_matcher = True
  213. # set to True to enable the back-to-back GEMM pass
  214. b2b_gemm_pass = False
  215. # register custom graph optimization pass hook. so far, pre/post passes are
  216. # only applied before/after pattern_matcher in post_grad_passes.
  217. #
  218. # Implement CustomGraphPass to allow Inductor to graph compiled artifacts
  219. # to which your custom passes have been applied:
  220. post_grad_custom_pre_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None
  221. post_grad_custom_post_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None
  222. # Allow users to pass in custom partition function
  223. custom_partitioner_fn: torch._inductor.custom_graph_pass.CustomPartitionerFnType = None
  224. # Registers a custom joint graph pass.
  225. joint_custom_pre_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None
  226. joint_custom_post_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None
  227. # Registers a custom pregrad pass. Note that the pre-grad IR is 1.
  228. # non-functional, 2. non-normalized, and 3. prone to change. Ideally we should
  229. # use post-grad passes.
  230. pre_grad_custom_pass: Optional[Callable[[torch.fx.graph.Graph], None]] = None
  231. # Registers a custom pass to be run right before fusion in Inductor scheduler.
  232. # WARNING: Inductor scheduler IR is at prototype stage and subject to change,
  233. # hence custom IR passes built on top of it might break in the future.
  234. _pre_fusion_custom_pass: Optional[
  235. Callable[
  236. [list["torch._inductor.scheduler.BaseSchedulerNode"]],
  237. list["torch._inductor.scheduler.BaseSchedulerNode"],
  238. ]
  239. ] = None
  240. # Registers a custom pass to be run right after fusion in Inductor scheduler.
  241. # WARNING: Inductor scheduler IR is at prototype stage and subject to change,
  242. # hence custom IR passes built on top of it might break in the future.
  243. _post_fusion_custom_pass: Optional[
  244. Callable[
  245. [list["torch._inductor.scheduler.BaseSchedulerNode"]],
  246. list["torch._inductor.scheduler.BaseSchedulerNode"],
  247. ]
  248. ] = None
  249. # Deprecated
  250. split_cat_fx_passes = True
  251. # Optimize conv-batchnorm if batchnorm is in eval mode. Slightly reduces numerical stability.
  252. efficient_conv_bn_eval_fx_passes = False
  253. # Enable predispatch aten IR for export
  254. is_predispatch = False
  255. # Deprecated
  256. group_fusion = False
  257. # Deprecated
  258. batch_fusion = True
  259. # Pre grad fusion and options in order, set to empty dict to disable fusion.
  260. # Call `torch._inductor.fx_passes.group_batch_fusion.list_group_batch_fusions()` to see available fusions.
  261. # batch fusion options:
  262. # batch_linear
  263. # batch_linear_lhs
  264. # batch_layernorm
  265. # batch_tanh
  266. # batch_relu
  267. # batch_sigmoid
  268. # split cat fusion options:
  269. # normalization_pass
  270. # remove_split_with_size_one_pass
  271. # merge_getitem_cat_pass
  272. # merge_stack_tahn_unbind
  273. # merge_splits_pass
  274. # mutate_cat_pass
  275. # split_cat_pass
  276. pre_grad_fusion_options: dict[str, dict[str, Any]] = {}
  277. # Post grad fusion and options, set to empty dict to disable fusion.
  278. # Call `torch._inductor.fx_passes.group_batch_fusion.list_group_batch_fusions(False)` to see available fusions.
  279. post_grad_fusion_options: dict[str, dict[str, Any]] = {}
  280. # enable reordering pass for improving memory locality
  281. reorder_for_locality = True
  282. # Scale down Rn_BLOCK for better occupancy
  283. dynamic_scale_rblock = os.environ.get("TORCHINDUCTOR_DYNAMIC_SCALE_RBLOCK", "1") == "1"
  284. # this forces fusion for int_mm with mul. Needed when you want to avoid realizing the int32
  285. # but the mul gets fused with other pointwise ops instead.
  286. force_fuse_int_mm_with_mul = False
  287. # DEPRECATED. This setting is ignored.
  288. use_mixed_mm = True
  289. # enable runtime numeric check for pre/post grad fx passes
  290. # floating point provides limited accuracy (about 7 decimal digits for single precision
  291. # floating point numbers,about 16 decimal digits for double precision floating point numbers)
  292. # according to PyTorch documentation.
  293. # https://pytorch.org/docs/stable/notes/numerical_accuracy.html#batched-computations-or-slice-computations
  294. fx_passes_numeric_check: dict[str, Any] = {
  295. "pre_grad": False,
  296. "precision": 1e-4,
  297. "num_iterations": 1,
  298. "requires_optimizer": True,
  299. }
  300. # DEPRECATED. This setting is ignored.
  301. mixed_mm_choice: Literal["default", "triton", "aten", "heuristic"] = "heuristic"
  302. # enable reordering pass for increasing overlap between compute and communication
  303. reorder_for_compute_comm_overlap = False
  304. # passes (in execution order) for increasing overlap between compute and communication
  305. # for built-in passes, use string name; for user-defined passes, pass in the function handle
  306. # WARNING: Inductor scheduler IR is at prototype stage and subject to change,
  307. # hence custom IR passes built on top of it might break in the future.
  308. #
  309. # See aten_distributed_optimizations, it is recommended way for distributed optimizations.
  310. #
  311. # Recommended configuration for reorder_for_compute_comm_overlap_passes:
  312. # [
  313. # "reorder_communication_preserving_peak_memory",
  314. # "sink_waits_iterative",
  315. # "reorder_communication_preserving_peak_memory",
  316. # ]
  317. reorder_for_compute_comm_overlap_passes: list[
  318. Union[
  319. str,
  320. Callable[
  321. [list["torch._inductor.scheduler.BaseSchedulerNode"]],
  322. list["torch._inductor.scheduler.BaseSchedulerNode"],
  323. ],
  324. ]
  325. ] = []
  326. # Maximum number of positions to advance a given collective, unlimited by default
  327. reorder_prefetch_limit: Optional[int] = None
  328. # enable operator reordering for peak memory optimization
  329. reorder_for_peak_memory = True
  330. reorder_for_peak_memory_debug = False
  331. # In some cases, when all the nodes that can be scheduled are quite large,
  332. # it is beneficial to switch the scheduling strategy. So instead of using
  333. # size as the criterion, we choose a node that can unlock more nodes to
  334. # become schedulable by analyzing their successor nodes. The default value
  335. # is zero, which turns off this optimization.
  336. size_threshold_for_succ_based_strategy: int = 0
  337. bucket_all_gathers_fx: Literal["none", "all", "only_fsdp"] = "none"
  338. # By default torch._inductor.fx_passes.bucketing.bucket_size_determinator is used
  339. bucket_all_gathers_fx_bucket_size_determinator: Optional[Callable[[int], int]] = None
  340. bucket_reduce_scatters_fx: Literal["none", "all"] = "none"
  341. # By default torch._inductor.fx_passes.bucketing.bucket_size_determinator is used
  342. bucket_reduce_scatters_fx_bucket_size_determinator: Optional[Callable[[int], int]] = (
  343. None
  344. )
  345. bucket_all_reduces_fx: Literal["none", "all"] = "none"
  346. # By default torch._inductor.fx_passes.bucketing.bucket_size_determinator is used
  347. bucket_all_reduces_fx_bucket_size_determinator: Optional[Callable[[int], int]] = None
  348. # runtime estimation function for ops
  349. # for built-in estimation function, pass in "default"; for user-defined estimation function, pass in the function handle
  350. estimate_op_runtime = "default"
  351. runtime_estimations_mms_benchmark: bool = False
  352. # unit: GB/s, uni-directional P2P bandwidth per card
  353. # default value is NVLink
  354. intra_node_bw = 300
  355. # unit: GB/s, uni-directional P2P bandwidth per node
  356. # default value is InfiniBand
  357. inter_node_bw = 25
  358. # unit: GB/s, uni-directional CPU<>GPU bandwidth
  359. # default value is PCIe; modify for your hardware or measured bandwidth
  360. cpu_gpu_bw = 50.0
  361. # use Inductor's experimental benchmarker (runtime/benchmarking.py)
  362. # to benchmark kernels during autotuning, otherwise fall back to
  363. # Triton's `do_bench`. the experimental benchmarker may produce
  364. # results that are not consistent with `do_bench`'s results
  365. use_experimental_benchmarker: bool = Config(
  366. default=True,
  367. env_name_force="TORCHINDUCTOR_USE_EXPERIMENTAL_BENCHMARKER",
  368. justknob="pytorch/inductor:use_experimental_benchmarker",
  369. )
  370. # Enable distributed autotuning. When this is enabled we will distribute the
  371. # autotuning across distributed ranks in the same program group - so instead of
  372. # each rank autotuning every kernel they only autotune 1/world size kernels and
  373. # then share the results.
  374. distributed_max_autotune_gemm = (
  375. os.environ.get("TORCHINDUCTOR_DISTRIBUTED_MAX_AUTOTUNE_GEMM") == "1"
  376. )
  377. # Pipeline autotuning for max-autotune-gemm. Overlap lowering and benchmarking on GPU
  378. pipeline_max_autotune_gemm = (
  379. os.environ.get("TORCHINDUCTOR_PIPELINE_GEMM_AUTOTUNING") == "1"
  380. )
  381. # enable slow autotuning passes to select algorithms
  382. max_autotune = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE") == "1"
  383. # enable slow autotuning passes to select pointwise/reductions algorithms
  384. max_autotune_pointwise = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE") == "1"
  385. # enable slow autotuning passes to select gemm algorithms
  386. max_autotune_gemm = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_GEMM") == "1"
  387. inductor_default_autotune_warmup = int(
  388. os.getenv("TORCHINDUCTOR_DEFAULT_AUTOTUNE_WARMUP", 25)
  389. )
  390. inductor_default_autotune_rep = int(
  391. os.getenv("TORCHINDUCTOR_DEFAULT_AUTOTUNE_REP", 100)
  392. )
  393. # Modifies the number of autotuning choices displayed, set to None for all
  394. def _autotune_num_choices_displayed_default() -> Optional[int]:
  395. env_val = os.environ.get("TORCHINDUCTOR_AUTOTUNE_NUM_CHOICES_DISPLAYED")
  396. if env_val is None:
  397. return 10
  398. if env_val.lower() in ("none", "all"):
  399. return None
  400. return int(env_val)
  401. autotune_num_choices_displayed: Optional[int] = (
  402. _autotune_num_choices_displayed_default()
  403. )
  404. # Report the autotune choices and their benchmark results. Default is True.
  405. max_autotune_report_choices_stats = (
  406. os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_REPORT_CHOICES_STATS", "1") == "1"
  407. )
  408. # Prune configs that have a theoretical maximum shared memory usage than the hardware limit
  409. # Will over-prune - pruning some valid configs with theoretical shared memory usage higher
  410. # than real shared memory usage, ensuring that invalid configs are not possibly autotuned
  411. max_autotune_prune_choices_based_on_shared_mem = (
  412. os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_PRUNE_CHOICES_BASED_ON_SHARED_MEM", "0")
  413. == "1"
  414. )
  415. # Disable triton from trying to initialize and detect devices on the host
  416. triton_disable_device_detection = (
  417. os.environ.get("TORCHINDUCTOR_TRITON_DISABLE_DEVICE_DETECTION", "0") == "1"
  418. )
  419. # enable inductor graph partition to allow multiple inductor graphs for the same dynamo graph
  420. graph_partition: bool = (
  421. os.environ.get("TORCHINDUCTOR_GRAPH_PARTITION", "1" if not is_fbcode() else "0")
  422. == "1"
  423. )
  424. # register ops upon which inductor should partition the graph. name format should be
  425. # "namespace::kernel_name" (e.g., aten::mm) for op overload packet, or
  426. # "namespace::kernel_name.overload" (e.g., aten::mm.default).
  427. custom_should_partition_ops: list[str] = []
  428. # register ops whose OUTPUT unbacked symints should cause partition. Any tensors or ops
  429. # that use these output unbacked symints (e.g. in their shapes, strides, or offsets)
  430. # will be excluded from cudagraph partitions. This is useful for operators that produce
  431. # data-dependent unbacked symints (e.g., from a custom op that returns a SymInt).
  432. # Note: Input symints to these ops remain cudagraph-safe; only the output symints are
  433. # marked as cudagraph-unsafe. Name format should be "namespace::kernel_name"
  434. # (e.g., mylib::get_split_point) for op overload packet, or
  435. # "namespace::kernel_name.overload" for specific overloads.
  436. cudagraph_unsafe_unbacked_ops: list[str] = []
  437. # whether template autotuning should allow flexible layouts if possible (e.g. only extern choices)
  438. max_autotune_allow_flexible_layouts: bool = False
  439. # force cublas and triton to use the same precision; cublas supports TF32 for matmul operations
  440. # when m, n, k are multiples of 16, 16, 8, whereas triton supports TF32 for matmul operations
  441. # for any combinations of m, n, k, regardless of their alignment. setting this flag will ensure
  442. # that triton does not use TF32 wherever cublas would not use TF32
  443. # DEPRECATED. cuBLAS no longer has the above alignment requirements. will remove in the future.
  444. force_same_precision: bool = Config(
  445. justknob="pytorch/compiler:force_same_precision",
  446. env_name_force="TORCHINDUCTOR_FORCE_SAME_PRECISION",
  447. default=False,
  448. )
  449. # Size hints for multi-kernel dispatch.
  450. # A reasonable default value of this config would be [64, 256, 4096]
  451. # TODO: @bobrenjc93 to roll this out to a few internal models to ensure this works
  452. # as expected before turning it on for everyone.
  453. multi_kernel_hints: list[int] = []
  454. # Specify candidate backends for gemm autotune.
  455. # Possible choices are combinations of: ATen, Triton, CUTLASS, CK, CKTILE, CPP.
  456. # ATen: default Pytorch ATen kernels.
  457. # Triton: Triton templates defined in torch inductor (AMD and NVidia GPUs).
  458. # CUTLASS: Cutlass templates and kernels (NVidia GPUs only).
  459. # CK: Composable Kernel templates and kernels (AMD Instinct GPUs only).
  460. # CKTILE: Composable Kernel templates and kernels, new API (AMD Instinct GPUs only).
  461. # CPP: CPP templates and kernels for CPU.
  462. max_autotune_gemm_backends = os.environ.get(
  463. "TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS", "ATEN,TRITON,CPP"
  464. ).upper()
  465. # Configures the maximum number of NVIDIA Universal GEMM (NVGEMM) configs to profile
  466. # in max_autotune. By default it's 5, to keep compile time reasonable.
  467. # Set to None (or env var "none"/"all") to tune all configs.
  468. def _nvgemm_max_profiling_configs_default() -> Optional[int]:
  469. env_val = os.environ.get("TORCHINDUCTOR_NVGEMM_MAX_PROFILING_CONFIGS", "5")
  470. if env_val.lower() in ("none", "all"):
  471. return None
  472. return int(env_val)
  473. nvgemm_max_profiling_configs: Optional[int] = _nvgemm_max_profiling_configs_default()
  474. # As above, specify candidate backends for conv autotune.
  475. # NB: in some cases for 1x1 convs we emit as matmul,
  476. # which will use the backends of `max_autotune_gemm_backends`
  477. max_autotune_conv_backends = os.environ.get(
  478. "TORCHINDUCTOR_MAX_AUTOTUNE_CONV_BACKENDS", "ATEN,TRITON"
  479. ).upper()
  480. # Specify the size of the search space for GEMM autotuning.
  481. # DEFAULT - balance between compile time overhead and performance
  482. # EXHAUSTIVE - maximize performance
  483. # pyrefly: ignore [bad-assignment]
  484. max_autotune_gemm_search_space: Literal["DEFAULT", "EXHAUSTIVE"] = os.environ.get(
  485. "TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_SEARCH_SPACE", "DEFAULT"
  486. ).upper() # type: ignore[assignment]
  487. # Specify the size of the search space for flex attention autotuning.
  488. # DEFAULT - balance between compile time overhead and performance
  489. # EXHAUSTIVE - maximize performance
  490. # pyrefly: ignore [bad-assignment]
  491. max_autotune_flex_search_space: Literal["DEFAULT", "EXHAUSTIVE"] = os.environ.get(
  492. "TORCHINDUCTOR_MAX_AUTOTUNE_FLEX_SEARCH_SPACE", "DEFAULT"
  493. ).upper() # type: ignore[assignment]
  494. # Fall back to ATen for all ops by default, except those nodes that users explicitly
  495. # annotated with regional inductor compile. Please read torch.fx.passes.regional_inductor
  496. # on to explicitly annotate. This is currently only used by inductor lite mode.
  497. # Different from default inductor mode that fuses all nodes, this config enables an
  498. # opt-in mode that only fuse for user-specified nodes. The motivation is to provide
  499. # guaranteed numeric correctness and give full control to users.
  500. fallback_by_default: bool = False
  501. # This config allows selective decomposition of certain operators in the graph.
  502. # Currently the only use case is to patch the same-name config in functorch, for
  503. # inductor lite mode. See more details in [Note: Selective Decomposition]
  504. selective_decompose: bool = False
  505. # Use dead code elimination
  506. use_dce: bool = True
  507. # Use fx graph passes
  508. use_pre_grad_passes: bool = True
  509. use_joint_graph_passes: bool = True
  510. use_post_grad_passes: bool = True
  511. cutedsl_enable_autotuning: bool = (
  512. os.environ.get("CUTEDSL_ENABLE_AUTOTUNING", "0") == "1"
  513. )
  514. # DEPRECATED. This setting is ignored.
  515. autotune_fallback_to_aten = False
  516. # the value used as a fallback for the unbacked SymInts
  517. # that can appear in the input shapes (e.g., in autotuning)
  518. unbacked_symint_fallback = 8192
  519. # DEPRECATED. This setting is ignored.
  520. search_autotune_cache = False
  521. save_args = os.environ.get("TORCHINDUCTOR_SAVE_ARGS") == "1"
  522. # We will disable creating subprocess for autotuning if this is False
  523. autotune_in_subproc = os.environ.get("TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC") == "1"
  524. # The following three timeouts are applicable if autotune_in_subproc is True:
  525. # Max time that a valid benchmark result may take during autotuning
  526. max_autotune_subproc_result_timeout_seconds = 60.0
  527. # DEPRECATED. This setting is ignored.
  528. max_autotune_subproc_graceful_timeout_seconds = 0.0
  529. # DEPRECATED. This setting is ignored.
  530. max_autotune_subproc_terminate_timeout_seconds = 0.0
  531. # If autotuning in subprocess, whether to use multiple devices
  532. autotune_multi_device = os.environ.get("TORCHINDUCTOR_AUTOTUNE_MULTI_DEVICE") == "1"
  533. # Number of benchmark runs for collective operations
  534. collective_benchmark_nruns = int(
  535. os.environ.get("TORCHINDUCTOR_COLLECTIVE_BENCHMARK_NRUNS", "50")
  536. )
  537. # Timeout in seconds for collective benchmarking
  538. collective_benchmark_timeout = float(
  539. os.environ.get("TORCHINDUCTOR_COLLECTIVE_BENCHMARK_TIMEOUT", "30")
  540. )
  541. coordinate_descent_tuning = (
  542. os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_TUNING") == "1"
  543. )
  544. coordinate_descent_check_all_directions = (
  545. os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_CHECK_ALL_DIRECTIONS") == "1"
  546. )
  547. coordinate_descent_search_radius = int(
  548. os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_RADIUS", "1")
  549. )
  550. # AutoHeuristic is a framework that allows one to collect data from autotuning, use the data to learn a heuristic, and
  551. # generate the learned heuristic to code which is shipped with the compiler
  552. # Specify a list of comma separated optimizations to collect data for
  553. autoheuristic_collect = os.environ.get("TORCHINDUCTOR_AUTOHEURISTIC_COLLECT", "")
  554. # Specify a list of comma separated optimizations to use learned heuristics for
  555. autoheuristic_use = os.environ.get("TORCHINDUCTOR_AUTOHEURISTIC_USE", "mixed_mm")
  556. # If set to 1, will run a JIT post compile hook if one is set.
  557. run_jit_post_compile_hook = (
  558. os.environ.get("TORCHINDUCTOR_RUN_JIT_POST_COMPILE_HOOK", "0") == "1"
  559. )
  560. def run_autoheuristic(name: str) -> bool:
  561. return collect_autoheuristic(name) or use_autoheuristic(name)
  562. def collect_autoheuristic(name: str) -> bool:
  563. return name in torch._inductor.config.autoheuristic_collect.split(",")
  564. def use_autoheuristic(name: str) -> bool:
  565. return name in torch._inductor.config.autoheuristic_use.split(",")
  566. # If set to "DEFAULT", this will use the default log path specified in autoheuristic.py.
  567. # If set to another path, autoheuristic will instead log results to the given path.
  568. autoheuristic_log_path = os.environ.get(
  569. "TORCHINDUCTOR_AUTOHEURISTIC_LOG_PATH", "DEFAULT"
  570. )
  571. # Disabled by default on ROCm, opt-in if model utilises NHWC convolutions
  572. layout_opt_default = "1" if not torch.version.hip else "0"
  573. layout_optimization = (
  574. os.environ.get("TORCHINDUCTOR_LAYOUT_OPTIMIZATION", layout_opt_default) == "1"
  575. )
  576. force_layout_optimization = os.environ.get("TORCHINDUCTOR_FORCE_LAYOUT_OPT", "0") == "1"
  577. # Whether to keep the output strides the same as eager after layout optimization.
  578. keep_output_stride = os.environ.get("TORCHINDUCTOR_KEEP_OUTPUT_STRIDE", "1") == "1"
  579. # Enabling this will let compiler print warning messages if a generated triton
  580. # kernel has inputs with mixed layouts. This is helpful for perf debugging
  581. # since kernel with mixed layout inputs may run much slower then one whose inputs
  582. # have uniform layouts.
  583. warn_mix_layout = os.environ.get("TORCHINDUCTOR_WARN_MIX_LAYOUT") == "1"
  584. # control store vs recompute heuristic
  585. # For fanouts, rematerialization can lead to exponential blowup. So, have
  586. # smaller threshold
  587. realize_reads_threshold = 4
  588. realize_opcount_threshold = 30
  589. # Threshold to prevent excessive accumulation of ops in one buffer during lowering
  590. realize_acc_reads_threshold = 8
  591. realize_acc_reads_size_threshold: Optional[int] = (
  592. None # TODO(xuanzh): harden this to make it non optional
  593. )
  594. # fallback to eager for random/dropout, this is slow but useful for debugging
  595. fallback_random = False
  596. # fallback embedding_bag_byte_unpack to eager
  597. fallback_embedding_bag_byte_unpack = False
  598. # automatically create fallbacks when encountering an unhandled op
  599. implicit_fallbacks = True
  600. assume_unaligned_fallback_output = (
  601. os.environ.get("TORCHINDUCTOR_ASSUME_UNALIGNED_FALLBACK_OUTPUT") == "1"
  602. )
  603. # Custom InductorChoices callable to use (can be a class or functools.partial with kwargs)
  604. inductor_choices_class: Optional[Callable[[], "InductorChoices"]] = None
  605. # fuse even in cases without common reads
  606. aggressive_fusion = False
  607. # For each fused kernel in the wrapper, comment with the nodes that get fused.
  608. # Useful for debugging fusion.
  609. debug_fusion: bool = os.environ.get("TORCHINDUCTOR_DEBUG_FUSION") == "1"
  610. benchmark_fusion: bool = os.environ.get("TORCHINDUCTOR_BENCHMARK_FUSION") == "1"
  611. enabled_metric_tables = os.environ.get("TORCHINDUCTOR_ENABLED_METRIC_TABLES", "")
  612. loop_ordering_after_fusion: bool = (
  613. os.environ.get(
  614. "TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION", "0" if is_fbcode() else "1"
  615. )
  616. == "1"
  617. )
  618. # When trying to fuse two nodes, one with:
  619. # a[contiguous_writes] = fn(...)
  620. # and another node:
  621. # b[contiguous_writes] = a[discontiguous_reads]
  622. # If b is unary, and we can figure out an inverse formula for
  623. # discontiguous writes, invert b as :
  624. # b[inverse(discontiguous_writes)] = a[contiguous_reads]
  625. # so that the nodes can fuse. for more details: https://gist.github.com/eellison/6f9f4a7ec10a860150b15b719f9285a9
  626. loop_index_inversion_in_fusion: bool = True
  627. # If fusing two nodes only save less then score_fusion_memory_threshold memory,
  628. # we should not bother fusing the nodes.
  629. #
  630. # This is especially helpful to resolve https://github.com/pytorch/pytorch/issues/133242
  631. # Previously we fuse two nodes because of common read of a scalar tensor.
  632. # If we skip it, the loop ordering after fusion mechanism kicks in and can
  633. # brings more savings.
  634. #
  635. # For the cases loop ordering after fusion does not help, we don't lose much.
  636. score_fusion_memory_threshold = 10
  637. # For Triton Templates, select fastest of best template + epilogue vs best template + separate epilogue kernel
  638. benchmark_epilogue_fusion = (
  639. os.environ.get("TORCHINDUCTOR_BENCHMARK_EPILOGUE_FUSION", "1") == "1"
  640. )
  641. # Take how many of the top triton kernels to benchmark epilogue
  642. max_epilogue_benchmarked_choices = 1
  643. # how many nodes to allow into a single fusion
  644. max_fusion_size = 64
  645. # how many nodes to attempt pairwise fusion with in a buffer group
  646. max_fusion_buffer_group_pairwise_attempts = 64
  647. # maximum number of unique input/output buffers allowed in fused kernels.
  648. # The check is disabled if set to None.
  649. max_fusion_unique_io_buffers: Optional[int] = None
  650. # max number of inputs to generate cat as a pointwise op with masked loads
  651. max_pointwise_cat_inputs = 8
  652. # force concat to be generated as a pointwise op with masked loads
  653. force_pointwise_cat = False
  654. # replace small reductions with pointwise, disable with `= 1`
  655. unroll_reductions_threshold = 8
  656. # Add extra comments to output code (causes compile cache misses)
  657. comment_origin = False
  658. # Convert 1x1 convs into matmuls
  659. conv_1x1_as_mm = False
  660. # For reductions with a small output size (usually 1, e.g. x.sum()) there is not enough
  661. # parallelism to saturate the GPU. We have two ways of handling this, either `split_reductions`
  662. # or `triton.cooperative_reductions` which are mutually exclusive.
  663. # split_reductions: uses multiple kernels to gain more parallelism
  664. # triton.cooperative_reductions: uses cross thread-block synchronization to gain more parallelism
  665. # enabling both of these will implicitly disable split_reductions
  666. split_reductions = os.getenv("TORCHINDUCTOR_SPLIT_REDUCTIONS", "1") == "1"
  667. # A deterministic mode that skips any on device benchmarking in Inductor
  668. # if we know they affect numerics. WARNING: Expect perf hit in this mode.
  669. deterministic = os.getenv("TORCHINDUCTOR_DETERMINISTIC") == "1"
  670. # When we do split reduction, this number control the minimum value for
  671. # num_split. Too small num_split make the split reduction less efficient.
  672. # It's a much bigger problem when we compile a dynamic shape kernel with
  673. # non-representative inputs.
  674. min_num_split = int(os.environ.get("TORCHINDUCTOR_MIN_NUM_SPLIT", 0))
  675. benchmark_kernel = os.environ.get("TORCHINDUCTOR_BENCHMARK_KERNEL", "0") == "1"
  676. # Enable constant and index_expr folding
  677. constant_and_index_propagation = True
  678. # we always add constants into graph.constants without
  679. # performing any constant-inlining optimization
  680. always_keep_tensor_constants = False
  681. # assert that indirect indexing does not read / write out of bounds
  682. assert_indirect_indexing = True
  683. # compute CSE bounds on variables that do not appear in the FX graph
  684. compute_all_bounds = False
  685. # enable the combo kernel that combines data-independent kernels (additional
  686. # to foreach kernels) into a single one (Experimental)
  687. combo_kernels = False
  688. # benchmark combo kernels and only allow ones with perf gains
  689. benchmark_combo_kernel = False
  690. # combo_kernel autotuning options: 0 - disable, 1 - enable except for foreach,
  691. # 2 - enable for all
  692. combo_kernels_autotune = 1
  693. # Enable masking for combining kernels of mixed sizes: 0 - disable, 1 - enable
  694. # for all except for foreach, 2 - enable for all
  695. combo_kernel_allow_mixed_sizes = 1
  696. # Enable dynamic shapes for foreach kernels
  697. combo_kernel_foreach_dynamic_shapes = True
  698. # Maximum number of arguments (read/write buffers) allowed in a combo kernel
  699. combo_kernel_max_num_args = 250
  700. # When True, each combo sub-kernel gets its own block sizes (XBLOCK_0, YBLOCK_0, etc.)
  701. # allowing different sub-kernels to use different tile sizes based on their heuristics.
  702. # When False, all sub-kernels share block sizes (XBLOCK, YBLOCK, etc.)
  703. combo_kernel_per_subkernel_blocks = False
  704. # When True, only pointwise kernels are eligible for combo kernel fusion.
  705. combo_kernels_pointwise_only = False
  706. # constant folding on the joint graph
  707. joint_graph_constant_folding = True
  708. # Enable indirect_indexing asserts for decompositions and lowerings
  709. debug_index_asserts = False
  710. # warnings intended for PyTorch developers, disable for point releases
  711. is_nightly_or_source = "dev" in torch.__version__ or "git" in torch.__version__
  712. developer_warnings = is_fbcode() or is_nightly_or_source
  713. # This pattern matches a special usage of scatter
  714. # 1. It's applied to a constant tensor
  715. # 2. The index tensor has size 1 in the scatter dimension
  716. # Such pattern generates a sparse matrix when the const tensor is all-zero.
  717. # We can lower this pattern to a pointwise kernel for more fusion opportunities
  718. # and saving memory footprint.
  719. optimize_scatter_upon_const_tensor = (
  720. os.environ.get("TORCHINDUCTOR_OPTIMIZE_SCATTER_UPON_CONST_TENSOR", "1") == "1"
  721. )
  722. # options in caffe2/torch/_inductor/fx_passes/pre_grad.py
  723. add_pre_grad_passes: Optional[str] = None
  724. remove_pre_grad_passes: Optional[str] = None
  725. # Comma-separated list of pass names to disable. Passes disabled via this config
  726. # will be skipped when they go through GraphTransformObserver.
  727. # Can be set via TORCHINDUCTOR_DISABLED_PASSES env var.
  728. # Use uppercase pass names (e.g., "PASS1,PASS2").
  729. disabled_passes: str = Config(
  730. env_name_force="TORCHINDUCTOR_DISABLED_PASSES",
  731. default="",
  732. )
  733. # The multiprocessing start method to use for inductor workers in the codecache.
  734. def decide_worker_start_method() -> str:
  735. if "TORCHINDUCTOR_WORKER_START" in os.environ:
  736. start_method = os.environ["TORCHINDUCTOR_WORKER_START"]
  737. else:
  738. start_method = "subprocess"
  739. assert start_method in (
  740. "subprocess",
  741. "fork",
  742. "spawn",
  743. ), f"Invalid start method: {start_method}"
  744. return start_method
  745. worker_start_method: str = decide_worker_start_method()
  746. # Threshold to decide if a kernel has small memory access in bytes
  747. # Default value is 16 MB which is arbitrarily selected.
  748. small_memory_access_threshold: int = 16777216
  749. # Whether to log from subprocess workers that are launched.
  750. worker_suppress_logging: bool = Config(
  751. justknob="pytorch/compiler:worker_suppress_logging",
  752. env_name_force="TORCHINDUCTOR_WORKER_SUPPRESS_LOGGING",
  753. default=True,
  754. )
  755. # Log per-operation runtime estimates for TLParse analysis.
  756. log_tlparse: bool = Config(
  757. env_name_force="LOG_TLPARSE",
  758. default=False,
  759. )
  760. # Flags to turn on all_reduce fusion. These 2 flags should be automatically turned
  761. # on by DDP and should not be set by the users.
  762. _fuse_ddp_communication = False
  763. _fuse_ddp_bucket_size = 25
  764. # Flag to control which fusion passes to apply. Functions in the list will
  765. # be applied in order. There are two different different fusion passes
  766. # --"fuse_ddp_with_concat_op" and "fuse_ddp_with_coalesced_op". The default
  767. # one is "fuse_ddp_with_concat_op". Users can also change this to a customized
  768. # fusion function.
  769. #
  770. # The fusion currently does not support multiple DDP with different PG or
  771. # data type. This feature will be added in the future PRs.
  772. #
  773. # "schedule_comm_wait" is used to delay the wait ops to maximize comm/comp
  774. # overlapping. At this moment, this pass performs better than
  775. # reorder_for_compute_comm_overlap_passes but we will add the logic of
  776. # "schedule_comm_wait" in the future and remove the one here.
  777. _fuse_ddp_communication_passes: list[Union[Callable[..., None], str]] = [
  778. "fuse_ddp_with_concat_op",
  779. "schedule_comm_wait",
  780. ]
  781. _micro_pipeline_tp: bool = False
  782. # Enable/disable partitioned scatter optimization for atomic add kernels
  783. # this will improve kernel performance at cost of memory usage.
  784. partitioned_scatter_enabled = (
  785. os.environ.get("TORCHINDUCTOR_PARTITIONED_SCATTER_ENABLED", "0") == "1"
  786. )
  787. # Min partitions for scatter optimization
  788. partitioned_scatter_min_partitions: int = 2
  789. # Max partitions for scatter optimization
  790. partitioned_scatter_max_partitions: int = 128
  791. # Memory budget fraction for scatter buffers
  792. partitioned_scatter_memory_budget: float = 0.10
  793. class _collective:
  794. auto_select: bool = False
  795. one_shot_all_reduce_threshold_bytes: int = 128 * 1024
  796. class aten_distributed_optimizations:
  797. """Configuration for distributed optimization passes on ATen FX graphs."""
  798. # Enable overlap scheduling pass
  799. enable_overlap_scheduling: bool = False
  800. # Enable overlap-preserving collective bucketing
  801. collective_bucketing: Optional[bool] = None
  802. # Insert ordering dependencies to preserve overlap relationships. This should only be used if
  803. # compiling with inductor, or for subsequent passes before removing the ops prior to execution
  804. insert_overlap_deps: Optional[bool] = None
  805. # Maximum compute node prefetch distance for overlap scheduling
  806. max_compute_pre_fetch: Optional[int] = None
  807. compute_overlap_multipler: Optional[float] = None
  808. # Custom runtime estimation function for ops
  809. # For user-defined estimation function, pass in the function handle
  810. # None means use default estimations
  811. # TODO - need estimated and profile based version
  812. custom_runtime_estimation: Optional[Callable[[torch.fx.Node], Optional[float]]] = (
  813. None
  814. )
  815. # Method for estimating collective runtime
  816. # "analytical": Use bandwidth formulas (default)
  817. # "benchmark": Use CUDA events with power-of-2 rounding and interpolation
  818. collective_estimator: Literal["analytical", "benchmark"] = "analytical"
  819. # Maximum memory increase above baseline for prefetch operations
  820. # Uses minimum of absolute cap and ratio of baseline
  821. max_memory_increase_gb: Optional[float] = None # Absolute cap in GB
  822. max_memory_increase_ratio: Optional[float] = None # Ratio of baseline peak memory
  823. # Maximum GB of concurrent collective data in flight. Too much in flight memory
  824. # can cause memory fragmentation within the CUDA Caching Allocator.
  825. max_in_flight_gb: Optional[float] = None
  826. # Maximum prefetch or bucketing candidates. Mainly intended for compile time.
  827. max_coll_distance: Optional[int] = None
  828. log_final_collectives_estimations: bool = False
  829. # Bucket exposed collectives first
  830. bucket_exposed_first: bool = True
  831. # Enable fusion region detection for overlap scheduling cost estimation.
  832. # When enabled, groups of fusible ops (pointwise, reduction, etc.) are treated
  833. # as atomic units with memory-bound runtime estimates.
  834. enable_fusion_regions: Optional[bool] = None
  835. # Prioritize bucketing during overlap scheduling by grouping candidates by bucket key
  836. prioritize_bucketing_during_scheduling: bool = True
  837. def parallel_compile_enabled_internally() -> bool:
  838. """
  839. TODO: Remove when parallel compiled is fully enabled internally. For rollout, use a
  840. knob to enable / disable. The justknob should not be performed at import, however.
  841. So for fbcode, we assign compile_threads to 'None' below and initialize lazily in
  842. async_compile.py.
  843. """
  844. ENABLE_PARALLEL_COMPILE_VERSION = 1
  845. jk_name = "pytorch/inductor:enable_parallel_compile_version"
  846. version = torch._utils_internal.justknobs_getval_int(jk_name)
  847. return ENABLE_PARALLEL_COMPILE_VERSION >= version
  848. def decide_compile_threads() -> int:
  849. """
  850. Here are the precedence to decide compile_threads
  851. 1. User can override it by TORCHINDUCTOR_COMPILE_THREADS. One may want to disable async compiling by
  852. setting this to 1 to make pdb happy.
  853. 2. Set to 1 if it's win32 platform
  854. 3. decide by the number of CPU cores
  855. """
  856. import logging
  857. # Defined locally so install_config_module doesn't try to parse
  858. # as a config option.
  859. log = logging.getLogger(__name__)
  860. if "TORCHINDUCTOR_COMPILE_THREADS" in os.environ:
  861. compile_threads = int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"])
  862. log.info("compile_threads set to %d via env", compile_threads)
  863. elif sys.platform == "win32":
  864. compile_threads = 1
  865. log.info("compile_threads set to 1 for win32")
  866. elif is_fbcode() and not parallel_compile_enabled_internally():
  867. compile_threads = 1
  868. log.info("compile_threads set to 1 in fbcode")
  869. else:
  870. cpu_count = (
  871. len(os.sched_getaffinity(0))
  872. if hasattr(os, "sched_getaffinity")
  873. else os.cpu_count()
  874. )
  875. assert cpu_count
  876. compile_threads = min(32, cpu_count)
  877. log.info("compile_threads set to %d", compile_threads)
  878. return compile_threads
  879. # TODO: Set directly after internal rollout.
  880. compile_threads: Optional[int] = None if is_fbcode() else decide_compile_threads()
  881. # Whether to quiesce the Triton-compile subprocess pool at the end of each compilation.
  882. quiesce_async_compile_pool: bool = Config(
  883. justknob="pytorch/inductor:quiesce_async_compile_pool",
  884. env_name_force="TORCHINDUCTOR_QUIESCE_ASYNC_COMPILE_POOL",
  885. default=True,
  886. )
  887. # Time in seconds to wait before quiescing
  888. quiesce_async_compile_time: int = Config(
  889. default=60,
  890. )
  891. # Whether or not to enable statically launching CUDA kernels
  892. # compiled by triton (instead of using triton's own launcher)
  893. use_static_cuda_launcher: bool = static_cuda_launcher_default()
  894. # Alias of use_static_cuda_launcher, used by both CUDA/XPU.
  895. use_static_triton_launcher: bool = Config(
  896. alias="torch._inductor.config.use_static_cuda_launcher"
  897. )
  898. # Attempt to statically launch user defined triton kernels
  899. # Requires use_static_cuda_launcher
  900. static_launch_user_defined_triton_kernels: bool = Config(
  901. justknob="pytorch/inductor:static_launch_user_defined_triton_kernels",
  902. env_name_force="TORCHINDUCTOR_STATIC_LAUNCH_USER_DEFINED_TRITON_KERNELS",
  903. default=False,
  904. )
  905. # Raise error if we bypass the launcher
  906. strict_static_cuda_launcher: bool = (
  907. os.environ.get("TORCHINDUCTOR_STRICT_STATIC_CUDA_LAUNCHER", "0") == "1"
  908. )
  909. # Alias of strict_static_cuda_launcher, used by both CUDA/XPU.
  910. strict_static_triton_launcher: bool = Config(
  911. alias="torch._inductor.config.strict_static_cuda_launcher"
  912. )
  913. # gemm autotuning global cache dir
  914. global_cache_dir: Optional[str]
  915. if is_fbcode():
  916. try:
  917. from libfb.py import parutil
  918. if __package__:
  919. global_cache_dir = parutil.get_dir_path(
  920. os.path.join(__package__.replace(".", os.sep), "fb/cache")
  921. )
  922. else:
  923. global_cache_dir = parutil.get_dir_path("fb/cache")
  924. except (ValueError, ImportError):
  925. global_cache_dir = None
  926. else:
  927. global_cache_dir = None
  928. # If kernel is fused, the name is generated from the origin node op names
  929. # for larger kernels limit this
  930. kernel_name_max_ops = 10
  931. # Pad input tensors of matmul/bmm/addmm to leverage Tensor Cores in NVIDIA GPUs
  932. shape_padding = os.environ.get("TORCHINDUCTOR_SHAPE_PADDING", "1") == "1"
  933. # Control if we will do padding for pointwise/reductions
  934. comprehensive_padding = (
  935. os.environ.get("TORCHINDUCTOR_COMPREHENSIVE_PADDING", "1") == "1"
  936. )
  937. pad_channels_last = False
  938. # Control if we will do padding on dynamic shapes
  939. pad_dynamic_shapes = False
  940. # Disable comprehensive padding on the CPU
  941. disable_padding_cpu = True
  942. # Control if we will expand the dimension of pointwise nodes to fuse
  943. expand_dimension_for_pointwise_nodes = False
  944. # The width of comprehensive padding, in bytes.
  945. # CUDA max memory transaction size is 128 bytes for a warp.
  946. padding_alignment_bytes = 128
  947. # Threshold on the minimum stride that will be padded.
  948. #
  949. # Don't align a too small stride since that causes too much memory increase.
  950. # Pad too small stride may also cause perf loss. We may result in many tiny data blocks
  951. # with gaps in between. That causes less coalesced GPU memory access!
  952. #
  953. # Initially we pick 320 as the threshold since for alignment=16,
  954. # that results in at most 5% memory cost.
  955. #
  956. # But later on we raise the threshold to 1024 to avoid interfere with persistent reduction.
  957. # Let's say an inner reduction has a row size 513. Inductor will generate
  958. # persistent reduction code.
  959. # If we do padding, the strides are not contiguous any more. Inductor
  960. # uses a much smaller threshold for persistent reduction in this case and
  961. # generates potentially worse non-persistent reduction code.
  962. #
  963. # This change turns HF AllenaiLongformerBase amp training from a loss of 1.09x to a win of 1.05x.
  964. # (baseline: 71.09ms, padding w/o this change: 77.38ms, padding with this change: 67.77ms)
  965. padding_stride_threshold = 1024
  966. # Enable padding outputs, even if they would not be padded in eager mode.
  967. # By default, we use the same strides as eager mode.
  968. pad_outputs = False
  969. # Whether to treat output of the backward graph as user visible.
  970. # For user visible outputs, inductor will make sure the stride matches with eager.
  971. bw_outputs_user_visible = True
  972. # Whether to always use shape padding if it is enabled and possible
  973. force_shape_pad: bool = False
  974. # Fx-based linear/matmul/bmm + permute/transpose vertical fusion
  975. permute_fusion = os.environ.get("TORCHINDUCTOR_PERMUTE_FUSION", "0") == "1"
  976. # Mark the wrapper call in PyTorch profiler
  977. profiler_mark_wrapper_call = False
  978. # Generate hook calls to torch._inductor.hooks.run_intermediate_hooks for
  979. # every intermediate for which we can correlate it with an intermediate
  980. # from the original FX graph
  981. generate_intermediate_hooks = False
  982. # Populate traceback field on IRNode; good for debugging why origin_node is
  983. # not populated, or finding out where an IRNode was constructed
  984. debug_ir_traceback = False
  985. # used for debugging to make sure config is properly set
  986. _raise_error_for_testing = False
  987. # Use fp64 for unbacked float scalars (from .item()) in Triton kernel signatures
  988. # to preserve precision. When False, uses fp32 (legacy behavior with precision loss).
  989. _use_fp64_for_unbacked_floats: bool = not is_fbcode()
  990. _profile_var = os.environ.get("TORCHINDUCTOR_PROFILE", "")
  991. profile_bandwidth = _profile_var != ""
  992. profile_bandwidth_regex = "" if _profile_var == "1" else _profile_var
  993. # Specify a file where we print out the profiling results.
  994. # None means we do not dump results to a file.
  995. profile_bandwidth_output: Optional[str] = os.environ.get(
  996. "TORCHINDUCTOR_PROFILE_OUTPUT", None
  997. )
  998. # Switch to do_bench_using_profiling to exclude the CPU overheads
  999. profile_bandwidth_with_do_bench_using_profiling = (
  1000. os.environ.get("TORCHINDUCTOR_PROFILE_WITH_DO_BENCH_USING_PROFILING") == "1"
  1001. )
  1002. # TODO: remove later
  1003. # incompatible with cpp_wrapper
  1004. disable_cpp_codegen = False
  1005. # Freezing will attempt to inline weights as constants in optimization
  1006. # and run constant folding and other optimizations on them. After freezing, weights
  1007. # can no longer be updated.
  1008. freezing: bool = os.environ.get("TORCHINDUCTOR_FREEZING", "0") == "1"
  1009. # Make freezing invalidate the eager Parameters of nn modules, to avoid memory overhead
  1010. # of potentially keeping multiple copies of weights.
  1011. freezing_discard_parameters: bool = False
  1012. # decompose some memory bound matmul/bmm to mul
  1013. decompose_mem_bound_mm: bool = False
  1014. # Wrap compiled regions in inductor_compiled_code HOP to make them visible to
  1015. # TorchDispatchModes like DebugMode and Selective Activation Checkpointing.
  1016. wrap_inductor_compiled_regions: bool = False
  1017. # assume_aligned_inputs means that we assume that inputs will be aligned; we generate
  1018. # code using this assumption, and clone tensors before use if they aren't aligned.
  1019. # In the common case, most inputs will be aligned.
  1020. assume_aligned_inputs: bool = False
  1021. # assume_32bit_indexing means that we assume 32-bit indexing is always safe; we always
  1022. # use 32-bit indices regardless of tensor sizes. If assume_32bit_indexing contradicts
  1023. # with example inputs we throw. This is useful when all dynamic shapes are unbacked and
  1024. # you know you only operate with 32-bit sizes.
  1025. assume_32bit_indexing: bool = False
  1026. # For the user-written Triton kernels compiled with the model, ignore the unsupported
  1027. # arguments passed to the @triton.autotune in the user's code; this is unsafe, as
  1028. # ignoring the unsupported args may lead to unexpected autotuning behavior: don't
  1029. # set unless you know what you're doing.
  1030. unsafe_ignore_unsupported_triton_autotune_args: bool = False
  1031. # When True, we will check in scheduler.py _codegen that there are no "loops"
  1032. # in the call stack; that is to say, the same frame multiple times. This
  1033. # ensures that a cProfile trace to this frame will be a straight line without
  1034. # any cycles. Incompatible with cpp_wrapper.
  1035. check_stack_no_cycles_TESTING_ONLY: bool = False
  1036. # When True, complex_memory_overlap always reports True
  1037. always_complex_memory_overlap_TESTING_ONLY: bool = False
  1038. # enable linear binary folding
  1039. enable_linear_binary_folding = (
  1040. os.environ.get("TORCHINDUCTOR_ENABLE_LINEAR_BINARY_FOLDING", "0") == "1"
  1041. )
  1042. # Adds NVTX annotations around training phases
  1043. annotate_training: bool = os.environ.get("TORCHINDUCTOR_ANNOTATE_TRAINING", "0") == "1"
  1044. # Enable caching codegen of triton templates.
  1045. enable_caching_generated_triton_templates: bool = True
  1046. # Lookup table for overriding autotune configs based on hash of Triton source code
  1047. autotune_lookup_table: dict[str, dict[str, Any]] = {}
  1048. file_lock_timeout: int = int(os.environ.get("TORCHINDUCTOR_FILE_LOCK_TIMEOUT", "600"))
  1049. enable_autograd_for_aot: bool = False
  1050. def get_worker_log_path() -> Optional[str]:
  1051. log_loc = None
  1052. if is_fbcode():
  1053. mast_job_name = os.environ.get("MAST_HPC_JOB_NAME", None)
  1054. global_rank = os.environ.get("ROLE_RANK", "0")
  1055. if mast_job_name is not None:
  1056. log_loc = f"/logs/dedicated_log_torch_compile_worker_rank{global_rank}"
  1057. return log_loc
  1058. torchinductor_worker_logpath: str = Config(
  1059. env_name_force="TORCHINDUCTOR_WORKER_LOGPATH",
  1060. default="",
  1061. )
  1062. class auto_chunker:
  1063. enable: bool = os.environ.get("TORCHINDUCTOR_AUTO_CHUNKER") == "1"
  1064. # Don't chunk from a node if the output size is not large enough
  1065. output_size_threshold: int = 1024 * 1024
  1066. # Don't chunk from a node if it does not 'amplify' the inputs a lot
  1067. amplify_ratio_threshold: int = 8
  1068. num_chunk: int | None = (
  1069. int(os.environ.get("TORCHINDUCTOR_CHUNKER_NUM_CHUNKS")) # type: ignore[arg-type]
  1070. if os.environ.get("TORCHINDUCTOR_CHUNKER_NUM_CHUNKS") is not None
  1071. else None
  1072. ) # If not None, use this to force number of chunks
  1073. # config specific to codegen/cpp.py
  1074. class cpp:
  1075. """
  1076. Settings for cpp backend.
  1077. This class provides a centralized location for managing cpp backend settings.
  1078. """
  1079. # set to torch.get_num_threads()
  1080. threads = -1
  1081. # Do not generate loops when the condition doesn't hold, like:
  1082. # for(long i0=4096; i0<4096; i0+=1)
  1083. no_redundant_loops = (
  1084. os.environ.get("TORCHINDUCTOR_CPP_NO_REDUNDANT_LOOPS", "1") == "1"
  1085. )
  1086. # Assume number of threads is dynamic, don't specialize thread number.
  1087. # Kernels don't recompile on thread number changes with this flag on.
  1088. # For single-threaded workload, turning it on would incur a slight
  1089. # performance degradation.
  1090. dynamic_threads = os.environ.get("TORCHINDUCTOR_CPP_DYNAMIC_THREADS", "0") == "1"
  1091. simdlen: Optional[int] = None
  1092. min_chunk_size = int(os.environ.get("TORCHINDUCTOR_CPP_MIN_CHUNK_SIZE", "512"))
  1093. cxx: tuple[None, str] = (
  1094. None, # download gcc12 from conda-forge if conda is installed
  1095. os.environ.get("CXX", "clang++" if sys.platform == "darwin" else "g++"),
  1096. ) # type: ignore[assignment]
  1097. # Allow kernel performance profiling via PyTorch profiler
  1098. enable_kernel_profile = (
  1099. os.environ.get("TORCHINDUCTOR_CPP_ENABLE_KERNEL_PROFILE", "0") == "1"
  1100. )
  1101. # enable weight prepacking to get a better performance; may lead to large memory footprint
  1102. weight_prepack = os.environ.get("TORCHINDUCTOR_CPP_WEIGHT_PREPACK", "1") == "1"
  1103. # Inject a bug into our relu implementation; useful for testing our repro
  1104. # extraction and minification functionality.
  1105. # Valid values: "compile_error", "runtime_error", "accuracy"
  1106. inject_relu_bug_TESTING_ONLY: Optional[str] = None
  1107. inject_log1p_bug_TESTING_ONLY: Optional[str] = None
  1108. # If None, autodetect whether or not AVX512/AVX2 can be used. Otherwise,
  1109. # force usage as specified, without testing. Default None.
  1110. vec_isa_ok: Optional[bool] = get_tristate_env("TORCHINDUCTOR_VEC_ISA_OK")
  1111. # similar to config.triton.descriptive_names
  1112. descriptive_names: Literal["torch", "original_aten", "inductor_node"] = (
  1113. "original_aten"
  1114. )
  1115. # how many nodes to allow into a single horizontal fusion
  1116. max_horizontal_fusion_size = int(
  1117. os.environ.get("TORCHINDUCTOR_CPP_MAX_HORIZONTAL_FUSION_SIZE", "16")
  1118. )
  1119. # Make scatter_reduce fallback when reduce is sum to avoid performance regression
  1120. # using atomic_add.
  1121. fallback_scatter_reduce_sum = (
  1122. os.environ.get("TORCHINDUCTOR_CPP_FALLBACK_SCATTER_REDUCE_SUM", "1") == "1"
  1123. )
  1124. # Use funsafe-math-optimizations when compiling
  1125. enable_unsafe_math_opt_flag = (
  1126. os.environ.get("TORCHINDUCTOR_CPP_ENABLE_UNSAFE_MATH_OPT_FLAG", "0") == "1"
  1127. )
  1128. # Use ffp-contract when compiling
  1129. # Options: "off" (default), "on", "fast"
  1130. # Per https://godbolt.org/z/bf4bvfc9r , clang/gcc has different behavior for "fast"
  1131. enable_floating_point_contract_flag = os.environ.get(
  1132. "TORCHINDUCTOR_CPP_ENABLE_FLOATING_POINT_CONTRACT_FLAG", "off"
  1133. )
  1134. # Disable the tiling select heuristic
  1135. enable_tiling_heuristics = (
  1136. os.environ.get("TORCHINDUCTOR_CPP_ENABLE_TILING_HEURISTIC", "1") == "1"
  1137. )
  1138. # Enable the Grouped GEMM Fusion
  1139. enable_grouped_gemm_template = False
  1140. # Maximal allowed number of slices on K-dim for a GEMM kernel. This controls
  1141. # the maximal parallelism of K-slicing. Since K-slicing requires extra thread
  1142. # synchronization and buffers, the maximal number of slices is limited to
  1143. # mitigate the sync overhead and memory usage.
  1144. # When set to 0, the number of slices is unlimited.
  1145. gemm_max_k_slices = int(os.environ.get("TORCHINDUCTOR_CPP_GEMM_MAX_K_SLICES", "1"))
  1146. # For perf tuning and debugging purpose, configure the pre-defined cache blocking for
  1147. # MxNxK dims respectively. The blockings are separated by comma and the unit is
  1148. # the number of register blocks.
  1149. # For example, "4,1,10" means 4 register blocks on M, 1 on N and 10 on K respectively.
  1150. gemm_cache_blocking = os.environ.get("TORCHINDUCTOR_CPP_GEMM_CACHE_BLOCKING", None)
  1151. # For perf tuning and debugging purpose, configure the pre-defined thread blocking factors for
  1152. # MxNxK dims respectively. The factors are separated by comma and their product
  1153. # should be the same as the total number of threads.
  1154. # For example, if the total number of threads is 56, "7,4,2" means the work is
  1155. # decomposed into 7x4x2 thread blocks along MxNxK of a GEMM.
  1156. gemm_thread_factors = os.environ.get("TORCHINDUCTOR_CPP_GEMM_THREAD_FACTORS", None)
  1157. # Whether to enable masked vectorization for the tail_loop.
  1158. enable_loop_tail_vec = True
  1159. # Whether to enable concat linear for cpu device
  1160. # Currently concat linear on CPU not always have benefit, depends on linear'shape or
  1161. # computing resource. We set this default to False to avoid regressions. User and
  1162. # enable this feature by their need.
  1163. enable_concat_linear = False
  1164. # Whether to use decomposed tanh for cpu device
  1165. # Disable by default due to https://github.com/pytorch/pytorch/issues/148241
  1166. use_decompose_tanh = (
  1167. os.environ.get("TORCHINDUCTOR_CPP_USE_DECOMPOSE_TANH", "0") == "1"
  1168. )
  1169. # Use a small dequant buffer for wgt of woq int4 size as: [q_group_size, Nr]
  1170. use_small_dequant_buffer = False
  1171. force_inline_kernel = (
  1172. os.environ.get("TORCHINDUCTOR_CPP_FORCE_INLINE_KERNEL", "0") == "1"
  1173. )
  1174. # Use static constexpr or static const for int array
  1175. use_constexpr_for_int_array = (
  1176. os.environ.get("TORCHINDUCTOR_CPP_USE_CONSTEXPR_FOR_INT_ARRAY", "1") == "1"
  1177. )
  1178. class triton:
  1179. """
  1180. Config specific to codegen/triton.py
  1181. """
  1182. # Use cudagraphs on output code
  1183. cudagraphs = os.environ.get("TORCHINDUCTOR_CUDAGRAPHS") == "1"
  1184. # Use cudagraph trees for memory pooling if `cudagraphs` is True
  1185. cudagraph_trees = True
  1186. # Should we skip cudagraphing graphs with dynamic shape inputs
  1187. # If False, we will re-record a graph for each unique set of shape inputs
  1188. cudagraph_skip_dynamic_graphs = False
  1189. # Specify dynamic shapes to capture cudagraphs and skip cudagraph for other shapes.
  1190. # Default to None, which means we capture cudagraphs for all shapes.
  1191. cudagraph_capture_sizes: Optional[tuple[Union[int, tuple[int, ...]]]] = None
  1192. # assertions not on the fast path, steady state
  1193. slow_path_cudagraph_asserts = True
  1194. # TODO - need to debug why this prevents cleanup
  1195. cudagraph_trees_history_recording = False
  1196. # Emit objgraph backref dumps for leaked cudagraph pool tensors
  1197. cudagraph_trees_objgraph = False
  1198. # Enable cudagraph support for mutated inputs from prior cudagraph pool
  1199. cudagraph_support_input_mutation = not is_fbcode()
  1200. # Maximal number of allowed cudagraph re-record for a function and
  1201. # a cudagraph node due to static input tensor address changes or
  1202. # cudagraph managed tensor data pointer changed.
  1203. # i.e., allow num_recording <= cudagraph_unexpected_rerecord_limit
  1204. # note: we are conservative here and choose a large limit.
  1205. cudagraph_unexpected_rerecord_limit = 128
  1206. # Warn loudly when the number of cudagraphs due to dynamic shape
  1207. # exceeds this limit
  1208. cudagraph_dynamic_shape_warn_limit: Optional[int] = 8
  1209. # synchronize after cudagraph invocation
  1210. force_cudagraph_sync = False
  1211. # always run cudagraphs in the eager warmup stage
  1212. # instead of recording and executing cudagraphs
  1213. force_cudagraphs_warmup = False
  1214. # If False (default), torch.compile skips cudagraph for a graph if it
  1215. # contains cudagraph-unsafe ops. If True, we require that all cuda ops
  1216. # be captured into cudagraph. If this is not possible, this will raise
  1217. # an error.
  1218. cudagraph_or_error: bool = Config(
  1219. env_name_force="TORCHINDUCTOR_CUDAGRAPH_OR_ERROR",
  1220. default=False,
  1221. )
  1222. # reorder nodes to minimize the number of graph partitions while
  1223. # not incurring large memory overhead
  1224. reorder_for_reducing_graph_partitions: bool = True
  1225. # assertions on the fast path
  1226. fast_path_cudagraph_asserts = False
  1227. # skip warmup for cudagraph trees
  1228. skip_cudagraph_warmup = False
  1229. # Synchronize before and after every compiled graph.
  1230. debug_sync_graph = False
  1231. # Synchronize after every kernel launch, to help pinpoint bugs
  1232. debug_sync_kernel = False
  1233. # Always load full blocks (rather than broadcasting inside the block)
  1234. dense_indexing = False
  1235. # TODO - enable by default
  1236. coalesce_tiling_analysis: bool = (
  1237. os.environ.get(
  1238. "TORCHINDUCTOR_COALESCE_TILING_ANALYSIS", "1" if not is_fbcode() else "0"
  1239. )
  1240. == "1"
  1241. )
  1242. # limit tiling dimensions
  1243. # - max_tiles=1 disables tiling
  1244. # - max_tiles=2
  1245. # - max_tiles=3 is experimental and may have bugs
  1246. # higher values are unsupported
  1247. # We use a max of 3 if coalesce_tiling_analysis is True, and 2 otherwise.
  1248. # Note - coalesce_tiling_analysis does not yet apply to dynamic shapes.
  1249. max_tiles: Optional[int] = None
  1250. # Prefer higher dimensional tilings. This simplifies indexing expressions, making
  1251. # it easier to identify block pointers.
  1252. prefer_nd_tiling: bool = False
  1253. # use triton.autotune for pointwise ops with complex layouts
  1254. # this should only be disabled for debugging/testing
  1255. autotune_pointwise = True
  1256. # max autotune gemm with cublasLt
  1257. autotune_cublasLt = True
  1258. # Tune the generated Triton kernels at compile time instead of first time they run
  1259. # Setting to None means uninitialized
  1260. autotune_at_compile_time: Optional[bool] = None
  1261. # We use random tensors for autotune by default. Setting this as true will let us
  1262. # use inputs from sample inputs to autotune user defined triton kernels.
  1263. # Side effect for this option is increased memory footprint during first pass compilation.
  1264. autotune_with_sample_inputs: bool = False
  1265. # Allows tiling reductions into multiple dimensions.
  1266. # For best results, this should be used with prefer_nd_tiling.
  1267. tile_reductions: bool = False
  1268. # Codegen matmul natively with tl.dot without using a template.
  1269. # This option makes Inductor generate matrix multiplication from scratch,
  1270. # instead of calling predefined Triton templates (mm, bmm, mm_plus_mm).
  1271. # Compile time may be longer because native matmul benchmarks more Triton configs
  1272. # than regular pointwise or reduction kernels.
  1273. # Native matmul often aggressively fuses operations around the matrix multiply,
  1274. # which can make it faster or slower depending on your program.
  1275. #
  1276. # This option takes priority over other GEMM implementations. If Inductor determines
  1277. # that a matmul can be generated, it will always generate it with native_matmul.
  1278. # That means optimized kernels such as decompose_k or persistent_tma_matmul will
  1279. # not be called when this option is enabled.
  1280. #
  1281. # Note: Native matmul does not currently support block pointers or TMA matmul.
  1282. # If both native_matmul and (use_block_ptr or enable_persistent_tma_matmul) are enabled,
  1283. # an error will be thrown.
  1284. native_matmul: bool = os.getenv("TORCHINDUCTOR_NATIVE_MATMUL", "0") == "1"
  1285. # should we stop a fusion to allow better tiling?
  1286. tiling_prevents_pointwise_fusion = True
  1287. tiling_prevents_reduction_fusion = True
  1288. # should we give different names to kernels
  1289. # Note: This is orthogonal to descriptive_names - this is deciding whether
  1290. # our triton kernel names should all be `triton_` (to maximize caching) or
  1291. # whether they should be unique.
  1292. unique_kernel_names = (
  1293. os.environ.get("TORCHINDUCTOR_UNIQUE_KERNEL_NAMES", "1") == "1"
  1294. )
  1295. # similar to the option above, but this is specific to user defined kernels,
  1296. # while unique_kernel_name is for kernels generated by inductor.
  1297. # We have this option because sometimes we reuse user's kernel code with different
  1298. # configs which would result in the same name.
  1299. # Note: This MODIFIES the user's kernel function name within inductor phase.
  1300. unique_user_kernel_names = (
  1301. os.environ.get("TORCHINDUCTOR_UNIQUE_USER_KERNEL_NAMES", "0") == "1"
  1302. )
  1303. # should we put op names in kernel names
  1304. # "torch": Maps to the fx op in the Dynamo graph (module name, method name, etc.)
  1305. # "original_aten": Maps to the highest-level aten op (i.e. pre-decompositions)
  1306. # "inductor_node": Maps to the node name in the FX graph passed to Inductor
  1307. descriptive_names: Literal["torch", "original_aten", "inductor_node"] = (
  1308. "original_aten"
  1309. )
  1310. # use alternate codegen for smaller reductions
  1311. persistent_reductions = (
  1312. os.environ.get("TORCHINDUCTOR_PERSISTENT_REDUCTIONS", "1") == "1"
  1313. )
  1314. # For small output size reductions uses cross thread-block synchronization to gain more parallelism
  1315. cooperative_reductions = (
  1316. os.environ.get("TORCHINDUCTOR_COOPERATIVE_REDUCTIONS", "0") == "1"
  1317. )
  1318. # used for debugging cooperative reduction codegen, always generate cooperative_reductions
  1319. force_cooperative_reductions = False
  1320. # 0: disable
  1321. # 1/True: enable, use tuning to pick between different subkernels
  1322. # 2: enable, force using persistent reduction (for debugging)
  1323. # 3: enable, force using non-persistent reduction (for debugging)
  1324. # pyrefly: ignore [bad-assignment]
  1325. multi_kernel: Literal[0, 1, 2, 3] = int(
  1326. os.environ.get("TORCHINDUCTOR_MULTI_KERNEL", "0")
  1327. ) # type: ignore[assignment]
  1328. # hint to Triton when arguments are divisible by 16
  1329. divisible_by_16 = os.environ.get("TORCHINDUCTOR_DIVISIBLE_BY_16", "1") == "1"
  1330. # Minimum R0_BLOCK to be used for a TritonSplitScanKernel
  1331. # NOTE: This also indirectly controls the size of workspace buffer required
  1332. min_split_scan_rblock = 256
  1333. # Store the generated cubin files for cpp wrapper code to load
  1334. store_cubin = False
  1335. # the max number of spills we allow for the configs we benchmark.
  1336. # Setting this to 0 means we skip a config if it spills even a single
  1337. # register.
  1338. # Setting it to a larger value allows a config spilling a small amount
  1339. # of registers being benchmarked.
  1340. #
  1341. # NOTE: triton will always report >0 register spills for kernels using sin/cos.
  1342. # (check this issue https://github.com/triton-lang/triton/issues/1756 )
  1343. # So far we see a fixed 8 spilled registers for kernels using sin/cos.
  1344. # Raise the threshold to 16 to be safe.
  1345. # We should revisit this once we understand more of the source of register spills.
  1346. spill_threshold: int = 32 if torch.version.hip else 16
  1347. # Generate code containing the newer tl.make_block_ptr() API for loads/store
  1348. use_block_ptr = False
  1349. # (Experimental)
  1350. # Generate code using the tl.make_tensor_descriptor() API for loads/store
  1351. # [Note: TMA API Restrictions] Currently the TMA API requires the following:
  1352. # - For Nvidia GPUs, the compute capability should be >= 9.0
  1353. # - The innermost stride of a descriptor should be 1
  1354. # - The size of the block shape in the innermost dimension should load / store
  1355. # at least 16 bytes.
  1356. # - Tensors are 16 byte aligned. Enabling this option therefore requires
  1357. # assume_aligned_inputs to also be enabled
  1358. # TMA descriptors are only going to be generated if the above conditions
  1359. # can be satisfied, along with any existing requirements for index expressions
  1360. use_tensor_descriptor = False
  1361. # (Experimental)
  1362. # Whether to allow reordering tensor descriptor matches with descending
  1363. # strides, at the expense of transposing values after load / before store.
  1364. transpose_discontiguous_tensor_descriptor = True
  1365. # Inject a bug into our relu implementation; useful for testing our repro
  1366. # extraction and minification functionality.
  1367. # Valid values: "compile_error", "runtime_error", "accuracy"
  1368. inject_relu_bug_TESTING_ONLY: Optional[str] = None
  1369. # Whether to upcast float16 / bfloat16 to float32 in triton codegen (Experimental)
  1370. codegen_upcast_to_fp32 = True
  1371. # Whether persistent matmul kernels should be enabled this flag only has effect when on h100
  1372. # with a version of triton new enough to support TMA
  1373. enable_persistent_tma_matmul = (
  1374. os.environ.get("ENABLE_PERSISTENT_TMA_MATMUL", "0") == "1"
  1375. )
  1376. # Should TMA store be enable from templates. TODO: Remove once we
  1377. # can autotune over the result.
  1378. enable_template_tma_store = os.environ.get("ENABLE_TEMPLATE_TMA_STORE", "0") == "1"
  1379. # Skip L1 cache for buffers that are used only once. Disabled by default
  1380. skip_l1_cache = os.environ.get("TORCHINDUCTOR_SKIP_L1", "0") == "1"
  1381. # During autotuning, if one of the kernels/configs fails for some reason,
  1382. # Inductor will usually skip it (and assign its latency to inf).
  1383. # For testing it's helpful to be able to assert that none of the configs fail.
  1384. # Note: it may also need to be used with config.compile_threads = 1
  1385. disallow_failing_autotune_kernels_TESTING_ONLY = False
  1386. # specify number of splits to autotune on for decompose_k. 0 disables decompose_k
  1387. # Disabled on ROCm by default pending performance validation.
  1388. num_decompose_k_splits = int(
  1389. os.environ.get(
  1390. "TORCHINDUCTOR_NUM_DECOMPOSE_K_SPLITS", "0" if torch.version.hip else "10"
  1391. )
  1392. )
  1393. # specify minimum ratio of K to M AND N in order to autotune on decompose_k. 0 enables
  1394. # it as an autotuning choice for all matmuls
  1395. decompose_k_threshold = int(
  1396. os.environ.get("TORCHINDUCTOR_DECOMPOSE_K_THRESHOLD", "32")
  1397. )
  1398. # Programmatic Dependent Launch improves launch latency on Nvidia Hopper+ devices
  1399. # If set to true, will generate PDL code on devices that support it.
  1400. # If set to false, will never generate PDL code.
  1401. enable_pdl = os.environ.get("TORCHINDUCTOR_ENABLE_PDL", "0") == "1"
  1402. mix_order_reduction = (
  1403. os.environ.get("TORCHINDUCTOR_MIX_ORDER_REDUCTION", "0" if is_fbcode() else "1")
  1404. == "1"
  1405. )
  1406. mix_order_reduction_initial_xblock = 1
  1407. mix_order_reduction_split_size: Optional[int] = None
  1408. mix_order_reduction_autotune_split_size = (
  1409. os.environ.get("TORCHINDUCTOR_MIX_ORDER_REDUCTION_AUTOTUNE_SPLIT_SIZE", "0")
  1410. == "1"
  1411. )
  1412. # If set to true, will skip some non-critical checks in the mix order reduction
  1413. # this could be helpful to avoid recompilations in some cases
  1414. mix_order_reduction_non_strict_mode = False
  1415. # Don't allow multi-stages by default to avoid out of shared memory
  1416. mix_order_reduction_allow_multi_stages = (
  1417. os.environ.get("TORCHINDUCTOR_MIX_ORDER_REDUCTION_ALLOW_MULTI_STAGES") == "1"
  1418. )
  1419. enable_tlx_templates: bool = (
  1420. os.environ.get("TORCHINDUCTOR_ENABLE_TLX_TEMPLATES", "0") == "1"
  1421. )
  1422. # Map for storing the amount of kernel runs with dumped input tensors
  1423. # Based on hash of Triton source code to avoid bloating the folder
  1424. debug_dump_kernel_inputs: dict[str, int] = {}
  1425. # Value for the maximum amount of runs with dumped kernel input tensors
  1426. # When the maximum is reached the first values get overwritten
  1427. # This ensures the last N runs are saved, where N is this value
  1428. max_kernel_dump_occurrences = 3
  1429. proton_profiling: bool = (
  1430. os.environ.get("TORCHINDUCTOR_TRITON_PROTON_PROFILING", "0") == "1"
  1431. )
  1432. # If not specified, proton traces will be saved to the debug directory
  1433. proton_output_dir: Optional[str] = os.environ.get(
  1434. "TORCHINDUCTOR_TRITON_PROTON_OUTPUT_DIR"
  1435. )
  1436. # Group CTAs by SM in proton trace files.
  1437. proton_group_by_sm: bool = (
  1438. os.environ.get("TORCHINDUCTOR_TRITON_PROTON_GROUP_BY_SM", "1") == "1"
  1439. )
  1440. # Split proton trace files by kernel invocation.
  1441. proton_split_invocations: bool = (
  1442. os.environ.get("TORCHINDUCTOR_TRITON_PROTON_SPLIT_INVOCATIONS", "1") == "1"
  1443. )
  1444. # Process warp tracks into CTA tracks (min warp start, max warp end) and
  1445. # assign CTAs to slots per SM such that CTAs do not overlap.
  1446. proton_per_cta_occupancy: bool = (
  1447. os.environ.get("TORCHINDUCTOR_TRITON_PROTON_PER_CTA_OCCUPANCY", "1") == "1"
  1448. )
  1449. class aot_inductor:
  1450. """
  1451. Settings for Ahead-Of-Time Inductor Compilation
  1452. """
  1453. # AOTInductor output path
  1454. # If an absolute path is specified, the generated lib files will be stored under the directory;
  1455. # If a relative path is specified, it will be used as a subdirectory under the default caching path;
  1456. # If not specified, a temp directory will be created under the default caching path.
  1457. # If the specified path contains something like "model.so", the sub-string will be used
  1458. # to name the generated library.
  1459. output_path = ""
  1460. debug_compile = os.environ.get("AOT_INDUCTOR_DEBUG_COMPILE", "0") == "1"
  1461. debug_symbols = os.environ.get("AOT_INDUCTOR_DEBUG_SYMBOLS", "0") == "1"
  1462. # Annotate generated main wrapper function, i.e. AOTInductorModel::run_impl,
  1463. # to use which cpp compiler optimization level, default to O1
  1464. compile_wrapper_opt_level = os.environ.get(
  1465. "AOT_INDUCTOR_COMPILE_WRAPPER_OPT_LEVEL", "O1"
  1466. )
  1467. # option for debug printing/saving for intermediate tensor values for aot inductor
  1468. # 0: disable debug dumping
  1469. # 1: enable saving intermediate tensor values
  1470. # 2: enable printing intermediate tensor values
  1471. # 3: enable printing kernel names only (useful for pinpointing troublesome kernels)
  1472. # pyrefly: ignore [bad-assignment]
  1473. debug_intermediate_value_printer: Literal["0", "1", "2", "3"] = os.environ.get(
  1474. "AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER", "0"
  1475. ) # type: ignore[assignment]
  1476. # filtered nodes to be printed for debug values. Specify this option when debug_intermediate_value_printer is set to 2
  1477. filtered_kernel_names = os.environ.get(
  1478. "AOT_INDUCTOR_FILTERED_KERNELS_TO_PRINT", None
  1479. )
  1480. # Serialized tree spec for flattening inputs
  1481. # TODO: Move this into metadata
  1482. serialized_in_spec = ""
  1483. # Serialized tree spec for flattening outputs
  1484. # TODO: Move this into metadata
  1485. serialized_out_spec = ""
  1486. # flag to decide whether to create a submodule for constant graph.
  1487. use_runtime_constant_folding: bool = False
  1488. # flag to force weight to be appended to the shared library and mapped by the runtime
  1489. # rather than embedded into the data section. Needed to support 1B+ parameter models
  1490. force_mmap_weights: bool = False
  1491. # Default value of use_consts_asm_build is True, it will build by assembly language.
  1492. # When the value is False, it will build by c++ language.
  1493. use_consts_asm_build = True
  1494. package: bool = False
  1495. package_cpp_only: Optional[bool] = None
  1496. # If package_cpp_only is True, whether cpp files will be compiled to a
  1497. # dynamically linked library or static linked library
  1498. dynamic_linkage: bool = True
  1499. # Dictionary of metadata users might want to save to pass to the runtime.
  1500. # TODO: Move this somewhere else, since it's no longer really a config
  1501. metadata: dict[str, str] = {}
  1502. # fbcode only. Whether to raise error if C++ codegen is too big to optimize
  1503. raise_error_on_ignored_optimization: bool = (
  1504. os.environ.get("AOTINDUCTOR_RAISE_ERROR_ON_IGNORED_OPTIMIZATION", "1") == "1"
  1505. )
  1506. # Whether to check lowerbound constraints on dynamic shapes during runtime.
  1507. # When disabled, allows models with dynamic sizes of 0 or 1 to work with
  1508. # AOTI_RUNTIME_CHECK_INPUTS=1, avoiding errors from the [2+, ...] lowerbound
  1509. # restriction when backed_size_oblivious is off.
  1510. check_lowerbound: bool = True
  1511. # dump an aoti minifier if program errors
  1512. dump_aoti_minifier: bool = os.environ.get("DUMP_AOTI_MINIFIER", "0") == "1"
  1513. # Compiler compilation debug info
  1514. # 1: Dumps the original graph out to repro.py if compilation fails
  1515. # 2: Dumps a minifier_launcher.py if aoti fails.
  1516. # 3: Always dumps a minifier_launcher.py. Good for segfaults.
  1517. # 4: Dumps a minifier_launcher.py if the accuracy fails.
  1518. repro_level: int = int(os.environ.get("AOTINDUCTOR_REPRO_LEVEL", 2))
  1519. # Dictionary of presets that can be passed in
  1520. presets: dict[str, Any] = {}
  1521. # Kill switch for allowing temporary tensors to be allocated as stack arrays. Tests
  1522. # should be run with this flag both on and off to make sure we have coverage.
  1523. allow_stack_allocation: bool = False
  1524. # Enables an alternate DSO interface (the "minimal ArrayRef interface") intended
  1525. # to maximize performance for use cases that it can accommodate at the expense of
  1526. # generality. In brief:
  1527. # - inputs and outputs are ArrayRefTensor<T> (note that strides are required, but the
  1528. # tensor must be contiguous)
  1529. # - constant handling is unchanged because it is not a per-inference-iteration bottleneck
  1530. #
  1531. # When the DSO is generated in this mode, the usual interface will also be supported,
  1532. # but performance for that interface may be degraded.
  1533. use_minimal_arrayref_interface: bool = False
  1534. # Set to True if we want to use Pytorch's CUDACachingAllocator for weight management
  1535. weight_use_caching_allocator: bool = (
  1536. os.environ.get("AOT_INDUCTOR_WEIGHT_USE_CACHING_ALLOCATOR", "0") == "1"
  1537. )
  1538. # Experimental. Flag to control whether to include weight in .so
  1539. # Not supported for cross_target_platform="windows".
  1540. package_constants_in_so: bool = True
  1541. # Experimental. Flag to control whether to package weight separately on disk and which
  1542. # format to package it in.
  1543. # Options:
  1544. # None:
  1545. # Do not package weight separately on disk.
  1546. # "pickle_weights":
  1547. # Each weight is pickled and stored separately in data/weights. We also store the
  1548. # FQN names of each weight in a weights_config.json in each model's data/aot_inductor/model folder.
  1549. # Can only be load back from python using torch._inductor.aoti_load_package API now.
  1550. # "binary_blob":
  1551. # Stores all weights in a single binary blob in data/aot_inductor/model folder for each model.
  1552. # This option and config.aot_inductor.force_mmap_weights cannot both be True
  1553. package_constants_on_disk_format: Optional[str] = None
  1554. # Experimental. Controls automatic precompiling of common AOTI include files.
  1555. precompile_headers: bool = not is_fbcode()
  1556. # Embed generated kernel binary files into model.so
  1557. embed_kernel_binary: Optional[bool] = None
  1558. # Generate kernel files that support multiple archs
  1559. # For CUDA, this means generating fatbin files for kernels, and the fatbin files
  1560. # contains PTX and SASS for the current architecture.
  1561. # For XPU, this means generating SPIR-V files for kernels, and the SPIR-V files
  1562. # will be compiled to target different XPU architectures at runtime.
  1563. emit_multi_arch_kernel: Optional[bool] = None
  1564. # If not None, the generated files with use this name in file stem.
  1565. # If None, we will use a hash to name files.
  1566. #
  1567. # If package_cpp_only, this name is also used for the target name in CMakelists.txt
  1568. # The default target name is "aoti_model"
  1569. #
  1570. # If compile_standalone, the aoti model class name is f"AOTInductorModel{name}"
  1571. #
  1572. # This name can only contain letters, numbers, and underscores.
  1573. model_name_for_generated_files: Optional[str] = None
  1574. # Custom ops that have implemented C shim wrappers, defined as an op to C shim declaration dict
  1575. custom_ops_to_c_shims: dict[torch._ops.OpOverload, list[str]] = {}
  1576. # custom op libs that have implemented C shim wrappers
  1577. custom_op_libs: Optional[list[str]] = None
  1578. # Whether to enable link-time-optimization
  1579. enable_lto = os.environ.get("AOT_INDUCTOR_ENABLE_LTO", "0") == "1"
  1580. # Whether the compiled .so should link to libtorch
  1581. link_libtorch: bool = True
  1582. # Currently the only valid option is "windows".
  1583. # We'll use x86_64-w64-mingw32-gcc to cross-compile a .dll file
  1584. # If using cuda, you also need to set WINDOWS_CUDA_HOME env var
  1585. # to point to windows CUDA toolkit.
  1586. # Example: WINDOWS_CUDA_HOME=cuda-windows-base/cuda_cudart/cudart/
  1587. # The path should contain lib cuda and lib cudart
  1588. cross_target_platform: Optional[str] = None
  1589. # If link_libtorch is False and cross_target_platform is windows,
  1590. # a library needs to be provided to provide the shim implementations.
  1591. aoti_shim_library: Optional[str | list[str]] = None
  1592. aoti_shim_library_path: Optional[str] = None
  1593. # a convenient class that automatically sets a group of the configs in aot_inductor
  1594. # it should only control the flags in aot_inductor.
  1595. # it should not do anything else.
  1596. class aot_inductor_mode:
  1597. # dynamic_linkage=False
  1598. # link_libtorch=False
  1599. # package_cpp_only=True
  1600. # embed_kernel_binary=True
  1601. # emit_multi_arch_kernel=True
  1602. compile_standalone: bool = False
  1603. class cutlass:
  1604. """
  1605. Config specific to cutlass backend.
  1606. """
  1607. compile_opt_level: Literal["-O0", "-O1", "-O2", "-O3", "-OS"] = "-O1"
  1608. # Whether to enable debug info, e.g. line number, cutlass debug info.
  1609. enable_debug_info = False
  1610. # Whether to use fast math.
  1611. use_fast_math = False
  1612. # Path to the CUTLASS repo root directory.
  1613. # The default path only works under PyTorch local development environment.
  1614. cutlass_dir = os.path.realpath(
  1615. os.environ.get(
  1616. "TORCHINDUCTOR_CUTLASS_DIR",
  1617. os.path.join(
  1618. os.path.dirname(torch.__file__),
  1619. "../third_party/cutlass/",
  1620. ),
  1621. )
  1622. )
  1623. # Configures the maximum number of CUTLASS configs to profile in max_autotune.
  1624. # By default it's None, so that all CUTLASS configs are tuned.
  1625. # This is mainly used to reduce test time in CI.
  1626. cutlass_max_profiling_configs: Optional[int] = None
  1627. # The L2 swizzle values to consider when profiling CUTLASS configs in max_autotune.
  1628. cutlass_max_profiling_swizzle_options: list[int] = [1, 2, 4, 8]
  1629. cutlass_dynamic_cluster_shape: tuple[int, int, int] = cast(
  1630. tuple[int, int, int],
  1631. tuple(
  1632. int(x)
  1633. for x in os.environ.get(
  1634. "TORCHINDUCTOR_CUTLASS_DYNAMIC_CLUSTER_SHAPE", "2,1,1"
  1635. ).split(",")
  1636. ),
  1637. )
  1638. cutlass_dynamic_cluster_fallback: tuple[int, int, int] = cast(
  1639. tuple[int, int, int],
  1640. tuple(
  1641. int(x)
  1642. for x in os.environ.get(
  1643. "TORCHINDUCTOR_CUTLASS_DYNAMIC_CLUSTER_FALLBACK",
  1644. ",".join(str(v) for v in cutlass_dynamic_cluster_shape),
  1645. ).split(",")
  1646. ),
  1647. )
  1648. # Whether to use CUTLASS EVT for epilogue fusion
  1649. cutlass_epilogue_fusion_enabled = (
  1650. os.environ.get("CUTLASS_EPILOGUE_FUSION", "0") == "1"
  1651. )
  1652. # Whether to only use TMA-compatible kernels in CUTLASS
  1653. cutlass_tma_only = False
  1654. # Minimum value of M*N*K to consider the CUTLASS backend for GEMM ops.
  1655. cutlass_backend_min_gemm_size: int = 1
  1656. # enable generation of inline standalone runner in CUDA CPP generated code
  1657. # which allows to compile the generated code into a standalone executable.
  1658. generate_test_runner: bool = (
  1659. os.environ.get("INDUCTOR_CUDA_BACKEND_GENERATE_TEST_RUNNER_CODE", "0") == "1"
  1660. )
  1661. # Keep only Cutlass op configs which contain this regular expression pattern
  1662. # Set this to "warpspecialized_cooperative_epi_tma" to enable only SM90 TMA Cutlass Kernels for large GEMMs
  1663. cutlass_op_allowlist_regex: Optional[str] = os.environ.get(
  1664. "TORCHINDUCTOR_CUTLASS_ALLOWLIST"
  1665. )
  1666. # Note: Names of Cutlass ops names can be obtained by calling
  1667. # op.configuration_name() on a Cutlass op instance, for example those
  1668. # returned from cutlass_utils.gen_ops() or the op argument passed to
  1669. # CUTLASSGemmTemplate.render(...)
  1670. # Filter Cutlass configs which contain this regular expression pattern
  1671. # Set this to "pingpong" to avoid numerical issues
  1672. # caused by the op ordering of the "pingpong" memory access
  1673. # pattern used by some Cutlass Kernels.
  1674. cutlass_op_denylist_regex: Optional[str] = os.environ.get(
  1675. "TORCHINDUCTOR_CUTLASS_DENYLIST"
  1676. )
  1677. # Non-negative integer which determines how many kernels are instantiated.
  1678. # 0 = 0000 generates the fewest kernels, 9999 generates all possible combinations.
  1679. # increasing first digit reduces schedule / mixed type pruning,
  1680. # increasing second digit generates more cluster sizes,
  1681. # increasing third digit generates more MMA multipliers,
  1682. # increasing fourth digit generates more instruction shapes.
  1683. cutlass_instantiation_level: str = os.environ.get(
  1684. "TORCHINDUCTOR_CUTLASS_INSTANTIATION_LEVEL", "0"
  1685. )
  1686. # use compile command to create kernel .cu and .so name
  1687. cutlass_hash_with_compile_cmd: bool = (
  1688. os.environ.get("TORCHINDUCTOR_CUTLASS_HASH_WITH_COMPILE_CMD", "0") == "1"
  1689. )
  1690. # Experimental. Prescreen top x configs before tuning on swizzle.
  1691. cutlass_prescreening: bool = (
  1692. os.environ.get("TORCHINDUCTOR_CUTLASS_PRESCREENING", "1") == "1"
  1693. )
  1694. # Specify which operations should use CUTLASS backend
  1695. # Comma-separated list like "mm,addmm,bmm", "all" for all operations, and "" for none.
  1696. # Acceptable operations: mm, int_mm, addmm, sparse_semi_structured_mm, bmm, scaled_mm
  1697. cutlass_enabled_ops: str = os.environ.get(
  1698. "TORCHINDUCTOR_CUTLASS_ENABLED_OPS", "all"
  1699. )
  1700. # Whether to consult the binary remote cache
  1701. use_binary_remote_cache: bool = True
  1702. # Whether to upload compiled kernels to remote cache
  1703. upload_to_binary_remote_cache: bool = False
  1704. # Whether to force upload if the key already exists
  1705. # Use this to overwrite and handle cache pollution
  1706. binary_remote_cache_force_write: bool = False
  1707. # Enable caching codegen of cuda templates.
  1708. enable_caching_codegen: bool = True
  1709. @inherit_fields_from(cutlass)
  1710. class cuda(cutlass):
  1711. # CUDA arch to use for CUDA template kernel compilation.
  1712. # e.g. "70", "75", "80", "90", etc.
  1713. # When arch is None, Inductor uses torch.cuda.get_device_capability(0).
  1714. arch: Optional[str] = None
  1715. # CUDA version to use for CUDA template kernel compilation.
  1716. # e.g. "11.4", "12.1", etc.
  1717. # When version is None, Inductor uses torch.version.cuda.
  1718. version: Optional[str] = None
  1719. # Path to CUDA NVCC.
  1720. # NVCC search order:
  1721. # 1) cuda_cxx set in this config
  1722. # 2) CUDACXX environment variable
  1723. # 3) CUDA_HOME environment variable
  1724. # 4) default system search PATH.
  1725. cuda_cxx: Optional[str] = None
  1726. # Whether to enable device LTO (link-time-optimization).
  1727. enable_cuda_lto = False
  1728. # Whether to keep intermediate files dring compilation.
  1729. enable_ptxas_info = False
  1730. # Configures the maximum number of NVIDIA Universal GEMM (NVGEMM) configs to profile in max_autotune.
  1731. # By default it's 5, to keep compile time to a reasonable level.
  1732. nvgemm_max_profiling_configs: Optional[int] = 5
  1733. @inherit_fields_from(cutlass)
  1734. class xpu(cutlass):
  1735. # Xe arch to use for SYCL kernel compilation.
  1736. # eg. 12, 20, which corresponding to Xe12(PVC) and Xe20 (BMG)
  1737. arch: Optional[str] = None
  1738. # oneAPI version to use for SYCL kernel compilation.
  1739. # e.g. "20250201".
  1740. version: Optional[str] = None
  1741. class rocm:
  1742. # Offload arch list for device code compilation, e.g. ["gfx90a", "gfx942"].
  1743. # If empty, the `native` arch is used
  1744. arch: list[str] = []
  1745. # Enable the CK backend for CDNA2 and CDNA3 only (for now)
  1746. # Processor name reference: https://llvm.org/docs/AMDGPUUsage.html#processors
  1747. ck_supported_arch: list[Literal["gfx90a", "gfx942", "gfx950"]] = [
  1748. "gfx90a",
  1749. "gfx942",
  1750. "gfx950",
  1751. ]
  1752. # Optimization level, use to balance compilation speed and runtime performance.
  1753. # The type will not necessarily be comprehensive and won't be enforced at runtime.
  1754. compile_opt_level: Literal[
  1755. "-O0", "-O1", "-O2", "-O3", "-Os", "-Oz", "-Omin", "-Ofast", "-Omax"
  1756. ] = "-O2"
  1757. # Flag to keep debug information in compiled objects
  1758. is_debug = False
  1759. # Flag to keep intermediate files (assembly listings, preprocessed sources, etc.)
  1760. save_temps = False
  1761. # Flag to add `-ffast-math`` to compile flags
  1762. use_fast_math = True
  1763. # Flag to add `-fgpu-flush-denormals-to-zero` to compile flags
  1764. flush_denormals = True
  1765. # Flag to print register and LDS usage during compilation
  1766. print_kernel_resource_usage = False
  1767. # Path to ROCm installation, if None, use env variable ROCM_HOME.
  1768. # In fbcode see triton/fb/TARGETS for how ROCM_HOME gets set.
  1769. rocm_home: Optional[str] = None
  1770. # Path to Composable Kernel library.
  1771. # Install with `pip install git+https://github.com/rocm/composable_kernel@develop`.
  1772. ck_dir = os.environ.get("TORCHINDUCTOR_CK_DIR")
  1773. # generate standalone executables for instances generated with the CK backend
  1774. generate_test_runner: bool = (
  1775. os.environ.get("INDUCTOR_CK_BACKEND_GENERATE_TEST_RUNNER_CODE", "0") == "1"
  1776. )
  1777. # Deprecated, use CK and/or CK-tile specific settings
  1778. n_max_profiling_configs: Optional[int] = None
  1779. # Number of op instance choices to trade off between runtime perf and compilation time
  1780. # For CK Kernels
  1781. ck_max_profiling_configs: Optional[int] = None
  1782. # Number of op instance choices to trade off between runtime perf and compilation time
  1783. # For CK-Tile Kernels
  1784. ck_tile_max_profiling_configs: Optional[int] = None
  1785. # Flag to use a short list of CK instances which perform well across a variety of shapes.
  1786. # Currently RCR and F16 only
  1787. use_preselected_instances: bool = False
  1788. # List to determine kBatch parameters to sweep over. By default, we calculate one in splitK
  1789. # scenarios, and run on kBatch=1 in non-splitK scenarios
  1790. kBatch_sweep: Optional[list[int]] = None
  1791. # The threshold at which we trigger a splitK config - K // max(M,N) has to be greater than this
  1792. split_k_threshold: int = 16
  1793. # The threshold at which we trigger a contiguous subgraph transformation
  1794. contiguous_threshold: int = 16
  1795. # Backend to use for CPU codegen either "cpp" or "triton" (experimental) or "halide" (experimental) or "pallas" (experimental)
  1796. cpu_backend: Literal["cpp", "triton", "halide", "pallas"] = "cpp"
  1797. # Backend to use for CUDA codegen either
  1798. # "triton", "halide" (experimental) or "pallas" (experimental)
  1799. cuda_backend: Literal["triton", "halide", "pallas"] = "triton"
  1800. # Backend to use for TPU codegen
  1801. tpu_backend: Literal["pallas"] = "pallas"
  1802. # Backend to use for XPU codegen either "triton"
  1803. xpu_backend: Literal["triton"] = "triton"
  1804. class halide:
  1805. # Base halide target to use for CPU devices
  1806. cpu_target = "host"
  1807. # Base halide target to use for CUDA devices
  1808. gpu_target = "host-cuda"
  1809. # Halide autoscheduler to use, choices are:
  1810. # "Anderson2021" (gpu-only), "Li2018", "Adams2019" (cpu-only), or "Mullapudi2016" (cpu-only)
  1811. scheduler_cuda: Literal["Anderson2021", "Li2018", "Adams2019", "Mullapudi2016"] = (
  1812. "Anderson2021"
  1813. )
  1814. scheduler_cpu: Literal["Anderson2021", "Li2018", "Adams2019", "Mullapudi2016"] = (
  1815. "Adams2019"
  1816. )
  1817. # Controls `no_asserts` flag passed to Halide target (warning: can false positive)
  1818. asserts = False
  1819. # Controls `debug` flag passed to Halide target
  1820. debug = False
  1821. # Enable (or fallback on) scan kernels such as cumsum
  1822. # Halide autoschedulers struggle with these kernels
  1823. scan_kernels = False
  1824. # create a directory containing lots of debug information
  1825. class trace:
  1826. # master switch for all debugging flags below
  1827. enabled = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1"
  1828. # save real tensors
  1829. save_real_tensors = os.environ.get("TORCH_COMPILE_DEBUG_SAVE_REAL", "0") == "1"
  1830. # Save debug information to a temporary directory
  1831. # If not specified, a temp directory will be created by system
  1832. debug_dir: Optional[str] = None
  1833. # Save python logger call >=logging.DEBUG
  1834. debug_log = False
  1835. # Save python logger call >=logging.INFO
  1836. info_log = False
  1837. # Save input FX graph (post decomps, pre optimization)
  1838. fx_graph = True
  1839. # Save FX graph after transformations
  1840. fx_graph_transformed = True
  1841. # Save TorchInductor IR before fusion pass
  1842. ir_pre_fusion = True
  1843. # Save TorchInductor IR after fusion pass
  1844. ir_post_fusion = True
  1845. # Copy generated code to trace dir
  1846. output_code = True
  1847. # SVG figure showing post-fusion graph
  1848. graph_diagram = os.environ.get("INDUCTOR_POST_FUSION_SVG", "0") == "1"
  1849. # SVG figure showing fx with fusion
  1850. draw_orig_fx_graph = os.environ.get("INDUCTOR_ORIG_FX_SVG", "0") == "1"
  1851. # We draw our fx graphs with the "record" shape attribute by default.
  1852. # Sometimes, when the graph is very complex, we may hit dot errors like below:
  1853. # "flat edge between adjacent nodes one of which has a record shape -
  1854. # replace records with HTML-like labels"
  1855. # and thus fail to generate a graph. So, let's give the user an option
  1856. # to specify the shape attribute for the dot graph. For example, passing
  1857. # INDUCTOR_DOT_GRAPH_SHAPE_SVG = "none" would let us generate HTML-like labels
  1858. # to workaround the above failure.
  1859. dot_graph_shape = os.environ.get("INDUCTOR_DOT_GRAPH_SHAPE_SVG", None)
  1860. # If not None, this is the URL that saves the SVG files of the input/output
  1861. # graph of each pass that changed the graph
  1862. # The nodes that are being transformed in each pass will be colored in yellow
  1863. # URL only supports local directory for now
  1864. log_url_for_graph_xform = os.environ.get("INDUCTOR_LOG_URL_FOR_GRAPH_XFORM", None)
  1865. # Store cProfile (see snakeviz to view)
  1866. compile_profile = False
  1867. # Upload the .tar.gz file
  1868. # Needs to be overridden based on specific environment needs
  1869. upload_tar: Optional[Callable[[str], None]] = None
  1870. log_autotuning_results = os.environ.get("LOG_AUTOTUNE_RESULTS", "0") == "1"
  1871. # Save mapping info from inductor generated kernel to post_grad/pre_grad fx nodes
  1872. # Levels:
  1873. # 0 - disabled (default)
  1874. # 1 - normal
  1875. # 2 - basic
  1876. # Backward compatibility:
  1877. # If TORCH_COMPILE_DEBUG=1, level is set to at least 1.
  1878. # If INDUCTOR_PROVENANCE is set, use its integer value.
  1879. provenance_tracking_level: int = int(
  1880. os.environ.get(
  1881. "INDUCTOR_PROVENANCE", os.environ.get("TORCH_COMPILE_DEBUG", "0")
  1882. )
  1883. )
  1884. _save_config_ignore: list[str] = [
  1885. # workaround: "Can't pickle <function ...>"
  1886. "trace.upload_tar",
  1887. "joint_custom_pre_pass",
  1888. "joint_custom_post_pass",
  1889. "pre_grad_custom_pass",
  1890. "aot_inductor.repro_level",
  1891. "aot_inductor.dump_aoti_minifier",
  1892. "post_grad_custom_pre_pass",
  1893. "post_grad_custom_post_pass",
  1894. "_fuse_ddp_communication_passes",
  1895. "_pre_fusion_custom_pass",
  1896. ]
  1897. _cache_config_ignore_prefix: list[str] = [
  1898. # trace functions are not relevant to config caching
  1899. "trace",
  1900. # uses absolute path
  1901. "cuda.cutlass_dir",
  1902. "cutlass.cutlass_dir",
  1903. "xpu.cutlass_dir",
  1904. # not relevant
  1905. "worker_start_method",
  1906. "compile_threads",
  1907. # see CustomGraphPass; these are handled specially
  1908. "post_grad_custom_post_pass",
  1909. "post_grad_custom_pre_pass",
  1910. "joint_custom_pre_pass",
  1911. "joint_custom_post_pass",
  1912. "_fuse_ddp_communication_passes",
  1913. "_pre_fusion_custom_pass",
  1914. # tests assume that changes here don't invalidate cache
  1915. "always_complex_memory_overlap_TESTING_ONLY",
  1916. # cache related options are not relevant to cache results
  1917. "fx_graph_cache",
  1918. "fx_graph_remote_cache",
  1919. "autotune_local_cache",
  1920. "autotune_remote_cache",
  1921. ]
  1922. # External callable for matmul tuning candidates
  1923. external_matmul: list[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], None]] = []
  1924. write_are_deterministic_algorithms_enabled = (
  1925. os.getenv("TORCHINDUCTOR_WRITE_ARE_DETERMINISTIC_ALGORITHMS_ENABLED", "1") == "1"
  1926. )
  1927. class lookup_table:
  1928. # Lookup table for template config overrides
  1929. table: Optional[dict[str, list[dict[str, Any]]]] = None
  1930. # Enable template src_hash checking in lookup table to prevent using stale configs.
  1931. # If True, configs with 'template_hash' field will be compared against the template's
  1932. # src_hash at runtime and filtered out if they don't match. If False, no
  1933. # hash checking is performed.
  1934. check_src_hash: bool = True
  1935. class test_configs:
  1936. force_extern_kernel_in_multi_template: bool = False
  1937. max_mm_configs: Optional[int] = None
  1938. runtime_triton_dtype_assert = False
  1939. runtime_triton_shape_assert = False
  1940. static_cpp_dtype_assert = False
  1941. # regex to control the set of considered autotuning
  1942. # choices (aka configs) by name and / or description
  1943. # Can be set via TORCHINDUCTOR_AUTOTUNE_CHOICE_NAME_REGEX and
  1944. # TORCHINDUCTOR_AUTOTUNE_CHOICE_DESC_REGEX environment variables
  1945. autotune_choice_name_regex: Optional[str] = os.environ.get(
  1946. "TORCHINDUCTOR_AUTOTUNE_CHOICE_NAME_REGEX"
  1947. )
  1948. autotune_choice_desc_regex: Optional[str] = os.environ.get(
  1949. "TORCHINDUCTOR_AUTOTUNE_CHOICE_DESC_REGEX"
  1950. )
  1951. graphsafe_rng_func_ignores_fallback_random = False
  1952. track_memory_lifecycle: Optional[Literal["assert", "log"]] = None
  1953. # If set to True, AOTI-generated CMakelists.txt will still use libtorch
  1954. # for unit testing
  1955. use_libtorch = False
  1956. # Assume bucketing reduces latency (mostly for testing)
  1957. assume_bucketing_reduces_latency: bool = True
  1958. # A test config to ease the test for perf of reduction config filtering
  1959. force_filter_reduction_configs = (
  1960. os.getenv("TORCHINDUCTOR_FORCE_FILTER_REDUCTION_CONFIGS") == "1"
  1961. )
  1962. # a testing config to distort benchmarking result
  1963. # - empty string to disable
  1964. # - "inverse" to inverse the numbers
  1965. # - "random" return a random value
  1966. distort_benchmarking_result = os.getenv(
  1967. "TORCHINDUCTOR_DISTORT_BENCHMARKING_RESULT", ""
  1968. )
  1969. bisect_pre_grad_graph = False
  1970. bisect_keep_custom_backend_for_inductor = False
  1971. if TYPE_CHECKING:
  1972. from torch.utils._config_typing import * # noqa: F401, F403
  1973. class eager_numerics:
  1974. # x / y in Triton is lowered to div.full which is approx
  1975. # PyTorch eager uses the equivalent of Triton's div_rn, which can
  1976. # come at a performance penalty
  1977. division_rounding: bool = (
  1978. os.environ.get("TORCHINDUCTOR_EMULATE_DIVISION_ROUNDING", "0") == "1"
  1979. )
  1980. disable_ftz: bool = False
  1981. # Mode to emulate PyTorch eager numerics when doing lower precision compute
  1982. # (fp16, bf16). PyTorch eager computes bf16/fp16 by upcasting inputs to fp32
  1983. # and downcasting after. When two low precision operators are fused together,
  1984. # Inductor will elide the downcast-upcast pairs (effectively a precision
  1985. # truncation) that would occur between these two operators. Typically,
  1986. # Inductor's behavior should be closer to fp64 ref numerics. However, with
  1987. # this knob you can ensure the downcast-upcast are preserved so that you can
  1988. # emulate the eager numerics.
  1989. emulate_precision_casts: bool = (
  1990. os.environ.get("TORCHINDUCTOR_EMULATE_PRECISION_CASTS", "0") == "1"
  1991. )
  1992. # adds patch, save_config, etc
  1993. install_config_module(sys.modules[__name__])