| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222 |
- from __future__ import annotations
- from torchgen.api import dispatcher
- from torchgen.api.types import (
- BaseCppType,
- BaseCType,
- Binding,
- boolT,
- ConstRefCType,
- CType,
- longT,
- NamedCType,
- tensorT,
- )
- from torchgen.model import (
- Argument,
- BaseTy,
- BaseType,
- FunctionSchema,
- NativeFunction,
- NativeFunctionsViewGroup,
- )
- # This file describes the translation of JIT schema to API's used
- # when creating `ViewMeta` specializations that are used by the functionalization pass.
- # These API's mostly follow the dispatcher API, with one difference:
- # - While the forward function just directly calls into the at::_ops API
- # (following the dispatcher convention), the logic here for the reverse function
- # is responsible for generating both the call-site, and the declarations
- # (which are implemented manually in the at::functionalization::impl namespace).
- # Define some specific lambda input arguments.
- base_binding = Binding(
- name="base",
- nctype=NamedCType(name="base", type=ConstRefCType(BaseCType(tensorT))),
- argument=Argument(
- name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
- ),
- default=None,
- )
- has_symbolic_inputs_binding = Binding(
- name="has_symbolic_inputs",
- nctype=NamedCType(name="has_symbolic_inputs", type=BaseCType(boolT)),
- argument=Argument(
- name="has_symbolic_inputs",
- type=BaseType(BaseTy.bool),
- default=None,
- annotation=None,
- ),
- default=None,
- )
- mutated_view_binding = Binding(
- name="mutated_view",
- nctype=NamedCType(name="mutated_view", type=ConstRefCType(BaseCType(tensorT))),
- argument=Argument(
- name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
- ),
- default=None,
- )
- out_index_binding = Binding(
- name="out_index",
- nctype=NamedCType(name="out_index", type=BaseCType(longT)),
- argument=Argument(
- name="out_index", type=BaseType(BaseTy.int), default=None, annotation=None
- ),
- default=None,
- )
- reapply_views_binding = Binding(
- name="reapply_views",
- nctype=NamedCType(name="reapply_views", type=BaseCType(boolT)),
- argument=Argument(
- name="reapply_views", type=BaseType(BaseTy.bool), default=None, annotation=None
- ),
- default=None,
- )
- InverseReturnModeT = BaseCppType("at::functionalization", "InverseReturnMode")
- inverse_return_mode_binding = Binding(
- name="inverse_return_mode",
- nctype=NamedCType(name="inverse_return_mode", type=BaseCType(InverseReturnModeT)),
- argument=Argument(
- name="inverse_return_mode",
- # NB: not actually a bool but it doesn't matter because this isn't used
- type=BaseType(BaseTy.bool),
- default=None,
- annotation=None,
- ),
- default=None,
- )
- # Name of the `ViewMeta` specialization class created.
- def classname(func: FunctionSchema, with_namespace: bool = False) -> str:
- namespace = "at::functionalization::" if with_namespace else ""
- return f"{namespace}{func.name.unambiguous_name()}_ViewMeta"
- # Name of the operation called inside the `forward`/`reverse` implementations.
- def name(
- g: NativeFunctionsViewGroup,
- *,
- is_reverse: bool,
- include_namespace: bool,
- reapply_views: bool | None = None,
- ) -> str:
- if reapply_views is None:
- # reapply_views is only important for the fwd lambda,
- # since we always plumb the runtime "reapply_views" argument into the reverse function.
- if not is_reverse:
- raise AssertionError("reapply_views can only be None for reverse")
- if is_reverse:
- return reverse_name(g.view, include_namespace)
- # in the forward case, we just directly call into the at::_ops API (so we always need the namespace)
- if not include_namespace:
- raise AssertionError("include_namespace must be True for forward")
- if g.view_copy is None:
- raise AssertionError("view_copy must be non-None for forward")
- api_name = (
- g.view.func.name.unambiguous_name()
- if reapply_views
- else g.view_copy.func.name.unambiguous_name()
- )
- return f"at::_ops::{api_name}::call"
- def reverse_name(f: NativeFunction, include_namespace: bool) -> str:
- # for the reverse: we plumb the "reapply_views" flag into that function and support
- # both copy and non-copy variants. (We could avoid doing that, but that would require
- # writing out twice as many view inverse functions).
- api_name = f.func.name.unambiguous_name()
- # in the reverse case, we codegen both the call-sites (which need the full namespace) and the declarations (which don't)
- if include_namespace:
- return f"at::functionalization::FunctionalInverses::{api_name}_inverse"
- else:
- return f"{api_name}_inverse"
- def returns_type(func: FunctionSchema) -> CType:
- # Assertion: all view ops return tensor-like outputs
- if len(func.returns) < 1:
- raise AssertionError("Expected at least one return value")
- for ret in func.returns:
- if not ret.type.is_tensor_like():
- raise AssertionError(f"Expected tensor-like return type, got {ret.type}")
- # However, the return type of the lambda is always an individual tensor.
- # For multi-tensor outputs, each tensor needs to be tracked individually.
- return BaseCType(tensorT)
- # Checks whether `func` might return more than one value.
- def is_multi_output(func: FunctionSchema) -> bool:
- return len(func.returns) > 1 or (
- len(func.returns) == 1 and func.returns[0].type.is_list_like() is not None
- )
- # `ViewMeta` specialization constructor parameters.
- def base_ctor_arguments(func: FunctionSchema) -> list[Binding]:
- # All specializations are parematerized by `has_symbolic_inputs` flag.
- arguments = [has_symbolic_inputs_binding]
- # If `func` might return more than 1 value, we also parameterize this specialization
- # with the output index.
- if is_multi_output(func):
- arguments.append(out_index_binding)
- return arguments
- # `ViewMeta` specialized class' constructor arguments.
- #
- # Values needed specifically by this specialization, that the base class does not need.
- # Same as the class' attributes, but non-owning.
- def extra_ctor_arguments(func: FunctionSchema) -> list[Binding]:
- return attributes(func, owning=False)
- # `ViewMeta` specialized class' non-static member data.
- #
- # Essential data for calling the instance's `forward` and `reverse functions. You can
- # think of them as values that should be captured from the functionalization kernel.
- def attributes(func: FunctionSchema, owning: bool = True) -> list[Binding]:
- args = func.arguments.flat_all
- if args[0].type != BaseType(BaseTy.Tensor):
- raise AssertionError(f"Expected first arg to be Tensor, got {args[0].type}")
- return [
- reapply_views_binding,
- inverse_return_mode_binding,
- *[dispatcher.argument(a, remove_non_owning_ref_types=owning) for a in args[1:]],
- ]
- def op_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]:
- args = func.arguments.flat_all
- if args[0].type != BaseType(BaseTy.Tensor):
- raise AssertionError(f"Expected first arg to be Tensor, got {args[0].type}")
- non_self_args = args[1:]
- # The forward lambda calls the at::_ops API, while the reverse lambda calls the view inverse API.
- # Both of these follow the dispatcher API.
- non_self_bindings = [dispatcher.argument(a) for a in non_self_args]
- if not is_reverse:
- # the forward lambda swaps out the original tensor argument with the lambd arg "base"
- return [base_binding] + non_self_bindings
- else:
- # the reverse lambda does the same, but with an additional "mutated_view" arg
- # additionally, we have a calling convention: for view ops that return multiple tensor outputs
- # their corresponding view_inverse function takes in an additional index argument.
- if is_multi_output(func):
- return [
- base_binding,
- mutated_view_binding,
- inverse_return_mode_binding,
- out_index_binding,
- ] + non_self_bindings
- else:
- return [
- base_binding,
- mutated_view_binding,
- inverse_return_mode_binding,
- ] + non_self_bindings
|