functionalization.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. from __future__ import annotations
  2. from torchgen.api import dispatcher
  3. from torchgen.api.types import (
  4. BaseCppType,
  5. BaseCType,
  6. Binding,
  7. boolT,
  8. ConstRefCType,
  9. CType,
  10. longT,
  11. NamedCType,
  12. tensorT,
  13. )
  14. from torchgen.model import (
  15. Argument,
  16. BaseTy,
  17. BaseType,
  18. FunctionSchema,
  19. NativeFunction,
  20. NativeFunctionsViewGroup,
  21. )
  22. # This file describes the translation of JIT schema to API's used
  23. # when creating `ViewMeta` specializations that are used by the functionalization pass.
  24. # These API's mostly follow the dispatcher API, with one difference:
  25. # - While the forward function just directly calls into the at::_ops API
  26. # (following the dispatcher convention), the logic here for the reverse function
  27. # is responsible for generating both the call-site, and the declarations
  28. # (which are implemented manually in the at::functionalization::impl namespace).
  29. # Define some specific lambda input arguments.
  30. base_binding = Binding(
  31. name="base",
  32. nctype=NamedCType(name="base", type=ConstRefCType(BaseCType(tensorT))),
  33. argument=Argument(
  34. name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
  35. ),
  36. default=None,
  37. )
  38. has_symbolic_inputs_binding = Binding(
  39. name="has_symbolic_inputs",
  40. nctype=NamedCType(name="has_symbolic_inputs", type=BaseCType(boolT)),
  41. argument=Argument(
  42. name="has_symbolic_inputs",
  43. type=BaseType(BaseTy.bool),
  44. default=None,
  45. annotation=None,
  46. ),
  47. default=None,
  48. )
  49. mutated_view_binding = Binding(
  50. name="mutated_view",
  51. nctype=NamedCType(name="mutated_view", type=ConstRefCType(BaseCType(tensorT))),
  52. argument=Argument(
  53. name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
  54. ),
  55. default=None,
  56. )
  57. out_index_binding = Binding(
  58. name="out_index",
  59. nctype=NamedCType(name="out_index", type=BaseCType(longT)),
  60. argument=Argument(
  61. name="out_index", type=BaseType(BaseTy.int), default=None, annotation=None
  62. ),
  63. default=None,
  64. )
  65. reapply_views_binding = Binding(
  66. name="reapply_views",
  67. nctype=NamedCType(name="reapply_views", type=BaseCType(boolT)),
  68. argument=Argument(
  69. name="reapply_views", type=BaseType(BaseTy.bool), default=None, annotation=None
  70. ),
  71. default=None,
  72. )
  73. InverseReturnModeT = BaseCppType("at::functionalization", "InverseReturnMode")
  74. inverse_return_mode_binding = Binding(
  75. name="inverse_return_mode",
  76. nctype=NamedCType(name="inverse_return_mode", type=BaseCType(InverseReturnModeT)),
  77. argument=Argument(
  78. name="inverse_return_mode",
  79. # NB: not actually a bool but it doesn't matter because this isn't used
  80. type=BaseType(BaseTy.bool),
  81. default=None,
  82. annotation=None,
  83. ),
  84. default=None,
  85. )
  86. # Name of the `ViewMeta` specialization class created.
  87. def classname(func: FunctionSchema, with_namespace: bool = False) -> str:
  88. namespace = "at::functionalization::" if with_namespace else ""
  89. return f"{namespace}{func.name.unambiguous_name()}_ViewMeta"
  90. # Name of the operation called inside the `forward`/`reverse` implementations.
  91. def name(
  92. g: NativeFunctionsViewGroup,
  93. *,
  94. is_reverse: bool,
  95. include_namespace: bool,
  96. reapply_views: bool | None = None,
  97. ) -> str:
  98. if reapply_views is None:
  99. # reapply_views is only important for the fwd lambda,
  100. # since we always plumb the runtime "reapply_views" argument into the reverse function.
  101. if not is_reverse:
  102. raise AssertionError("reapply_views can only be None for reverse")
  103. if is_reverse:
  104. return reverse_name(g.view, include_namespace)
  105. # in the forward case, we just directly call into the at::_ops API (so we always need the namespace)
  106. if not include_namespace:
  107. raise AssertionError("include_namespace must be True for forward")
  108. if g.view_copy is None:
  109. raise AssertionError("view_copy must be non-None for forward")
  110. api_name = (
  111. g.view.func.name.unambiguous_name()
  112. if reapply_views
  113. else g.view_copy.func.name.unambiguous_name()
  114. )
  115. return f"at::_ops::{api_name}::call"
  116. def reverse_name(f: NativeFunction, include_namespace: bool) -> str:
  117. # for the reverse: we plumb the "reapply_views" flag into that function and support
  118. # both copy and non-copy variants. (We could avoid doing that, but that would require
  119. # writing out twice as many view inverse functions).
  120. api_name = f.func.name.unambiguous_name()
  121. # in the reverse case, we codegen both the call-sites (which need the full namespace) and the declarations (which don't)
  122. if include_namespace:
  123. return f"at::functionalization::FunctionalInverses::{api_name}_inverse"
  124. else:
  125. return f"{api_name}_inverse"
  126. def returns_type(func: FunctionSchema) -> CType:
  127. # Assertion: all view ops return tensor-like outputs
  128. if len(func.returns) < 1:
  129. raise AssertionError("Expected at least one return value")
  130. for ret in func.returns:
  131. if not ret.type.is_tensor_like():
  132. raise AssertionError(f"Expected tensor-like return type, got {ret.type}")
  133. # However, the return type of the lambda is always an individual tensor.
  134. # For multi-tensor outputs, each tensor needs to be tracked individually.
  135. return BaseCType(tensorT)
  136. # Checks whether `func` might return more than one value.
  137. def is_multi_output(func: FunctionSchema) -> bool:
  138. return len(func.returns) > 1 or (
  139. len(func.returns) == 1 and func.returns[0].type.is_list_like() is not None
  140. )
  141. # `ViewMeta` specialization constructor parameters.
  142. def base_ctor_arguments(func: FunctionSchema) -> list[Binding]:
  143. # All specializations are parematerized by `has_symbolic_inputs` flag.
  144. arguments = [has_symbolic_inputs_binding]
  145. # If `func` might return more than 1 value, we also parameterize this specialization
  146. # with the output index.
  147. if is_multi_output(func):
  148. arguments.append(out_index_binding)
  149. return arguments
  150. # `ViewMeta` specialized class' constructor arguments.
  151. #
  152. # Values needed specifically by this specialization, that the base class does not need.
  153. # Same as the class' attributes, but non-owning.
  154. def extra_ctor_arguments(func: FunctionSchema) -> list[Binding]:
  155. return attributes(func, owning=False)
  156. # `ViewMeta` specialized class' non-static member data.
  157. #
  158. # Essential data for calling the instance's `forward` and `reverse functions. You can
  159. # think of them as values that should be captured from the functionalization kernel.
  160. def attributes(func: FunctionSchema, owning: bool = True) -> list[Binding]:
  161. args = func.arguments.flat_all
  162. if args[0].type != BaseType(BaseTy.Tensor):
  163. raise AssertionError(f"Expected first arg to be Tensor, got {args[0].type}")
  164. return [
  165. reapply_views_binding,
  166. inverse_return_mode_binding,
  167. *[dispatcher.argument(a, remove_non_owning_ref_types=owning) for a in args[1:]],
  168. ]
  169. def op_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]:
  170. args = func.arguments.flat_all
  171. if args[0].type != BaseType(BaseTy.Tensor):
  172. raise AssertionError(f"Expected first arg to be Tensor, got {args[0].type}")
  173. non_self_args = args[1:]
  174. # The forward lambda calls the at::_ops API, while the reverse lambda calls the view inverse API.
  175. # Both of these follow the dispatcher API.
  176. non_self_bindings = [dispatcher.argument(a) for a in non_self_args]
  177. if not is_reverse:
  178. # the forward lambda swaps out the original tensor argument with the lambd arg "base"
  179. return [base_binding] + non_self_bindings
  180. else:
  181. # the reverse lambda does the same, but with an additional "mutated_view" arg
  182. # additionally, we have a calling convention: for view ops that return multiple tensor outputs
  183. # their corresponding view_inverse function takes in an additional index argument.
  184. if is_multi_output(func):
  185. return [
  186. base_binding,
  187. mutated_view_binding,
  188. inverse_return_mode_binding,
  189. out_index_binding,
  190. ] + non_self_bindings
  191. else:
  192. return [
  193. base_binding,
  194. mutated_view_binding,
  195. inverse_return_mode_binding,
  196. ] + non_self_bindings