utils.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523
  1. from __future__ import annotations
  2. import contextlib
  3. import functools
  4. import hashlib
  5. import os
  6. import re
  7. import sys
  8. import textwrap
  9. from dataclasses import is_dataclass
  10. from enum import auto, Enum
  11. from pathlib import Path
  12. from pprint import pformat
  13. from typing import Any, Generic, TYPE_CHECKING, TypeVar
  14. from typing_extensions import assert_never, Self
  15. from torchgen.code_template import CodeTemplate
  16. if TYPE_CHECKING:
  17. from argparse import Namespace
  18. from collections.abc import Callable, Iterable, Iterator, Sequence
  19. TORCHGEN_ROOT = Path(__file__).absolute().parent
  20. REPO_ROOT = TORCHGEN_ROOT.parent
  21. # Many of these functions share logic for defining both the definition
  22. # and declaration (for example, the function signature is the same), so
  23. # we organize them into one function that takes a Target to say which
  24. # code we want.
  25. #
  26. # This is an OPEN enum (we may add more cases to it in the future), so be sure
  27. # to explicitly specify with Literal[Target.XXX] or Literal[Target.XXX, Target.YYY]
  28. # what targets are valid for your use.
  29. class Target(Enum):
  30. # top level namespace (not including at)
  31. DEFINITION = auto()
  32. DECLARATION = auto()
  33. # TORCH_LIBRARY(...) { ... }
  34. REGISTRATION = auto()
  35. # namespace { ... }
  36. ANONYMOUS_DEFINITION = auto()
  37. # namespace cpu { ... }
  38. NAMESPACED_DEFINITION = auto()
  39. NAMESPACED_DECLARATION = auto()
  40. # Matches "foo" in "foo, bar" but not "foobar". Used to search for the
  41. # occurrence of a parameter in the derivative formula
  42. IDENT_REGEX = r"(^|\W){}($|\W)"
  43. # TODO: Use a real parser here; this will get bamboozled
  44. def split_name_params(schema: str) -> tuple[str, list[str]]:
  45. m = re.match(r"(\w+)(\.\w+)?\((.*)\)", schema)
  46. if m is None:
  47. raise RuntimeError(f"Unsupported function schema: {schema}")
  48. name, _, params = m.groups()
  49. return name, params.split(", ")
  50. T = TypeVar("T")
  51. S = TypeVar("S")
  52. # These two functions purposely return generators in analogy to map()
  53. # so that you don't mix up when you need to list() them
  54. # Map over function that may return None; omit Nones from output sequence
  55. def mapMaybe(func: Callable[[T], S | None], xs: Iterable[T]) -> Iterator[S]:
  56. for x in xs:
  57. r = func(x)
  58. if r is not None:
  59. yield r
  60. # Map over function that returns sequences and cat them all together
  61. def concatMap(func: Callable[[T], Sequence[S]], xs: Iterable[T]) -> Iterator[S]:
  62. for x in xs:
  63. yield from func(x)
  64. # Conveniently add error context to exceptions raised. Lets us
  65. # easily say that an error occurred while processing a specific
  66. # context.
  67. @contextlib.contextmanager
  68. def context(msg_fn: Callable[[], str]) -> Iterator[None]:
  69. try:
  70. yield
  71. except Exception as e:
  72. # TODO: this does the wrong thing with KeyError
  73. msg = msg_fn()
  74. msg = textwrap.indent(msg, " ")
  75. msg = f"{e.args[0]}\n{msg}" if e.args else msg
  76. e.args = (msg,) + e.args[1:]
  77. raise
  78. @functools.cache
  79. def _read_template(template_fn: str) -> CodeTemplate:
  80. return CodeTemplate.from_file(template_fn)
  81. # String hash that's stable across different executions, unlike builtin hash
  82. def string_stable_hash(s: str) -> int:
  83. sha1 = hashlib.sha1(s.encode("latin1"), usedforsecurity=False).digest()
  84. return int.from_bytes(sha1, byteorder="little")
  85. # A small abstraction for writing out generated files and keeping track
  86. # of what files have been written (so you can write out a list of output
  87. # files)
  88. class FileManager:
  89. def __init__(
  90. self,
  91. install_dir: str | Path,
  92. template_dir: str | Path,
  93. dry_run: bool,
  94. ) -> None:
  95. self.install_dir = Path(install_dir)
  96. self.template_dir = Path(template_dir)
  97. self.files: set[Path] = set()
  98. self.dry_run = dry_run
  99. @property
  100. def filenames(self) -> frozenset[str]:
  101. return frozenset({file.as_posix() for file in self.files})
  102. def _write_if_changed(self, filename: str | Path, contents: str) -> None:
  103. file = Path(filename)
  104. old_contents: str | None = None
  105. try:
  106. old_contents = file.read_text(encoding="utf-8")
  107. except OSError:
  108. pass
  109. if contents != old_contents:
  110. # Create output directory if it doesn't exist
  111. file.parent.mkdir(parents=True, exist_ok=True)
  112. file.write_text(contents, encoding="utf-8")
  113. # Read from template file and replace pattern with callable (type could be dict or str).
  114. def substitute_with_template(
  115. self,
  116. template_fn: str | Path,
  117. env_callable: Callable[[], str | dict[str, Any]],
  118. ) -> str:
  119. if Path(template_fn).is_absolute():
  120. raise AssertionError(f"template_fn must be relative: {template_fn}")
  121. template_path = self.template_dir / template_fn
  122. env = env_callable()
  123. if isinstance(env, dict):
  124. if "generated_comment" not in env:
  125. generator_default = TORCHGEN_ROOT / "gen.py"
  126. try:
  127. generator = Path(
  128. sys.modules["__main__"].__file__ or generator_default
  129. ).absolute()
  130. except (KeyError, AttributeError):
  131. generator = generator_default.absolute()
  132. try:
  133. generator_path = generator.relative_to(REPO_ROOT).as_posix()
  134. except ValueError:
  135. generator_path = generator.name
  136. env = {
  137. **env, # copy the original dict instead of mutating it
  138. "generated_comment": (
  139. "@" + f"generated by {generator_path} from {template_fn}"
  140. ),
  141. }
  142. template = _read_template(template_path)
  143. substitute_out = template.substitute(env)
  144. # Ensure an extra blank line between the class/function definition
  145. # and the docstring of the previous class/function definition.
  146. # NB: It is generally not recommended to have docstrings in pyi stub
  147. # files. But if there are any, we need to ensure that the file
  148. # is properly formatted.
  149. return re.sub(
  150. r'''
  151. (""")\n+ # match triple quotes
  152. (
  153. (\s*@.+\n)* # match decorators if any
  154. \s*(class|def) # match class/function definition
  155. )
  156. ''',
  157. r"\g<1>\n\n\g<2>",
  158. substitute_out,
  159. flags=re.VERBOSE,
  160. )
  161. if isinstance(env, str):
  162. return env
  163. assert_never(env)
  164. def write_with_template(
  165. self,
  166. filename: str | Path,
  167. template_fn: str | Path,
  168. env_callable: Callable[[], str | dict[str, Any]],
  169. ) -> None:
  170. filename = Path(filename)
  171. if filename.is_absolute():
  172. raise AssertionError(f"filename must be relative: {filename}")
  173. file = self.install_dir / filename
  174. if file in self.files:
  175. raise AssertionError(f"duplicate file write {file}")
  176. self.files.add(file)
  177. if not self.dry_run:
  178. substitute_out = self.substitute_with_template(
  179. template_fn=template_fn,
  180. env_callable=env_callable,
  181. )
  182. self._write_if_changed(filename=file, contents=substitute_out)
  183. def write(
  184. self,
  185. filename: str | Path,
  186. env_callable: Callable[[], str | dict[str, Any]],
  187. ) -> None:
  188. self.write_with_template(filename, filename, env_callable)
  189. def write_sharded(
  190. self,
  191. filename: str | Path,
  192. items: Iterable[T],
  193. *,
  194. key_fn: Callable[[T], str],
  195. env_callable: Callable[[T], dict[str, list[str]]],
  196. num_shards: int,
  197. base_env: dict[str, Any] | None = None,
  198. sharded_keys: set[str],
  199. ) -> None:
  200. self.write_sharded_with_template(
  201. filename,
  202. filename,
  203. items,
  204. key_fn=key_fn,
  205. env_callable=env_callable,
  206. num_shards=num_shards,
  207. base_env=base_env,
  208. sharded_keys=sharded_keys,
  209. )
  210. def write_sharded_with_template(
  211. self,
  212. filename: str | Path,
  213. template_fn: str | Path,
  214. items: Iterable[T],
  215. *,
  216. key_fn: Callable[[T], str],
  217. env_callable: Callable[[T], dict[str, list[str]]],
  218. num_shards: int,
  219. base_env: dict[str, Any] | None = None,
  220. sharded_keys: set[str],
  221. ) -> None:
  222. file = Path(filename)
  223. if file.is_absolute():
  224. raise AssertionError(f"filename must be relative: {filename}")
  225. everything: dict[str, Any] = {"shard_id": "Everything"}
  226. shards: list[dict[str, Any]] = [
  227. {"shard_id": f"_{i}"} for i in range(num_shards)
  228. ]
  229. all_shards = [everything] + shards
  230. if base_env is not None:
  231. for shard in all_shards:
  232. shard.update(base_env)
  233. for key in sharded_keys:
  234. for shard in all_shards:
  235. if key in shard:
  236. if not isinstance(shard[key], list):
  237. raise AssertionError("sharded keys in base_env must be a list")
  238. shard[key] = shard[key].copy()
  239. else:
  240. shard[key] = []
  241. def merge_env(into: dict[str, list[str]], from_: dict[str, list[str]]) -> None:
  242. for k, v in from_.items():
  243. if k not in sharded_keys:
  244. raise AssertionError(f"undeclared sharded key {k}")
  245. into[k] += v
  246. if self.dry_run:
  247. # Dry runs don't write any templates, so incomplete environments are fine
  248. items = ()
  249. for item in items:
  250. key = key_fn(item)
  251. sid = string_stable_hash(key) % num_shards
  252. env = env_callable(item)
  253. merge_env(shards[sid], env)
  254. merge_env(everything, env)
  255. for shard in all_shards:
  256. shard_id = shard["shard_id"]
  257. self.write_with_template(
  258. file.with_stem(f"{file.stem}{shard_id}"),
  259. template_fn,
  260. lambda: shard,
  261. )
  262. # filenames is used to track compiled files, but FooEverything.cpp isn't meant to be compiled
  263. self.files.discard(self.install_dir / file.with_stem(f"{file.stem}Everything"))
  264. def write_outputs(self, variable_name: str, filename: str | Path) -> None:
  265. """Write a file containing the list of all outputs which are generated by this script."""
  266. content = "\n".join(
  267. (
  268. "set(",
  269. variable_name,
  270. # Use POSIX paths to avoid invalid escape sequences on Windows
  271. *(f' "{file.as_posix()}"' for file in sorted(self.files)),
  272. ")",
  273. )
  274. )
  275. self._write_if_changed(filename, content)
  276. def template_dir_for_comments(self) -> str:
  277. """
  278. This needs to be deterministic. The template dir is an absolute path
  279. that varies across builds. So, just use the path relative to this file,
  280. which will point to the codegen source but will be stable.
  281. """
  282. return os.path.relpath(self.template_dir, os.path.dirname(__file__))
  283. # Helper function to generate file manager
  284. def make_file_manager(
  285. options: Namespace,
  286. install_dir: str | Path | None = None,
  287. ) -> FileManager:
  288. template_dir = os.path.join(options.source_path, "templates")
  289. install_dir = install_dir if install_dir else options.install_dir
  290. return FileManager(
  291. install_dir=install_dir,
  292. template_dir=template_dir,
  293. dry_run=options.dry_run,
  294. )
  295. # Helper function to create a pretty representation for dataclasses
  296. def dataclass_repr(
  297. obj: Any,
  298. indent: int = 0,
  299. width: int = 80,
  300. ) -> str:
  301. return pformat(obj, indent, width)
  302. def _format_dict(
  303. attr: dict[Any, Any],
  304. indent: int,
  305. width: int,
  306. curr_indent: int,
  307. ) -> str:
  308. curr_indent += indent + 3
  309. dict_repr = []
  310. for k, v in attr.items():
  311. k_repr = repr(k)
  312. v_str = (
  313. pformat(v, indent, width, curr_indent + len(k_repr))
  314. if is_dataclass(v)
  315. else repr(v)
  316. )
  317. dict_repr.append(f"{k_repr}: {v_str}")
  318. return _format(dict_repr, indent, width, curr_indent, "{", "}")
  319. def _format_list(
  320. attr: list[Any] | set[Any] | tuple[Any, ...],
  321. indent: int,
  322. width: int,
  323. curr_indent: int,
  324. ) -> str:
  325. curr_indent += indent + 1
  326. list_repr = [
  327. pformat(l, indent, width, curr_indent) if is_dataclass(l) else repr(l)
  328. for l in attr
  329. ]
  330. start, end = ("[", "]") if isinstance(attr, list) else ("(", ")")
  331. return _format(list_repr, indent, width, curr_indent, start, end)
  332. def _format(
  333. fields_str: list[str],
  334. indent: int,
  335. width: int,
  336. curr_indent: int,
  337. start: str,
  338. end: str,
  339. ) -> str:
  340. delimiter, curr_indent_str = "", ""
  341. # if it exceed the max width then we place one element per line
  342. if len(repr(fields_str)) >= width:
  343. delimiter = "\n"
  344. curr_indent_str = " " * curr_indent
  345. indent_str = " " * indent
  346. body = f", {delimiter}{curr_indent_str}".join(fields_str)
  347. return f"{start}{indent_str}{body}{end}"
  348. class NamespaceHelper:
  349. """A helper for constructing the namespace open and close strings for a nested set of namespaces.
  350. e.g. for namespace_str torch::lazy,
  351. prologue:
  352. namespace torch {
  353. namespace lazy {
  354. epilogue:
  355. } // namespace lazy
  356. } // namespace torch
  357. """
  358. def __init__(
  359. self,
  360. namespace_str: str,
  361. entity_name: str = "",
  362. max_level: int = 2,
  363. ) -> None:
  364. # cpp_namespace can be a colon joined string such as torch::lazy
  365. cpp_namespaces = namespace_str.split("::")
  366. if len(cpp_namespaces) > max_level:
  367. raise AssertionError(
  368. f"Codegen doesn't support more than {max_level} level(s) of "
  369. f"custom namespace. Got {namespace_str}."
  370. )
  371. self.cpp_namespace_ = namespace_str
  372. self.prologue_ = "\n".join([f"namespace {n} {{" for n in cpp_namespaces])
  373. self.epilogue_ = "\n".join(
  374. [f"}} // namespace {n}" for n in reversed(cpp_namespaces)]
  375. )
  376. self.namespaces_ = cpp_namespaces
  377. self.entity_name_ = entity_name
  378. @staticmethod
  379. def from_namespaced_entity(
  380. namespaced_entity: str,
  381. max_level: int = 2,
  382. ) -> NamespaceHelper:
  383. """
  384. Generate helper from nested namespaces as long as class/function name. E.g.: "torch::lazy::add"
  385. """
  386. names = namespaced_entity.split("::")
  387. entity_name = names[-1]
  388. namespace_str = "::".join(names[:-1])
  389. return NamespaceHelper(
  390. namespace_str=namespace_str, entity_name=entity_name, max_level=max_level
  391. )
  392. @property
  393. def prologue(self) -> str:
  394. return self.prologue_
  395. @property
  396. def epilogue(self) -> str:
  397. return self.epilogue_
  398. @property
  399. def entity_name(self) -> str:
  400. return self.entity_name_
  401. # Only allow certain level of namespaces
  402. def get_cpp_namespace(self, default: str = "") -> str:
  403. """
  404. Return the namespace string from joining all the namespaces by "::" (hence no leading "::").
  405. Return default if namespace string is empty.
  406. """
  407. return self.cpp_namespace_ if self.cpp_namespace_ else default
  408. class OrderedSet(Generic[T]):
  409. storage: dict[T, None]
  410. def __init__(self, iterable: Iterable[T] | None = None) -> None:
  411. if iterable is None:
  412. self.storage = {}
  413. else:
  414. self.storage = dict.fromkeys(iterable)
  415. def __contains__(self, item: T) -> bool:
  416. return item in self.storage
  417. def __iter__(self) -> Iterator[T]:
  418. return iter(self.storage.keys())
  419. def update(self, items: OrderedSet[T]) -> None:
  420. self.storage.update(items.storage)
  421. def add(self, item: T) -> None:
  422. self.storage[item] = None
  423. def copy(self) -> OrderedSet[T]:
  424. ret: OrderedSet[T] = OrderedSet()
  425. ret.storage = self.storage.copy()
  426. return ret
  427. @staticmethod
  428. def union(*args: OrderedSet[T]) -> OrderedSet[T]:
  429. ret = args[0].copy()
  430. for s in args[1:]:
  431. ret.update(s)
  432. return ret
  433. def __or__(self, other: OrderedSet[T]) -> OrderedSet[T]:
  434. return OrderedSet.union(self, other)
  435. def __ior__(self, other: OrderedSet[T]) -> Self:
  436. self.update(other)
  437. return self
  438. def __eq__(self, other: object) -> bool:
  439. if isinstance(other, OrderedSet):
  440. return self.storage == other.storage
  441. else:
  442. return set(self.storage.keys()) == other