_internal.py 54 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517
  1. # mypy: allow-untyped-defs
  2. import contextlib
  3. import functools
  4. import hashlib
  5. import importlib.util
  6. import itertools
  7. import json
  8. import logging
  9. import os
  10. import os.path
  11. import pathlib
  12. import pkgutil
  13. import re
  14. import sys
  15. import tempfile
  16. import time
  17. import warnings
  18. from collections import defaultdict
  19. from collections.abc import Callable, Sequence
  20. from dataclasses import dataclass, field
  21. from typing import Any, Generic, Optional, Union
  22. from typing_extensions import ParamSpec
  23. from weakref import WeakSet
  24. import torch._logging.structured
  25. from torch._guards import CompileId
  26. from torch._utils_internal import log_trace_structured_event
  27. from torch.utils._traceback import CapturedTraceback
  28. _P = ParamSpec("_P")
  29. log = logging.getLogger(__name__)
  30. # This is a synthetic logger which doesn't correspond to an actual logger,
  31. # but handles all of our "tracing" logging, which is structured and doesn't go
  32. # to stderr but always goes to a dedicated log file. We don't put these
  33. # loggers in the classic module hierarchy, because we don't want a suppression
  34. # of logs to also cause a trace to get suppressed (traces typically are not
  35. # collected, unless we are in prod, in which case they always are collected.)
  36. #
  37. # TODO: Maybe we should allow for some sub-hierarchy so you can control which
  38. # traces you want to collect, for performance reasons.
  39. #
  40. # See https://docs.google.com/document/d/1CX_hJ0PNy9f3R1y8TJrfkSeLkvGjjjLU84BSXgS2AZ8/edit
  41. trace_log = logging.getLogger("torch.__trace")
  42. DEFAULT_LOG_LEVEL = logging.WARNING
  43. LOG_ENV_VAR = "TORCH_LOGS"
  44. LOG_OUT_ENV_VAR = "TORCH_LOGS_OUT"
  45. LOG_FORMAT_ENV_VAR = "TORCH_LOGS_FORMAT"
  46. LOG_TRACE_ID_FILTER = "TORCH_LOGS_TRACE_ID_FILTER"
  47. TRACE_ENV_VAR = "TORCH_TRACE"
  48. DTRACE_ENV_VAR = "TORCH_DTRACE"
  49. LOG_TRACE_HANDLER: Optional["LazyTraceHandler"] = None
  50. GET_DTRACE_STRUCTURED = False
  51. LOG_PREFIX = "dedicated_log_torch_trace_"
  52. @dataclass
  53. class LogRegistry:
  54. # shorthand name to log qualified name
  55. # Note: this only contains loggers registered
  56. # from register_log
  57. # e.g. "dynamo" -> "torch._dynamo"
  58. log_alias_to_log_qnames: dict[str, list[str]] = field(default_factory=dict)
  59. # artifact logger qualified names,
  60. # this is populated lazily, as calls to getArtifactLogger
  61. # currently formatted as <module>.__<artifact_name>
  62. # e.g. "torch._dynamo.convert_frame.__guards"
  63. artifact_log_qnames: set[str] = field(default_factory=set)
  64. # child logs of registered logs if specified via open
  65. # registration by the user (ie placing "torch._dynamo.output_graph" in the env var)
  66. # these need to be tracked so their levels can be reset properly
  67. # e.g. "torch._dynamo.output_graph"
  68. child_log_qnames: set[str] = field(default_factory=set)
  69. # artifact names, populated by register_artifact
  70. # e.g. "guards"
  71. artifact_names: set[str] = field(default_factory=set)
  72. # Artifacts that should be visible by default in the error message
  73. visible_artifacts: set[str] = field(default_factory=set)
  74. # A short description of each artifact
  75. artifact_descriptions: dict[str, str] = field(default_factory=dict)
  76. # artifacts which are not displayed unless explicitly named in the
  77. # settings. Ex. output_code is NOT displayed even if the inductor
  78. # log level is set to DEBUG. It must be explicitly named in the settings
  79. off_by_default_artifact_names: set[str] = field(default_factory=set)
  80. # logging format string for artifacts
  81. artifact_log_formatters: dict[str, logging.Formatter] = field(default_factory=dict)
  82. def is_artifact(self, name):
  83. return name in self.artifact_names
  84. def is_log(self, alias):
  85. return alias in self.log_alias_to_log_qnames
  86. # register a log with an alias
  87. def register_log(self, alias, log_qnames: Union[str, list[str]]) -> None:
  88. if isinstance(log_qnames, str):
  89. log_qnames = [log_qnames]
  90. self.log_alias_to_log_qnames[alias] = log_qnames
  91. # register an artifact name
  92. def register_artifact_name(
  93. self, name, description, visible, off_by_default, log_format
  94. ) -> None:
  95. self.artifact_names.add(name)
  96. if visible:
  97. self.visible_artifacts.add(name)
  98. self.artifact_descriptions[name] = description
  99. # if off by default, don't enable it
  100. # when log_name's log_level is set to DEBUG
  101. if off_by_default:
  102. self.off_by_default_artifact_names.add(name)
  103. if log_format is not None:
  104. self.artifact_log_formatters[name] = logging.Formatter(log_format)
  105. # register the qualified name of an artifact log
  106. # this is needed to know which logs need to be reset
  107. # whenever the log_state is changed
  108. def register_artifact_log(self, artifact_log_qname) -> None:
  109. self.artifact_log_qnames.add(artifact_log_qname)
  110. def register_child_log(self, log_qname) -> None:
  111. self.child_log_qnames.add(log_qname)
  112. # flattens all the qnames together (TODO: consider memoizing?)
  113. def get_log_qnames(self) -> set[str]:
  114. return set(itertools.chain.from_iterable(self.log_alias_to_log_qnames.values()))
  115. def get_artifact_log_qnames(self):
  116. return set(self.artifact_log_qnames)
  117. def get_child_log_qnames(self):
  118. return set(self.child_log_qnames)
  119. def is_off_by_default(self, artifact_qname):
  120. return artifact_qname in self.off_by_default_artifact_names
  121. @dataclass
  122. class LogState:
  123. # qualified log names -> currently set log level
  124. log_qname_to_level: dict[str, str] = field(default_factory=dict)
  125. # the set of currently enabled artifacts
  126. artifact_names: set[str] = field(default_factory=set)
  127. def enable_artifact(self, artifact_name) -> None:
  128. self.artifact_names.add(artifact_name)
  129. def is_artifact_enabled(self, name):
  130. return name in self.artifact_names
  131. def enable_log(self, log_qnames, log_level) -> None:
  132. if isinstance(log_qnames, str):
  133. log_qnames = [log_qnames]
  134. for log_qname in log_qnames:
  135. self.log_qname_to_level[log_qname] = log_level
  136. def get_log_level_pairs(self):
  137. """Returns all qualified module names for which the user requested
  138. explicit logging settings.
  139. .. warning:
  140. This function used to return all loggers, regardless of whether
  141. or not the user specified them or not; it now only returns logs
  142. which were explicitly mentioned by the user (and torch, which
  143. always is implicitly requested when we initialize our logging
  144. subsystem.)
  145. """
  146. return self.log_qname_to_level.items()
  147. def clear(self) -> None:
  148. self.log_qname_to_level.clear()
  149. self.artifact_names.clear()
  150. log_registry = LogRegistry()
  151. log_state = LogState()
  152. # sample usage: torch._logging.set_logs(**torch._logging.DEFAULT_LOGGING)
  153. DEFAULT_LOGGING = {
  154. "dynamo": logging.INFO,
  155. "aot": logging.INFO,
  156. "inductor": logging.INFO,
  157. "fsdp": logging.INFO,
  158. "ddp_graphs": True,
  159. "graph_breaks": True,
  160. "side_effects": True,
  161. "guards": True,
  162. "recompiles": True,
  163. "dynamic": logging.INFO,
  164. }
  165. def set_logs(
  166. *,
  167. all: Optional[int] = None,
  168. dynamo: Optional[int] = None,
  169. aot: Optional[int] = None,
  170. autograd: Optional[int] = None,
  171. dynamic: Optional[int] = None,
  172. inductor: Optional[int] = None,
  173. distributed: Optional[int] = None,
  174. c10d: Optional[int] = None,
  175. ddp: Optional[int] = None,
  176. fsdp: Optional[int] = None,
  177. dtensor: Optional[int] = None,
  178. onnx: Optional[int] = None,
  179. bytecode: bool = False,
  180. aot_graphs: bool = False,
  181. aot_joint_graph: bool = False,
  182. ddp_graphs: bool = False,
  183. graph: bool = False,
  184. graph_code: bool = False,
  185. graph_code_verbose: bool = False,
  186. graph_breaks: bool = False,
  187. side_effects: bool = False,
  188. graph_sizes: bool = False,
  189. guards: bool = False,
  190. recompiles: bool = False,
  191. recompiles_verbose: bool = False,
  192. trace_source: bool = False,
  193. trace_call: bool = False,
  194. trace_bytecode: bool = False,
  195. output_code: bool = False,
  196. kernel_code: bool = False,
  197. schedule: bool = False,
  198. perf_hints: bool = False,
  199. pre_grad_graphs: bool = False,
  200. post_grad_graphs: bool = False,
  201. ir_pre_fusion: bool = False,
  202. ir_post_fusion: bool = False,
  203. onnx_diagnostics: bool = False,
  204. fusion: bool = False,
  205. overlap: bool = False,
  206. export: Optional[int] = None,
  207. modules: Optional[dict[str, Union[int, bool]]] = None,
  208. cudagraphs: bool = False,
  209. sym_node: bool = False,
  210. compiled_autograd: bool = False,
  211. compiled_autograd_verbose: bool = False,
  212. cudagraph_static_inputs: bool = False,
  213. benchmarking: bool = False,
  214. autotuning: bool = False,
  215. graph_region_expansion: bool = False,
  216. inductor_metrics: bool = False,
  217. hierarchical_compile: bool = False,
  218. compute_dependencies: bool = False,
  219. caching: bool = False,
  220. ) -> None:
  221. """
  222. Sets the log level for individual components and toggles individual log
  223. artifact types.
  224. .. warning:: This feature is a prototype and may have compatibility
  225. breaking changes in the future.
  226. .. note:: The ``TORCH_LOGS`` environment variable has complete precedence
  227. over this function, so if it was set, this function does nothing.
  228. A component is a set of related features in PyTorch. All of the log
  229. messages emitted from a given component have their own log levels. If the
  230. log level of a particular message has priority greater than or equal to its
  231. component's log level setting, it is emitted. Otherwise, it is suppressed.
  232. This allows you to, for instance, silence large groups of log messages that
  233. are not relevant to you and increase verbosity of logs for components that
  234. are relevant. The expected log level values, ordered from highest to lowest
  235. priority, are:
  236. * ``logging.CRITICAL``
  237. * ``logging.ERROR``
  238. * ``logging.WARNING``
  239. * ``logging.INFO``
  240. * ``logging.DEBUG``
  241. * ``logging.NOTSET``
  242. See documentation for the Python ``logging`` module for more information on
  243. log levels: `<https://docs.python.org/3/library/logging.html#logging-levels>`_
  244. An artifact is a particular type of log message. Each artifact is assigned
  245. to a parent component. A component can emit many different kinds of
  246. artifacts. In general, an artifact is emitted if either its corresponding
  247. setting in the argument list below is turned on or if its parent component
  248. is set to a log level less than or equal to the log level of the artifact.
  249. Keyword args:
  250. all (:class:`Optional[int]`):
  251. The default log level for all components. Default: ``logging.WARN``
  252. dynamo (:class:`Optional[int]`):
  253. The log level for the TorchDynamo component. Default: ``logging.WARN``
  254. aot (:class:`Optional[int]`):
  255. The log level for the AOTAutograd component. Default: ``logging.WARN``
  256. autograd (:class:`Optional[int]`):
  257. The log level for autograd. Default: ``logging.WARN``
  258. inductor (:class:`Optional[int]`):
  259. The log level for the TorchInductor component. Default: ``logging.WARN``
  260. dynamic (:class:`Optional[int]`):
  261. The log level for dynamic shapes. Default: ``logging.WARN``
  262. distributed (:class:`Optional[int]`):
  263. Whether to log c10d communication operations and other debug info from PyTorch Distributed components.
  264. Default: ``logging.WARN``
  265. c10d (:class:`Optional[int]`):
  266. Whether to log c10d communication operations related debug info in PyTorch Distributed components.
  267. Default: ``logging.WARN``
  268. ddp (:class:`Optional[int]`):
  269. Whether to log debug info related to ``DistributedDataParallel``(DDP) from PyTorch Distributed components.
  270. Default: ``logging.WARN``
  271. fsdp (:class:`Optional[int]`):
  272. Whether to log debug info related to ``FullyShardedDataParallel``(FSDP) in PyTorch Distributed components.
  273. Default: ``logging.WARN``
  274. dtensor (:class:`Optional[int]`):
  275. Whether to log debug info related to ``DTensor``(DTensor) in PyTorch Distributed components.
  276. Default: ``logging.WARN``
  277. onnx (:class:`Optional[int]`):
  278. The log level for the ONNX exporter component. Default: ``logging.WARN``
  279. bytecode (:class:`bool`):
  280. Whether to emit the original and generated bytecode from TorchDynamo.
  281. Default: ``False``
  282. aot_graphs (:class:`bool`):
  283. Whether to emit the graphs generated by AOTAutograd. Default: ``False``
  284. aot_joint_graph (:class:`bool`):
  285. Whether to emit the joint forward-backward graph generated by AOTAutograd. Default: ``False``
  286. ddp_graphs (:class:`bool`):
  287. Whether to emit graphs generated by DDPOptimizer. Default: ``False``
  288. graph (:class:`bool`):
  289. Whether to emit the graph captured by TorchDynamo in tabular format.
  290. Default: ``False``
  291. graph_code (:class:`bool`):
  292. Whether to emit the python source of the graph captured by TorchDynamo.
  293. Default: ``False``
  294. graph_code_verbose (:class:`bool`):
  295. Whether to emit verbose/intermediate FX pass logs for graph code. Default: ``False``
  296. graph_breaks (:class:`bool`):
  297. Whether to emit the graph breaks encountered by TorchDynamo.
  298. Default: ``False``
  299. side_effects (:class:`bool`):
  300. Whether to emit side effects (mutations, hooks, etc.) that TorchDynamo
  301. codegenerates in the output graph. Default: ``False``
  302. graph_sizes (:class:`bool`):
  303. Whether to emit tensor sizes of the graph captured by TorchDynamo.
  304. Default: ``False``
  305. guards (:class:`bool`):
  306. Whether to emit the guards generated by TorchDynamo for each compiled
  307. function. Default: ``False``
  308. recompiles (:class:`bool`):
  309. Whether to emit a guard failure reason and message every time
  310. TorchDynamo recompiles a function. Default: ``False``
  311. recompiles_verbose (:class:`bool`):
  312. Whether to emit all guard failure reasons when TorchDynamo recompiles
  313. a function, even those that are not actually run. Default: ``False``
  314. trace_source (:class:`bool`):
  315. Whether to emit when TorchDynamo begins tracing a new line. Default: ``False``
  316. trace_call (:class:`bool`):
  317. Whether to emit detailed line location when TorchDynamo creates an FX node
  318. corresponding to function call. Python 3.11+ only. Default: ``False``
  319. trace_bytecode (:class:`bool`):
  320. Whether to emit bytecode instructions and traced stack state as TorchDynamo
  321. traces bytecode. Default: ``False``
  322. output_code (:class:`bool`):
  323. Whether to emit the TorchInductor output code on a per-graph basis. Default: ``False``
  324. kernel_code (:class:`bool`):
  325. Whether to emit the TorchInductor output code on a per-kernel bases. Default: ``False``
  326. schedule (:class:`bool`):
  327. Whether to emit the TorchInductor schedule. Default: ``False``
  328. perf_hints (:class:`bool`):
  329. Whether to emit the TorchInductor perf hints. Default: ``False``
  330. pre_grad_graphs (:class:`bool`):
  331. Whether to emit the graphs before inductor grad passes. Default: ``False``
  332. post_grad_graphs (:class:`bool`):
  333. Whether to emit the graphs generated by after post grad passes. Default: ``False``
  334. ir_pre_fusion (:class:`bool`):
  335. Whether to emit the graphs before inductor fusion passes. Default: ``False``
  336. ir_post_fusion (:class:`bool`):
  337. Whether to emit the graphs after inductor fusion passes. Default: ``False``
  338. onnx_diagnostics (:class:`bool`):
  339. Whether to emit the ONNX exporter diagnostics in logging. Default: ``False``
  340. fusion (:class:`bool`):
  341. Whether to emit detailed Inductor fusion decisions. Default: ``False``
  342. overlap (:class:`bool`):
  343. Whether to emit detailed Inductor compute/comm overlap decisions. Default: ``False``
  344. sym_node (:class:`bool`):
  345. Whether to emit debug info for various SymNode opterations. Default: ``False``
  346. export (:class:`Optional[int]`):
  347. The log level for export. Default: ``logging.WARN``
  348. benchmarking (:class:`bool`):
  349. Whether to emit detailed Inductor benchmarking information. Default: ``False``
  350. modules (dict):
  351. This argument provides an alternate way to specify the above log
  352. component and artifact settings, in the format of a keyword args
  353. dictionary given as a single argument. There are two cases
  354. where this is useful (1) if a new log component or artifact has
  355. been registered but a keyword argument for it has not been added
  356. to this function and (2) if the log level for an unregistered module
  357. needs to be set. This can be done by providing the fully-qualified module
  358. name as the key, with the log level as the value. Default: ``None``
  359. cudagraph_static_inputs (:class:`bool`):
  360. Whether to emit debug info for cudagraph static input detection. Default: ``False``
  361. autotuning (:class:`bool`):
  362. Autotuning choice logs, such as kernel source, perf, and tuning parameters. Default: ``False``
  363. graph_region_expansion (:class:`bool`):
  364. Whether to emit the detailed steps of the duplicate graph region tracker expansion algorithm. Default: ``False``
  365. inductor_metrics (:class:`bool`):
  366. Whether to estimate the runtimes of the nodes in a graph and log them to the metrics table. Default: ``False``
  367. hierarchical_compile (:class:`bool`):
  368. Whether to emit debug info for hierarchical compilation. Default: ``False``
  369. caching (:class:`bool`):
  370. Whether to emit detailed Inductor caching information. Default: ``False``
  371. Example::
  372. >>> # xdoctest: +SKIP
  373. >>> import logging
  374. # The following changes the "dynamo" component to emit DEBUG-level
  375. # logs, and to emit "graph_code" artifacts.
  376. >>> torch._logging.set_logs(dynamo=logging.DEBUG, graph_code=True)
  377. # The following enables the logs for a different module
  378. >>> torch._logging.set_logs(modules={"unregistered.module.name": logging.DEBUG})
  379. """
  380. # ignore if env var is set
  381. if LOG_ENV_VAR in os.environ:
  382. log.warning(
  383. "Using TORCH_LOGS environment variable for log settings, ignoring call to set_logs"
  384. )
  385. return
  386. log_state.clear()
  387. modules = modules or {}
  388. def _set_logs(**kwargs) -> None:
  389. for alias, val in itertools.chain(kwargs.items(), modules.items()): # type: ignore[union-attr]
  390. if val is None:
  391. continue
  392. if log_registry.is_artifact(alias):
  393. if not isinstance(val, bool):
  394. raise ValueError(
  395. f"Expected bool to enable artifact {alias}, received {val}"
  396. )
  397. if val:
  398. log_state.enable_artifact(alias)
  399. elif log_registry.is_log(alias) or alias in log_registry.child_log_qnames:
  400. if val not in logging._levelToName:
  401. raise ValueError(
  402. f"Unrecognized log level for log {alias}: {val}, valid level values "
  403. f"are: {','.join([str(k) for k in logging._levelToName])}"
  404. )
  405. log_state.enable_log(
  406. log_registry.log_alias_to_log_qnames.get(alias, alias), val
  407. )
  408. elif _is_valid_module(alias):
  409. found_modules = _get_module_and_submodules(alias) or (alias,)
  410. for module_name in found_modules:
  411. if not _has_registered_parent(module_name):
  412. log_registry.register_log(module_name, module_name)
  413. else:
  414. log_registry.register_child_log(module_name)
  415. log_state.enable_log(
  416. log_registry.log_alias_to_log_qnames.get(
  417. module_name, module_name
  418. ),
  419. val,
  420. )
  421. else:
  422. raise ValueError(
  423. f"Unrecognized log or artifact name passed to set_logs: {alias}"
  424. )
  425. _init_logs()
  426. _set_logs(
  427. torch=all,
  428. dynamo=dynamo,
  429. aot=aot,
  430. autograd=autograd,
  431. inductor=inductor,
  432. dynamic=dynamic,
  433. bytecode=bytecode,
  434. aot_graphs=aot_graphs,
  435. aot_joint_graph=aot_joint_graph,
  436. ddp_graphs=ddp_graphs,
  437. distributed=distributed,
  438. c10d=c10d,
  439. ddp=ddp,
  440. fsdp=fsdp,
  441. dtensor=dtensor,
  442. graph=graph,
  443. graph_code=graph_code,
  444. graph_code_verbose=graph_code_verbose,
  445. graph_breaks=graph_breaks,
  446. side_effects=side_effects,
  447. graph_sizes=graph_sizes,
  448. guards=guards,
  449. recompiles=recompiles,
  450. recompiles_verbose=recompiles_verbose,
  451. trace_source=trace_source,
  452. trace_call=trace_call,
  453. trace_bytecode=trace_bytecode,
  454. output_code=output_code,
  455. kernel_code=kernel_code,
  456. schedule=schedule,
  457. perf_hints=perf_hints,
  458. pre_grad_graphs=pre_grad_graphs,
  459. post_grad_graphs=post_grad_graphs,
  460. ir_pre_fusion=ir_pre_fusion,
  461. ir_post_fusion=ir_post_fusion,
  462. onnx=onnx,
  463. onnx_diagnostics=onnx_diagnostics,
  464. fusion=fusion,
  465. overlap=overlap,
  466. sym_node=sym_node,
  467. export=export,
  468. cudagraphs=cudagraphs,
  469. compiled_autograd=compiled_autograd,
  470. compiled_autograd_verbose=compiled_autograd_verbose,
  471. cudagraph_static_inputs=cudagraph_static_inputs,
  472. benchmarking=benchmarking,
  473. autotuning=autotuning,
  474. graph_region_expansion=graph_region_expansion,
  475. inductor_metrics=inductor_metrics,
  476. hierarchical_compile=hierarchical_compile,
  477. compute_dependencies=compute_dependencies,
  478. caching=caching,
  479. )
  480. def get_loggers() -> list[logging.Logger]:
  481. """
  482. Returns: a list of all registered loggers
  483. """
  484. return [logging.getLogger(qname) for qname in log_registry.get_log_qnames()]
  485. def register_log(setting_name, log_name) -> None:
  486. """
  487. Enables a log to be controlled by the env var and user API with the setting_name
  488. Args:
  489. setting_name: the shorthand name used in the env var and user API
  490. log_name: the log name that the setting_name is associated with
  491. """
  492. log_registry.register_log(setting_name, log_name)
  493. def register_artifact(
  494. setting_name, description, visible=False, off_by_default=False, log_format=None
  495. ) -> None:
  496. """
  497. Enables an artifact to be controlled by the env var and user API with name
  498. Args:
  499. setting_name: the shorthand name used in the env var and user API
  500. description: A description of what this outputs
  501. visible: Whether it gets suggested to users by default
  502. off_by_default: whether this artifact should be logged when the ancestor loggers
  503. are enabled at level DEBUG
  504. """
  505. log_registry.register_artifact_name(
  506. setting_name, description, visible, off_by_default, log_format
  507. )
  508. def getArtifactLogger(module_qname, artifact_name) -> logging.Logger:
  509. if artifact_name not in log_registry.artifact_names:
  510. raise ValueError(
  511. f"Artifact name: {repr(artifact_name)} not registered,"
  512. f"please call register_artifact({repr(artifact_name)}) in torch._logging.registrations."
  513. )
  514. qname = module_qname + f".__{artifact_name}"
  515. log = logging.getLogger(qname)
  516. log.artifact_name = artifact_name # type: ignore[attr-defined]
  517. log_registry.register_artifact_log(qname)
  518. configure_artifact_log(log)
  519. return log
  520. INCR_VERBOSITY_CHAR = "+"
  521. DECR_VERBOSITY_CHAR = "-"
  522. VERBOSITY_REGEX = (
  523. "("
  524. + "|".join([re.escape(INCR_VERBOSITY_CHAR), re.escape(DECR_VERBOSITY_CHAR)])
  525. + "?)"
  526. )
  527. def configure_artifact_log(log) -> None:
  528. # If the artifact is off by default, then it should only be logged when explicitly
  529. # enabled; set propagate to False so that this artifact is not propagated
  530. # to its ancestor logger
  531. if log_registry.is_off_by_default(log.artifact_name):
  532. log.propagate = False
  533. # enable artifact logging when explicitly enabled
  534. if log_state.is_artifact_enabled(log.artifact_name):
  535. log.setLevel(logging.DEBUG)
  536. log.propagate = True
  537. # match a comma separated list of loggable names (whitespace allowed after commas)
  538. def _gen_settings_regex():
  539. return re.compile(r"((\+|-)?[\w\.]+,\s*)*(\+|-)?[\w\.]+?")
  540. def _validate_settings(settings):
  541. return re.fullmatch(_gen_settings_regex(), settings) is not None
  542. def help_message(verbose=False):
  543. def pad_to(s, length=30):
  544. if len(s) > length:
  545. raise AssertionError(f"string length {len(s)} exceeds max {length}")
  546. return s + " " * (length - len(s))
  547. if verbose:
  548. printed_artifacts = log_registry.artifact_names
  549. else:
  550. printed_artifacts = log_registry.visible_artifacts
  551. if verbose:
  552. heading = "All registered names"
  553. else:
  554. heading = "Visible registered names (use TORCH_LOGS='+help' for full list)"
  555. lines = (
  556. ["all"]
  557. + sorted(log_registry.log_alias_to_log_qnames.keys())
  558. + sorted(
  559. [
  560. f"{pad_to(name)}\t{log_registry.artifact_descriptions[name]}"
  561. for name in printed_artifacts
  562. ]
  563. )
  564. )
  565. setting_info = " " + "\n ".join(lines)
  566. examples = """
  567. Examples:
  568. TORCH_LOGS="+dynamo,aot" will set the log level of TorchDynamo to
  569. logging.DEBUG and AOT to logging.INFO
  570. TORCH_LOGS="-dynamo,+inductor" will set the log level of TorchDynamo to
  571. logging.ERROR and TorchInductor to logging.DEBUG
  572. TORCH_LOGS="aot_graphs" will enable the aot_graphs artifact
  573. TORCH_LOGS="+dynamo,schedule" will enable set the log level of TorchDynamo
  574. to logging.DEBUG and enable the schedule artifact
  575. TORCH_LOGS="+some.random.module,schedule" will set the log level of
  576. some.random.module to logging.DEBUG and enable the schedule artifact
  577. TORCH_LOGS="+torch._functorch._aot_autograd" will set the log level of
  578. torch._functorch._aot_autograd and all its submodules to logging.DEBUG
  579. (directory-based logging)
  580. TORCH_LOGS_FORMAT="%(levelname)s: %(message)s" or any provided format
  581. string will set the output format
  582. Valid keys are "levelname", "message", "pathname", "levelno", "lineno",
  583. "filename" and "name".
  584. TORCH_LOGS_OUT=/tmp/output.txt will output the logs to /tmp/output.txt as
  585. well. This is useful when the output is long.
  586. """
  587. msg = f"""
  588. TORCH_LOGS Info
  589. {examples}
  590. {heading}
  591. {setting_info}
  592. """
  593. return msg
  594. def _invalid_settings_err_msg(settings, verbose=False):
  595. valid_settings = (
  596. ["all"]
  597. + list(log_registry.log_alias_to_log_qnames.keys())
  598. + list(log_registry.artifact_names)
  599. )
  600. valid_settings = ", ".join(sorted(valid_settings))
  601. msg = f"""
  602. Invalid log settings: {settings}, must be a comma separated list of fully
  603. qualified module names, registered log names or registered artifact names.
  604. For more info on various settings, try TORCH_LOGS="help"
  605. Valid settings:
  606. {valid_settings}
  607. """
  608. return msg
  609. def process_env_var_string_for_windows(env_var_str: str) -> str:
  610. """
  611. When we setup logging config as guide: https://docs.pytorch.org/docs/stable/logging.html
  612. Such as:
  613. TORCH_LOGS="+schedule,+inductor,+output_code"
  614. On Linux, it shows as:
  615. declare -x SSH_TTY="/dev/pts/0"
  616. declare -x TERM="xterm"
  617. declare -x TORCH_LOGS="+schedule,+inductor,+output_code"
  618. declare -x USER="xu"
  619. On Windows, it shows as:
  620. TORCHINDUCTOR_WINDOWS_TESTS=1
  621. TORCH_LOGS="+schedule,+inductor,+output_code"
  622. UCRTVersion=10.0.22000.0
  623. For Linux, it shows quotes by default, And Windows is not shows quotes.
  624. Besides that, Windows would auto assemble quotes when env var processing.
  625. On Linux, we will get variable: "+schedule,+inductor,+output_code"
  626. On Windows, we will get variable: '"+schedule,+inductor,+output_code"'
  627. So, we need remove the outer quotes for Windows.
  628. """
  629. _IS_WINDOWS = sys.platform == "win32"
  630. def remove_outer_quotes(s: str) -> str:
  631. if len(s) >= 2 and (
  632. (s[0] == '"' and s[-1] == '"') or (s[0] == "'" and s[-1] == "'")
  633. ):
  634. return s[1:-1]
  635. return s
  636. if _IS_WINDOWS:
  637. env_var_str = remove_outer_quotes(env_var_str)
  638. return env_var_str
  639. @functools.lru_cache
  640. def _parse_log_settings(settings):
  641. settings = process_env_var_string_for_windows(settings)
  642. if settings == "":
  643. return {}
  644. if settings == "help":
  645. raise ValueError(help_message(verbose=False))
  646. elif settings == "+help":
  647. raise ValueError(help_message(verbose=True))
  648. if not _validate_settings(settings):
  649. raise ValueError(_invalid_settings_err_msg(settings))
  650. settings = re.sub(r"\s+", "", settings)
  651. log_names = settings.split(",")
  652. def get_name_level_pair(name):
  653. clean_name = name.replace(INCR_VERBOSITY_CHAR, "")
  654. clean_name = clean_name.replace(DECR_VERBOSITY_CHAR, "")
  655. if name[0] == INCR_VERBOSITY_CHAR:
  656. level = logging.DEBUG
  657. elif name[0] == DECR_VERBOSITY_CHAR:
  658. level = logging.ERROR
  659. else:
  660. level = logging.INFO
  661. return clean_name, level
  662. log_state = LogState()
  663. for name in log_names:
  664. name, level = get_name_level_pair(name)
  665. if name == "all":
  666. name = "torch"
  667. if log_registry.is_log(name):
  668. if level is None:
  669. raise AssertionError("level must not be None for log name")
  670. log_qnames = log_registry.log_alias_to_log_qnames[name]
  671. log_state.enable_log(log_qnames, level)
  672. elif log_registry.is_artifact(name):
  673. log_state.enable_artifact(name)
  674. elif _is_valid_module(name):
  675. # Get the module and all its submodules if it's a package
  676. found_modules = _get_module_and_submodules(name) or (name,)
  677. for module_name in found_modules:
  678. if not _has_registered_parent(module_name):
  679. log_registry.register_log(module_name, module_name)
  680. else:
  681. log_registry.register_child_log(module_name)
  682. log_state.enable_log(module_name, level)
  683. else:
  684. raise ValueError(_invalid_settings_err_msg(settings))
  685. return log_state
  686. def _is_valid_module(qname):
  687. spec = importlib.util.find_spec(qname)
  688. return spec is not None
  689. def _get_module_and_submodules(qname: str) -> Sequence[str] | None:
  690. """
  691. Get a module and all its submodules (recursively).
  692. If qname is a package, this returns a list of all modules and submodules.
  693. If qname is a simple module, this returns a list containing just that module.
  694. Args:
  695. qname: The fully qualified module name
  696. Returns:
  697. A list of fully qualified module names, or None if the module doesn't exist
  698. """
  699. spec = importlib.util.find_spec(qname)
  700. if spec is None:
  701. return None
  702. modules = [qname]
  703. if spec.submodule_search_locations is not None:
  704. package = importlib.import_module(qname)
  705. if hasattr(package, "__path__"):
  706. for importer, modname, ispkg in pkgutil.walk_packages(
  707. path=package.__path__,
  708. prefix=qname + ".",
  709. onerror=lambda x: None,
  710. ):
  711. modules.append(modname)
  712. return modules
  713. def _update_log_state_from_env() -> None:
  714. global log_state
  715. log_setting = os.environ.get(LOG_ENV_VAR, None)
  716. if log_setting is not None:
  717. log_state = _parse_log_settings(log_setting)
  718. def _has_registered_parent(log_qname) -> bool:
  719. cur_log = logging.getLogger(log_qname)
  720. registered_log_qnames = log_registry.get_log_qnames()
  721. while cur_log.parent:
  722. if cur_log.name in registered_log_qnames:
  723. return True
  724. cur_log = cur_log.parent
  725. return False
  726. @functools.lru_cache
  727. def _make_module_path_relative(abs_path: str, sys_path: tuple[str, ...]) -> str:
  728. """
  729. The string manipulation here is fairly expensive and we expect very high
  730. cache hit rates. Because `sys.path` changes very infrequently it's most
  731. performant to convert it into a tuple of strings and use that for the cache
  732. key because python has dedicated fast paths for list to tuple conversion
  733. and tuple hashing. (Empirically, a single top-level cache is about 2x faster
  734. than trying to cache individual parts.)
  735. """
  736. abs_path_resolved = pathlib.Path(abs_path).resolve()
  737. for path in sys_path:
  738. try:
  739. rel_path = abs_path_resolved.relative_to(path)
  740. except ValueError:
  741. continue
  742. else:
  743. return str(rel_path)
  744. return str(abs_path_resolved)
  745. def make_module_path_relative(abs_path: str) -> str:
  746. """
  747. Given an absolute filepath corresponding to a Python module which was
  748. loaded via normal import mechanisms using sys.path, convert it into
  749. a relative path relative to one of the Python search paths.
  750. """
  751. return _make_module_path_relative(abs_path, tuple(sys.path))
  752. # apply custom formats to artifacts when necessary
  753. class TorchLogsFormatter(logging.Formatter):
  754. def __init__(
  755. self, *, trace: bool = False, trace_id_filter: Optional[set[str]] = None
  756. ) -> None:
  757. super().__init__()
  758. self._is_trace = trace
  759. self._trace_id_filter = trace_id_filter
  760. def format(self, record):
  761. artifact_name = getattr(logging.getLogger(record.name), "artifact_name", None)
  762. if artifact_name is not None:
  763. artifact_formatter = log_registry.artifact_log_formatters.get(
  764. artifact_name, None
  765. )
  766. if artifact_formatter is not None:
  767. return artifact_formatter.format(record)
  768. record.message = record.getMessage()
  769. record.asctime = self.formatTime(record, "%m%d %H:%M:%S")
  770. # exception handling - copied from logging.Formatter.format
  771. s = record.message
  772. if record.exc_info:
  773. from torch._dynamo import config
  774. should_format_exc = config.verbose or artifact_name != "graph_breaks"
  775. # Cache the traceback text to avoid converting it multiple times
  776. # (it's constant anyway)
  777. if should_format_exc:
  778. if not record.exc_text:
  779. record.exc_text = self.formatException(record.exc_info)
  780. if record.exc_text:
  781. if s[-1:] != "\n":
  782. s = s + "\n"
  783. s = s + record.exc_text
  784. if record.stack_info:
  785. if s[-1:] != "\n":
  786. s = s + "\n"
  787. s = s + self.formatStack(record.stack_info)
  788. record.rankprefix = ""
  789. if not self._is_trace and dist.is_available() and dist.is_initialized():
  790. record.rankprefix = f"[rank{dist.get_rank()}]:"
  791. record.traceid = ""
  792. if (
  793. not self._is_trace
  794. and (trace_id := torch._guards.CompileContext.current_trace_id())
  795. is not None
  796. ):
  797. record.traceid = f" [{trace_id}]"
  798. glog_level_to_abbr = {
  799. "DEBUG": "V", # V is for VERBOSE in glog
  800. "INFO": "I",
  801. "WARNING": "W",
  802. "ERROR": "E",
  803. "CRITICAL": "C",
  804. }
  805. shortlevel = glog_level_to_abbr.get(record.levelname, record.levelname)
  806. record.artifactprefix = ""
  807. if artifact_name is not None:
  808. record.artifactprefix = f" [__{artifact_name}]"
  809. filepath = make_module_path_relative(record.pathname)
  810. if (
  811. self._trace_id_filter
  812. and record.traceid.strip() not in self._trace_id_filter
  813. ):
  814. return ""
  815. prefix = (
  816. f"{record.rankprefix}{shortlevel}{record.asctime}.{int(record.msecs * 1000):06d} {record.process} "
  817. f"{filepath}:"
  818. f"{record.lineno}]{record.traceid}{record.artifactprefix}"
  819. )
  820. if self._is_trace:
  821. if s != "":
  822. raise AssertionError(f"expected empty string for trace, got {s!r}")
  823. try:
  824. r = f"{prefix} {json.dumps(record.metadata)}"
  825. except TypeError:
  826. log.warning("failing metadata: %r", record.metadata)
  827. raise
  828. if record.payload is not None:
  829. r += "".join(f"\n\t{l}" for l in record.payload.split("\n"))
  830. return r
  831. else:
  832. lines = s.split("\n")
  833. return "\n".join(f"{prefix} {l}" for l in lines)
  834. def _default_formatter():
  835. fmt = os.environ.get(LOG_FORMAT_ENV_VAR, None)
  836. trace_id_filter = {
  837. item.strip()
  838. for item in os.environ.get(LOG_TRACE_ID_FILTER, "").split(",")
  839. if item.strip()
  840. }
  841. if fmt is None:
  842. return TorchLogsFormatter(trace_id_filter=trace_id_filter)
  843. else:
  844. if fmt in ("short", "basic"):
  845. fmt = logging.BASIC_FORMAT
  846. return logging.Formatter(fmt)
  847. DEFAULT_FORMATTER = _default_formatter()
  848. def _setup_handlers(create_handler_fn, log) -> None:
  849. debug_handler = _track_handler(create_handler_fn())
  850. debug_handler.setFormatter(DEFAULT_FORMATTER)
  851. debug_handler.setLevel(logging.DEBUG)
  852. log.addHandler(debug_handler)
  853. handlers = WeakSet() # type: ignore[var-annotated]
  854. # mark handlers that we've created
  855. # so we don't modify user handlers
  856. def _track_handler(handler):
  857. handlers.add(handler)
  858. return handler
  859. def _is_torch_handler(handler):
  860. return handler in handlers
  861. # clears all torch handlers on specified loggers
  862. def _clear_handlers(log) -> None:
  863. to_remove = [handler for handler in log.handlers if _is_torch_handler(handler)]
  864. for handler in to_remove:
  865. log.removeHandler(handler)
  866. def _reset_logs() -> None:
  867. # reset all registered logs
  868. for log_qname in log_registry.get_log_qnames():
  869. log = logging.getLogger(log_qname)
  870. log.setLevel(logging.WARNING)
  871. log.propagate = False
  872. _clear_handlers(log)
  873. # reset all artifact and child logs
  874. for artifact_log_qname in itertools.chain(
  875. log_registry.get_artifact_log_qnames(), log_registry.get_child_log_qnames()
  876. ):
  877. log = logging.getLogger(artifact_log_qname)
  878. log.setLevel(logging.NOTSET)
  879. log.propagate = True
  880. trace_log.propagate = False
  881. _clear_handlers(trace_log)
  882. def _get_log_state():
  883. return log_state
  884. def _set_log_state(state) -> None:
  885. global log_state
  886. log_state = state
  887. def _init_logs(log_file_name=None) -> None:
  888. global GET_DTRACE_STRUCTURED
  889. _reset_logs()
  890. _update_log_state_from_env()
  891. out = os.environ.get(LOG_OUT_ENV_VAR, None)
  892. if out is not None:
  893. log_file_name = out
  894. # First, reset all known (registered) loggers to NOTSET, so that they
  895. # respect their parent log level
  896. for log_qname in log_registry.get_log_qnames():
  897. # But not the top level torch level: this defaults to WARNING so
  898. # that our log messages don't leak to the lower levels
  899. if log_qname == "torch":
  900. continue
  901. log = logging.getLogger(log_qname)
  902. log.setLevel(logging.NOTSET)
  903. # Now, for all loggers which the user requested to have non-standard
  904. # logging behavior, modify their log levels
  905. for log_qname, level in log_state.get_log_level_pairs():
  906. log = logging.getLogger(log_qname)
  907. log.setLevel(level)
  908. # Finally, setup handlers for all registered loggers
  909. for log_qname in log_registry.get_log_qnames():
  910. log = logging.getLogger(log_qname)
  911. _setup_handlers(
  912. logging.StreamHandler,
  913. log,
  914. )
  915. if log_file_name is not None:
  916. _setup_handlers(
  917. lambda: logging.FileHandler(log_file_name),
  918. log,
  919. )
  920. # configure artifact loggers, note: this must happen last
  921. # since the levels of ancestor loggers are taken into account
  922. for artifact_log_qname in log_registry.get_artifact_log_qnames():
  923. log = logging.getLogger(artifact_log_qname)
  924. configure_artifact_log(log)
  925. # Setup handler for the special trace_log, with different default
  926. # configuration
  927. trace_dir_name = os.environ.get(TRACE_ENV_VAR, None)
  928. if dtrace_dir_name := os.environ.get(DTRACE_ENV_VAR, None):
  929. GET_DTRACE_STRUCTURED = True
  930. trace_dir_name = dtrace_dir_name
  931. # This handler may remove itself if trace_dir_name is None and we are not
  932. # actually in an FB environment. This allows us to defer actually
  933. # initializing it until we actually need to log anything. This is
  934. # important because JK initializes a C++ singleton, which will pork our
  935. # process if we subsequently fork.
  936. global LOG_TRACE_HANDLER
  937. if LOG_TRACE_HANDLER is None:
  938. LOG_TRACE_HANDLER = LazyTraceHandler(trace_dir_name)
  939. # This log is ALWAYS at debug level. We will additionally test if there
  940. # are any handlers before deciding to actually call logging on this. Do
  941. # not manually call
  942. trace_log.setLevel(logging.DEBUG)
  943. trace_log_handler = _track_handler(LOG_TRACE_HANDLER)
  944. trace_log_handler.setFormatter(TorchLogsFormatter(trace=True))
  945. trace_log.addHandler(trace_log_handler)
  946. class LazyTraceHandler(logging.StreamHandler):
  947. """Like FileHandler, but the file is allocated lazily only upon the first log message"""
  948. def __init__(self, root_dir: Optional[str]) -> None:
  949. # This is implemented in the same way that delay is implemented on
  950. # FileHandler
  951. self.root_dir = root_dir
  952. logging.Handler.__init__(self)
  953. self.stream = None
  954. self._builtin_open = open
  955. # cloned from FileHandler in cpython
  956. def close(self) -> None:
  957. self.acquire()
  958. try:
  959. try:
  960. if self.stream:
  961. try:
  962. self.flush()
  963. finally:
  964. stream = self.stream
  965. self.stream = None
  966. if hasattr(stream, "close"):
  967. stream.close()
  968. finally:
  969. # Issue #19523: call unconditionally to
  970. # prevent a handler leak when delay is set
  971. # Also see Issue #42378: we also rely on
  972. # self._closed being set to True there
  973. logging.StreamHandler.close(self)
  974. finally:
  975. self.release()
  976. def emit(self, record) -> None:
  977. if self.stream is None:
  978. if self.root_dir is None:
  979. TRACE_LOG_DIR = "/logs"
  980. import torch.version as torch_version
  981. if (
  982. hasattr(torch_version, "git_version")
  983. and os.getenv("MAST_HPC_JOB_NAME") is None
  984. ):
  985. log.info(
  986. "LazyTraceHandler: disabled because not fbcode or conda on mast"
  987. )
  988. elif not torch._utils_internal.justknobs_check("pytorch/trace:enable"):
  989. log.info(
  990. "LazyTraceHandler: disabled because justknobs_check('pytorch/trace:enable') returned False"
  991. )
  992. elif not os.path.exists(TRACE_LOG_DIR):
  993. log.info(
  994. "LazyTraceHandler: disabled because %s does not exist",
  995. TRACE_LOG_DIR,
  996. )
  997. elif not os.access(TRACE_LOG_DIR, os.W_OK):
  998. log.info(
  999. "LazyTraceHandler: disabled because %s is not writeable",
  1000. TRACE_LOG_DIR,
  1001. )
  1002. else:
  1003. self.root_dir = TRACE_LOG_DIR
  1004. if self.root_dir is not None:
  1005. os.makedirs(self.root_dir, exist_ok=True)
  1006. ranksuffix = ""
  1007. if dist.is_available() and dist.is_initialized():
  1008. ranksuffix = f"rank_{dist.get_rank()}_"
  1009. self.stream = tempfile.NamedTemporaryFile( # noqa: SIM115
  1010. mode="w+",
  1011. suffix=".log",
  1012. prefix=LOG_PREFIX + ranksuffix,
  1013. dir=self.root_dir,
  1014. delete=False,
  1015. )
  1016. log.info("LazyTraceHandler: logging to %s", self.stream.name)
  1017. else:
  1018. # We go poof, remove and no-op
  1019. trace_log.removeHandler(self)
  1020. return
  1021. if self.stream:
  1022. super().emit(record)
  1023. @functools.cache
  1024. def warning_once(logger_obj, *args, **kwargs) -> None:
  1025. """
  1026. This function is similar to `logger.warning()`, but will emit the warning with the same message only once
  1027. Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the cache.
  1028. The assumption here is that all warning messages are unique across the code. If they aren't then need to switch to
  1029. another type of cache that includes the caller frame information in the hashing function.
  1030. """
  1031. logger_obj.warning(*args, **kwargs)
  1032. def safe_grad_filter(message, category, filename, lineno, file=None, line=None) -> bool:
  1033. return "The .grad attribute of a Tensor" not in str(message)
  1034. def user_warning_filter(
  1035. message, category, filename, lineno, file=None, line=None
  1036. ) -> bool:
  1037. return category is not UserWarning
  1038. @contextlib.contextmanager
  1039. def hide_warnings(filter_fn=lambda *args, **kwargs: True):
  1040. """
  1041. A context manager that temporarily suppresses warnings,
  1042. using public API: https://docs.python.org/3/library/warnings.html#warnings.showwarning.
  1043. Useful to hide warnings without mutating warnings module state, see:
  1044. https://github.com/pytorch/pytorch/issues/128427#issuecomment-2161496162.
  1045. NOTE: Warnings issued under this context will still be cached in the __warningregistry__
  1046. and count towards the once/default rule. So you should NEVER use this on a user-land function.
  1047. Filter must implement the showwarning API:
  1048. def filter_fn(message, category, filename, lineno, file=None, line=None) -> bool:
  1049. return True # show this warning entry
  1050. """
  1051. prior = warnings.showwarning
  1052. def _showwarning(*args, **kwargs):
  1053. if filter_fn(*args, **kwargs):
  1054. prior(*args, **kwargs)
  1055. try:
  1056. warnings.showwarning = _showwarning
  1057. yield
  1058. finally:
  1059. warnings.showwarning = prior
  1060. class LazyString(Generic[_P]):
  1061. def __init__(
  1062. self, func: Callable[_P, str], *args: _P.args, **kwargs: _P.kwargs
  1063. ) -> None:
  1064. self.func = func
  1065. self.args = args
  1066. self.kwargs = kwargs
  1067. def __str__(self) -> str:
  1068. return self.func(*self.args, **self.kwargs)
  1069. # Logs the time it takes to do structured logging by frame/compile id
  1070. # key is always {frame_id}_{frame_compile_id}
  1071. structured_logging_overhead: dict[str, float] = defaultdict(float)
  1072. def add_structured_logging_overhead(time_spent: float) -> None:
  1073. global structured_logging_overhead
  1074. key = None
  1075. if (trace_id := torch._guards.CompileContext.current_trace_id()) is not None:
  1076. frame_id = trace_id.compile_id.frame_id
  1077. frame_compile_id = trace_id.compile_id.frame_compile_id
  1078. # Why not trace_id.attempt, like structured logging?
  1079. # We aggregate across all attempts because
  1080. # a compilation metric is logged per successful attempt
  1081. key = f"{frame_id}_{frame_compile_id}"
  1082. # TODO: deal with structured logging that occurs outside of specific compile ids
  1083. # It's hard to figure out where we would log that if we want it in compilation metrics
  1084. # itself.
  1085. if key is not None:
  1086. key = str(key)
  1087. structured_logging_overhead[key] += time_spent
  1088. def get_structured_logging_overhead() -> Optional[float]:
  1089. key = None
  1090. if (trace_id := torch._guards.CompileContext.current_trace_id()) is not None:
  1091. frame_id = trace_id.compile_id.frame_id
  1092. frame_compile_id = trace_id.compile_id.frame_compile_id
  1093. key = f"{frame_id}_{frame_compile_id}"
  1094. if key is not None:
  1095. return structured_logging_overhead.get(key)
  1096. else:
  1097. return None
  1098. def trace_structured_artifact(
  1099. name: str, # this will go in metadata
  1100. encoding: str,
  1101. payload_fn: Callable[[], Optional[Union[str, object]]] = lambda: None,
  1102. compile_id: Optional[CompileId] = None,
  1103. ) -> None:
  1104. trace_structured(
  1105. "artifact",
  1106. metadata_fn=lambda: {
  1107. "name": name,
  1108. "encoding": encoding,
  1109. },
  1110. payload_fn=payload_fn,
  1111. compile_id=compile_id,
  1112. )
  1113. def trace_structured(
  1114. name: str,
  1115. # NB: metadata expected to be dict so adding more info is forward compatible
  1116. # Tuple[str, int] is a special case for string interning
  1117. metadata_fn: Callable[[], Union[dict[str, Any], tuple[str, int]]] = dict,
  1118. *,
  1119. payload_fn: Callable[[], Optional[Union[str, object]]] = lambda: None,
  1120. suppress_context: bool = False,
  1121. expect_trace_id: bool = True, # Whether or not we expect to have a current trace id
  1122. record_logging_overhead: bool = True, # Whether or not to record the time spent on structured logging
  1123. compile_id: Optional[CompileId] = None, # Optional if unavailable in the trace
  1124. ) -> None:
  1125. """
  1126. metadata is an arbitrary JSON compatible struct, but it's expected to not be
  1127. too long (e.g., less than 1MB)
  1128. payload is an arbitrary string, which can be arbitrarily long (but expected to have
  1129. newlines so no lines are too long)
  1130. """
  1131. reserved_names = [
  1132. "rank",
  1133. "compiled_autograd_id",
  1134. "frame_id",
  1135. "frame_compile_id",
  1136. "attempt",
  1137. "severity",
  1138. "timestamp",
  1139. "pathname",
  1140. "thread",
  1141. ]
  1142. if name in reserved_names:
  1143. raise AssertionError(f"name {name!r} is reserved and cannot be used")
  1144. if not callable(metadata_fn):
  1145. raise AssertionError(
  1146. f"metadata_fn should be callable, but got {type(metadata_fn)}"
  1147. )
  1148. if not callable(payload_fn):
  1149. raise AssertionError(
  1150. f"payload_fn should be callable, but got {type(payload_fn)}"
  1151. )
  1152. # trace_log never propagates and is ALWAYS DEBUG, so also check that there
  1153. # are handlers instead of checking the log level
  1154. if trace_log.handlers:
  1155. start_time = time.time_ns()
  1156. record: dict[str, object] = {}
  1157. record[name] = metadata_fn()
  1158. if not suppress_context:
  1159. # TODO: Actually, the rank probably should just be emitted once at
  1160. # the top, and not repeatedly spammed in all the logs, since it
  1161. # never changes and we assume no interleaving
  1162. if dist.is_available() and dist.is_initialized():
  1163. record["rank"] = dist.get_rank()
  1164. trace_id = torch._guards.CompileContext.current_trace_id()
  1165. if expect_trace_id and trace_id is None and compile_id is None:
  1166. # Record the stack of the log call to better diagnose why we
  1167. # don't have a frame id for it
  1168. record["stack"] = torch._logging.structured.from_traceback(
  1169. CapturedTraceback.extract(skip=1).summary()
  1170. )
  1171. else:
  1172. cid = trace_id.compile_id if trace_id else compile_id
  1173. if cid is not None:
  1174. if cid.compiled_autograd_id is not None:
  1175. record["compiled_autograd_id"] = cid.compiled_autograd_id
  1176. if cid.frame_id is not None:
  1177. record["frame_id"] = cid.frame_id
  1178. if cid.frame_compile_id is not None:
  1179. record["frame_compile_id"] = cid.frame_compile_id
  1180. if trace_id:
  1181. record["attempt"] = trace_id.attempt
  1182. payload = payload_fn()
  1183. if payload is not None:
  1184. if not isinstance(payload, str):
  1185. if isinstance(payload, list):
  1186. # special case to look better
  1187. payload = "[\n" + ",\n".join(json.dumps(i) for i in payload) + "\n]"
  1188. else:
  1189. def json_default(obj):
  1190. # Sets aren't json serializable
  1191. if isinstance(obj, set):
  1192. return list(obj)
  1193. raise TypeError(
  1194. f"Object of type {type(obj)} is not JSON serializable"
  1195. )
  1196. # force newlines so we are unlikely to overflow line limit
  1197. payload = json.dumps(payload, default=json_default, indent=0)
  1198. h = hashlib.md5(usedforsecurity=False)
  1199. h.update(payload.encode("utf-8"))
  1200. record["has_payload"] = h.hexdigest()
  1201. trace_log.debug(
  1202. "", extra={"metadata": record, "payload": payload}, stacklevel=2
  1203. )
  1204. log_trace_structured_event(name, record)
  1205. if record_logging_overhead:
  1206. # Convert to seconds from nanoseconds, add it to the frame compile total
  1207. structured_logging_overhead_s = (time.time_ns() - start_time) / 1e9
  1208. add_structured_logging_overhead(structured_logging_overhead_s)
  1209. def dtrace_structured(
  1210. name: str,
  1211. # NB: metadata expected to be dict so adding more info is forward compatible
  1212. # Tuple[str, int] is a special case for string interning
  1213. metadata_fn: Callable[[], Union[dict[str, Any], tuple[str, int]]] = dict,
  1214. *,
  1215. payload_fn: Callable[[], Optional[Union[str, object]]] = lambda: None,
  1216. suppress_context: bool = False,
  1217. expect_trace_id: bool = False, # Whether or not we expect to have a current trace id
  1218. record_logging_overhead: bool = True, # Whether or not to record the time spent on structured logging
  1219. ) -> None:
  1220. """
  1221. For logging more detailed information used for debugging. This may result in
  1222. the program becoming slow.
  1223. """
  1224. if GET_DTRACE_STRUCTURED:
  1225. trace_structured(
  1226. name,
  1227. metadata_fn,
  1228. payload_fn=payload_fn,
  1229. suppress_context=suppress_context,
  1230. expect_trace_id=expect_trace_id,
  1231. record_logging_overhead=record_logging_overhead,
  1232. )
  1233. import torch._guards
  1234. import torch._utils_internal
  1235. import torch.distributed as dist