lazy.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470
  1. from __future__ import annotations
  2. from typing import Any
  3. from torchgen.api.types import (
  4. BaseCppType,
  5. BaseCType,
  6. boolT,
  7. CType,
  8. deviceT,
  9. doubleT,
  10. generatorT,
  11. layoutT,
  12. ListCType,
  13. longT,
  14. memoryFormatT,
  15. NamedCType,
  16. OptionalCType,
  17. scalarT,
  18. scalarTypeT,
  19. stringT,
  20. SymIntT,
  21. VectorCType,
  22. )
  23. from torchgen.model import (
  24. Argument,
  25. BaseTy,
  26. BaseType,
  27. FunctionSchema,
  28. ListType,
  29. OperatorName,
  30. OptionalType,
  31. Return,
  32. TensorOptionsArguments,
  33. Type,
  34. )
  35. _valueT: BaseCppType | None = None
  36. # A ValueT is an IR type which represents the computation of a Tensor. In other
  37. # words, a PyTorch user will do operations on lazy tensors, and each output lazy
  38. # tensor internally tracks a ValueT representing the IR node that would have
  39. # actually produced the value of this tensor for real.
  40. #
  41. # This is configurable because different lazy tensor backends (LTC vs XLA) will
  42. # have different IR representations. (Though, arguably, after unification they
  43. # shouldn't!)
  44. def getValueT() -> BaseCppType:
  45. global _valueT
  46. if not _valueT:
  47. raise NotImplementedError(
  48. "The value type needs to be set with setValueT() in run_gen_lazy_tensor()"
  49. )
  50. return _valueT
  51. def setValueT(val: BaseCppType) -> None:
  52. global _valueT
  53. _valueT = val
  54. # this is a bad hack. I need to refactor the data model to represent each arg in the schema as an object,
  55. # making it easier to represent special properties of an arg.
  56. tensorListValueT = BaseCppType("torch::lazy", "Value")
  57. def process_ir_type(
  58. typ: Type, properties: LazyIrProperties, *, symint: bool
  59. ) -> BaseCType | VectorCType | OptionalCType | ListCType:
  60. """
  61. This function takes a type from NativeFunctions and converts it for use with
  62. lazy tensor codegen.
  63. Type conversion for lazy currently consists of
  64. (1) changing at::Tensors into lazy::Values
  65. (2) wrapping everything in a BaseCType
  66. (3) making cpp-reference types into cpp-value types (e.g. vector instead of IntArrayRef)
  67. (1) converts at::Tensors to lazy::Values (which wrap lazy::Nodes, with which Lazy IR represents tensors.)
  68. There is special handling for Optional[Tensor] or list[Tensor], etc- hence 'tensor-like'
  69. This is incomplete- there are assertions in places that it's expected to need to add
  70. more types as the codegen is used with more operators.
  71. """
  72. if isinstance(typ, BaseType):
  73. if typ.name == BaseTy.Tensor:
  74. return BaseCType(getValueT())
  75. elif typ.name == BaseTy.Scalar:
  76. if properties.TreatScalarsAsConstants:
  77. return BaseCType(scalarT)
  78. # at::scalar has special handling,
  79. # and is wrapped in an lazy::Value just like at::tensor
  80. return BaseCType(getValueT())
  81. elif typ.name == BaseTy.ScalarType:
  82. return BaseCType(scalarTypeT)
  83. elif typ.name == BaseTy.int:
  84. return BaseCType(longT)
  85. elif typ.name == BaseTy.SymInt:
  86. if symint:
  87. return BaseCType(getValueT())
  88. else:
  89. return BaseCType(longT)
  90. elif typ.name == BaseTy.bool:
  91. return BaseCType(boolT)
  92. elif typ.name == BaseTy.float:
  93. return BaseCType(doubleT)
  94. elif typ.name == BaseTy.str:
  95. return BaseCType(stringT)
  96. elif typ.name == BaseTy.Device:
  97. return BaseCType(deviceT)
  98. elif typ.name == BaseTy.Generator:
  99. return BaseCType(generatorT)
  100. elif typ.name == BaseTy.Layout:
  101. return BaseCType(layoutT)
  102. elif typ.name == BaseTy.MemoryFormat:
  103. return BaseCType(memoryFormatT)
  104. else:
  105. raise AssertionError(f"TODO add support for type {repr(typ)}")
  106. elif isinstance(typ, OptionalType):
  107. return OptionalCType(process_ir_type(typ.elem, properties, symint=symint))
  108. elif isinstance(typ, ListType):
  109. if str(typ.elem) == "Tensor?":
  110. # TODO(whc) is this actually correct? or should it use a Vector like above
  111. return ListCType(OptionalCType(BaseCType(getValueT())))
  112. elif str(typ.elem) == "Tensor":
  113. # this is a TensorList which comes in from GetTensorList as a Value
  114. return BaseCType(tensorListValueT)
  115. elif typ.elem == BaseType(BaseTy.SymInt):
  116. # TODO: return a value type. The problem here is analogous to
  117. # the problem with tensorListValueT: if you have SymInt[] you
  118. # cannot conveniently save the list of Value directly, as nodes
  119. # expect to save values as a vector for ALL arguments. So you
  120. # need a separate IR node that represents all of the size nodes
  121. # assembled into a list. I'm not an LTC dev so I don't want to
  122. # figure it out right now. Y'all figure it out...
  123. return VectorCType(BaseCType(longT))
  124. else:
  125. return VectorCType(process_ir_type(typ.elem, properties, symint=symint))
  126. else:
  127. raise AssertionError(f"unrecognized type {repr(typ)}")
  128. # TODO: Determining this based off of CType is bad; this should be computed
  129. # from Type directly; then the same logic as process_ir_type can be used
  130. #
  131. # Invariant: passed typ should be an *owning* CType (e.g., we will report
  132. # that ArrayRef<Value> is NOT a value type)
  133. def isValueType(typ: CType, properties: LazyIrProperties | None = None) -> bool:
  134. """
  135. Given a type, determine if it is a Value-like type. This is equivalent to
  136. being Tensor-like, but assumes the type has already been transformed.
  137. """
  138. if isinstance(typ, BaseCType):
  139. # I am regretting my naming conventions, but now we are wrapping at::scalar in
  140. # lazy value, while preserving other 'scalar' types as scalars in the IR
  141. treat_scalars_as_constants = properties and properties.TreatScalarsAsConstants
  142. return (
  143. typ.type == getValueT()
  144. or (typ.type == scalarT and not treat_scalars_as_constants)
  145. or typ.type == SymIntT
  146. )
  147. elif typ == VectorCType(BaseCType(SymIntT)):
  148. # TODO: report True for this
  149. return False
  150. elif isinstance(typ, (OptionalCType, ListCType, VectorCType)):
  151. return isValueType(typ.elem, properties)
  152. return False
  153. def isSymIntType(typ: Type) -> bool:
  154. return isinstance(typ, BaseType) and typ.name == BaseTy.SymInt
  155. def isWrappedScalarType(typ: Type) -> bool:
  156. """
  157. Given a type, determine if it is a c10::scalar which we will wrap in a lazy Value.
  158. Since we literally change the type from scalarT to valueT, information is lost.
  159. This function helps build a list of wrapped scalars to save that information
  160. """
  161. if isinstance(typ, BaseType):
  162. # I am regretting my naming conventions, but now we are wrapping at::scalar in
  163. # lazy value, while preserving other 'scalar' types as scalars in the IR
  164. return typ.name == BaseTy.Scalar
  165. elif isinstance(typ, (OptionalType, ListType)):
  166. return isWrappedScalarType(typ.elem)
  167. return False
  168. # TODO: dedupe with Type.is_generator_like
  169. def isGeneratorType(typ: Type) -> bool:
  170. if isinstance(typ, BaseType):
  171. return typ.name == BaseTy.Generator
  172. elif isinstance(typ, (OptionalType)):
  173. return isGeneratorType(typ.elem)
  174. return False
  175. # This class caches a few derived properties computed from an Argument
  176. # and LazyIrProperties
  177. class LazyArgument:
  178. name: str
  179. orig_type: Type
  180. lazy_type_: CType | None
  181. is_wrapped_scalar: bool
  182. is_generator: bool
  183. # TODO: this is lies, it is false for symint list
  184. is_symint_or_list: bool
  185. # Whether or not we are treating this as symint or not
  186. symint: bool
  187. # true if this argument is or contains a lazy IR value
  188. is_lazy_value: bool
  189. def __init__(
  190. self, arg: Argument, properties: LazyIrProperties, *, symint: bool
  191. ) -> None:
  192. self.name = arg.name
  193. self.orig_type = arg.type
  194. self.symint = symint
  195. self.is_optional = isinstance(arg.type, OptionalType)
  196. self.is_generator = isGeneratorType(arg.type)
  197. self.lazy_type_ = process_ir_type(arg.type, properties, symint=symint)
  198. self.is_wrapped_scalar = isWrappedScalarType(arg.type)
  199. self.is_symint_or_list = symint and (
  200. isSymIntType(arg.type)
  201. or (isinstance(arg.type, OptionalType) and isSymIntType(arg.type.elem))
  202. # TODO: lists of symints are not currently treated as value types
  203. # or (isinstance(arg.type, ListType) and isSymIntType(arg.type.elem))
  204. )
  205. self.is_lazy_value = isValueType(self.lazy_type, properties)
  206. @property
  207. def lazy_type(self) -> CType:
  208. if self.lazy_type_ is None:
  209. raise AssertionError(
  210. f"Attempted to access lazy_type for invalid argument {self.name}"
  211. )
  212. return self.lazy_type_
  213. class LazyIrProperties:
  214. """Collection of properties for an IR node
  215. The property groups are listed below. Each group is mutually
  216. exclusive, meaning that only one property from each group can be True
  217. at any one time. The properties can be accessed as if they were normal
  218. attributes. The mutual exclusivity is automatically handled.
  219. """
  220. Properties: tuple[tuple[str, ...], ...] = (
  221. (
  222. "ShapePrecompute", # Assume shape has been precomputed
  223. "ShapeCompute", # Need to compute the shape on construction
  224. "ShapeCache", # Utilize the shape cache to defer computation
  225. ),
  226. (
  227. "Lower", # Codegen full lower function
  228. "LowerDeclOnly", # Codegen only lower function declaration
  229. ),
  230. (
  231. "CanBeReused", # Codegen full reuse function
  232. "CanBeReusedDeclOnly", # Codegen only reuse function declaration
  233. ),
  234. (
  235. "CreateFn", # Codegen full create function
  236. "CreateFnDeclOnly", # Codegen only create function declaration
  237. ),
  238. (
  239. "TreatScalarsAsConstants", # Treat Scalars as constants instead of handling like values
  240. ),
  241. )
  242. def __init__(self, *default_properties: str) -> None:
  243. properties: dict[tuple[str, ...], str | None] = dict.fromkeys(
  244. LazyIrProperties.Properties
  245. )
  246. self.__dict__["properties"] = properties
  247. for p in default_properties:
  248. setattr(self, p, True)
  249. def __getattr__(self, key: str) -> Any:
  250. properties = self.__dict__["properties"]
  251. for values in LazyIrProperties.Properties:
  252. if key in values:
  253. return properties[values] == key
  254. return self.__getattribute__(key)
  255. def __setattr__(self, key: str, value: Any) -> Any:
  256. properties = self.__dict__["properties"]
  257. for values in LazyIrProperties.Properties:
  258. if key in values:
  259. properties[values] = key if value else None
  260. return value
  261. raise KeyError(f"Invalid property: {key}")
  262. # Inspired by a FunctionSchema object, a LazyIrSchema holds the schema of a Lazy IR node.
  263. # Unlike a FunctionSchema, it has no round-trippable string form (relating to the YAML),
  264. # but carries type information from a native FunctionSchema modified for use with IR nodes,
  265. # and preserving original argument names.
  266. #
  267. # TODO: This is not idiomatic with how other torchgen APIs transform on schema.
  268. class LazyIrSchema:
  269. # The name of the operator this function schema describes.
  270. name: OperatorName
  271. positional_args: tuple[LazyArgument, ...]
  272. keyword_args: tuple[LazyArgument, ...]
  273. # TODO: Need to handle collisions with argument names at some point
  274. returns: tuple[Return, ...]
  275. # if this schema has a Generator arg, list its orig ctype/name but don't
  276. # build a LazyArgument since lazy IR doesn't support it
  277. generator_arg: NamedCType | None = None
  278. # original function schema
  279. func: FunctionSchema
  280. # Whether or not we are code-genning for SymInt or not
  281. symint: bool
  282. properties: LazyIrProperties = LazyIrProperties(
  283. # default properties
  284. "ShapePrecompute",
  285. "Lower",
  286. "CanBeReused",
  287. )
  288. opkind: str | None = None
  289. def __init__(
  290. self,
  291. func: FunctionSchema,
  292. properties: LazyIrProperties | None = None,
  293. *,
  294. symint: bool,
  295. ) -> None:
  296. if properties:
  297. self.properties = properties
  298. self.func = func
  299. self.symint = symint
  300. positional_args: list[LazyArgument] = []
  301. for arg_field in ["pre_self_positional", "self_arg", "post_self_positional"]:
  302. if arg_field == "self_arg" and func.arguments.self_arg is not None:
  303. arg = func.arguments.self_arg.argument
  304. positional_args.append(
  305. LazyArgument(arg, self.properties, symint=symint)
  306. )
  307. elif getattr(func.arguments, arg_field) is not None:
  308. positional_args.extend(
  309. LazyArgument(arg, self.properties, symint=symint)
  310. for arg in getattr(func.arguments, arg_field)
  311. )
  312. self.positional_args = tuple(positional_args)
  313. keyword_args: list[LazyArgument] = []
  314. for arg_field in [
  315. "pre_tensor_options_kwarg_only",
  316. "tensor_options",
  317. "post_tensor_options_kwarg_only",
  318. "out",
  319. ]:
  320. curr_args = getattr(func.arguments, arg_field)
  321. if curr_args is not None:
  322. if isinstance(curr_args, TensorOptionsArguments):
  323. curr_args = curr_args.all()
  324. for arg in curr_args:
  325. if isGeneratorType(arg.type):
  326. if self.generator_arg is not None:
  327. raise AssertionError(
  328. "We expect there is only one generator arg"
  329. )
  330. self.generator_arg = NamedCType(
  331. arg.name,
  332. arg.type, # type:ignore[arg-type]
  333. )
  334. keyword_args.extend(
  335. LazyArgument(arg, self.properties, symint=symint)
  336. for arg in curr_args
  337. )
  338. self.keyword_args = tuple(keyword_args)
  339. self.name = func.name
  340. self.returns = func.returns
  341. @property
  342. def node_name(self) -> str:
  343. """
  344. Return camel-case version of op in node.
  345. Note: This function also appends any `overload_name` in the operation.
  346. For example, if the op is `bitwise_and.Tensor`, the returned name
  347. will be `BitwiseAndTensor`.
  348. """
  349. op_name = f"{self.name.name}_{self.name.overload_name}".lower()
  350. return "".join(word.capitalize() or "" for word in op_name.split("_"))
  351. @property
  352. def aten_name(self) -> str:
  353. return str(self.name.name)
  354. @property
  355. def base_name(self) -> str:
  356. return f"{self.name.name.base}"
  357. def filtered_args(
  358. self,
  359. positional: bool = True,
  360. keyword: bool = True,
  361. values: bool = True,
  362. scalars: bool = True,
  363. generator: bool = True,
  364. ) -> list[LazyArgument]:
  365. # This function maintains the sorted order of arguments but provides different filtered views.
  366. # Some parts of the code care about kwargs vs args (TS lowerings),
  367. # other parts care about whether they need to wrap the arg in a lazy value or leave it alone.
  368. # Generators are special cased, as they are needed for fallback/shape-inference but not supported
  369. # in TS lowerings and therefore also omitted from lazy IR.
  370. args: list[LazyArgument] = []
  371. if positional:
  372. args.extend(self.positional_args)
  373. if keyword:
  374. args.extend(self.keyword_args)
  375. if values and scalars and generator:
  376. return args
  377. elif values and scalars:
  378. return [a for a in args if not a.is_generator]
  379. elif values:
  380. return [a for a in args if a.is_lazy_value]
  381. elif scalars:
  382. return [
  383. a
  384. for a in args
  385. if not a.is_lazy_value and (generator or not a.is_generator)
  386. ]
  387. return []
  388. @property
  389. def positional_values(self) -> list[LazyArgument]:
  390. return self.filtered_args(
  391. positional=True, keyword=False, values=True, scalars=False
  392. )
  393. @property
  394. def positional_scalars(self) -> list[LazyArgument]:
  395. return self.filtered_args(
  396. positional=True, keyword=False, values=False, scalars=True
  397. )
  398. @property
  399. def keyword_values(self) -> list[LazyArgument]:
  400. return self.filtered_args(
  401. positional=False, keyword=True, values=True, scalars=False
  402. )
  403. @property
  404. def keyword_scalars(self) -> list[LazyArgument]:
  405. return self.filtered_args(
  406. positional=False, keyword=True, values=False, scalars=True
  407. )