selector.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. from __future__ import annotations
  2. from collections import defaultdict
  3. from collections.abc import Iterable
  4. from dataclasses import dataclass
  5. from typing import TYPE_CHECKING
  6. import yaml
  7. from torchgen.selective_build.operator import (
  8. merge_debug_info,
  9. merge_operator_dicts,
  10. SelectiveBuildOperator,
  11. strip_operator_overload_name,
  12. )
  13. if TYPE_CHECKING:
  14. from torchgen.model import NativeFunction
  15. # A SelectiveBuilder holds information extracted from the selective build
  16. # YAML specification.
  17. #
  18. # It includes information about the build's selectivity, the debug_info
  19. # associated with this selective build (opaque string), and the set of
  20. # operators that should be included in the build.
  21. #
  22. @dataclass(frozen=True)
  23. class SelectiveBuilder:
  24. # If true, then the build is not selective, and includes all
  25. # operators.
  26. include_all_operators: bool
  27. # Debug Information at the selective/custom build level.
  28. _debug_info: tuple[str, ...] | None
  29. # A dictionary of operator -> operator metadata.
  30. operators: dict[str, SelectiveBuildOperator]
  31. # A dictionary of selected kernel tags and dtypes. Typically a
  32. # PyTorch Operator Kernel (function) may have many code paths
  33. # that are specialized for many many Tensor dtypes, so it's not
  34. # one per kernel function, but there could be many per kernel
  35. # function. The tag isn't a kernel function name, but some fragment
  36. # of the kernel function implementation itself.
  37. kernel_metadata: dict[str, list[str]]
  38. # ExecuTorch only. A dictionary of kernel tag -> list of (list of input
  39. # dtypes for tensor-like input args).
  40. # This is from selective.yaml
  41. et_kernel_metadata: dict[str, list[str]]
  42. # A set of all the custom torch bind classes used by the selected models
  43. # Stored as a set internally to remove duplicates proactively, but written
  44. # as a list to yamls
  45. custom_classes: set[str]
  46. # A set of all the build features used by the selected models
  47. # Stored as a set internally to remove duplicates proactively, but written
  48. # as a list to yamls
  49. build_features: set[str]
  50. # If true, then fragments for all dtypes for all kernel functions
  51. # are included as well as all custom classes. This is typically set when any one of the
  52. # operator lists is generated from a mechanism other than
  53. # tracing based selective build.
  54. include_all_non_op_selectives: bool
  55. @staticmethod
  56. def get_nop_selector() -> SelectiveBuilder:
  57. return SelectiveBuilder.from_yaml_dict({"include_all_operators": True})
  58. @staticmethod
  59. def from_yaml_dict(data: dict[str, object]) -> SelectiveBuilder:
  60. valid_top_level_keys = {
  61. "include_all_non_op_selectives",
  62. "include_all_operators",
  63. "debug_info",
  64. "operators",
  65. "kernel_metadata",
  66. "et_kernel_metadata",
  67. "custom_classes",
  68. "build_features",
  69. }
  70. top_level_keys = set(data.keys())
  71. if len(top_level_keys - valid_top_level_keys) > 0:
  72. raise Exception( # noqa: TRY002
  73. "Got unexpected top level keys: {}".format(
  74. ",".join(top_level_keys - valid_top_level_keys),
  75. )
  76. )
  77. include_all_operators = data.get("include_all_operators", False)
  78. if not isinstance(include_all_operators, bool):
  79. raise AssertionError(
  80. f"Expected 'include_all_operators' to be bool, got {type(include_all_operators)}"
  81. )
  82. debug_info = None
  83. if "debug_info" in data:
  84. di_list = data["debug_info"]
  85. if not isinstance(di_list, list):
  86. raise AssertionError(
  87. f"Expected 'debug_info' to be list, got {type(di_list)}"
  88. )
  89. debug_info = tuple(str(x) for x in di_list)
  90. operators = {}
  91. operators_dict = data.get("operators", {})
  92. if not isinstance(operators_dict, dict):
  93. raise AssertionError(
  94. f"Expected 'operators' to be dict, got {type(operators_dict)}"
  95. )
  96. for k, v in operators_dict.items():
  97. operators[k] = SelectiveBuildOperator.from_yaml_dict(k, v)
  98. kernel_metadata = {}
  99. kernel_metadata_dict = data.get("kernel_metadata", {})
  100. if not isinstance(kernel_metadata_dict, dict):
  101. raise AssertionError(
  102. f"Expected 'kernel_metadata' to be dict, got {type(kernel_metadata_dict)}"
  103. )
  104. for k, v in kernel_metadata_dict.items():
  105. kernel_metadata[str(k)] = [str(dtype) for dtype in v]
  106. et_kernel_metadata = data.get("et_kernel_metadata", {})
  107. if not isinstance(et_kernel_metadata, dict):
  108. raise AssertionError(
  109. f"Expected 'et_kernel_metadata' to be dict, got {type(et_kernel_metadata)}"
  110. )
  111. custom_classes = data.get("custom_classes", [])
  112. if not isinstance(custom_classes, Iterable):
  113. raise AssertionError(
  114. f"Expected 'custom_classes' to be Iterable, got {type(custom_classes)}"
  115. )
  116. custom_classes = set(custom_classes)
  117. build_features = data.get("build_features", [])
  118. if not isinstance(build_features, Iterable):
  119. raise AssertionError(
  120. f"Expected 'build_features' to be Iterable, got {type(build_features)}"
  121. )
  122. build_features = set(build_features)
  123. include_all_non_op_selectives = data.get("include_all_non_op_selectives", False)
  124. if not isinstance(include_all_non_op_selectives, bool):
  125. raise AssertionError(
  126. f"Expected 'include_all_non_op_selectives' to be bool, "
  127. f"got {type(include_all_non_op_selectives)}"
  128. )
  129. return SelectiveBuilder(
  130. include_all_operators,
  131. debug_info,
  132. operators,
  133. kernel_metadata,
  134. et_kernel_metadata,
  135. custom_classes, # type: ignore[arg-type]
  136. build_features, # type: ignore[arg-type]
  137. include_all_non_op_selectives,
  138. )
  139. @staticmethod
  140. def from_yaml_str(config_contents: str) -> SelectiveBuilder:
  141. contents = yaml.safe_load(config_contents)
  142. return SelectiveBuilder.from_yaml_dict(contents)
  143. @staticmethod
  144. def from_yaml_path(config_path: str) -> SelectiveBuilder:
  145. with open(config_path) as f:
  146. contents = yaml.safe_load(f)
  147. return SelectiveBuilder.from_yaml_dict(contents)
  148. @staticmethod
  149. def from_legacy_op_registration_allow_list(
  150. allow_list: set[str], is_root_operator: bool, is_used_for_training: bool
  151. ) -> SelectiveBuilder:
  152. operators = {}
  153. for op in allow_list:
  154. operators[op] = {
  155. "name": op,
  156. "is_root_operator": is_root_operator,
  157. "is_used_for_training": is_used_for_training,
  158. "include_all_overloads": True,
  159. }
  160. return SelectiveBuilder.from_yaml_dict(
  161. {
  162. "operators": operators,
  163. "include_all_non_op_selectives": True,
  164. }
  165. )
  166. def is_operator_selected(self, name: str) -> bool:
  167. if self.include_all_operators:
  168. return True
  169. if name in self.operators:
  170. return True
  171. name = strip_operator_overload_name(name)
  172. return name in self.operators and self.operators[name].include_all_overloads
  173. def is_native_function_selected(self, func: NativeFunction) -> bool:
  174. op_name = op_name_from_native_function(func)
  175. return self.is_operator_selected(op_name)
  176. def is_operator_selected_for_training(self, name: str) -> bool:
  177. if not self.is_operator_selected(name):
  178. return False
  179. if self.include_all_operators:
  180. return True
  181. not_training_op = SelectiveBuildOperator(
  182. name="",
  183. is_root_operator=False,
  184. is_used_for_training=False,
  185. include_all_overloads=False,
  186. _debug_info=None,
  187. )
  188. op = not_training_op
  189. if name in self.operators:
  190. op = self.operators[name]
  191. name = strip_operator_overload_name(name)
  192. base_op = not_training_op
  193. if name in self.operators:
  194. base_op = self.operators[name]
  195. return op.is_used_for_training or (
  196. base_op.include_all_overloads and base_op.is_used_for_training
  197. )
  198. def is_native_function_selected_for_training(self, func: NativeFunction) -> bool:
  199. op_name = op_name_from_native_function(func)
  200. return self.is_operator_selected_for_training(op_name)
  201. def is_root_operator(self, name: str) -> bool:
  202. if not self.is_operator_selected(name):
  203. return False
  204. if self.include_all_operators:
  205. return True
  206. if name in self.operators:
  207. op: SelectiveBuildOperator = self.operators[name]
  208. return op.is_root_operator
  209. name = strip_operator_overload_name(name)
  210. if name not in self.operators:
  211. return False
  212. base_op: SelectiveBuildOperator = self.operators[name]
  213. return base_op.include_all_overloads and base_op.is_root_operator
  214. def is_kernel_dtype_selected(self, kernel_tag: str, dtype: str) -> bool:
  215. if self.include_all_operators or self.include_all_non_op_selectives:
  216. return True
  217. return (
  218. kernel_tag in self.kernel_metadata
  219. and dtype in self.kernel_metadata[kernel_tag]
  220. )
  221. def et_get_selected_kernels(self, op_name: str, kernel_key: list[str]) -> list[str]:
  222. """
  223. Return a list of kernel keys that cover the used ops
  224. """
  225. # If no kernel metadata, either it's implied by include_all_operators=True or the op is not used.
  226. if op_name not in self.et_kernel_metadata:
  227. return kernel_key if self.include_all_operators else []
  228. # Otherwise, only return the specific kernel keys.
  229. result_set = set()
  230. for model_kernel_keys in self.et_kernel_metadata[op_name]:
  231. key_found = False
  232. for key in kernel_key:
  233. # Don't compare the version for now
  234. if (
  235. key != "default"
  236. and key.split("/")[1] == model_kernel_keys.split("/")[1]
  237. ):
  238. result_set.add(key)
  239. key_found = True
  240. break
  241. if not key_found:
  242. if "default" not in kernel_key:
  243. raise Exception("Missing kernel for the model") # noqa: TRY002
  244. else:
  245. result_set.add("default")
  246. return list(result_set)
  247. def to_dict(self) -> dict[str, object]:
  248. ret: dict[str, object] = {
  249. "include_all_non_op_selectives": self.include_all_non_op_selectives,
  250. "include_all_operators": self.include_all_operators,
  251. }
  252. operators = {}
  253. for op_name, op in self.operators.items():
  254. operators[op_name] = op.to_dict()
  255. ret["operators"] = operators
  256. if self._debug_info is not None:
  257. ret["debug_info"] = sorted(self._debug_info)
  258. ret["kernel_metadata"] = {
  259. k: sorted(v) for (k, v) in self.kernel_metadata.items()
  260. }
  261. ret["et_kernel_metadata"] = self.et_kernel_metadata
  262. ret["custom_classes"] = sorted(self.custom_classes)
  263. ret["build_features"] = sorted(self.build_features)
  264. return ret
  265. def merge_kernel_metadata(
  266. lhs: dict[str, list[str]],
  267. rhs: dict[str, list[str]],
  268. ) -> dict[str, list[str]]:
  269. kernel_metadata: dict[str, list[str]] = {}
  270. for tag_name, dtypes in list(lhs.items()) + list(rhs.items()):
  271. dtypes_copy = set(dtypes)
  272. if tag_name in kernel_metadata:
  273. dtypes_copy |= set(kernel_metadata[tag_name])
  274. kernel_metadata[tag_name] = list(dtypes_copy)
  275. return kernel_metadata
  276. def merge_et_kernel_metadata(
  277. lhs: dict[str, list[str]],
  278. rhs: dict[str, list[str]],
  279. ) -> dict[str, list[str]]:
  280. merge_et_kernel_metadata: dict[str, set[str]] = defaultdict(set)
  281. for op in list(lhs.keys()) + list(rhs.keys()):
  282. merge_et_kernel_metadata[op].update(lhs.get(op, []))
  283. merge_et_kernel_metadata[op].update(rhs.get(op, []))
  284. return {op: sorted(val) for op, val in merge_et_kernel_metadata.items()}
  285. def combine_selective_builders(
  286. lhs: SelectiveBuilder, rhs: SelectiveBuilder
  287. ) -> SelectiveBuilder:
  288. include_all_operators = lhs.include_all_operators or rhs.include_all_operators
  289. debug_info = merge_debug_info(lhs._debug_info, rhs._debug_info)
  290. operators = merge_operator_dicts(lhs.operators, rhs.operators)
  291. kernel_metadata = merge_kernel_metadata(lhs.kernel_metadata, rhs.kernel_metadata)
  292. et_kernel_metadata = merge_et_kernel_metadata(
  293. lhs.et_kernel_metadata, rhs.et_kernel_metadata
  294. )
  295. include_all_non_op_selectives = (
  296. lhs.include_all_non_op_selectives or rhs.include_all_non_op_selectives
  297. )
  298. custom_classes = lhs.custom_classes.union(rhs.custom_classes)
  299. build_features = lhs.build_features.union(rhs.build_features)
  300. return SelectiveBuilder(
  301. include_all_operators,
  302. debug_info,
  303. operators,
  304. kernel_metadata,
  305. et_kernel_metadata,
  306. custom_classes,
  307. build_features,
  308. include_all_non_op_selectives,
  309. )
  310. def op_name_from_native_function(f: NativeFunction) -> str:
  311. # This was originally read from the 'operator_name_with_overload' field in the
  312. # declaration dict, which was the part before the first '(' in 'schema_string'.
  313. return f"{f.namespace}::{f.func.name}"