| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377 |
- # mypy: allow-untyped-defs
- import functools
- import logging
- import os
- import sys
- import tempfile
- import typing_extensions
- from collections.abc import Callable
- from typing import Any, TypeVar
- from typing_extensions import ParamSpec
- import torch
- from torch._strobelight.compile_time_profiler import StrobelightCompileTimeProfiler
- _T = TypeVar("_T")
- _P = ParamSpec("_P")
- log = logging.getLogger(__name__)
- if os.environ.get("TORCH_COMPILE_STROBELIGHT", False):
- import shutil
- if not shutil.which("strobeclient"):
- log.info(
- "TORCH_COMPILE_STROBELIGHT is true, but seems like you are not on a FB machine."
- )
- else:
- log.info("Strobelight profiler is enabled via environment variable")
- StrobelightCompileTimeProfiler.enable()
- # this arbitrary-looking assortment of functionality is provided here
- # to have a central place for overridable behavior. The motivating
- # use is the FB build environment, where this source file is replaced
- # by an equivalent.
- if os.path.basename(os.path.dirname(__file__)) == "shared":
- torch_parent = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
- else:
- torch_parent = os.path.dirname(os.path.dirname(__file__))
- def get_file_path(*path_components: str) -> str:
- return os.path.join(torch_parent, *path_components)
- def get_file_path_2(*path_components: str) -> str:
- return os.path.join(*path_components)
- def get_writable_path(path: str) -> str:
- if os.access(path, os.W_OK):
- return path
- return tempfile.mkdtemp(suffix=os.path.basename(path))
- def prepare_multiprocessing_environment(path: str) -> None:
- pass
- def resolve_library_path(path: str) -> str:
- return os.path.realpath(path)
- def throw_abstract_impl_not_imported_error(opname, module, context):
- if module in sys.modules:
- raise NotImplementedError(
- f"{opname}: We could not find the fake impl for this operator. "
- )
- else:
- raise NotImplementedError(
- f"{opname}: We could not find the fake impl for this operator. "
- f"The operator specified that you may need to import the '{module}' "
- f"Python module to load the fake impl. {context}"
- )
- # NB! This treats "skip" kwarg specially!!
- def compile_time_strobelight_meta(
- phase_name: str,
- ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
- def compile_time_strobelight_meta_inner(
- function: Callable[_P, _T],
- ) -> Callable[_P, _T]:
- @functools.wraps(function)
- def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> _T:
- if "skip" in kwargs and isinstance(
- skip := kwargs["skip"],
- int,
- ):
- kwargs["skip"] = skip + 1
- # This is not needed but we have it here to avoid having profile_compile_time
- # in stack traces when profiling is not enabled.
- if not StrobelightCompileTimeProfiler.enabled:
- return function(*args, **kwargs)
- return StrobelightCompileTimeProfiler.profile_compile_time(
- function, phase_name, *args, **kwargs
- )
- return wrapper_function
- return compile_time_strobelight_meta_inner
- # Meta only, see
- # https://www.internalfb.com/intern/wiki/ML_Workflow_Observability/User_Guides/Adding_instrumentation_to_your_code/
- #
- # This will cause an event to get logged to Scuba via the signposts API. You
- # can view samples on the API at https://fburl.com/scuba/workflow_signpost/zh9wmpqs
- # we log to subsystem "torch", and the category and name you provide here.
- # Each of the arguments translate into a Scuba column. We're still figuring
- # out local conventions in PyTorch, but category should be something like
- # "dynamo" or "inductor", and name should be a specific string describing what
- # kind of event happened.
- #
- # Killswitch is at
- # https://www.internalfb.com/intern/justknobs/?name=pytorch%2Fsignpost#event
- def signpost_event(category: str, name: str, parameters: dict[str, Any]):
- log.info("%s %s: %r", category, name, parameters)
- def add_mlhub_insight(category: str, insight: str, insight_description: str):
- pass
- def log_compilation_event(metrics):
- log.info("%s", metrics)
- def upload_graph(graph):
- pass
- def set_pytorch_distributed_envs_from_justknobs():
- pass
- def log_export_usage(**kwargs):
- pass
- def log_draft_export_usage(**kwargs):
- pass
- def log_trace_structured_event(*args, **kwargs) -> None:
- pass
- def log_cache_bypass(*args, **kwargs) -> None:
- pass
- def log_torchscript_usage(api: str, **kwargs):
- _ = api
- return
- def check_if_torch_exportable():
- return False
- def export_training_ir_rollout_check() -> bool:
- return True
- def full_aoti_runtime_assert() -> bool:
- return True
- def log_torch_jit_trace_exportability(
- api: str,
- type_of_export: str,
- export_outcome: str,
- result: str,
- ):
- _, _, _, _ = api, type_of_export, export_outcome, result
- return
- DISABLE_JUSTKNOBS = True
- def justknobs_check(name: str, default: bool = True) -> bool:
- """
- This function can be used to killswitch functionality in FB prod,
- where you can toggle this value to False in JK without having to
- do a code push. In OSS, we always have everything turned on all
- the time, because downstream users can simply choose to not update
- PyTorch. (If more fine-grained enable/disable is needed, we could
- potentially have a map we lookup name in to toggle behavior. But
- the point is that it's all tied to source code in OSS, since there's
- no live server to query.)
- This is the bare minimum functionality I needed to do some killswitches.
- We have a more detailed plan at
- https://docs.google.com/document/d/1Ukerh9_42SeGh89J-tGtecpHBPwGlkQ043pddkKb3PU/edit
- In particular, in some circumstances it may be necessary to read in
- a knob once at process start, and then use it consistently for the
- rest of the process. Future functionality will codify these patterns
- into a better high level API.
- WARNING: Do NOT call this function at module import time, JK is not
- fork safe and you will break anyone who forks the process and then
- hits JK again.
- """
- return default
- def justknobs_getval_int(name: str) -> int:
- """
- Read warning on justknobs_check
- """
- return 0
- def is_fb_unit_test() -> bool:
- return False
- @functools.cache
- def max_clock_rate():
- """
- unit: MHz
- """
- if not torch.version.hip:
- from triton.testing import nvsmi
- return nvsmi(["clocks.max.sm"])[0]
- else:
- # Manually set max-clock speeds on ROCm until equivalent nvmsi
- # functionality in triton.testing or via pyamdsmi enablement. Required
- # for test_snode_runtime unit tests.
- gcn_arch = str(torch.cuda.get_device_properties(0).gcnArchName.split(":", 1)[0])
- if "gfx94" in gcn_arch:
- return 1700
- elif "gfx90a" in gcn_arch:
- return 1700
- elif "gfx908" in gcn_arch:
- return 1502
- elif "gfx12" in gcn_arch:
- return 1700
- elif "gfx11" in gcn_arch:
- return 1700
- elif "gfx103" in gcn_arch:
- return 1967
- elif "gfx101" in gcn_arch:
- return 1144
- elif "gfx95" in gcn_arch:
- return 1700 # TODO: placeholder, get actual value
- else:
- return 1100
- def get_mast_job_name_version() -> tuple[str, int] | None:
- return None
- TEST_MASTER_ADDR = "127.0.0.1"
- TEST_MASTER_PORT = 29500
- # USE_GLOBAL_DEPS controls whether __init__.py tries to load
- # libtorch_global_deps, see Note [Global dependencies]
- USE_GLOBAL_DEPS = True
- # USE_RTLD_GLOBAL_WITH_LIBTORCH controls whether __init__.py tries to load
- # _C.so with RTLD_GLOBAL during the call to dlopen.
- USE_RTLD_GLOBAL_WITH_LIBTORCH = False
- # If an op was defined in C++ and extended from Python using the
- # torch.library.register_fake, returns if we require that there be a
- # m.set_python_module("mylib.ops") call from C++ that associates
- # the C++ op with a python module.
- REQUIRES_SET_PYTHON_MODULE = False
- def maybe_upload_prof_stats_to_manifold(profile_path: str) -> str | None:
- print("Uploading profile stats (fb-only otherwise no-op)")
- return None
- def log_chromium_event_internal(
- event: dict[str, Any],
- stack: list[str],
- logger_uuid: str,
- start_time_ns: int,
- ):
- return None
- def record_chromium_event_internal(
- event: dict[str, Any],
- ):
- return None
- def profiler_allow_cudagraph_cupti_lazy_reinit_cuda12():
- return True
- def deprecated():
- """
- When we deprecate a function that might still be in use, we make it internal
- by adding a leading underscore. This decorator is used with a private function,
- and creates a public alias without the leading underscore, but has a deprecation
- warning. This tells users "THIS FUNCTION IS DEPRECATED, please use something else"
- without breaking them, however, if they still really really want to use the
- deprecated function without the warning, they can do so by using the internal
- function name.
- """
- def decorator(func: Callable[_P, _T]) -> Callable[_P, _T]:
- # Validate naming convention - single leading underscore, not dunder
- if not (func.__name__.startswith("_")):
- raise ValueError(
- "@deprecate must decorate a function whose name "
- "starts with a single leading underscore (e.g. '_foo') as the api should be considered internal for deprecation."
- )
- public_name = func.__name__[1:] # drop exactly one leading underscore
- module = sys.modules[func.__module__]
- # Don't clobber an existing symbol accidentally.
- if hasattr(module, public_name):
- raise RuntimeError(
- f"Cannot create alias '{public_name}' -> symbol already exists in {module.__name__}. \
- Please rename it or consult a pytorch developer on what to do"
- )
- warning_msg = f"{func.__name__[1:]} is DEPRECATED, please consider using an alternative API(s). "
- # public deprecated alias
- alias = typing_extensions.deprecated(
- # pyrefly: ignore [bad-argument-type]
- warning_msg,
- category=UserWarning,
- stacklevel=1,
- )(func)
- alias.__name__ = public_name
- # Adjust qualname if nested inside a class or another function
- if "." in func.__qualname__:
- alias.__qualname__ = func.__qualname__.rsplit(".", 1)[0] + "." + public_name
- else:
- alias.__qualname__ = public_name
- setattr(module, public_name, alias)
- return func
- return decorator
- def get_default_numa_options():
- """
- When using elastic agent, if no numa options are provided, we will use these
- as the default.
- For external use cases, we return None, i.e. no numa binding. If you would like
- to use torch's automatic numa binding capabilities, you should provide
- NumaOptions to your launch config directly or use the numa binding option
- available in torchrun.
- Must return None or NumaOptions, but not specifying to avoid circular import.
- """
- return None
- def log_triton_builds(fail: str | None):
- pass
- def find_compile_subproc_binary() -> str | None:
- """
- Allows overriding the binary used for subprocesses
- """
- return None
|