context.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. from __future__ import annotations
  2. import contextlib
  3. import functools
  4. from typing import Any, TYPE_CHECKING, TypeVar
  5. import torchgen.local as local
  6. from torchgen.model import (
  7. BackendIndex,
  8. DispatchKey,
  9. NativeFunction,
  10. NativeFunctionsGroup,
  11. NativeFunctionsViewGroup,
  12. )
  13. from torchgen.utils import context, S, T
  14. if TYPE_CHECKING:
  15. from collections.abc import Callable, Iterator
  16. # Helper functions for defining generators on things in the model
  17. F = TypeVar(
  18. "F",
  19. NativeFunction,
  20. NativeFunctionsGroup,
  21. NativeFunctionsViewGroup,
  22. NativeFunction | NativeFunctionsGroup,
  23. NativeFunction | NativeFunctionsViewGroup,
  24. )
  25. F2 = TypeVar(
  26. "F2",
  27. NativeFunction,
  28. NativeFunctionsGroup,
  29. NativeFunction | None,
  30. bool,
  31. str,
  32. )
  33. F3 = TypeVar("F3", tuple[NativeFunction, Any], list[NativeFunction])
  34. @contextlib.contextmanager
  35. def native_function_manager(
  36. g: NativeFunctionsGroup | NativeFunctionsViewGroup | NativeFunction,
  37. ) -> Iterator[None]:
  38. if isinstance(g, NativeFunctionsGroup):
  39. # By default, we associate all errors with structured native functions
  40. # with the out variant. In some cases, it might be better to have
  41. # a more specific place to hang things; if so, use
  42. # native_function_manager again on the inside
  43. f = g.out
  44. elif isinstance(g, NativeFunctionsViewGroup):
  45. # We associate errors with the view operator
  46. f = g.view
  47. else:
  48. f = g
  49. with context(lambda: f"in native_functions.yaml line {f.loc}:\n {f.func}"):
  50. with local.parametrize(
  51. use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors,
  52. use_ilistref_for_tensor_lists=f.part_of_structured_group,
  53. ):
  54. yield
  55. # Given a function that operates on NativeFunction, wrap it into a new function
  56. # that sets some appropriate context managers for that native function.
  57. # YOU MUST WRAP FUNCTIONS IN THIS for calls to api modules to be sound
  58. # (you will get an error if we try to access the local variables without having
  59. # set them).
  60. def with_native_function(func: Callable[[F], T]) -> Callable[[F], T]:
  61. @functools.wraps(func)
  62. def wrapper(f: F) -> T:
  63. with native_function_manager(f):
  64. return func(f)
  65. return wrapper
  66. def with_native_function_and(func: Callable[[F, F2], T]) -> Callable[[F, F2], T]:
  67. @functools.wraps(func)
  68. def wrapper(f: F, f2: F2) -> T:
  69. # The first native_function is assumed to be the one with the appropriate context.
  70. with native_function_manager(f):
  71. return func(f, f2)
  72. return wrapper
  73. def method_with_native_function(func: Callable[[S, F], T]) -> Callable[[S, F], T]:
  74. @functools.wraps(func)
  75. def wrapper(slf: S, f: F) -> T:
  76. with native_function_manager(f):
  77. return func(slf, f)
  78. return wrapper
  79. def method_with_nested_native_function(
  80. func: Callable[[S, F3], T],
  81. ) -> Callable[[S, F3], T]:
  82. @functools.wraps(func)
  83. def wrapper(slf: S, f: F3) -> T:
  84. with native_function_manager(f[0]):
  85. return func(slf, f)
  86. return wrapper
  87. # Convenience decorator for functions that explicitly take in a BackendIndex,
  88. # instead of indirectly taking one in as a closure
  89. def with_native_function_and_index(
  90. func: Callable[[F, BackendIndex], T],
  91. ) -> Callable[[F, BackendIndex], T]:
  92. @functools.wraps(func)
  93. def wrapper(f: F, backend_index: BackendIndex) -> T:
  94. with native_function_manager(f):
  95. return func(f, backend_index)
  96. return wrapper
  97. # Convenience decorator for functions that explicitly take in a Dict of BackendIndices
  98. def with_native_function_and_indices(
  99. func: Callable[[F, dict[DispatchKey, BackendIndex]], T],
  100. ) -> Callable[[F, dict[DispatchKey, BackendIndex]], T]:
  101. @functools.wraps(func)
  102. def wrapper(f: F, backend_indices: dict[DispatchKey, BackendIndex]) -> T:
  103. with native_function_manager(f):
  104. return func(f, backend_indices)
  105. return wrapper