| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834 |
- r"""
- This module exposes a TunableOp interface.
- Some operations, such as GEMMs, could be implemented using more than one library
- or more than one technique. For example, a GEMM could be implemented for CUDA or
- ROCm using either the blas or blasLt libraries. Further, ROCm's rocblas and
- hipblaslt libraries allow the user to query for all possible algorithms and then
- choose one. How does one know which implementation is the fastest and should be
- chosen? That's what TunableOp provides.
- Enabling TunableOp and Tuning Separately
- ========================================
- The TunableOp feature is enabled separately from enabling the tuning phase
- itself. Enabling TunableOp means that PyTorch will replace any standard
- operators with their Tunable implementations. Any call to a TunableOp first
- checks whether it has already been tuned for the given operator inputs. If so,
- it will immediately call the tuned operation; no further tuning will take place
- even when the tuning setting is enabled. Instead if no tuning result is found,
- and tuning is enabled, the TunableOp will benchmark every registered
- implementation of that operator for the given set of inputs and select the
- fastest.
- File Input and Output
- =====================
- The first time any TunableOp is invoked, the internal database of tuned
- operations will be prepared by attempting to read the results from the given
- file. The default filename is 'tunableop_results.csv'. To support tuning when
- multiple GPUs are used across multiple processes, the GPU device ordinal is
- automatically inserted into the filename to avoid multiple processes overwriting
- the same file.
- If tuning is enabled and new tunings are discovered during the course of your
- workload, it will also write out to this same filename with all tunings, both
- the ones it read in at startup as well as the new ones found at runtime. This
- can be used, for example, to build up a tunings file across many workloads by
- reusing the same file. The output file is automatically created when the
- application terminates. This behavior can be controlled by the C++ and Python
- APIs but not the environment variables.
- Assuming you specified a filename, you'll end up with a CSV file with contents
- like so::
- Validator,PT_VERSION,2.2.0
- Validator,ROCM_VERSION,6.0.0.0-12969-1544e39
- Validator,HIPBLASLT_VERSION,0.6.0-a9c5cc7
- Validator,ROCBLAS_VERSION,4.0.0-72e57364-dirty
- GemmTunableOp_float_NT,nt_25088_4096_64,Gemm_Hipblaslt_1219,1.262
- GemmTunableOp_float_NT,nt_4096_4096_64,Gemm_Rocblas_1216,0.033
- Note the "Validator" lines. If you change a library version, or ROCm version, or
- PyTorch version, TunableOp will detect this and reject the tunings file because
- the prior tunings are likely affected by other software changes.
- The remaining lines are the tuned solutions for each TunableOp encountered
- during your execution. Each line consists of 4 comma-separated fields: operator
- name, operator parameters, solution name, and average execution time. The
- execution time is an optional field. The CSV file can be edited, but with
- caution. For example, the solution name (field 3) can be changed to "Default"
- and it will fall back to the original PyTorch untuned implementation. Or, in the
- case of ROCm's hipBLAS or hipBLASLt libraries, if you know the specific solution
- index you can override the solution that TunableOp selected by replacing the
- value. The operator name and parameters (fields 1 and 2) are internally named
- and should not be modified. In the case of GemmTunableOp, field 1 indicates the
- datatype and whether the inputs are transposed (T) or not (N) and field 2
- indicates the M, N, K input shapes.
- There is an option to enable verbose output but it is only recommended for
- debugging purposes. This will produce a lot of diagnostic messages but may be
- useful to see if TunableOp is being used at all. Otherwise, TunableOp is
- completely silent, besides file output, unless there is a warning or error
- during its use. The verbose option is only available by setting the environment
- variable PYTORCH_TUNABLEOP_VEROBSE=1.
- A Note on Tuning Behavior, Warmup, and Cache Effects
- ====================================================
- Tuning an operator consists of iterating through the list or registered
- implementations and profiling each one. The profile is established by running a
- single implementation in a loop multiple times and taking the average execution
- time. There is also an optional warmup phase prior to tuning that can help with
- reaching stable power states by the hardware. During tuning of a workload the
- various hardware caches will more likely produce hits than when not tuning.
- There are options for flushing the instruction cache and rotate the input tensors
- which might help produce a more faithful profile of the tuned operator as if the
- operator were run within a larger workload instead of in a tight, repetitive loop.
- By default, each possible solution for a given operator will be run for either
- 100 iterations or as many iterations that can be run within 30ms, whichever is
- smaller, and its average execution will be calculated. The fastest solution
- among all that were successfully profiled will be chosen. A profile might fail
- if the given solution doesn't achieve the same accuracy as the default
- implementation or if the solution returns an error code.
- Current Tunable Operators
- =========================
- TunableGemm for ROCm
- --------------------
- Currently only a TunableGemm for ROCm is implemented. Note that CUDA builds of
- PyTorch will function correctly when using TunableOp but the only solution
- available to CUDA builds is the 'Default' implementation i.e. the original
- cuBLAS default, now called through TunableOp. Any call to at::cuda::blas::gemm()
- or ::bgemm() will be routed through TunableOp when enabled. Calling gemm() for a
- given set of input arguments (transa, transb, m, n, k) will attempt to use the
- fastest available implementation across both rocblas and hipblaslt.
- Offline Tuning
- ==============
- Motivation
- ----------
- There are several use cases for offline tuning.
- One use case involves a workload with a high-memory utilization, where regular tuning might lead to running out of memory.
- Another use case is for compute-intensive workloads. In such cases, it is more resource-efficient to collect
- the GEMMs for the workload once and then tune repeatedly with different tuning parameters or libraries.
- Workflow
- --------
- There are basically two steps:
- 1) Set the environment variables to collect the untuned GEMM and this will generate ``tunableop_untuned0.csv``:
- .. code-block:: bash
- export PYTORCH_TUNABLEOP_ENABLED=1
- export PYTORCH_TUNABLEOP_TUNING=0
- export PYTORCH_TUNABLEOP_RECORD_UNTUNED=1
- ...
- 2) Run a Python script that reads the ``tunableop_untuned0.csv`` and generates the ``tunableop_results0.csv``, like this:
- .. code-block:: python
- import torch.cuda.tunable as tunable
- import os
- os.putenv("PYTORCH_TUNABLEOP_ENABLED", "1")
- os.putenv("PYTORCH_TUNABLEOP_TUNING", "1")
- os.putenv("PYTORCH_TUNABLEOP_RECORD_UNTUNED", "0")
- tunable.tune_gemm_in_file("tunableop_untuned0.csv")
- It is also possible to take multiple untuned files and distribute the GEMMs for tuning to multiple GPUs
- within a single node. In the first step, the GEMMs are first gathered and duplicate GEMMs are eliminated.
- Next, the GEMMs are distributed to different GPUs for tuning. After all GEMMs are tuned, the results from
- all the GPUs are then gathered into a single file whose base filename has ``_full0`` appended to it
- (for example ``tunableop_results_full0.csv``). Finally, this new file, containing the gathered results, will be
- duplicated N times, once for each GPU as convenience to the user will run the workload with the tuned
- configuration on N GPUs.
- .. code-block:: python
- if __name__ == "__main__":
- num_gpus = 8 # number of GPUs that will be used during the tuning process
- tunable.mgpu_tune_gemm_in_file("tunableop_untuned?.csv", num_gpus)
- Note that the usage of the ``mgpu_tune_gemm_in_file`` API is different from its single GPU counterpart
- (``tune_gemm_in_file``). The body of the Python script that calls the API must be wrapped in ``main()`` as shown
- due to the use of concurrent futures module. The argument to ``mgpu_tune_gemm_in_file`` must contain a wild card
- expression (``?`` or ``*``) to generate the list of untuned files containing the GEMMs to be processed. The ``num_gpus``
- must between 1 and the total number of GPUs available.
- Tuning Context
- ==============
- The behavior of TunableOp is currently manipulated through environment
- variables, the C++ interface of at::cuda::tunable::getTuningContext(), or the
- torch.cuda.tunable python interfaces. The environment variables take precedence
- over any setting you manipulate using the C++ or Python APIs.
- Environment Variable Interface
- ------------------------------
- Environment variables are cached the first time they are read. You cannot use the
- environment variable interface programmatically since the settings become fixed.
- Use the C++ or Python APIs instead.
- """
- import concurrent.futures
- import glob
- import multiprocessing as mp
- import os
- import shutil
- import warnings
- import torch
- __all__ = [
- "enable",
- "is_enabled",
- "tuning_enable",
- "tuning_is_enabled",
- "record_untuned_enable",
- "record_untuned_is_enabled",
- "set_max_tuning_duration",
- "get_max_tuning_duration",
- "set_max_tuning_iterations",
- "get_max_tuning_iterations",
- "set_filename",
- "get_filename",
- "get_results",
- "get_validators",
- "read_file",
- "tune_gemm_in_file",
- "mgpu_tune_gemm_in_file",
- "set_rotating_buffer_size",
- "get_rotating_buffer_size",
- "set_numerical_check_tolerances",
- ]
- def enable(val: bool = True) -> None:
- r"""This is the big on/off switch for all TunableOp implementations."""
- torch._C._cuda_tunableop_enable(val) # type: ignore[attr-defined]
- def is_enabled() -> bool:
- r"""Returns whether the TunableOp feature is enabled."""
- return torch._C._cuda_tunableop_is_enabled() # type: ignore[attr-defined]
- def tuning_enable(val: bool = True) -> None:
- r"""Enable tuning of TunableOp implementations.
- When enabled, if a tuned entry isn't found, run the tuning step and record
- the entry.
- """
- torch._C._cuda_tunableop_tuning_enable(val) # type: ignore[attr-defined]
- def tuning_is_enabled() -> bool:
- r"""Returns whether TunableOp implementations can be tuned."""
- return torch._C._cuda_tunableop_tuning_is_enabled() # type: ignore[attr-defined]
- def record_untuned_enable(val: bool = True) -> None:
- r"""Enable recording untuned of TunableOp perations for offline tuning.
- When enabled, if a tuned entry isn't found, write it to the untuned file.
- """
- torch._C._cuda_record_untuned_enable(val) # type: ignore[attr-defined]
- def record_untuned_is_enabled() -> bool:
- r"""Returns whether TunableOp operations are recorded for offline tuning."""
- return torch._C._cuda_record_untuned_is_enabled() # type: ignore[attr-defined]
- def set_max_tuning_duration(duration: int) -> None:
- r"""Set max time in milliseconds to spend tuning a given solution.
- If both max tuning duration and iterations are set, the smaller of the two
- will be honored. At minimum 1 tuning iteration will always be run.
- """
- torch._C._cuda_tunableop_set_max_tuning_duration(duration) # type: ignore[attr-defined]
- def get_max_tuning_duration() -> int:
- r"""Get max time to spend tuning a given solution."""
- return torch._C._cuda_tunableop_get_max_tuning_duration() # type: ignore[attr-defined]
- def set_max_tuning_iterations(iterations: int) -> None:
- r"""Set max number of iterations to spend tuning a given solution.
- If both max tuning duration and iterations are set, the smaller of the two
- will be honored. At minimum 1 tuning iteration will always be run.
- """
- torch._C._cuda_tunableop_set_max_tuning_iterations(iterations) # type: ignore[attr-defined]
- def get_max_tuning_iterations() -> int:
- r"""Get max iterations to spend tuning a given solution."""
- return torch._C._cuda_tunableop_get_max_tuning_iterations() # type: ignore[attr-defined]
- def set_filename(filename: str, insert_device_ordinal: bool = False) -> None:
- r"""Set the filename to use for input/output of tuning results.
- If :attr:`insert_device_ordinal` is ``True`` then the current device ordinal
- will be added to the given filename automatically. This can be used in a
- 1-process-per-gpu scenario to ensure all processes write to a separate file.
- """
- torch._C._cuda_tunableop_set_filename(filename, insert_device_ordinal) # type: ignore[attr-defined]
- def get_filename() -> str:
- r"""Get the results filename."""
- return torch._C._cuda_tunableop_get_filename() # type: ignore[attr-defined]
- def get_results() -> tuple[str, str, str, float]:
- r"""Return all TunableOp results."""
- return torch._C._cuda_tunableop_get_results() # type: ignore[attr-defined]
- def get_validators() -> tuple[str, str]:
- r"""Return the TunableOp validators."""
- return torch._C._cuda_tunableop_get_validators() # type: ignore[attr-defined]
- def read_file(filename: str | None = None) -> bool:
- r"""Read results from a TunableOp CSV file.
- If :attr:`filename` is not given, ``get_filename()`` is called.
- """
- if filename is None:
- filename = get_filename()
- return torch._C._cuda_tunableop_read_file(filename) # type: ignore[attr-defined]
- def set_rotating_buffer_size(buffer_size: int) -> None:
- r"""Set rotating buffer size to this value in MB, if the buffer size is greater than zero.
- If less than zero, query L2 cache size. If equal to zero, means deactivate rotating buffer.
- """
- return torch._C._cuda_tunableop_set_rotating_buffer_size(buffer_size) # type: ignore[attr-defined]
- def get_rotating_buffer_size() -> int:
- r"""Get the rotating buffer size in kilobytes."""
- return torch._C._cuda_tunableop_get_rotating_buffer_size() # type: ignore[attr-defined]
- def set_numerical_check_tolerances(
- enable: bool, atol: float = 1e-5, rtol: float = 1e-5
- ) -> None:
- r"""Set the atol and rtol values in numeric check"""
- return torch._C._cuda_tunableop_set_numerical_check_tolerances(enable, atol, rtol) # type: ignore[attr-defined]
- def tune_gemm_in_file(filename: str) -> None:
- r"""tune GEMM in file."""
- if not is_enabled():
- raise AssertionError("TunableOp is not enabled")
- if not tuning_is_enabled():
- raise AssertionError("Tuning is not enabled")
- deviceid = torch.cuda.current_device()
- with open(filename) as file:
- for line in file:
- if line.startswith(("Gemm", "ScaledGemm")):
- _process_single_offline_gemm(line, deviceid)
- def _gather_unique_untuned_gemm_from_files(filename_pattern: str) -> set[str]:
- r"""Process multiple untuned results file and return a set with duplicates removed."""
- unique_gemm_entries = set() # set will avoid duplicates
- for file_path in glob.glob(filename_pattern):
- with open(file_path) as file:
- for line in file:
- if line.startswith(("Gemm", "ScaledGemm")):
- unique_gemm_entries.add(line)
- return unique_gemm_entries
- def _gather_tunableop_results() -> None:
- r"""Gather results from multiple tunableop results file and create a single file."""
- gemm_lines = set()
- validator_lines = []
- # Need to allow for the possibility that results filename was
- # set with the Python API instead of with environment variable.
- # Also possible that results filename was not set at all.
- # There are several test cases to check, but ultimately we
- # need a glob-able expression
- results_filename = get_filename() # Note empty string could be returned here
- if (
- results_filename is not None and results_filename != ""
- ): # Case were the Python API was used to set the filename
- dot_pos = results_filename.find(".")
- if dot_pos != -1 and dot_pos > 0:
- # Replace the character just to the left of the dot
- filename_pattern = (
- results_filename[: dot_pos - 1] + "?" + results_filename[dot_pos:]
- )
- else:
- filename_pattern = "" # Needed to make linter happy
- else: # Case where the environment variable was used to set the filename.
- results_filename_env = os.getenv("PYTORCH_TUNABLEOP_FILENAME")
- if results_filename_env is None or results_filename_env == "":
- filename_pattern = "tunableop_results?.csv"
- elif "%d" in results_filename_env:
- filename_pattern = results_filename_env.replace("%d", "?")
- else:
- filename_pattern = results_filename_env.replace(".", "?.")
- if "?" not in filename_pattern:
- raise AssertionError(
- f"filename_pattern must contain '?', got {filename_pattern!r}"
- )
- FirstFile = False
- matching_files = glob.glob(filename_pattern)
- num_matching_files = len(matching_files)
- for file_path in matching_files:
- with open(file_path) as file:
- for line in file:
- if line.startswith("Validator"):
- if not (FirstFile):
- # Only read Validator from first file
- validator_lines.append(line)
- else:
- gemm_lines.add(line)
- FirstFile = True
- output_file = filename_pattern.replace("?", "_full0")
- with open(output_file, "w") as out_file:
- for line in validator_lines:
- out_file.write(line)
- for line in gemm_lines:
- out_file.write(line)
- # Create num_matching_copies of the results file
- for i in range(1, num_matching_files):
- duplicate_file = output_file.replace("0", str(i))
- shutil.copy(output_file, duplicate_file)
- def _create_matrices(
- m: int,
- n: int,
- k: int,
- lda: int,
- ldb: int,
- ldc: int,
- transA: bool,
- transB: bool,
- dtypeA: torch.dtype,
- deviceid: str,
- dtypeB: torch.dtype | None = None,
- randn: bool = True,
- subMatrix: bool = False,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- r"""Helper function for _process_single_offline_gemm.
- Creates matrices that are then consumed by one of the Torch GEMM APIs.
- """
- # Fill parameters set for use with ScaledGEMM
- fillA = 0.25
- fillB = 0.75
- if dtypeB is None:
- dtypeB = dtypeA
- if subMatrix:
- # User reference for understanding leading dimension:
- # https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/dgemm.f
- # TO DO: According to lines 108 - 133, there is no lower bound on rowsA,
- # but there is a restriction on rowsB. Using this formula for now as it
- # seems to work for all UTs.
- rowsA = rowsB = max(ldc, k)
- if randn:
- matA = torch.randn(rowsA, lda, dtype=dtypeA, device=deviceid)
- matB = torch.randn(rowsB, ldb, dtype=dtypeA, device=deviceid)
- else:
- matA = torch.full((rowsA, lda), fillA, dtype=dtypeB, device=deviceid)
- matB = torch.full((rowsB, ldb), fillB, dtype=dtypeB, device=deviceid)
- subA = matA[:k, :m].t() if transA else matA[:m, :k]
- subB = matB[:n, :k].t() if transB else matB[:k, :n]
- return subA, subB
- else:
- if randn:
- matA = (
- torch.rand(k, m, dtype=dtypeA, device=deviceid).t()
- if transA
- else torch.rand(m, k, dtype=dtypeA, device=deviceid)
- )
- matB = (
- torch.rand(n, k, dtype=dtypeB, device=deviceid).t()
- if transB
- else torch.rand(k, n, dtype=dtypeB, device=deviceid)
- )
- else:
- matA = (
- torch.full((k, m), fillA, dtype=dtypeA, device=deviceid).t()
- if transA
- else torch.full((m, k), fillA, dtype=dtypeA, device=deviceid)
- )
- matB = (
- torch.full((n, k), fillB, dtype=dtypeB, device=deviceid).t()
- if transB
- else torch.full((k, n), fillB, dtype=dtypeB, device=deviceid)
- )
- return matA, matB
- def _create_batch_matrices(
- m: int,
- n: int,
- k: int,
- b: int,
- lda: int,
- ldb: int,
- ldc: int,
- transA: bool,
- transB: bool,
- dtype: torch.dtype,
- deviceid: str,
- subMatrix: bool = False,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- r"""Helper function for _process_single_offline_gemm.
- Creates batch matrices that are then consumed by one of the Torch GEMM APIs.
- Similar to _create_matrices but for 3D batch matrices.
- """
- if subMatrix:
- # User reference for understanding leading dimension:
- # https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/dgemm.f
- # TO DO: According to lines 108 - 133, there is no lower bound on rowsA,
- # but there is a restriction on rowsB. Using this formula for now as it
- # seems to work for all UTs.
- rowsA = rowsB = max(ldc, k)
- matA = torch.randn(b, rowsA, lda, dtype=dtype, device=deviceid)
- matB = torch.randn(b, rowsB, ldb, dtype=dtype, device=deviceid)
- subA = matA[:b, :k, :m].transpose(1, 2) if transA else matA[:b, :m, :k]
- subB = matB[:b, :n, :k].transpose(1, 2) if transB else matB[:b, :k, :n]
- return subA, subB
- else:
- matA = (
- torch.rand(b, k, m, dtype=dtype, device=deviceid)
- if transA
- else torch.rand(b, m, k, dtype=dtype, device=deviceid)
- )
- matB = (
- torch.rand(b, n, k, dtype=dtype, device=deviceid)
- if transB
- else torch.rand(b, k, n, dtype=dtype, device=deviceid)
- )
- matA = matA.transpose(1, 2) if transA else matA
- matB = matB.transpose(1, 2) if transB else matB
- return matA, matB
- def _process_single_offline_gemm(untuned_gemm_line: str, gpu_id: int) -> None:
- r"""Process a single untuned GEMM."""
- deviceid = "cuda:" + str(gpu_id)
- dtype_dict = {
- "float": torch.float32,
- "tf32": torch.float32,
- "double": torch.float64,
- "BFloat16": torch.bfloat16,
- "Half": torch.half,
- "c10::complex<double>": torch.complex128,
- "c10::complex<float>": torch.complex64,
- "Float8_e4m3fn": torch.float8_e4m3fn,
- "Float8_e5m2": torch.float8_e5m2,
- "Float8_e4m3fnuz": torch.float8_e4m3fnuz,
- "Float8_e5m2fnuz": torch.float8_e5m2fnuz,
- }
- untuned_gemm = untuned_gemm_line.strip().split(",")[:]
- underscore_count = untuned_gemm[0].count("_")
- # Initialize dtype to make linter happy
- dtype = None
- dtypeA = None
- dtypeB = None
- dtypeC = None
- # Extract BLAS parameters
- if underscore_count == 2:
- [op_sig, data_type, layout] = untuned_gemm[0].split("_")
- transB = layout[0] == "T"
- transA = layout[1] == "T"
- dtype = dtype_dict.get(data_type)
- if data_type == "tf32":
- torch.backends.cuda.matmul.allow_tf32 = True
- else:
- torch.backends.cuda.matmul.allow_tf32 = False
- else: # ScaledGEMM
- count = untuned_gemm[0].count("_")
- if count not in [6, 7]:
- raise AssertionError(f"count must be 6 or 7, got {count}")
- untuned_gemm_temp = untuned_gemm[0].split("_")
- # dtypeC = might not be FP8 type, keep track
- # of the number of underscores
- op_sig = untuned_gemm_temp[0]
- data_typeA = untuned_gemm_temp[1] + "_" + untuned_gemm_temp[2]
- data_typeB = untuned_gemm_temp[3] + "_" + untuned_gemm_temp[4]
- if count == 7:
- data_typeC = untuned_gemm_temp[5] + "_" + untuned_gemm_temp[6]
- else:
- data_typeC = untuned_gemm_temp[5]
- transB = untuned_gemm_temp[count][0] == "T"
- transA = untuned_gemm_temp[count][1] == "T"
- dtypeA = dtype_dict.get(data_typeA)
- dtypeB = dtype_dict.get(data_typeB)
- dtypeC = dtype_dict.get(data_typeC)
- untuned_gemm_temp = untuned_gemm[1].split("_")
- [n, m, k] = [int(g) for g in untuned_gemm_temp[1:4]]
- if op_sig == "GemmStridedBatchedTunableOp":
- if untuned_gemm_temp[6] != "ld":
- raise AssertionError(
- f"expected 'ld' at index 6, got {untuned_gemm_temp[6]!r}"
- )
- [ldb, lda, ldc] = [int(g) for g in untuned_gemm_temp[7:10]]
- else:
- if untuned_gemm_temp[4] != "ld":
- raise AssertionError(
- f"expected 'ld' at index 4, got {untuned_gemm_temp[4]!r}"
- )
- [ldb, lda, ldc] = [int(g) for g in untuned_gemm_temp[5:8]]
- # Detect subMatrix case
- if all(item in [n, m, k] for item in [lda, ldb, ldc]):
- subMatrix = False
- else:
- subMatrix = True
- if op_sig == "GemmTunableOp":
- # Warnings for unsupported cases:
- if m == 1 or n == 1 or k == 1:
- if (not transA) and (not transB):
- pass # case is supported
- elif transA and n == 1:
- pass # case is supported
- else:
- warnings.warn(
- "Offline tuning is not supported for this GEMM. Use online tuning instead. "
- + f"Skipped tuning for: {untuned_gemm[1]}",
- stacklevel=2,
- )
- return
- # Resolve linter issue
- if dtype is None or not isinstance(dtype, torch.dtype):
- raise TypeError(f"dtype must be a torch.dtype, but got {dtype}")
- matA, matB = _create_matrices(
- m, n, k, lda, ldb, ldc, transA, transB, dtype, deviceid, subMatrix=subMatrix
- )
- torch.mm(matA, matB)
- elif op_sig == "GemmStridedBatchedTunableOp":
- # Warnings for unsupported cases:
- if m == 1 or n == 1 or k == 1:
- warnings.warn(
- "Offline tuning is not support for this GEMM. Use online tuning instead. "
- + f"Skipped tuning for: {untuned_gemm[1]}",
- stacklevel=2,
- )
- return
- [b] = [int(g) for g in untuned_gemm_temp[5:6]]
- # Resolve linter issue
- if dtype is None or not isinstance(dtype, torch.dtype):
- raise TypeError(f"dtype must be a torch.dtype, but got {dtype}")
- matA, matB = _create_batch_matrices(
- m,
- n,
- k,
- b,
- lda,
- ldb,
- ldc,
- transA,
- transB,
- dtype,
- deviceid,
- subMatrix=subMatrix,
- )
- torch.bmm(matA, matB)
- elif op_sig == "ScaledGemmTunableOp":
- # Only combination supported by PyTorch
- if transB is not True:
- raise AssertionError(
- f"transB must be True for ScaledGemmTunableOp, got {transB}"
- )
- if transA is not False:
- raise AssertionError(
- f"transA must be False for ScaledGemmTunableOp, got {transA}"
- )
- # Resolve linter issue
- if dtypeA is None or not isinstance(dtypeA, torch.dtype):
- raise TypeError(f"dtype must be a torch.dtype, but got {dtypeA}")
- matA, matB = _create_matrices(
- m,
- n,
- k,
- lda,
- ldb,
- ldc,
- transA,
- transB,
- dtypeA,
- deviceid,
- dtypeB=dtypeB,
- randn=False,
- subMatrix=subMatrix,
- )
- if untuned_gemm_temp[8] != "rw":
- raise AssertionError(
- f"expected 'rw' at index 8, got {untuned_gemm_temp[8]!r}"
- )
- if untuned_gemm_temp[9] == "1":
- rowwise = True
- else:
- rowwise = False
- if rowwise:
- scaleA = (
- torch.ones((1, m), device=deviceid)
- if transA
- else torch.ones((m, 1), device=deviceid)
- )
- scaleB = (
- torch.ones((1, n), device=deviceid)
- if transB
- else torch.ones((n, 1), device=deviceid)
- )
- else:
- scaleA = torch.tensor(0.8, device=deviceid)
- scaleB = torch.tensor(0.9, device=deviceid)
- if untuned_gemm_temp[10] != "bias":
- raise AssertionError(
- f"expected 'bias' at index 10, got {untuned_gemm_temp[10]!r}"
- )
- if untuned_gemm_temp[11] == "None": # no bias vector
- torch._scaled_mm(
- matA, matB, scale_a=scaleA, scale_b=scaleB, out_dtype=dtypeC
- )
- else: # bias vector present
- fillbias = 0.10
- bias_dtype = dtype_dict.get(untuned_gemm_temp[11])
- bias = (
- torch.full((n,), fillbias, dtype=bias_dtype, device=deviceid)
- if transB
- else torch.full((m,), fillbias, dtype=bias_dtype, device=deviceid)
- )
- torch._scaled_mm(
- matA, matB, scale_a=scaleA, scale_b=scaleB, out_dtype=dtypeC, bias=bias
- )
- elif op_sig == "GemmAndBiasTunableOp":
- # y = x*A^T + b
- if transA == transB:
- raise AssertionError(
- f"transA and transB must differ for GemmAndBiasTunableOp, got transA={transA}, transB={transB}"
- )
- # Resolve linter issue
- if dtype is None or not isinstance(dtype, torch.dtype):
- raise TypeError(f"dtype must be a torch.dtype, but got {dtype}")
- bias = torch.rand(n, dtype=dtype, device=deviceid)
- X, matA = _create_matrices(
- m, n, k, lda, ldb, ldc, transA, transB, dtype, deviceid, subMatrix=subMatrix
- )
- matA = matA.t()
- torch.nn.functional.linear(X, matA, bias)
- else:
- warnings.warn(f"error: unknown op {op_sig}", stacklevel=2)
- def _check_tuning_assertions() -> None:
- r"""Helper function for multi-GPU tuning case. Need to check that TunableOp feature
- is enabled and that tuning is enabled.
- """
- if is_enabled() is False:
- warnings.warn("TunableOp was disabled. Trying to enable now.", stacklevel=2)
- enable(True)
- if is_enabled() is not True:
- raise AssertionError("is_enabled() must be True")
- if tuning_is_enabled() is not True:
- raise AssertionError("tuning_is_enabled() must be True")
- if record_untuned_is_enabled() is not False:
- raise AssertionError("record_untuned_is_enabled() must be False")
- def mgpu_tune_gemm_in_file(filename_pattern: str, num_gpus: int) -> None:
- r"""Process one or more files and distribute work over one or more GPUs."""
- unique_gemm_entries = _gather_unique_untuned_gemm_from_files(filename_pattern)
- total_gpus = torch.cuda.device_count()
- if not (1 <= num_gpus <= total_gpus):
- raise AssertionError(
- f"num_gpus must be between 1 and {total_gpus}, got {num_gpus}"
- )
- mp_context = mp.get_context("spawn")
- futures = [] # empty list to hold futures
- # GEMM are assigned to GPUs in a round robin manner
- h = 0
- with concurrent.futures.ProcessPoolExecutor(
- max_workers=num_gpus,
- mp_context=mp_context,
- initializer=_check_tuning_assertions,
- ) as executor:
- # The workers are a separate process. TunableOp will be
- # enabled in the child processes if PYTORCH_TUNABLEOP_ENABLED=1
- # In the initializer, we also try to enable TunableOP if th
- # environment variable was NOT set.
- for line in unique_gemm_entries:
- future = executor.submit(_process_single_offline_gemm, line, h)
- futures.append(future)
- h = (h + 1) % num_gpus
- for future in concurrent.futures.as_completed(futures):
- future.result()
- torch.cuda.synchronize()
- _gather_tunableop_results()
|