dispatcher.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. from __future__ import annotations
  2. import itertools
  3. from typing import TYPE_CHECKING
  4. from typing_extensions import assert_never
  5. from torchgen.api import cpp
  6. from torchgen.api.types import ArgName, Binding, CType, NamedCType
  7. from torchgen.model import (
  8. Argument,
  9. FunctionSchema,
  10. Return,
  11. SelfArgument,
  12. TensorOptionsArguments,
  13. Type,
  14. )
  15. from torchgen.utils import concatMap
  16. if TYPE_CHECKING:
  17. from collections.abc import Sequence
  18. # This file describes the translation of JIT schema to the dispatcher
  19. # API, the *unboxed* calling convention by which invocations through
  20. # the dispatcher are made. Historically, the dispatcher API matched
  21. # the C++ API, but with the establishment of the boxed API, we've
  22. # made changes to the dispatcher API to so that the unboxed API
  23. # better aligns with the boxed API. The dispatcher API hooks heavily
  24. # into our template based boxing/unboxing machinery, so changes
  25. # to this convention will usually need template updates too.
  26. #
  27. # Prominent characteristics of the dispatcher API:
  28. #
  29. # - dtype, layout, device and pin_memory are represented as separate
  30. # arguments.
  31. #
  32. def name(func: FunctionSchema) -> str:
  33. return cpp.name(func)
  34. def argumenttype_type(
  35. t: Type,
  36. *,
  37. mutable: bool,
  38. binds: ArgName,
  39. remove_non_owning_ref_types: bool = False,
  40. symint: bool = True,
  41. ) -> NamedCType:
  42. # This is a faux amis. If it makes sense in the future to add
  43. # more special cases here, or invert things so cpp.argument_type
  44. # calls this, or just completely inline the function, please do
  45. # it.
  46. return cpp.argumenttype_type(
  47. t,
  48. mutable=mutable,
  49. binds=binds,
  50. symint=symint,
  51. remove_non_owning_ref_types=remove_non_owning_ref_types,
  52. )
  53. def argument_type(
  54. a: Argument,
  55. *,
  56. binds: ArgName,
  57. remove_non_owning_ref_types: bool = False,
  58. symint: bool = True,
  59. ) -> NamedCType:
  60. return argumenttype_type(
  61. a.type,
  62. mutable=a.is_write,
  63. binds=binds,
  64. remove_non_owning_ref_types=remove_non_owning_ref_types,
  65. symint=symint,
  66. )
  67. def returns_type(rs: Sequence[Return], *, symint: bool = True) -> CType:
  68. # At present, there is no difference. But there could be!
  69. return cpp.returns_type(rs, symint=symint)
  70. def jit_arguments(func: FunctionSchema) -> list[Argument]:
  71. def to_argument(
  72. a: Argument | TensorOptionsArguments | SelfArgument,
  73. ) -> list[Argument]:
  74. if isinstance(a, Argument):
  75. return [a]
  76. elif isinstance(a, SelfArgument):
  77. return [a.argument]
  78. elif isinstance(a, TensorOptionsArguments):
  79. return [a.dtype, a.layout, a.device, a.pin_memory]
  80. else:
  81. assert_never(a)
  82. return list(
  83. concatMap(
  84. to_argument,
  85. itertools.chain(
  86. func.arguments.positional, func.arguments.kwarg_only, func.arguments.out
  87. ),
  88. )
  89. )
  90. def argument(
  91. a: Argument, *, remove_non_owning_ref_types: bool = False, symint: bool = True
  92. ) -> Binding:
  93. return Binding(
  94. nctype=argument_type(
  95. a,
  96. binds=a.name,
  97. remove_non_owning_ref_types=remove_non_owning_ref_types,
  98. symint=symint,
  99. ),
  100. name=a.name,
  101. argument=a,
  102. )
  103. def arguments(func: FunctionSchema, *, symint: bool = True) -> list[Binding]:
  104. return [argument(a, symint=symint) for a in jit_arguments(func)]