| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798 |
- from typing import Any
- from torchgen.model import (
- Annotation,
- Argument,
- Arguments,
- BaseOperatorName,
- BaseTy,
- BaseType,
- CustomClassType,
- FunctionSchema,
- ListType,
- OperatorName,
- Return,
- )
- # Note: These aren't actually used in torchgen, they're some utilities for generating a schema
- # from real arguments. For example, this is used to generate HigherOrderOperators' schema since
- # their schemas can vary for different instances of the same HOP.
- class TypeGen:
- convert_to_base_ty = {
- int: BaseTy.int,
- float: BaseTy.float,
- str: BaseTy.str,
- bool: BaseTy.bool,
- }
- @staticmethod
- def from_example(obj: Any) -> BaseType | ListType | CustomClassType:
- import torch
- if isinstance(obj, torch.fx.GraphModule):
- return BaseType(BaseTy.GraphModule)
- elif isinstance(obj, torch.Tensor):
- return BaseType(BaseTy.Tensor)
- elif isinstance(obj, torch.SymInt):
- return BaseType(BaseTy.SymInt)
- elif isinstance(obj, torch.SymBool):
- return BaseType(BaseTy.SymBool)
- elif isinstance(obj, torch.ScriptObject):
- return CustomClassType(obj._type().name()) # type: ignore[attr-defined]
- elif isinstance(obj, (list, tuple)):
- if len(obj) == 0:
- raise AssertionError("list/tuple must be non-empty")
- all_base_tys = [TypeGen.from_example(x) for x in obj]
- if len(set(all_base_tys)) > 1:
- raise RuntimeError(
- f"Cannot generate schema for a sequence of args of heterogeneous types: {all_base_tys}. "
- "Consider unpacking the argument and give proper names to them if possible "
- "instead of using *args."
- )
- return ListType(all_base_tys[0], len(obj))
- tp = type(obj)
- if tp not in TypeGen.convert_to_base_ty:
- raise RuntimeError(f"unsupported type {tp}")
- return BaseType(TypeGen.convert_to_base_ty[tp])
- class ReturnGen:
- @staticmethod
- def from_example(
- name: str | None, obj: Any, annotation: Annotation | None
- ) -> Return:
- return Return(name, TypeGen.from_example(obj), annotation)
- class ArgumentGen:
- @staticmethod
- def from_example(
- name: str, obj: Any, default: str | None, annotation: Annotation | None
- ) -> Argument:
- return Argument(
- name, TypeGen.from_example(obj), default=default, annotation=annotation
- )
- class FunctionSchemaGen:
- @staticmethod
- def from_example(
- op_name: str,
- example_inputs: tuple[tuple[str, Any], ...],
- example_outputs: tuple[Any, ...],
- ) -> FunctionSchema:
- args = []
- for name, inp in example_inputs:
- args.append(ArgumentGen.from_example(name, inp, None, None))
- # ignore the annotations and other attributes for now, we could add more when needed.
- arguments = Arguments(
- tuple(), None, tuple(args), tuple(), None, tuple(), tuple()
- )
- returns = tuple(
- ReturnGen.from_example(None, out, None) for out in example_outputs
- )
- op_name = OperatorName(BaseOperatorName(op_name, False, False, False), "")
- return FunctionSchema(op_name, arguments, returns)
|