gen_schema_utils.py 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. from typing import Any
  2. from torchgen.model import (
  3. Annotation,
  4. Argument,
  5. Arguments,
  6. BaseOperatorName,
  7. BaseTy,
  8. BaseType,
  9. CustomClassType,
  10. FunctionSchema,
  11. ListType,
  12. OperatorName,
  13. Return,
  14. )
  15. # Note: These aren't actually used in torchgen, they're some utilities for generating a schema
  16. # from real arguments. For example, this is used to generate HigherOrderOperators' schema since
  17. # their schemas can vary for different instances of the same HOP.
  18. class TypeGen:
  19. convert_to_base_ty = {
  20. int: BaseTy.int,
  21. float: BaseTy.float,
  22. str: BaseTy.str,
  23. bool: BaseTy.bool,
  24. }
  25. @staticmethod
  26. def from_example(obj: Any) -> BaseType | ListType | CustomClassType:
  27. import torch
  28. if isinstance(obj, torch.fx.GraphModule):
  29. return BaseType(BaseTy.GraphModule)
  30. elif isinstance(obj, torch.Tensor):
  31. return BaseType(BaseTy.Tensor)
  32. elif isinstance(obj, torch.SymInt):
  33. return BaseType(BaseTy.SymInt)
  34. elif isinstance(obj, torch.SymBool):
  35. return BaseType(BaseTy.SymBool)
  36. elif isinstance(obj, torch.ScriptObject):
  37. return CustomClassType(obj._type().name()) # type: ignore[attr-defined]
  38. elif isinstance(obj, (list, tuple)):
  39. if len(obj) == 0:
  40. raise AssertionError("list/tuple must be non-empty")
  41. all_base_tys = [TypeGen.from_example(x) for x in obj]
  42. if len(set(all_base_tys)) > 1:
  43. raise RuntimeError(
  44. f"Cannot generate schema for a sequence of args of heterogeneous types: {all_base_tys}. "
  45. "Consider unpacking the argument and give proper names to them if possible "
  46. "instead of using *args."
  47. )
  48. return ListType(all_base_tys[0], len(obj))
  49. tp = type(obj)
  50. if tp not in TypeGen.convert_to_base_ty:
  51. raise RuntimeError(f"unsupported type {tp}")
  52. return BaseType(TypeGen.convert_to_base_ty[tp])
  53. class ReturnGen:
  54. @staticmethod
  55. def from_example(
  56. name: str | None, obj: Any, annotation: Annotation | None
  57. ) -> Return:
  58. return Return(name, TypeGen.from_example(obj), annotation)
  59. class ArgumentGen:
  60. @staticmethod
  61. def from_example(
  62. name: str, obj: Any, default: str | None, annotation: Annotation | None
  63. ) -> Argument:
  64. return Argument(
  65. name, TypeGen.from_example(obj), default=default, annotation=annotation
  66. )
  67. class FunctionSchemaGen:
  68. @staticmethod
  69. def from_example(
  70. op_name: str,
  71. example_inputs: tuple[tuple[str, Any], ...],
  72. example_outputs: tuple[Any, ...],
  73. ) -> FunctionSchema:
  74. args = []
  75. for name, inp in example_inputs:
  76. args.append(ArgumentGen.from_example(name, inp, None, None))
  77. # ignore the annotations and other attributes for now, we could add more when needed.
  78. arguments = Arguments(
  79. tuple(), None, tuple(args), tuple(), None, tuple(), tuple()
  80. )
  81. returns = tuple(
  82. ReturnGen.from_example(None, out, None) for out in example_outputs
  83. )
  84. op_name = OperatorName(BaseOperatorName(op_name, False, False, False), "")
  85. return FunctionSchema(op_name, arguments, returns)