| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523 |
- from __future__ import annotations
- import contextlib
- import functools
- import hashlib
- import os
- import re
- import sys
- import textwrap
- from dataclasses import is_dataclass
- from enum import auto, Enum
- from pathlib import Path
- from pprint import pformat
- from typing import Any, Generic, TYPE_CHECKING, TypeVar
- from typing_extensions import assert_never, Self
- from torchgen.code_template import CodeTemplate
- if TYPE_CHECKING:
- from argparse import Namespace
- from collections.abc import Callable, Iterable, Iterator, Sequence
- TORCHGEN_ROOT = Path(__file__).absolute().parent
- REPO_ROOT = TORCHGEN_ROOT.parent
- # Many of these functions share logic for defining both the definition
- # and declaration (for example, the function signature is the same), so
- # we organize them into one function that takes a Target to say which
- # code we want.
- #
- # This is an OPEN enum (we may add more cases to it in the future), so be sure
- # to explicitly specify with Literal[Target.XXX] or Literal[Target.XXX, Target.YYY]
- # what targets are valid for your use.
- class Target(Enum):
- # top level namespace (not including at)
- DEFINITION = auto()
- DECLARATION = auto()
- # TORCH_LIBRARY(...) { ... }
- REGISTRATION = auto()
- # namespace { ... }
- ANONYMOUS_DEFINITION = auto()
- # namespace cpu { ... }
- NAMESPACED_DEFINITION = auto()
- NAMESPACED_DECLARATION = auto()
- # Matches "foo" in "foo, bar" but not "foobar". Used to search for the
- # occurrence of a parameter in the derivative formula
- IDENT_REGEX = r"(^|\W){}($|\W)"
- # TODO: Use a real parser here; this will get bamboozled
- def split_name_params(schema: str) -> tuple[str, list[str]]:
- m = re.match(r"(\w+)(\.\w+)?\((.*)\)", schema)
- if m is None:
- raise RuntimeError(f"Unsupported function schema: {schema}")
- name, _, params = m.groups()
- return name, params.split(", ")
- T = TypeVar("T")
- S = TypeVar("S")
- # These two functions purposely return generators in analogy to map()
- # so that you don't mix up when you need to list() them
- # Map over function that may return None; omit Nones from output sequence
- def mapMaybe(func: Callable[[T], S | None], xs: Iterable[T]) -> Iterator[S]:
- for x in xs:
- r = func(x)
- if r is not None:
- yield r
- # Map over function that returns sequences and cat them all together
- def concatMap(func: Callable[[T], Sequence[S]], xs: Iterable[T]) -> Iterator[S]:
- for x in xs:
- yield from func(x)
- # Conveniently add error context to exceptions raised. Lets us
- # easily say that an error occurred while processing a specific
- # context.
- @contextlib.contextmanager
- def context(msg_fn: Callable[[], str]) -> Iterator[None]:
- try:
- yield
- except Exception as e:
- # TODO: this does the wrong thing with KeyError
- msg = msg_fn()
- msg = textwrap.indent(msg, " ")
- msg = f"{e.args[0]}\n{msg}" if e.args else msg
- e.args = (msg,) + e.args[1:]
- raise
- @functools.cache
- def _read_template(template_fn: str) -> CodeTemplate:
- return CodeTemplate.from_file(template_fn)
- # String hash that's stable across different executions, unlike builtin hash
- def string_stable_hash(s: str) -> int:
- sha1 = hashlib.sha1(s.encode("latin1"), usedforsecurity=False).digest()
- return int.from_bytes(sha1, byteorder="little")
- # A small abstraction for writing out generated files and keeping track
- # of what files have been written (so you can write out a list of output
- # files)
- class FileManager:
- def __init__(
- self,
- install_dir: str | Path,
- template_dir: str | Path,
- dry_run: bool,
- ) -> None:
- self.install_dir = Path(install_dir)
- self.template_dir = Path(template_dir)
- self.files: set[Path] = set()
- self.dry_run = dry_run
- @property
- def filenames(self) -> frozenset[str]:
- return frozenset({file.as_posix() for file in self.files})
- def _write_if_changed(self, filename: str | Path, contents: str) -> None:
- file = Path(filename)
- old_contents: str | None = None
- try:
- old_contents = file.read_text(encoding="utf-8")
- except OSError:
- pass
- if contents != old_contents:
- # Create output directory if it doesn't exist
- file.parent.mkdir(parents=True, exist_ok=True)
- file.write_text(contents, encoding="utf-8")
- # Read from template file and replace pattern with callable (type could be dict or str).
- def substitute_with_template(
- self,
- template_fn: str | Path,
- env_callable: Callable[[], str | dict[str, Any]],
- ) -> str:
- if Path(template_fn).is_absolute():
- raise AssertionError(f"template_fn must be relative: {template_fn}")
- template_path = self.template_dir / template_fn
- env = env_callable()
- if isinstance(env, dict):
- if "generated_comment" not in env:
- generator_default = TORCHGEN_ROOT / "gen.py"
- try:
- generator = Path(
- sys.modules["__main__"].__file__ or generator_default
- ).absolute()
- except (KeyError, AttributeError):
- generator = generator_default.absolute()
- try:
- generator_path = generator.relative_to(REPO_ROOT).as_posix()
- except ValueError:
- generator_path = generator.name
- env = {
- **env, # copy the original dict instead of mutating it
- "generated_comment": (
- "@" + f"generated by {generator_path} from {template_fn}"
- ),
- }
- template = _read_template(template_path)
- substitute_out = template.substitute(env)
- # Ensure an extra blank line between the class/function definition
- # and the docstring of the previous class/function definition.
- # NB: It is generally not recommended to have docstrings in pyi stub
- # files. But if there are any, we need to ensure that the file
- # is properly formatted.
- return re.sub(
- r'''
- (""")\n+ # match triple quotes
- (
- (\s*@.+\n)* # match decorators if any
- \s*(class|def) # match class/function definition
- )
- ''',
- r"\g<1>\n\n\g<2>",
- substitute_out,
- flags=re.VERBOSE,
- )
- if isinstance(env, str):
- return env
- assert_never(env)
- def write_with_template(
- self,
- filename: str | Path,
- template_fn: str | Path,
- env_callable: Callable[[], str | dict[str, Any]],
- ) -> None:
- filename = Path(filename)
- if filename.is_absolute():
- raise AssertionError(f"filename must be relative: {filename}")
- file = self.install_dir / filename
- if file in self.files:
- raise AssertionError(f"duplicate file write {file}")
- self.files.add(file)
- if not self.dry_run:
- substitute_out = self.substitute_with_template(
- template_fn=template_fn,
- env_callable=env_callable,
- )
- self._write_if_changed(filename=file, contents=substitute_out)
- def write(
- self,
- filename: str | Path,
- env_callable: Callable[[], str | dict[str, Any]],
- ) -> None:
- self.write_with_template(filename, filename, env_callable)
- def write_sharded(
- self,
- filename: str | Path,
- items: Iterable[T],
- *,
- key_fn: Callable[[T], str],
- env_callable: Callable[[T], dict[str, list[str]]],
- num_shards: int,
- base_env: dict[str, Any] | None = None,
- sharded_keys: set[str],
- ) -> None:
- self.write_sharded_with_template(
- filename,
- filename,
- items,
- key_fn=key_fn,
- env_callable=env_callable,
- num_shards=num_shards,
- base_env=base_env,
- sharded_keys=sharded_keys,
- )
- def write_sharded_with_template(
- self,
- filename: str | Path,
- template_fn: str | Path,
- items: Iterable[T],
- *,
- key_fn: Callable[[T], str],
- env_callable: Callable[[T], dict[str, list[str]]],
- num_shards: int,
- base_env: dict[str, Any] | None = None,
- sharded_keys: set[str],
- ) -> None:
- file = Path(filename)
- if file.is_absolute():
- raise AssertionError(f"filename must be relative: {filename}")
- everything: dict[str, Any] = {"shard_id": "Everything"}
- shards: list[dict[str, Any]] = [
- {"shard_id": f"_{i}"} for i in range(num_shards)
- ]
- all_shards = [everything] + shards
- if base_env is not None:
- for shard in all_shards:
- shard.update(base_env)
- for key in sharded_keys:
- for shard in all_shards:
- if key in shard:
- if not isinstance(shard[key], list):
- raise AssertionError("sharded keys in base_env must be a list")
- shard[key] = shard[key].copy()
- else:
- shard[key] = []
- def merge_env(into: dict[str, list[str]], from_: dict[str, list[str]]) -> None:
- for k, v in from_.items():
- if k not in sharded_keys:
- raise AssertionError(f"undeclared sharded key {k}")
- into[k] += v
- if self.dry_run:
- # Dry runs don't write any templates, so incomplete environments are fine
- items = ()
- for item in items:
- key = key_fn(item)
- sid = string_stable_hash(key) % num_shards
- env = env_callable(item)
- merge_env(shards[sid], env)
- merge_env(everything, env)
- for shard in all_shards:
- shard_id = shard["shard_id"]
- self.write_with_template(
- file.with_stem(f"{file.stem}{shard_id}"),
- template_fn,
- lambda: shard,
- )
- # filenames is used to track compiled files, but FooEverything.cpp isn't meant to be compiled
- self.files.discard(self.install_dir / file.with_stem(f"{file.stem}Everything"))
- def write_outputs(self, variable_name: str, filename: str | Path) -> None:
- """Write a file containing the list of all outputs which are generated by this script."""
- content = "\n".join(
- (
- "set(",
- variable_name,
- # Use POSIX paths to avoid invalid escape sequences on Windows
- *(f' "{file.as_posix()}"' for file in sorted(self.files)),
- ")",
- )
- )
- self._write_if_changed(filename, content)
- def template_dir_for_comments(self) -> str:
- """
- This needs to be deterministic. The template dir is an absolute path
- that varies across builds. So, just use the path relative to this file,
- which will point to the codegen source but will be stable.
- """
- return os.path.relpath(self.template_dir, os.path.dirname(__file__))
- # Helper function to generate file manager
- def make_file_manager(
- options: Namespace,
- install_dir: str | Path | None = None,
- ) -> FileManager:
- template_dir = os.path.join(options.source_path, "templates")
- install_dir = install_dir if install_dir else options.install_dir
- return FileManager(
- install_dir=install_dir,
- template_dir=template_dir,
- dry_run=options.dry_run,
- )
- # Helper function to create a pretty representation for dataclasses
- def dataclass_repr(
- obj: Any,
- indent: int = 0,
- width: int = 80,
- ) -> str:
- return pformat(obj, indent, width)
- def _format_dict(
- attr: dict[Any, Any],
- indent: int,
- width: int,
- curr_indent: int,
- ) -> str:
- curr_indent += indent + 3
- dict_repr = []
- for k, v in attr.items():
- k_repr = repr(k)
- v_str = (
- pformat(v, indent, width, curr_indent + len(k_repr))
- if is_dataclass(v)
- else repr(v)
- )
- dict_repr.append(f"{k_repr}: {v_str}")
- return _format(dict_repr, indent, width, curr_indent, "{", "}")
- def _format_list(
- attr: list[Any] | set[Any] | tuple[Any, ...],
- indent: int,
- width: int,
- curr_indent: int,
- ) -> str:
- curr_indent += indent + 1
- list_repr = [
- pformat(l, indent, width, curr_indent) if is_dataclass(l) else repr(l)
- for l in attr
- ]
- start, end = ("[", "]") if isinstance(attr, list) else ("(", ")")
- return _format(list_repr, indent, width, curr_indent, start, end)
- def _format(
- fields_str: list[str],
- indent: int,
- width: int,
- curr_indent: int,
- start: str,
- end: str,
- ) -> str:
- delimiter, curr_indent_str = "", ""
- # if it exceed the max width then we place one element per line
- if len(repr(fields_str)) >= width:
- delimiter = "\n"
- curr_indent_str = " " * curr_indent
- indent_str = " " * indent
- body = f", {delimiter}{curr_indent_str}".join(fields_str)
- return f"{start}{indent_str}{body}{end}"
- class NamespaceHelper:
- """A helper for constructing the namespace open and close strings for a nested set of namespaces.
- e.g. for namespace_str torch::lazy,
- prologue:
- namespace torch {
- namespace lazy {
- epilogue:
- } // namespace lazy
- } // namespace torch
- """
- def __init__(
- self,
- namespace_str: str,
- entity_name: str = "",
- max_level: int = 2,
- ) -> None:
- # cpp_namespace can be a colon joined string such as torch::lazy
- cpp_namespaces = namespace_str.split("::")
- if len(cpp_namespaces) > max_level:
- raise AssertionError(
- f"Codegen doesn't support more than {max_level} level(s) of "
- f"custom namespace. Got {namespace_str}."
- )
- self.cpp_namespace_ = namespace_str
- self.prologue_ = "\n".join([f"namespace {n} {{" for n in cpp_namespaces])
- self.epilogue_ = "\n".join(
- [f"}} // namespace {n}" for n in reversed(cpp_namespaces)]
- )
- self.namespaces_ = cpp_namespaces
- self.entity_name_ = entity_name
- @staticmethod
- def from_namespaced_entity(
- namespaced_entity: str,
- max_level: int = 2,
- ) -> NamespaceHelper:
- """
- Generate helper from nested namespaces as long as class/function name. E.g.: "torch::lazy::add"
- """
- names = namespaced_entity.split("::")
- entity_name = names[-1]
- namespace_str = "::".join(names[:-1])
- return NamespaceHelper(
- namespace_str=namespace_str, entity_name=entity_name, max_level=max_level
- )
- @property
- def prologue(self) -> str:
- return self.prologue_
- @property
- def epilogue(self) -> str:
- return self.epilogue_
- @property
- def entity_name(self) -> str:
- return self.entity_name_
- # Only allow certain level of namespaces
- def get_cpp_namespace(self, default: str = "") -> str:
- """
- Return the namespace string from joining all the namespaces by "::" (hence no leading "::").
- Return default if namespace string is empty.
- """
- return self.cpp_namespace_ if self.cpp_namespace_ else default
- class OrderedSet(Generic[T]):
- storage: dict[T, None]
- def __init__(self, iterable: Iterable[T] | None = None) -> None:
- if iterable is None:
- self.storage = {}
- else:
- self.storage = dict.fromkeys(iterable)
- def __contains__(self, item: T) -> bool:
- return item in self.storage
- def __iter__(self) -> Iterator[T]:
- return iter(self.storage.keys())
- def update(self, items: OrderedSet[T]) -> None:
- self.storage.update(items.storage)
- def add(self, item: T) -> None:
- self.storage[item] = None
- def copy(self) -> OrderedSet[T]:
- ret: OrderedSet[T] = OrderedSet()
- ret.storage = self.storage.copy()
- return ret
- @staticmethod
- def union(*args: OrderedSet[T]) -> OrderedSet[T]:
- ret = args[0].copy()
- for s in args[1:]:
- ret.update(s)
- return ret
- def __or__(self, other: OrderedSet[T]) -> OrderedSet[T]:
- return OrderedSet.union(self, other)
- def __ior__(self, other: OrderedSet[T]) -> Self:
- self.update(other)
- return self
- def __eq__(self, other: object) -> bool:
- if isinstance(other, OrderedSet):
- return self.storage == other.storage
- else:
- return set(self.storage.keys()) == other
|