| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377 |
- from __future__ import annotations
- from collections import defaultdict
- from collections.abc import Iterable
- from dataclasses import dataclass
- from typing import TYPE_CHECKING
- import yaml
- from torchgen.selective_build.operator import (
- merge_debug_info,
- merge_operator_dicts,
- SelectiveBuildOperator,
- strip_operator_overload_name,
- )
- if TYPE_CHECKING:
- from torchgen.model import NativeFunction
- # A SelectiveBuilder holds information extracted from the selective build
- # YAML specification.
- #
- # It includes information about the build's selectivity, the debug_info
- # associated with this selective build (opaque string), and the set of
- # operators that should be included in the build.
- #
- @dataclass(frozen=True)
- class SelectiveBuilder:
- # If true, then the build is not selective, and includes all
- # operators.
- include_all_operators: bool
- # Debug Information at the selective/custom build level.
- _debug_info: tuple[str, ...] | None
- # A dictionary of operator -> operator metadata.
- operators: dict[str, SelectiveBuildOperator]
- # A dictionary of selected kernel tags and dtypes. Typically a
- # PyTorch Operator Kernel (function) may have many code paths
- # that are specialized for many many Tensor dtypes, so it's not
- # one per kernel function, but there could be many per kernel
- # function. The tag isn't a kernel function name, but some fragment
- # of the kernel function implementation itself.
- kernel_metadata: dict[str, list[str]]
- # ExecuTorch only. A dictionary of kernel tag -> list of (list of input
- # dtypes for tensor-like input args).
- # This is from selective.yaml
- et_kernel_metadata: dict[str, list[str]]
- # A set of all the custom torch bind classes used by the selected models
- # Stored as a set internally to remove duplicates proactively, but written
- # as a list to yamls
- custom_classes: set[str]
- # A set of all the build features used by the selected models
- # Stored as a set internally to remove duplicates proactively, but written
- # as a list to yamls
- build_features: set[str]
- # If true, then fragments for all dtypes for all kernel functions
- # are included as well as all custom classes. This is typically set when any one of the
- # operator lists is generated from a mechanism other than
- # tracing based selective build.
- include_all_non_op_selectives: bool
- @staticmethod
- def get_nop_selector() -> SelectiveBuilder:
- return SelectiveBuilder.from_yaml_dict({"include_all_operators": True})
- @staticmethod
- def from_yaml_dict(data: dict[str, object]) -> SelectiveBuilder:
- valid_top_level_keys = {
- "include_all_non_op_selectives",
- "include_all_operators",
- "debug_info",
- "operators",
- "kernel_metadata",
- "et_kernel_metadata",
- "custom_classes",
- "build_features",
- }
- top_level_keys = set(data.keys())
- if len(top_level_keys - valid_top_level_keys) > 0:
- raise Exception( # noqa: TRY002
- "Got unexpected top level keys: {}".format(
- ",".join(top_level_keys - valid_top_level_keys),
- )
- )
- include_all_operators = data.get("include_all_operators", False)
- if not isinstance(include_all_operators, bool):
- raise AssertionError(
- f"Expected 'include_all_operators' to be bool, got {type(include_all_operators)}"
- )
- debug_info = None
- if "debug_info" in data:
- di_list = data["debug_info"]
- if not isinstance(di_list, list):
- raise AssertionError(
- f"Expected 'debug_info' to be list, got {type(di_list)}"
- )
- debug_info = tuple(str(x) for x in di_list)
- operators = {}
- operators_dict = data.get("operators", {})
- if not isinstance(operators_dict, dict):
- raise AssertionError(
- f"Expected 'operators' to be dict, got {type(operators_dict)}"
- )
- for k, v in operators_dict.items():
- operators[k] = SelectiveBuildOperator.from_yaml_dict(k, v)
- kernel_metadata = {}
- kernel_metadata_dict = data.get("kernel_metadata", {})
- if not isinstance(kernel_metadata_dict, dict):
- raise AssertionError(
- f"Expected 'kernel_metadata' to be dict, got {type(kernel_metadata_dict)}"
- )
- for k, v in kernel_metadata_dict.items():
- kernel_metadata[str(k)] = [str(dtype) for dtype in v]
- et_kernel_metadata = data.get("et_kernel_metadata", {})
- if not isinstance(et_kernel_metadata, dict):
- raise AssertionError(
- f"Expected 'et_kernel_metadata' to be dict, got {type(et_kernel_metadata)}"
- )
- custom_classes = data.get("custom_classes", [])
- if not isinstance(custom_classes, Iterable):
- raise AssertionError(
- f"Expected 'custom_classes' to be Iterable, got {type(custom_classes)}"
- )
- custom_classes = set(custom_classes)
- build_features = data.get("build_features", [])
- if not isinstance(build_features, Iterable):
- raise AssertionError(
- f"Expected 'build_features' to be Iterable, got {type(build_features)}"
- )
- build_features = set(build_features)
- include_all_non_op_selectives = data.get("include_all_non_op_selectives", False)
- if not isinstance(include_all_non_op_selectives, bool):
- raise AssertionError(
- f"Expected 'include_all_non_op_selectives' to be bool, "
- f"got {type(include_all_non_op_selectives)}"
- )
- return SelectiveBuilder(
- include_all_operators,
- debug_info,
- operators,
- kernel_metadata,
- et_kernel_metadata,
- custom_classes, # type: ignore[arg-type]
- build_features, # type: ignore[arg-type]
- include_all_non_op_selectives,
- )
- @staticmethod
- def from_yaml_str(config_contents: str) -> SelectiveBuilder:
- contents = yaml.safe_load(config_contents)
- return SelectiveBuilder.from_yaml_dict(contents)
- @staticmethod
- def from_yaml_path(config_path: str) -> SelectiveBuilder:
- with open(config_path) as f:
- contents = yaml.safe_load(f)
- return SelectiveBuilder.from_yaml_dict(contents)
- @staticmethod
- def from_legacy_op_registration_allow_list(
- allow_list: set[str], is_root_operator: bool, is_used_for_training: bool
- ) -> SelectiveBuilder:
- operators = {}
- for op in allow_list:
- operators[op] = {
- "name": op,
- "is_root_operator": is_root_operator,
- "is_used_for_training": is_used_for_training,
- "include_all_overloads": True,
- }
- return SelectiveBuilder.from_yaml_dict(
- {
- "operators": operators,
- "include_all_non_op_selectives": True,
- }
- )
- def is_operator_selected(self, name: str) -> bool:
- if self.include_all_operators:
- return True
- if name in self.operators:
- return True
- name = strip_operator_overload_name(name)
- return name in self.operators and self.operators[name].include_all_overloads
- def is_native_function_selected(self, func: NativeFunction) -> bool:
- op_name = op_name_from_native_function(func)
- return self.is_operator_selected(op_name)
- def is_operator_selected_for_training(self, name: str) -> bool:
- if not self.is_operator_selected(name):
- return False
- if self.include_all_operators:
- return True
- not_training_op = SelectiveBuildOperator(
- name="",
- is_root_operator=False,
- is_used_for_training=False,
- include_all_overloads=False,
- _debug_info=None,
- )
- op = not_training_op
- if name in self.operators:
- op = self.operators[name]
- name = strip_operator_overload_name(name)
- base_op = not_training_op
- if name in self.operators:
- base_op = self.operators[name]
- return op.is_used_for_training or (
- base_op.include_all_overloads and base_op.is_used_for_training
- )
- def is_native_function_selected_for_training(self, func: NativeFunction) -> bool:
- op_name = op_name_from_native_function(func)
- return self.is_operator_selected_for_training(op_name)
- def is_root_operator(self, name: str) -> bool:
- if not self.is_operator_selected(name):
- return False
- if self.include_all_operators:
- return True
- if name in self.operators:
- op: SelectiveBuildOperator = self.operators[name]
- return op.is_root_operator
- name = strip_operator_overload_name(name)
- if name not in self.operators:
- return False
- base_op: SelectiveBuildOperator = self.operators[name]
- return base_op.include_all_overloads and base_op.is_root_operator
- def is_kernel_dtype_selected(self, kernel_tag: str, dtype: str) -> bool:
- if self.include_all_operators or self.include_all_non_op_selectives:
- return True
- return (
- kernel_tag in self.kernel_metadata
- and dtype in self.kernel_metadata[kernel_tag]
- )
- def et_get_selected_kernels(self, op_name: str, kernel_key: list[str]) -> list[str]:
- """
- Return a list of kernel keys that cover the used ops
- """
- # If no kernel metadata, either it's implied by include_all_operators=True or the op is not used.
- if op_name not in self.et_kernel_metadata:
- return kernel_key if self.include_all_operators else []
- # Otherwise, only return the specific kernel keys.
- result_set = set()
- for model_kernel_keys in self.et_kernel_metadata[op_name]:
- key_found = False
- for key in kernel_key:
- # Don't compare the version for now
- if (
- key != "default"
- and key.split("/")[1] == model_kernel_keys.split("/")[1]
- ):
- result_set.add(key)
- key_found = True
- break
- if not key_found:
- if "default" not in kernel_key:
- raise Exception("Missing kernel for the model") # noqa: TRY002
- else:
- result_set.add("default")
- return list(result_set)
- def to_dict(self) -> dict[str, object]:
- ret: dict[str, object] = {
- "include_all_non_op_selectives": self.include_all_non_op_selectives,
- "include_all_operators": self.include_all_operators,
- }
- operators = {}
- for op_name, op in self.operators.items():
- operators[op_name] = op.to_dict()
- ret["operators"] = operators
- if self._debug_info is not None:
- ret["debug_info"] = sorted(self._debug_info)
- ret["kernel_metadata"] = {
- k: sorted(v) for (k, v) in self.kernel_metadata.items()
- }
- ret["et_kernel_metadata"] = self.et_kernel_metadata
- ret["custom_classes"] = sorted(self.custom_classes)
- ret["build_features"] = sorted(self.build_features)
- return ret
- def merge_kernel_metadata(
- lhs: dict[str, list[str]],
- rhs: dict[str, list[str]],
- ) -> dict[str, list[str]]:
- kernel_metadata: dict[str, list[str]] = {}
- for tag_name, dtypes in list(lhs.items()) + list(rhs.items()):
- dtypes_copy = set(dtypes)
- if tag_name in kernel_metadata:
- dtypes_copy |= set(kernel_metadata[tag_name])
- kernel_metadata[tag_name] = list(dtypes_copy)
- return kernel_metadata
- def merge_et_kernel_metadata(
- lhs: dict[str, list[str]],
- rhs: dict[str, list[str]],
- ) -> dict[str, list[str]]:
- merge_et_kernel_metadata: dict[str, set[str]] = defaultdict(set)
- for op in list(lhs.keys()) + list(rhs.keys()):
- merge_et_kernel_metadata[op].update(lhs.get(op, []))
- merge_et_kernel_metadata[op].update(rhs.get(op, []))
- return {op: sorted(val) for op, val in merge_et_kernel_metadata.items()}
- def combine_selective_builders(
- lhs: SelectiveBuilder, rhs: SelectiveBuilder
- ) -> SelectiveBuilder:
- include_all_operators = lhs.include_all_operators or rhs.include_all_operators
- debug_info = merge_debug_info(lhs._debug_info, rhs._debug_info)
- operators = merge_operator_dicts(lhs.operators, rhs.operators)
- kernel_metadata = merge_kernel_metadata(lhs.kernel_metadata, rhs.kernel_metadata)
- et_kernel_metadata = merge_et_kernel_metadata(
- lhs.et_kernel_metadata, rhs.et_kernel_metadata
- )
- include_all_non_op_selectives = (
- lhs.include_all_non_op_selectives or rhs.include_all_non_op_selectives
- )
- custom_classes = lhs.custom_classes.union(rhs.custom_classes)
- build_features = lhs.build_features.union(rhs.build_features)
- return SelectiveBuilder(
- include_all_operators,
- debug_info,
- operators,
- kernel_metadata,
- et_kernel_metadata,
- custom_classes,
- build_features,
- include_all_non_op_selectives,
- )
- def op_name_from_native_function(f: NativeFunction) -> str:
- # This was originally read from the 'operator_name_with_overload' field in the
- # declaration dict, which was the part before the first '(' in 'schema_string'.
- return f"{f.namespace}::{f.func.name}"
|