| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517 |
- # mypy: allow-untyped-defs
- import contextlib
- import functools
- import hashlib
- import importlib.util
- import itertools
- import json
- import logging
- import os
- import os.path
- import pathlib
- import pkgutil
- import re
- import sys
- import tempfile
- import time
- import warnings
- from collections import defaultdict
- from collections.abc import Callable, Sequence
- from dataclasses import dataclass, field
- from typing import Any, Generic, Optional, Union
- from typing_extensions import ParamSpec
- from weakref import WeakSet
- import torch._logging.structured
- from torch._guards import CompileId
- from torch._utils_internal import log_trace_structured_event
- from torch.utils._traceback import CapturedTraceback
- _P = ParamSpec("_P")
- log = logging.getLogger(__name__)
- # This is a synthetic logger which doesn't correspond to an actual logger,
- # but handles all of our "tracing" logging, which is structured and doesn't go
- # to stderr but always goes to a dedicated log file. We don't put these
- # loggers in the classic module hierarchy, because we don't want a suppression
- # of logs to also cause a trace to get suppressed (traces typically are not
- # collected, unless we are in prod, in which case they always are collected.)
- #
- # TODO: Maybe we should allow for some sub-hierarchy so you can control which
- # traces you want to collect, for performance reasons.
- #
- # See https://docs.google.com/document/d/1CX_hJ0PNy9f3R1y8TJrfkSeLkvGjjjLU84BSXgS2AZ8/edit
- trace_log = logging.getLogger("torch.__trace")
- DEFAULT_LOG_LEVEL = logging.WARNING
- LOG_ENV_VAR = "TORCH_LOGS"
- LOG_OUT_ENV_VAR = "TORCH_LOGS_OUT"
- LOG_FORMAT_ENV_VAR = "TORCH_LOGS_FORMAT"
- LOG_TRACE_ID_FILTER = "TORCH_LOGS_TRACE_ID_FILTER"
- TRACE_ENV_VAR = "TORCH_TRACE"
- DTRACE_ENV_VAR = "TORCH_DTRACE"
- LOG_TRACE_HANDLER: Optional["LazyTraceHandler"] = None
- GET_DTRACE_STRUCTURED = False
- LOG_PREFIX = "dedicated_log_torch_trace_"
- @dataclass
- class LogRegistry:
- # shorthand name to log qualified name
- # Note: this only contains loggers registered
- # from register_log
- # e.g. "dynamo" -> "torch._dynamo"
- log_alias_to_log_qnames: dict[str, list[str]] = field(default_factory=dict)
- # artifact logger qualified names,
- # this is populated lazily, as calls to getArtifactLogger
- # currently formatted as <module>.__<artifact_name>
- # e.g. "torch._dynamo.convert_frame.__guards"
- artifact_log_qnames: set[str] = field(default_factory=set)
- # child logs of registered logs if specified via open
- # registration by the user (ie placing "torch._dynamo.output_graph" in the env var)
- # these need to be tracked so their levels can be reset properly
- # e.g. "torch._dynamo.output_graph"
- child_log_qnames: set[str] = field(default_factory=set)
- # artifact names, populated by register_artifact
- # e.g. "guards"
- artifact_names: set[str] = field(default_factory=set)
- # Artifacts that should be visible by default in the error message
- visible_artifacts: set[str] = field(default_factory=set)
- # A short description of each artifact
- artifact_descriptions: dict[str, str] = field(default_factory=dict)
- # artifacts which are not displayed unless explicitly named in the
- # settings. Ex. output_code is NOT displayed even if the inductor
- # log level is set to DEBUG. It must be explicitly named in the settings
- off_by_default_artifact_names: set[str] = field(default_factory=set)
- # logging format string for artifacts
- artifact_log_formatters: dict[str, logging.Formatter] = field(default_factory=dict)
- def is_artifact(self, name):
- return name in self.artifact_names
- def is_log(self, alias):
- return alias in self.log_alias_to_log_qnames
- # register a log with an alias
- def register_log(self, alias, log_qnames: Union[str, list[str]]) -> None:
- if isinstance(log_qnames, str):
- log_qnames = [log_qnames]
- self.log_alias_to_log_qnames[alias] = log_qnames
- # register an artifact name
- def register_artifact_name(
- self, name, description, visible, off_by_default, log_format
- ) -> None:
- self.artifact_names.add(name)
- if visible:
- self.visible_artifacts.add(name)
- self.artifact_descriptions[name] = description
- # if off by default, don't enable it
- # when log_name's log_level is set to DEBUG
- if off_by_default:
- self.off_by_default_artifact_names.add(name)
- if log_format is not None:
- self.artifact_log_formatters[name] = logging.Formatter(log_format)
- # register the qualified name of an artifact log
- # this is needed to know which logs need to be reset
- # whenever the log_state is changed
- def register_artifact_log(self, artifact_log_qname) -> None:
- self.artifact_log_qnames.add(artifact_log_qname)
- def register_child_log(self, log_qname) -> None:
- self.child_log_qnames.add(log_qname)
- # flattens all the qnames together (TODO: consider memoizing?)
- def get_log_qnames(self) -> set[str]:
- return set(itertools.chain.from_iterable(self.log_alias_to_log_qnames.values()))
- def get_artifact_log_qnames(self):
- return set(self.artifact_log_qnames)
- def get_child_log_qnames(self):
- return set(self.child_log_qnames)
- def is_off_by_default(self, artifact_qname):
- return artifact_qname in self.off_by_default_artifact_names
- @dataclass
- class LogState:
- # qualified log names -> currently set log level
- log_qname_to_level: dict[str, str] = field(default_factory=dict)
- # the set of currently enabled artifacts
- artifact_names: set[str] = field(default_factory=set)
- def enable_artifact(self, artifact_name) -> None:
- self.artifact_names.add(artifact_name)
- def is_artifact_enabled(self, name):
- return name in self.artifact_names
- def enable_log(self, log_qnames, log_level) -> None:
- if isinstance(log_qnames, str):
- log_qnames = [log_qnames]
- for log_qname in log_qnames:
- self.log_qname_to_level[log_qname] = log_level
- def get_log_level_pairs(self):
- """Returns all qualified module names for which the user requested
- explicit logging settings.
- .. warning:
- This function used to return all loggers, regardless of whether
- or not the user specified them or not; it now only returns logs
- which were explicitly mentioned by the user (and torch, which
- always is implicitly requested when we initialize our logging
- subsystem.)
- """
- return self.log_qname_to_level.items()
- def clear(self) -> None:
- self.log_qname_to_level.clear()
- self.artifact_names.clear()
- log_registry = LogRegistry()
- log_state = LogState()
- # sample usage: torch._logging.set_logs(**torch._logging.DEFAULT_LOGGING)
- DEFAULT_LOGGING = {
- "dynamo": logging.INFO,
- "aot": logging.INFO,
- "inductor": logging.INFO,
- "fsdp": logging.INFO,
- "ddp_graphs": True,
- "graph_breaks": True,
- "side_effects": True,
- "guards": True,
- "recompiles": True,
- "dynamic": logging.INFO,
- }
- def set_logs(
- *,
- all: Optional[int] = None,
- dynamo: Optional[int] = None,
- aot: Optional[int] = None,
- autograd: Optional[int] = None,
- dynamic: Optional[int] = None,
- inductor: Optional[int] = None,
- distributed: Optional[int] = None,
- c10d: Optional[int] = None,
- ddp: Optional[int] = None,
- fsdp: Optional[int] = None,
- dtensor: Optional[int] = None,
- onnx: Optional[int] = None,
- bytecode: bool = False,
- aot_graphs: bool = False,
- aot_joint_graph: bool = False,
- ddp_graphs: bool = False,
- graph: bool = False,
- graph_code: bool = False,
- graph_code_verbose: bool = False,
- graph_breaks: bool = False,
- side_effects: bool = False,
- graph_sizes: bool = False,
- guards: bool = False,
- recompiles: bool = False,
- recompiles_verbose: bool = False,
- trace_source: bool = False,
- trace_call: bool = False,
- trace_bytecode: bool = False,
- output_code: bool = False,
- kernel_code: bool = False,
- schedule: bool = False,
- perf_hints: bool = False,
- pre_grad_graphs: bool = False,
- post_grad_graphs: bool = False,
- ir_pre_fusion: bool = False,
- ir_post_fusion: bool = False,
- onnx_diagnostics: bool = False,
- fusion: bool = False,
- overlap: bool = False,
- export: Optional[int] = None,
- modules: Optional[dict[str, Union[int, bool]]] = None,
- cudagraphs: bool = False,
- sym_node: bool = False,
- compiled_autograd: bool = False,
- compiled_autograd_verbose: bool = False,
- cudagraph_static_inputs: bool = False,
- benchmarking: bool = False,
- autotuning: bool = False,
- graph_region_expansion: bool = False,
- inductor_metrics: bool = False,
- hierarchical_compile: bool = False,
- compute_dependencies: bool = False,
- caching: bool = False,
- ) -> None:
- """
- Sets the log level for individual components and toggles individual log
- artifact types.
- .. warning:: This feature is a prototype and may have compatibility
- breaking changes in the future.
- .. note:: The ``TORCH_LOGS`` environment variable has complete precedence
- over this function, so if it was set, this function does nothing.
- A component is a set of related features in PyTorch. All of the log
- messages emitted from a given component have their own log levels. If the
- log level of a particular message has priority greater than or equal to its
- component's log level setting, it is emitted. Otherwise, it is suppressed.
- This allows you to, for instance, silence large groups of log messages that
- are not relevant to you and increase verbosity of logs for components that
- are relevant. The expected log level values, ordered from highest to lowest
- priority, are:
- * ``logging.CRITICAL``
- * ``logging.ERROR``
- * ``logging.WARNING``
- * ``logging.INFO``
- * ``logging.DEBUG``
- * ``logging.NOTSET``
- See documentation for the Python ``logging`` module for more information on
- log levels: `<https://docs.python.org/3/library/logging.html#logging-levels>`_
- An artifact is a particular type of log message. Each artifact is assigned
- to a parent component. A component can emit many different kinds of
- artifacts. In general, an artifact is emitted if either its corresponding
- setting in the argument list below is turned on or if its parent component
- is set to a log level less than or equal to the log level of the artifact.
- Keyword args:
- all (:class:`Optional[int]`):
- The default log level for all components. Default: ``logging.WARN``
- dynamo (:class:`Optional[int]`):
- The log level for the TorchDynamo component. Default: ``logging.WARN``
- aot (:class:`Optional[int]`):
- The log level for the AOTAutograd component. Default: ``logging.WARN``
- autograd (:class:`Optional[int]`):
- The log level for autograd. Default: ``logging.WARN``
- inductor (:class:`Optional[int]`):
- The log level for the TorchInductor component. Default: ``logging.WARN``
- dynamic (:class:`Optional[int]`):
- The log level for dynamic shapes. Default: ``logging.WARN``
- distributed (:class:`Optional[int]`):
- Whether to log c10d communication operations and other debug info from PyTorch Distributed components.
- Default: ``logging.WARN``
- c10d (:class:`Optional[int]`):
- Whether to log c10d communication operations related debug info in PyTorch Distributed components.
- Default: ``logging.WARN``
- ddp (:class:`Optional[int]`):
- Whether to log debug info related to ``DistributedDataParallel``(DDP) from PyTorch Distributed components.
- Default: ``logging.WARN``
- fsdp (:class:`Optional[int]`):
- Whether to log debug info related to ``FullyShardedDataParallel``(FSDP) in PyTorch Distributed components.
- Default: ``logging.WARN``
- dtensor (:class:`Optional[int]`):
- Whether to log debug info related to ``DTensor``(DTensor) in PyTorch Distributed components.
- Default: ``logging.WARN``
- onnx (:class:`Optional[int]`):
- The log level for the ONNX exporter component. Default: ``logging.WARN``
- bytecode (:class:`bool`):
- Whether to emit the original and generated bytecode from TorchDynamo.
- Default: ``False``
- aot_graphs (:class:`bool`):
- Whether to emit the graphs generated by AOTAutograd. Default: ``False``
- aot_joint_graph (:class:`bool`):
- Whether to emit the joint forward-backward graph generated by AOTAutograd. Default: ``False``
- ddp_graphs (:class:`bool`):
- Whether to emit graphs generated by DDPOptimizer. Default: ``False``
- graph (:class:`bool`):
- Whether to emit the graph captured by TorchDynamo in tabular format.
- Default: ``False``
- graph_code (:class:`bool`):
- Whether to emit the python source of the graph captured by TorchDynamo.
- Default: ``False``
- graph_code_verbose (:class:`bool`):
- Whether to emit verbose/intermediate FX pass logs for graph code. Default: ``False``
- graph_breaks (:class:`bool`):
- Whether to emit the graph breaks encountered by TorchDynamo.
- Default: ``False``
- side_effects (:class:`bool`):
- Whether to emit side effects (mutations, hooks, etc.) that TorchDynamo
- codegenerates in the output graph. Default: ``False``
- graph_sizes (:class:`bool`):
- Whether to emit tensor sizes of the graph captured by TorchDynamo.
- Default: ``False``
- guards (:class:`bool`):
- Whether to emit the guards generated by TorchDynamo for each compiled
- function. Default: ``False``
- recompiles (:class:`bool`):
- Whether to emit a guard failure reason and message every time
- TorchDynamo recompiles a function. Default: ``False``
- recompiles_verbose (:class:`bool`):
- Whether to emit all guard failure reasons when TorchDynamo recompiles
- a function, even those that are not actually run. Default: ``False``
- trace_source (:class:`bool`):
- Whether to emit when TorchDynamo begins tracing a new line. Default: ``False``
- trace_call (:class:`bool`):
- Whether to emit detailed line location when TorchDynamo creates an FX node
- corresponding to function call. Python 3.11+ only. Default: ``False``
- trace_bytecode (:class:`bool`):
- Whether to emit bytecode instructions and traced stack state as TorchDynamo
- traces bytecode. Default: ``False``
- output_code (:class:`bool`):
- Whether to emit the TorchInductor output code on a per-graph basis. Default: ``False``
- kernel_code (:class:`bool`):
- Whether to emit the TorchInductor output code on a per-kernel bases. Default: ``False``
- schedule (:class:`bool`):
- Whether to emit the TorchInductor schedule. Default: ``False``
- perf_hints (:class:`bool`):
- Whether to emit the TorchInductor perf hints. Default: ``False``
- pre_grad_graphs (:class:`bool`):
- Whether to emit the graphs before inductor grad passes. Default: ``False``
- post_grad_graphs (:class:`bool`):
- Whether to emit the graphs generated by after post grad passes. Default: ``False``
- ir_pre_fusion (:class:`bool`):
- Whether to emit the graphs before inductor fusion passes. Default: ``False``
- ir_post_fusion (:class:`bool`):
- Whether to emit the graphs after inductor fusion passes. Default: ``False``
- onnx_diagnostics (:class:`bool`):
- Whether to emit the ONNX exporter diagnostics in logging. Default: ``False``
- fusion (:class:`bool`):
- Whether to emit detailed Inductor fusion decisions. Default: ``False``
- overlap (:class:`bool`):
- Whether to emit detailed Inductor compute/comm overlap decisions. Default: ``False``
- sym_node (:class:`bool`):
- Whether to emit debug info for various SymNode opterations. Default: ``False``
- export (:class:`Optional[int]`):
- The log level for export. Default: ``logging.WARN``
- benchmarking (:class:`bool`):
- Whether to emit detailed Inductor benchmarking information. Default: ``False``
- modules (dict):
- This argument provides an alternate way to specify the above log
- component and artifact settings, in the format of a keyword args
- dictionary given as a single argument. There are two cases
- where this is useful (1) if a new log component or artifact has
- been registered but a keyword argument for it has not been added
- to this function and (2) if the log level for an unregistered module
- needs to be set. This can be done by providing the fully-qualified module
- name as the key, with the log level as the value. Default: ``None``
- cudagraph_static_inputs (:class:`bool`):
- Whether to emit debug info for cudagraph static input detection. Default: ``False``
- autotuning (:class:`bool`):
- Autotuning choice logs, such as kernel source, perf, and tuning parameters. Default: ``False``
- graph_region_expansion (:class:`bool`):
- Whether to emit the detailed steps of the duplicate graph region tracker expansion algorithm. Default: ``False``
- inductor_metrics (:class:`bool`):
- Whether to estimate the runtimes of the nodes in a graph and log them to the metrics table. Default: ``False``
- hierarchical_compile (:class:`bool`):
- Whether to emit debug info for hierarchical compilation. Default: ``False``
- caching (:class:`bool`):
- Whether to emit detailed Inductor caching information. Default: ``False``
- Example::
- >>> # xdoctest: +SKIP
- >>> import logging
- # The following changes the "dynamo" component to emit DEBUG-level
- # logs, and to emit "graph_code" artifacts.
- >>> torch._logging.set_logs(dynamo=logging.DEBUG, graph_code=True)
- # The following enables the logs for a different module
- >>> torch._logging.set_logs(modules={"unregistered.module.name": logging.DEBUG})
- """
- # ignore if env var is set
- if LOG_ENV_VAR in os.environ:
- log.warning(
- "Using TORCH_LOGS environment variable for log settings, ignoring call to set_logs"
- )
- return
- log_state.clear()
- modules = modules or {}
- def _set_logs(**kwargs) -> None:
- for alias, val in itertools.chain(kwargs.items(), modules.items()): # type: ignore[union-attr]
- if val is None:
- continue
- if log_registry.is_artifact(alias):
- if not isinstance(val, bool):
- raise ValueError(
- f"Expected bool to enable artifact {alias}, received {val}"
- )
- if val:
- log_state.enable_artifact(alias)
- elif log_registry.is_log(alias) or alias in log_registry.child_log_qnames:
- if val not in logging._levelToName:
- raise ValueError(
- f"Unrecognized log level for log {alias}: {val}, valid level values "
- f"are: {','.join([str(k) for k in logging._levelToName])}"
- )
- log_state.enable_log(
- log_registry.log_alias_to_log_qnames.get(alias, alias), val
- )
- elif _is_valid_module(alias):
- found_modules = _get_module_and_submodules(alias) or (alias,)
- for module_name in found_modules:
- if not _has_registered_parent(module_name):
- log_registry.register_log(module_name, module_name)
- else:
- log_registry.register_child_log(module_name)
- log_state.enable_log(
- log_registry.log_alias_to_log_qnames.get(
- module_name, module_name
- ),
- val,
- )
- else:
- raise ValueError(
- f"Unrecognized log or artifact name passed to set_logs: {alias}"
- )
- _init_logs()
- _set_logs(
- torch=all,
- dynamo=dynamo,
- aot=aot,
- autograd=autograd,
- inductor=inductor,
- dynamic=dynamic,
- bytecode=bytecode,
- aot_graphs=aot_graphs,
- aot_joint_graph=aot_joint_graph,
- ddp_graphs=ddp_graphs,
- distributed=distributed,
- c10d=c10d,
- ddp=ddp,
- fsdp=fsdp,
- dtensor=dtensor,
- graph=graph,
- graph_code=graph_code,
- graph_code_verbose=graph_code_verbose,
- graph_breaks=graph_breaks,
- side_effects=side_effects,
- graph_sizes=graph_sizes,
- guards=guards,
- recompiles=recompiles,
- recompiles_verbose=recompiles_verbose,
- trace_source=trace_source,
- trace_call=trace_call,
- trace_bytecode=trace_bytecode,
- output_code=output_code,
- kernel_code=kernel_code,
- schedule=schedule,
- perf_hints=perf_hints,
- pre_grad_graphs=pre_grad_graphs,
- post_grad_graphs=post_grad_graphs,
- ir_pre_fusion=ir_pre_fusion,
- ir_post_fusion=ir_post_fusion,
- onnx=onnx,
- onnx_diagnostics=onnx_diagnostics,
- fusion=fusion,
- overlap=overlap,
- sym_node=sym_node,
- export=export,
- cudagraphs=cudagraphs,
- compiled_autograd=compiled_autograd,
- compiled_autograd_verbose=compiled_autograd_verbose,
- cudagraph_static_inputs=cudagraph_static_inputs,
- benchmarking=benchmarking,
- autotuning=autotuning,
- graph_region_expansion=graph_region_expansion,
- inductor_metrics=inductor_metrics,
- hierarchical_compile=hierarchical_compile,
- compute_dependencies=compute_dependencies,
- caching=caching,
- )
- def get_loggers() -> list[logging.Logger]:
- """
- Returns: a list of all registered loggers
- """
- return [logging.getLogger(qname) for qname in log_registry.get_log_qnames()]
- def register_log(setting_name, log_name) -> None:
- """
- Enables a log to be controlled by the env var and user API with the setting_name
- Args:
- setting_name: the shorthand name used in the env var and user API
- log_name: the log name that the setting_name is associated with
- """
- log_registry.register_log(setting_name, log_name)
- def register_artifact(
- setting_name, description, visible=False, off_by_default=False, log_format=None
- ) -> None:
- """
- Enables an artifact to be controlled by the env var and user API with name
- Args:
- setting_name: the shorthand name used in the env var and user API
- description: A description of what this outputs
- visible: Whether it gets suggested to users by default
- off_by_default: whether this artifact should be logged when the ancestor loggers
- are enabled at level DEBUG
- """
- log_registry.register_artifact_name(
- setting_name, description, visible, off_by_default, log_format
- )
- def getArtifactLogger(module_qname, artifact_name) -> logging.Logger:
- if artifact_name not in log_registry.artifact_names:
- raise ValueError(
- f"Artifact name: {repr(artifact_name)} not registered,"
- f"please call register_artifact({repr(artifact_name)}) in torch._logging.registrations."
- )
- qname = module_qname + f".__{artifact_name}"
- log = logging.getLogger(qname)
- log.artifact_name = artifact_name # type: ignore[attr-defined]
- log_registry.register_artifact_log(qname)
- configure_artifact_log(log)
- return log
- INCR_VERBOSITY_CHAR = "+"
- DECR_VERBOSITY_CHAR = "-"
- VERBOSITY_REGEX = (
- "("
- + "|".join([re.escape(INCR_VERBOSITY_CHAR), re.escape(DECR_VERBOSITY_CHAR)])
- + "?)"
- )
- def configure_artifact_log(log) -> None:
- # If the artifact is off by default, then it should only be logged when explicitly
- # enabled; set propagate to False so that this artifact is not propagated
- # to its ancestor logger
- if log_registry.is_off_by_default(log.artifact_name):
- log.propagate = False
- # enable artifact logging when explicitly enabled
- if log_state.is_artifact_enabled(log.artifact_name):
- log.setLevel(logging.DEBUG)
- log.propagate = True
- # match a comma separated list of loggable names (whitespace allowed after commas)
- def _gen_settings_regex():
- return re.compile(r"((\+|-)?[\w\.]+,\s*)*(\+|-)?[\w\.]+?")
- def _validate_settings(settings):
- return re.fullmatch(_gen_settings_regex(), settings) is not None
- def help_message(verbose=False):
- def pad_to(s, length=30):
- if len(s) > length:
- raise AssertionError(f"string length {len(s)} exceeds max {length}")
- return s + " " * (length - len(s))
- if verbose:
- printed_artifacts = log_registry.artifact_names
- else:
- printed_artifacts = log_registry.visible_artifacts
- if verbose:
- heading = "All registered names"
- else:
- heading = "Visible registered names (use TORCH_LOGS='+help' for full list)"
- lines = (
- ["all"]
- + sorted(log_registry.log_alias_to_log_qnames.keys())
- + sorted(
- [
- f"{pad_to(name)}\t{log_registry.artifact_descriptions[name]}"
- for name in printed_artifacts
- ]
- )
- )
- setting_info = " " + "\n ".join(lines)
- examples = """
- Examples:
- TORCH_LOGS="+dynamo,aot" will set the log level of TorchDynamo to
- logging.DEBUG and AOT to logging.INFO
- TORCH_LOGS="-dynamo,+inductor" will set the log level of TorchDynamo to
- logging.ERROR and TorchInductor to logging.DEBUG
- TORCH_LOGS="aot_graphs" will enable the aot_graphs artifact
- TORCH_LOGS="+dynamo,schedule" will enable set the log level of TorchDynamo
- to logging.DEBUG and enable the schedule artifact
- TORCH_LOGS="+some.random.module,schedule" will set the log level of
- some.random.module to logging.DEBUG and enable the schedule artifact
- TORCH_LOGS="+torch._functorch._aot_autograd" will set the log level of
- torch._functorch._aot_autograd and all its submodules to logging.DEBUG
- (directory-based logging)
- TORCH_LOGS_FORMAT="%(levelname)s: %(message)s" or any provided format
- string will set the output format
- Valid keys are "levelname", "message", "pathname", "levelno", "lineno",
- "filename" and "name".
- TORCH_LOGS_OUT=/tmp/output.txt will output the logs to /tmp/output.txt as
- well. This is useful when the output is long.
- """
- msg = f"""
- TORCH_LOGS Info
- {examples}
- {heading}
- {setting_info}
- """
- return msg
- def _invalid_settings_err_msg(settings, verbose=False):
- valid_settings = (
- ["all"]
- + list(log_registry.log_alias_to_log_qnames.keys())
- + list(log_registry.artifact_names)
- )
- valid_settings = ", ".join(sorted(valid_settings))
- msg = f"""
- Invalid log settings: {settings}, must be a comma separated list of fully
- qualified module names, registered log names or registered artifact names.
- For more info on various settings, try TORCH_LOGS="help"
- Valid settings:
- {valid_settings}
- """
- return msg
- def process_env_var_string_for_windows(env_var_str: str) -> str:
- """
- When we setup logging config as guide: https://docs.pytorch.org/docs/stable/logging.html
- Such as:
- TORCH_LOGS="+schedule,+inductor,+output_code"
- On Linux, it shows as:
- declare -x SSH_TTY="/dev/pts/0"
- declare -x TERM="xterm"
- declare -x TORCH_LOGS="+schedule,+inductor,+output_code"
- declare -x USER="xu"
- On Windows, it shows as:
- TORCHINDUCTOR_WINDOWS_TESTS=1
- TORCH_LOGS="+schedule,+inductor,+output_code"
- UCRTVersion=10.0.22000.0
- For Linux, it shows quotes by default, And Windows is not shows quotes.
- Besides that, Windows would auto assemble quotes when env var processing.
- On Linux, we will get variable: "+schedule,+inductor,+output_code"
- On Windows, we will get variable: '"+schedule,+inductor,+output_code"'
- So, we need remove the outer quotes for Windows.
- """
- _IS_WINDOWS = sys.platform == "win32"
- def remove_outer_quotes(s: str) -> str:
- if len(s) >= 2 and (
- (s[0] == '"' and s[-1] == '"') or (s[0] == "'" and s[-1] == "'")
- ):
- return s[1:-1]
- return s
- if _IS_WINDOWS:
- env_var_str = remove_outer_quotes(env_var_str)
- return env_var_str
- @functools.lru_cache
- def _parse_log_settings(settings):
- settings = process_env_var_string_for_windows(settings)
- if settings == "":
- return {}
- if settings == "help":
- raise ValueError(help_message(verbose=False))
- elif settings == "+help":
- raise ValueError(help_message(verbose=True))
- if not _validate_settings(settings):
- raise ValueError(_invalid_settings_err_msg(settings))
- settings = re.sub(r"\s+", "", settings)
- log_names = settings.split(",")
- def get_name_level_pair(name):
- clean_name = name.replace(INCR_VERBOSITY_CHAR, "")
- clean_name = clean_name.replace(DECR_VERBOSITY_CHAR, "")
- if name[0] == INCR_VERBOSITY_CHAR:
- level = logging.DEBUG
- elif name[0] == DECR_VERBOSITY_CHAR:
- level = logging.ERROR
- else:
- level = logging.INFO
- return clean_name, level
- log_state = LogState()
- for name in log_names:
- name, level = get_name_level_pair(name)
- if name == "all":
- name = "torch"
- if log_registry.is_log(name):
- if level is None:
- raise AssertionError("level must not be None for log name")
- log_qnames = log_registry.log_alias_to_log_qnames[name]
- log_state.enable_log(log_qnames, level)
- elif log_registry.is_artifact(name):
- log_state.enable_artifact(name)
- elif _is_valid_module(name):
- # Get the module and all its submodules if it's a package
- found_modules = _get_module_and_submodules(name) or (name,)
- for module_name in found_modules:
- if not _has_registered_parent(module_name):
- log_registry.register_log(module_name, module_name)
- else:
- log_registry.register_child_log(module_name)
- log_state.enable_log(module_name, level)
- else:
- raise ValueError(_invalid_settings_err_msg(settings))
- return log_state
- def _is_valid_module(qname):
- spec = importlib.util.find_spec(qname)
- return spec is not None
- def _get_module_and_submodules(qname: str) -> Sequence[str] | None:
- """
- Get a module and all its submodules (recursively).
- If qname is a package, this returns a list of all modules and submodules.
- If qname is a simple module, this returns a list containing just that module.
- Args:
- qname: The fully qualified module name
- Returns:
- A list of fully qualified module names, or None if the module doesn't exist
- """
- spec = importlib.util.find_spec(qname)
- if spec is None:
- return None
- modules = [qname]
- if spec.submodule_search_locations is not None:
- package = importlib.import_module(qname)
- if hasattr(package, "__path__"):
- for importer, modname, ispkg in pkgutil.walk_packages(
- path=package.__path__,
- prefix=qname + ".",
- onerror=lambda x: None,
- ):
- modules.append(modname)
- return modules
- def _update_log_state_from_env() -> None:
- global log_state
- log_setting = os.environ.get(LOG_ENV_VAR, None)
- if log_setting is not None:
- log_state = _parse_log_settings(log_setting)
- def _has_registered_parent(log_qname) -> bool:
- cur_log = logging.getLogger(log_qname)
- registered_log_qnames = log_registry.get_log_qnames()
- while cur_log.parent:
- if cur_log.name in registered_log_qnames:
- return True
- cur_log = cur_log.parent
- return False
- @functools.lru_cache
- def _make_module_path_relative(abs_path: str, sys_path: tuple[str, ...]) -> str:
- """
- The string manipulation here is fairly expensive and we expect very high
- cache hit rates. Because `sys.path` changes very infrequently it's most
- performant to convert it into a tuple of strings and use that for the cache
- key because python has dedicated fast paths for list to tuple conversion
- and tuple hashing. (Empirically, a single top-level cache is about 2x faster
- than trying to cache individual parts.)
- """
- abs_path_resolved = pathlib.Path(abs_path).resolve()
- for path in sys_path:
- try:
- rel_path = abs_path_resolved.relative_to(path)
- except ValueError:
- continue
- else:
- return str(rel_path)
- return str(abs_path_resolved)
- def make_module_path_relative(abs_path: str) -> str:
- """
- Given an absolute filepath corresponding to a Python module which was
- loaded via normal import mechanisms using sys.path, convert it into
- a relative path relative to one of the Python search paths.
- """
- return _make_module_path_relative(abs_path, tuple(sys.path))
- # apply custom formats to artifacts when necessary
- class TorchLogsFormatter(logging.Formatter):
- def __init__(
- self, *, trace: bool = False, trace_id_filter: Optional[set[str]] = None
- ) -> None:
- super().__init__()
- self._is_trace = trace
- self._trace_id_filter = trace_id_filter
- def format(self, record):
- artifact_name = getattr(logging.getLogger(record.name), "artifact_name", None)
- if artifact_name is not None:
- artifact_formatter = log_registry.artifact_log_formatters.get(
- artifact_name, None
- )
- if artifact_formatter is not None:
- return artifact_formatter.format(record)
- record.message = record.getMessage()
- record.asctime = self.formatTime(record, "%m%d %H:%M:%S")
- # exception handling - copied from logging.Formatter.format
- s = record.message
- if record.exc_info:
- from torch._dynamo import config
- should_format_exc = config.verbose or artifact_name != "graph_breaks"
- # Cache the traceback text to avoid converting it multiple times
- # (it's constant anyway)
- if should_format_exc:
- if not record.exc_text:
- record.exc_text = self.formatException(record.exc_info)
- if record.exc_text:
- if s[-1:] != "\n":
- s = s + "\n"
- s = s + record.exc_text
- if record.stack_info:
- if s[-1:] != "\n":
- s = s + "\n"
- s = s + self.formatStack(record.stack_info)
- record.rankprefix = ""
- if not self._is_trace and dist.is_available() and dist.is_initialized():
- record.rankprefix = f"[rank{dist.get_rank()}]:"
- record.traceid = ""
- if (
- not self._is_trace
- and (trace_id := torch._guards.CompileContext.current_trace_id())
- is not None
- ):
- record.traceid = f" [{trace_id}]"
- glog_level_to_abbr = {
- "DEBUG": "V", # V is for VERBOSE in glog
- "INFO": "I",
- "WARNING": "W",
- "ERROR": "E",
- "CRITICAL": "C",
- }
- shortlevel = glog_level_to_abbr.get(record.levelname, record.levelname)
- record.artifactprefix = ""
- if artifact_name is not None:
- record.artifactprefix = f" [__{artifact_name}]"
- filepath = make_module_path_relative(record.pathname)
- if (
- self._trace_id_filter
- and record.traceid.strip() not in self._trace_id_filter
- ):
- return ""
- prefix = (
- f"{record.rankprefix}{shortlevel}{record.asctime}.{int(record.msecs * 1000):06d} {record.process} "
- f"{filepath}:"
- f"{record.lineno}]{record.traceid}{record.artifactprefix}"
- )
- if self._is_trace:
- if s != "":
- raise AssertionError(f"expected empty string for trace, got {s!r}")
- try:
- r = f"{prefix} {json.dumps(record.metadata)}"
- except TypeError:
- log.warning("failing metadata: %r", record.metadata)
- raise
- if record.payload is not None:
- r += "".join(f"\n\t{l}" for l in record.payload.split("\n"))
- return r
- else:
- lines = s.split("\n")
- return "\n".join(f"{prefix} {l}" for l in lines)
- def _default_formatter():
- fmt = os.environ.get(LOG_FORMAT_ENV_VAR, None)
- trace_id_filter = {
- item.strip()
- for item in os.environ.get(LOG_TRACE_ID_FILTER, "").split(",")
- if item.strip()
- }
- if fmt is None:
- return TorchLogsFormatter(trace_id_filter=trace_id_filter)
- else:
- if fmt in ("short", "basic"):
- fmt = logging.BASIC_FORMAT
- return logging.Formatter(fmt)
- DEFAULT_FORMATTER = _default_formatter()
- def _setup_handlers(create_handler_fn, log) -> None:
- debug_handler = _track_handler(create_handler_fn())
- debug_handler.setFormatter(DEFAULT_FORMATTER)
- debug_handler.setLevel(logging.DEBUG)
- log.addHandler(debug_handler)
- handlers = WeakSet() # type: ignore[var-annotated]
- # mark handlers that we've created
- # so we don't modify user handlers
- def _track_handler(handler):
- handlers.add(handler)
- return handler
- def _is_torch_handler(handler):
- return handler in handlers
- # clears all torch handlers on specified loggers
- def _clear_handlers(log) -> None:
- to_remove = [handler for handler in log.handlers if _is_torch_handler(handler)]
- for handler in to_remove:
- log.removeHandler(handler)
- def _reset_logs() -> None:
- # reset all registered logs
- for log_qname in log_registry.get_log_qnames():
- log = logging.getLogger(log_qname)
- log.setLevel(logging.WARNING)
- log.propagate = False
- _clear_handlers(log)
- # reset all artifact and child logs
- for artifact_log_qname in itertools.chain(
- log_registry.get_artifact_log_qnames(), log_registry.get_child_log_qnames()
- ):
- log = logging.getLogger(artifact_log_qname)
- log.setLevel(logging.NOTSET)
- log.propagate = True
- trace_log.propagate = False
- _clear_handlers(trace_log)
- def _get_log_state():
- return log_state
- def _set_log_state(state) -> None:
- global log_state
- log_state = state
- def _init_logs(log_file_name=None) -> None:
- global GET_DTRACE_STRUCTURED
- _reset_logs()
- _update_log_state_from_env()
- out = os.environ.get(LOG_OUT_ENV_VAR, None)
- if out is not None:
- log_file_name = out
- # First, reset all known (registered) loggers to NOTSET, so that they
- # respect their parent log level
- for log_qname in log_registry.get_log_qnames():
- # But not the top level torch level: this defaults to WARNING so
- # that our log messages don't leak to the lower levels
- if log_qname == "torch":
- continue
- log = logging.getLogger(log_qname)
- log.setLevel(logging.NOTSET)
- # Now, for all loggers which the user requested to have non-standard
- # logging behavior, modify their log levels
- for log_qname, level in log_state.get_log_level_pairs():
- log = logging.getLogger(log_qname)
- log.setLevel(level)
- # Finally, setup handlers for all registered loggers
- for log_qname in log_registry.get_log_qnames():
- log = logging.getLogger(log_qname)
- _setup_handlers(
- logging.StreamHandler,
- log,
- )
- if log_file_name is not None:
- _setup_handlers(
- lambda: logging.FileHandler(log_file_name),
- log,
- )
- # configure artifact loggers, note: this must happen last
- # since the levels of ancestor loggers are taken into account
- for artifact_log_qname in log_registry.get_artifact_log_qnames():
- log = logging.getLogger(artifact_log_qname)
- configure_artifact_log(log)
- # Setup handler for the special trace_log, with different default
- # configuration
- trace_dir_name = os.environ.get(TRACE_ENV_VAR, None)
- if dtrace_dir_name := os.environ.get(DTRACE_ENV_VAR, None):
- GET_DTRACE_STRUCTURED = True
- trace_dir_name = dtrace_dir_name
- # This handler may remove itself if trace_dir_name is None and we are not
- # actually in an FB environment. This allows us to defer actually
- # initializing it until we actually need to log anything. This is
- # important because JK initializes a C++ singleton, which will pork our
- # process if we subsequently fork.
- global LOG_TRACE_HANDLER
- if LOG_TRACE_HANDLER is None:
- LOG_TRACE_HANDLER = LazyTraceHandler(trace_dir_name)
- # This log is ALWAYS at debug level. We will additionally test if there
- # are any handlers before deciding to actually call logging on this. Do
- # not manually call
- trace_log.setLevel(logging.DEBUG)
- trace_log_handler = _track_handler(LOG_TRACE_HANDLER)
- trace_log_handler.setFormatter(TorchLogsFormatter(trace=True))
- trace_log.addHandler(trace_log_handler)
- class LazyTraceHandler(logging.StreamHandler):
- """Like FileHandler, but the file is allocated lazily only upon the first log message"""
- def __init__(self, root_dir: Optional[str]) -> None:
- # This is implemented in the same way that delay is implemented on
- # FileHandler
- self.root_dir = root_dir
- logging.Handler.__init__(self)
- self.stream = None
- self._builtin_open = open
- # cloned from FileHandler in cpython
- def close(self) -> None:
- self.acquire()
- try:
- try:
- if self.stream:
- try:
- self.flush()
- finally:
- stream = self.stream
- self.stream = None
- if hasattr(stream, "close"):
- stream.close()
- finally:
- # Issue #19523: call unconditionally to
- # prevent a handler leak when delay is set
- # Also see Issue #42378: we also rely on
- # self._closed being set to True there
- logging.StreamHandler.close(self)
- finally:
- self.release()
- def emit(self, record) -> None:
- if self.stream is None:
- if self.root_dir is None:
- TRACE_LOG_DIR = "/logs"
- import torch.version as torch_version
- if (
- hasattr(torch_version, "git_version")
- and os.getenv("MAST_HPC_JOB_NAME") is None
- ):
- log.info(
- "LazyTraceHandler: disabled because not fbcode or conda on mast"
- )
- elif not torch._utils_internal.justknobs_check("pytorch/trace:enable"):
- log.info(
- "LazyTraceHandler: disabled because justknobs_check('pytorch/trace:enable') returned False"
- )
- elif not os.path.exists(TRACE_LOG_DIR):
- log.info(
- "LazyTraceHandler: disabled because %s does not exist",
- TRACE_LOG_DIR,
- )
- elif not os.access(TRACE_LOG_DIR, os.W_OK):
- log.info(
- "LazyTraceHandler: disabled because %s is not writeable",
- TRACE_LOG_DIR,
- )
- else:
- self.root_dir = TRACE_LOG_DIR
- if self.root_dir is not None:
- os.makedirs(self.root_dir, exist_ok=True)
- ranksuffix = ""
- if dist.is_available() and dist.is_initialized():
- ranksuffix = f"rank_{dist.get_rank()}_"
- self.stream = tempfile.NamedTemporaryFile( # noqa: SIM115
- mode="w+",
- suffix=".log",
- prefix=LOG_PREFIX + ranksuffix,
- dir=self.root_dir,
- delete=False,
- )
- log.info("LazyTraceHandler: logging to %s", self.stream.name)
- else:
- # We go poof, remove and no-op
- trace_log.removeHandler(self)
- return
- if self.stream:
- super().emit(record)
- @functools.cache
- def warning_once(logger_obj, *args, **kwargs) -> None:
- """
- This function is similar to `logger.warning()`, but will emit the warning with the same message only once
- Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the cache.
- The assumption here is that all warning messages are unique across the code. If they aren't then need to switch to
- another type of cache that includes the caller frame information in the hashing function.
- """
- logger_obj.warning(*args, **kwargs)
- def safe_grad_filter(message, category, filename, lineno, file=None, line=None) -> bool:
- return "The .grad attribute of a Tensor" not in str(message)
- def user_warning_filter(
- message, category, filename, lineno, file=None, line=None
- ) -> bool:
- return category is not UserWarning
- @contextlib.contextmanager
- def hide_warnings(filter_fn=lambda *args, **kwargs: True):
- """
- A context manager that temporarily suppresses warnings,
- using public API: https://docs.python.org/3/library/warnings.html#warnings.showwarning.
- Useful to hide warnings without mutating warnings module state, see:
- https://github.com/pytorch/pytorch/issues/128427#issuecomment-2161496162.
- NOTE: Warnings issued under this context will still be cached in the __warningregistry__
- and count towards the once/default rule. So you should NEVER use this on a user-land function.
- Filter must implement the showwarning API:
- def filter_fn(message, category, filename, lineno, file=None, line=None) -> bool:
- return True # show this warning entry
- """
- prior = warnings.showwarning
- def _showwarning(*args, **kwargs):
- if filter_fn(*args, **kwargs):
- prior(*args, **kwargs)
- try:
- warnings.showwarning = _showwarning
- yield
- finally:
- warnings.showwarning = prior
- class LazyString(Generic[_P]):
- def __init__(
- self, func: Callable[_P, str], *args: _P.args, **kwargs: _P.kwargs
- ) -> None:
- self.func = func
- self.args = args
- self.kwargs = kwargs
- def __str__(self) -> str:
- return self.func(*self.args, **self.kwargs)
- # Logs the time it takes to do structured logging by frame/compile id
- # key is always {frame_id}_{frame_compile_id}
- structured_logging_overhead: dict[str, float] = defaultdict(float)
- def add_structured_logging_overhead(time_spent: float) -> None:
- global structured_logging_overhead
- key = None
- if (trace_id := torch._guards.CompileContext.current_trace_id()) is not None:
- frame_id = trace_id.compile_id.frame_id
- frame_compile_id = trace_id.compile_id.frame_compile_id
- # Why not trace_id.attempt, like structured logging?
- # We aggregate across all attempts because
- # a compilation metric is logged per successful attempt
- key = f"{frame_id}_{frame_compile_id}"
- # TODO: deal with structured logging that occurs outside of specific compile ids
- # It's hard to figure out where we would log that if we want it in compilation metrics
- # itself.
- if key is not None:
- key = str(key)
- structured_logging_overhead[key] += time_spent
- def get_structured_logging_overhead() -> Optional[float]:
- key = None
- if (trace_id := torch._guards.CompileContext.current_trace_id()) is not None:
- frame_id = trace_id.compile_id.frame_id
- frame_compile_id = trace_id.compile_id.frame_compile_id
- key = f"{frame_id}_{frame_compile_id}"
- if key is not None:
- return structured_logging_overhead.get(key)
- else:
- return None
- def trace_structured_artifact(
- name: str, # this will go in metadata
- encoding: str,
- payload_fn: Callable[[], Optional[Union[str, object]]] = lambda: None,
- compile_id: Optional[CompileId] = None,
- ) -> None:
- trace_structured(
- "artifact",
- metadata_fn=lambda: {
- "name": name,
- "encoding": encoding,
- },
- payload_fn=payload_fn,
- compile_id=compile_id,
- )
- def trace_structured(
- name: str,
- # NB: metadata expected to be dict so adding more info is forward compatible
- # Tuple[str, int] is a special case for string interning
- metadata_fn: Callable[[], Union[dict[str, Any], tuple[str, int]]] = dict,
- *,
- payload_fn: Callable[[], Optional[Union[str, object]]] = lambda: None,
- suppress_context: bool = False,
- expect_trace_id: bool = True, # Whether or not we expect to have a current trace id
- record_logging_overhead: bool = True, # Whether or not to record the time spent on structured logging
- compile_id: Optional[CompileId] = None, # Optional if unavailable in the trace
- ) -> None:
- """
- metadata is an arbitrary JSON compatible struct, but it's expected to not be
- too long (e.g., less than 1MB)
- payload is an arbitrary string, which can be arbitrarily long (but expected to have
- newlines so no lines are too long)
- """
- reserved_names = [
- "rank",
- "compiled_autograd_id",
- "frame_id",
- "frame_compile_id",
- "attempt",
- "severity",
- "timestamp",
- "pathname",
- "thread",
- ]
- if name in reserved_names:
- raise AssertionError(f"name {name!r} is reserved and cannot be used")
- if not callable(metadata_fn):
- raise AssertionError(
- f"metadata_fn should be callable, but got {type(metadata_fn)}"
- )
- if not callable(payload_fn):
- raise AssertionError(
- f"payload_fn should be callable, but got {type(payload_fn)}"
- )
- # trace_log never propagates and is ALWAYS DEBUG, so also check that there
- # are handlers instead of checking the log level
- if trace_log.handlers:
- start_time = time.time_ns()
- record: dict[str, object] = {}
- record[name] = metadata_fn()
- if not suppress_context:
- # TODO: Actually, the rank probably should just be emitted once at
- # the top, and not repeatedly spammed in all the logs, since it
- # never changes and we assume no interleaving
- if dist.is_available() and dist.is_initialized():
- record["rank"] = dist.get_rank()
- trace_id = torch._guards.CompileContext.current_trace_id()
- if expect_trace_id and trace_id is None and compile_id is None:
- # Record the stack of the log call to better diagnose why we
- # don't have a frame id for it
- record["stack"] = torch._logging.structured.from_traceback(
- CapturedTraceback.extract(skip=1).summary()
- )
- else:
- cid = trace_id.compile_id if trace_id else compile_id
- if cid is not None:
- if cid.compiled_autograd_id is not None:
- record["compiled_autograd_id"] = cid.compiled_autograd_id
- if cid.frame_id is not None:
- record["frame_id"] = cid.frame_id
- if cid.frame_compile_id is not None:
- record["frame_compile_id"] = cid.frame_compile_id
- if trace_id:
- record["attempt"] = trace_id.attempt
- payload = payload_fn()
- if payload is not None:
- if not isinstance(payload, str):
- if isinstance(payload, list):
- # special case to look better
- payload = "[\n" + ",\n".join(json.dumps(i) for i in payload) + "\n]"
- else:
- def json_default(obj):
- # Sets aren't json serializable
- if isinstance(obj, set):
- return list(obj)
- raise TypeError(
- f"Object of type {type(obj)} is not JSON serializable"
- )
- # force newlines so we are unlikely to overflow line limit
- payload = json.dumps(payload, default=json_default, indent=0)
- h = hashlib.md5(usedforsecurity=False)
- h.update(payload.encode("utf-8"))
- record["has_payload"] = h.hexdigest()
- trace_log.debug(
- "", extra={"metadata": record, "payload": payload}, stacklevel=2
- )
- log_trace_structured_event(name, record)
- if record_logging_overhead:
- # Convert to seconds from nanoseconds, add it to the frame compile total
- structured_logging_overhead_s = (time.time_ns() - start_time) / 1e9
- add_structured_logging_overhead(structured_logging_overhead_s)
- def dtrace_structured(
- name: str,
- # NB: metadata expected to be dict so adding more info is forward compatible
- # Tuple[str, int] is a special case for string interning
- metadata_fn: Callable[[], Union[dict[str, Any], tuple[str, int]]] = dict,
- *,
- payload_fn: Callable[[], Optional[Union[str, object]]] = lambda: None,
- suppress_context: bool = False,
- expect_trace_id: bool = False, # Whether or not we expect to have a current trace id
- record_logging_overhead: bool = True, # Whether or not to record the time spent on structured logging
- ) -> None:
- """
- For logging more detailed information used for debugging. This may result in
- the program becoming slow.
- """
- if GET_DTRACE_STRUCTURED:
- trace_structured(
- name,
- metadata_fn,
- payload_fn=payload_fn,
- suppress_context=suppress_context,
- expect_trace_id=expect_trace_id,
- record_logging_overhead=record_logging_overhead,
- )
- import torch._guards
- import torch._utils_internal
- import torch.distributed as dist
|