autograd.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892
  1. from __future__ import annotations
  2. import re
  3. from dataclasses import dataclass
  4. from typing import cast, TYPE_CHECKING
  5. from torchgen import local
  6. from torchgen.api import cpp
  7. from torchgen.api.types import BaseCType, Binding, NamedCType, tensorListT
  8. from torchgen.model import (
  9. BaseTy,
  10. BaseType,
  11. FunctionSchema,
  12. ListType,
  13. NativeFunction,
  14. NativeFunctionsViewGroup,
  15. SchemaKind,
  16. Type,
  17. )
  18. from torchgen.utils import IDENT_REGEX
  19. if TYPE_CHECKING:
  20. from collections.abc import Sequence
  21. # Represents a saved attribute involved in backward calculation.
  22. # Note that it can be a derived property of an input argument, e.g.:
  23. # we could save `other.scalar_type()` instead of the entire `other` tensor.
  24. @dataclass(frozen=True)
  25. class SavedAttribute:
  26. # The NamedCType holds the updated name and cpp type of the attribute
  27. # for the name, Suffix is appended if it's derived property, e.g.: `other_scalar_type`
  28. nctype: NamedCType
  29. # The expression to read the derived property at save time, e.g.:
  30. # `other.scalar_type()`.
  31. expr: str
  32. # Represents a backward formula that calculates derivatives for one
  33. # or more tensors.
  34. @dataclass(frozen=True)
  35. class Derivative:
  36. # The formula string (legit C++ expression).
  37. # Note that expressions against input arguments have been replaced with the
  38. # corresponding saved attributes.
  39. # E.g.:
  40. # raw formula: `mul_tensor_backward(grad, self, other.scalar_type())`
  41. # here: `mul_tensor_backward(grad, self, other_scalar_type)`
  42. formula: str
  43. # The formula string before input argument replacement
  44. original_formula: str
  45. # Names of the arguments for which this formula calculates derivatives.
  46. var_names: tuple[str, ...]
  47. # Saved inputs that are referenced by the formula.
  48. saved_inputs: tuple[SavedAttribute, ...]
  49. # Saved outputs that are referenced by the formula.
  50. saved_outputs: tuple[SavedAttribute, ...]
  51. # Gradients that are referenced by name in the formula.
  52. named_gradients: set[str]
  53. # Represents a forward formula that calculates forward derivatives
  54. # for one tensor.
  55. @dataclass(frozen=True)
  56. class ForwardDerivative:
  57. # The formula string (legit C++ expression).
  58. # Note that special keywords such as "linear" or "element_wise" have been
  59. # replaced by the automatically generated formula.
  60. formula: str
  61. # Name of the output arguments for which this formula calculates forward
  62. # derivatives
  63. var_names: tuple[str, ...]
  64. # Type of the output arguments for which this formula calculates forward
  65. # derivatives
  66. var_types: tuple[Type, ...]
  67. # Inputs for which the forward derivatives are required for this formula
  68. required_inputs_fw_grad: tuple[str, ...] | None
  69. # Inputs for which the primal is required for this formula
  70. required_inputs_primal: tuple[str, ...] | None
  71. # Flag to specify if this formula requires the original value of self
  72. # This is only used by inplace operations
  73. required_original_self_value: bool
  74. # If this formula is specified in derivatives.yaml or if we are reusing the
  75. # out of place formula for inplace
  76. is_reusing_outplace_formula: bool
  77. # Represents differentiability info for a NativeFunction.
  78. @dataclass(frozen=True)
  79. class DifferentiabilityInfo:
  80. # The base name read from derivatives.yaml.
  81. name: str
  82. # The matching native function.
  83. #
  84. # There can be multiple NativeFunction having the same base name:
  85. # - different overloads with different types of input arguments;
  86. # - in-place/out/functional variants of the same function;
  87. #
  88. # We first use the schema string (under the 'name' key) in derivatives.yaml
  89. # to find the NativeFunction having the same schema string.
  90. # Then we find the in-place/out/functional variants of the matching function.
  91. # Among these variants, we choose the one having the same name as the
  92. # derivatives.yaml entry. If there is no exact match, then we choose the
  93. # in-place variant.
  94. # TODO: maybe the logic to search for all variants is no longer necessary?
  95. func: NativeFunction
  96. # The name of the generated autograd function.
  97. # It's set only if we will calculate a derivative, i.e.
  98. # 'args_with_derivatives' is not empty.
  99. op: str | None
  100. # The derivatives formulae for this function.
  101. # Note that the length of this sequence is the number of differentiable inputs
  102. derivatives: Sequence[Derivative]
  103. # The forward derivatives formulae for this function.
  104. # Note that the length of this sequence is the number of differentiable outputs
  105. forward_derivatives: Sequence[ForwardDerivative]
  106. # The union of 'saved_inputs' of all 'derivatives'.
  107. all_saved_inputs: Sequence[SavedAttribute]
  108. # The union of 'saved_outputs' of all 'derivatives'.
  109. all_saved_outputs: Sequence[SavedAttribute]
  110. # All named gradients that are available for use, in the same
  111. # order as in the grads vector.
  112. available_named_gradients: Sequence[str]
  113. # The named gradients that are used in any of the derivatives.
  114. # Invariant: all(name in available_named_gradients for name in used_named_gradients)
  115. used_named_gradients: set[str]
  116. # The function's input arguments for which it calculates derivatives.
  117. # It's the union of 'var_names' of all 'derivatives', sorted by the
  118. # argument order in the function schema.
  119. args_with_derivatives: Sequence[Binding]
  120. # Names of arguments whose derivative formula is 'non_differentiable'.
  121. non_differentiable_arg_names: Sequence[str]
  122. # Raw data read from derivatives.yaml.
  123. output_differentiability: list[bool] | None
  124. # output_differentiability in derivatives.yaml can be a list of
  125. # conditions that express if the output is differentiable. In this case,
  126. # the number of conditions must match the number of outputs
  127. # (NB: we only support one condition right now).
  128. # output_differentiability gets populated with True for each condition,
  129. # while output_differentiability_conditions gets populated with the conditions
  130. output_differentiability_conditions: list[str] | None
  131. @property
  132. def has_derivatives(self) -> bool:
  133. return len(self.args_with_derivatives) > 0
  134. # Generates a new DifferentiabilityInfo using the exact same set of derivative information,
  135. # but with a new operator name.
  136. # This is used when generating "copy" variants of view ops,
  137. # which are able to use the exact same derivative formula as the original view op
  138. # See Note [Codegen'd {view}_copy Operators]
  139. def create_view_copy_from_view_derivative(
  140. self, g: NativeFunctionsViewGroup
  141. ) -> DifferentiabilityInfo | None:
  142. if g.view_copy is None:
  143. return None
  144. f = g.view_copy
  145. name_split_by_period = self.name.split(".", maxsplit=2)
  146. # Append a "_copy" to the base name of the operator (but keep the overload name the same)
  147. view_copy_name = f"{name_split_by_period[0]}_copy." + ".".join(
  148. name_split_by_period[1:]
  149. )
  150. view_copy_op_name = None if self.op is None else f"{self.op}_copy"
  151. return DifferentiabilityInfo(
  152. # Use the "_copy" version of name/func/op
  153. name=view_copy_name,
  154. func=f,
  155. op=view_copy_op_name,
  156. # But keep all derivative info the same
  157. derivatives=self.derivatives,
  158. forward_derivatives=self.forward_derivatives,
  159. all_saved_inputs=self.all_saved_inputs,
  160. all_saved_outputs=self.all_saved_outputs,
  161. available_named_gradients=self.available_named_gradients,
  162. used_named_gradients=self.used_named_gradients,
  163. args_with_derivatives=self.args_with_derivatives,
  164. non_differentiable_arg_names=self.non_differentiable_arg_names,
  165. output_differentiability=self.output_differentiability,
  166. output_differentiability_conditions=self.output_differentiability_conditions,
  167. )
  168. def uses_ident(info: DifferentiabilityInfo | None, ident: str) -> bool:
  169. if info is None:
  170. return False
  171. for derivative in info.derivatives:
  172. formula = derivative.formula
  173. if re.search(IDENT_REGEX.format(ident), formula):
  174. return True
  175. return False
  176. def uses_retain_variables(info: DifferentiabilityInfo | None) -> bool:
  177. return uses_ident(info, "retain_variables")
  178. def uses_single_grad(info: DifferentiabilityInfo | None) -> bool:
  179. return uses_ident(info, "grad")
  180. # Represents a differentiable `Argument`.
  181. # How is it different from the `Argument` type?
  182. # - It's processed Arguments which are differentiable and only used in the
  183. # context of the autograd codegen;
  184. # - It can represent SelfArgument or regular Argument but not TensorOptionsArgument;
  185. @dataclass(frozen=True)
  186. class DifferentiableInput:
  187. name: str
  188. type: Type
  189. # TODO: only to keep it byte-for-byte compatible with the old codegen, should remove.
  190. cpp_type: str
  191. # Represents a differentiable `Return`.
  192. # How it it different from the `Return` type?
  193. # - The name in `Return` is optional. Here it is always populated using the same
  194. # `cpp.return_names()` method.
  195. # TODO: some cpp naming logic (e.g. resolving name conflict) might be irrelevant?
  196. # - It's processed Returns which are differentiable, in compliance with the
  197. # `output_differentiability` field defined in derivatives.yaml (if specified),
  198. # and are only used in the context of the autograd codegen;
  199. @dataclass(frozen=True)
  200. class DifferentiableOutput:
  201. name: str
  202. type: Type
  203. # TODO: only to keep it byte-for-byte compatible with the old codegen, should remove.
  204. cpp_type: str
  205. @dataclass(frozen=True)
  206. class NativeFunctionWithDifferentiabilityInfo:
  207. func: NativeFunction
  208. info: dict[str, DifferentiabilityInfo] | None
  209. fw_derivatives: dict[str, Sequence[ForwardDerivative]] | None
  210. # TODO: Update comment below since it is out of date.
  211. def dispatch_strategy(fn: NativeFunctionWithDifferentiabilityInfo) -> str:
  212. """How are we going to call the underlying implementation of a
  213. declaration? There are two strategies:
  214. - use_derived: we want to call the implementation on CPUDoubleType
  215. (or a similar, derived Type instance). Because these derived
  216. instances deal in Tensors, not Variables (it's a completely different
  217. object, so it doesn't dispatch back to VariableType), code on
  218. this dispatch path needs to wrap/unwrap tensors. If the
  219. derived implementation takes and returns tensors, the
  220. implementation is usually differentiable (although we also use
  221. the derived dispatch path for non-differentiable functions
  222. that we still want to dispatch on the derived Type instance;
  223. e.g., size())
  224. - use_type: we want to call the implementation on Type, because
  225. it is implemented concretely, and the functions it invokes will
  226. get dispatched back to VariableType (which will ensure that they
  227. are differentiable.)
  228. """
  229. # fn is derived as long as any of its per-key differentiability infos
  230. # has_derivatives. dispatch_strategy() is used to guard generation of fns in VariableType
  231. # and ADInplaceOrViewType. We want to generate these functions as long as a
  232. # derivative is defined for ANY dispatch key.
  233. if fn.func.is_abstract or (
  234. fn.info is not None and any(info.has_derivatives for info in fn.info.values())
  235. ):
  236. # If the function is abstract (not implemented on at::Type), we must
  237. # call the implementation on the derived type with unpacked tensors.
  238. # If the function has a derivative specified and is concrete, we could
  239. # call either implementation. We prefer the calling the derived
  240. # type's implementation with unpacked tensors because it is more
  241. # performant in some cases: any internal calls to other ATen functions
  242. # won't have the history tracked.
  243. # If the function has a type dispatched argument (i.e. is a factory),
  244. # we prefer calling the derived type's implementation both because it is
  245. # more performant and to ensure factory functions return tensors with _version
  246. # of 0 (probably not strictly necessary, but nice to have to keeps versions simple
  247. # to understand.
  248. return "use_derived"
  249. else:
  250. # If the function is concrete (we don't have to override it) and we
  251. # didn't declare it in derivatives.yaml, we'll assume that it is
  252. # actually implemented out of differentiable functions. (This
  253. # assumption might not hold, but then you'll see gradcheck fail.)
  254. return "use_type"
  255. def is_foreach_func(f: NativeFunction) -> bool:
  256. return f.func.name.name.base.startswith("_foreach_")
  257. # note(crcrpar): Most foreach functions can reference an out-place `torch` function whose schema kind
  258. # is functional for their backward derivatives (and forward derivatives in the future), i.e.,
  259. # they would find such one in `functional_info_by_signature`. There however are some exceptions:
  260. _foreach_with_inplace_ref = {"_foreach_zero_"}
  261. _foreach_with_tensor_overload = {
  262. "_foreach_add.Tensor",
  263. "_foreach_mul.Tensor",
  264. "_foreach_div.Tensor",
  265. }
  266. # The following do not support the alpha kwarg, which the nonforeach versions support.
  267. _skip_argument_len_check = {
  268. "_foreach_add.Scalar",
  269. "_foreach_add_.Scalar",
  270. "_foreach_add.ScalarList",
  271. "_foreach_add_.ScalarList",
  272. "_foreach_sub.Scalar",
  273. "_foreach_sub_.Scalar",
  274. "_foreach_sub.ScalarList",
  275. "_foreach_sub_.ScalarList",
  276. }
  277. # Checks if `function_schema` is a native, non-foreach function which `f`, a foreach function
  278. # reference to generate derivatives.
  279. def is_reference_for_foreach(
  280. f: NativeFunction,
  281. function_schema: FunctionSchema,
  282. ) -> bool:
  283. return (
  284. f.func.name.name.base.split("_foreach_")[-1] == function_schema.name.name.base
  285. and (
  286. not function_schema.name.name.inplace
  287. or str(f.func.name) in _foreach_with_inplace_ref
  288. )
  289. and (
  290. str(f.func.name) in _skip_argument_len_check
  291. or len(f.func.arguments.flat_non_out)
  292. == len(function_schema.arguments.flat_non_out)
  293. )
  294. and all(
  295. ref_arg.type in (arg.type, getattr(arg.type, "elem", None))
  296. for arg, ref_arg in zip(
  297. f.func.arguments.flat_non_out,
  298. function_schema.arguments.flat_non_out,
  299. )
  300. )
  301. )
  302. # TODO(crcrpar): Avoid hard coding "Default" ideally.
  303. def gen_foreach_derivativeinfo(
  304. foreach_function: NativeFunction,
  305. functional_info_by_signature: dict[
  306. FunctionSchema, dict[str, DifferentiabilityInfo]
  307. ],
  308. non_functional_info_by_signature: dict[
  309. FunctionSchema, dict[str, DifferentiabilityInfo]
  310. ],
  311. dispatch_key: str = "Default",
  312. ) -> tuple[DifferentiabilityInfo | None, bool]:
  313. """Generate DifferentiabilityInfo for out-place foreach function, return the existing one for in-place.
  314. The second return value indicates whether the info is generated in this function.
  315. """
  316. ref_diff_info: DifferentiabilityInfo | None = None
  317. for function_schema, diff_info in functional_info_by_signature.items():
  318. if not is_reference_for_foreach(foreach_function, function_schema):
  319. continue
  320. ref_diff_info = diff_info[dispatch_key]
  321. if ref_diff_info is not None:
  322. break
  323. # note(crcrpar): It seems like `zero`'s info isn't available in functional_info_by_signature
  324. # while the info of `zero_` is in non_functional_info_by_signature
  325. if (
  326. ref_diff_info is None
  327. and foreach_function.func.kind() == SchemaKind.inplace
  328. and str(foreach_function.func.name) in _foreach_with_inplace_ref
  329. ):
  330. for function_schema, diff_info in non_functional_info_by_signature.items():
  331. if not is_reference_for_foreach(foreach_function, function_schema):
  332. continue
  333. ref_diff_info = diff_info[dispatch_key]
  334. if ref_diff_info is not None:
  335. break
  336. if ref_diff_info is None:
  337. return None, False
  338. # non out-place uses the existing Derivative.
  339. if foreach_function.func.kind() == SchemaKind.inplace:
  340. return ref_diff_info, False
  341. map_refarg2foreacharg, map_name2arg = {}, {}
  342. for i, (arg, ref_arg) in enumerate(
  343. zip(
  344. foreach_function.func.arguments.flat_non_out,
  345. function_schema.arguments.flat_non_out,
  346. )
  347. ):
  348. map_refarg2foreacharg[ref_arg.name] = arg.name
  349. map_name2arg[arg.name] = arg
  350. all_saved_inputs, all_saved_outputs, all_var_names = [], [], []
  351. modified_derivative_formulas = []
  352. for i, derivative in enumerate(ref_diff_info.derivatives):
  353. modified_formula = derivative.formula.replace("grad", "grads[i]").replace(
  354. "result", "result[i]"
  355. )
  356. saved_inputs, saved_outputs = [], []
  357. # note(crcrpar): This context seems necessary to call `cpp.argument_type`
  358. with local.parametrize(
  359. use_const_ref_for_mutable_tensors=foreach_function.use_const_ref_for_mutable_tensors,
  360. use_ilistref_for_tensor_lists=foreach_function.part_of_structured_group,
  361. ):
  362. for ref_input in derivative.saved_inputs:
  363. ref_input_jit_name = ref_input.expr.split(".")[0]
  364. mapped_name = map_refarg2foreacharg[ref_input_jit_name]
  365. if isinstance(map_name2arg[mapped_name].type, ListType):
  366. mapped_expr = mapped_name + "[i]"
  367. else:
  368. mapped_expr = mapped_name
  369. new_expr = ref_input.expr.replace(ref_input_jit_name, mapped_expr)
  370. modified_formula = modified_formula.replace(
  371. cast(str, ref_input.nctype.name), new_expr
  372. )
  373. nctype = cpp.argument_type(map_name2arg[mapped_name], binds=mapped_name)
  374. canonical_nctype = NamedCType(
  375. nctype.name, nctype.type.remove_const_ref()
  376. )
  377. saved_inputs.append(
  378. SavedAttribute(nctype=canonical_nctype, expr=mapped_name)
  379. )
  380. for ref_output in derivative.saved_outputs:
  381. if ref_output.nctype.name == "result":
  382. saved_outputs.append(
  383. SavedAttribute(
  384. nctype=NamedCType(
  385. name="result", type=BaseCType(tensorListT)
  386. ),
  387. expr="result",
  388. )
  389. )
  390. else:
  391. raise RuntimeError("")
  392. var_names = [map_refarg2foreacharg[var] for var in derivative.var_names]
  393. all_var_names.extend(var_names)
  394. all_saved_inputs.extend(saved_inputs)
  395. all_saved_outputs.extend(saved_outputs)
  396. modified_derivative = Derivative(
  397. formula=modified_formula,
  398. original_formula=derivative.formula,
  399. var_names=tuple(var_names),
  400. saved_inputs=tuple(saved_inputs),
  401. saved_outputs=tuple(saved_outputs),
  402. named_gradients=set(),
  403. )
  404. modified_derivative_formulas.append(modified_derivative)
  405. with local.parametrize(
  406. use_const_ref_for_mutable_tensors=foreach_function.use_const_ref_for_mutable_tensors,
  407. use_ilistref_for_tensor_lists=foreach_function.part_of_structured_group,
  408. ):
  409. args_with_derivatives = [
  410. Binding(
  411. name=arg.name,
  412. nctype=cpp.argument_type(arg, binds=arg.name),
  413. argument=arg,
  414. default=None,
  415. )
  416. for arg in foreach_function.func.arguments.flat_non_out
  417. if arg.name in all_var_names
  418. ]
  419. forward_derivatives: list[ForwardDerivative] = []
  420. fw_derivative: ForwardDerivative
  421. for fw_derivative in ref_diff_info.forward_derivatives:
  422. var_names: list[str] = list(fw_derivative.var_names) # type: ignore[no-redef]
  423. var_types: list[Type] = list(fw_derivative.var_types)
  424. required_inputs_fw_grad: list[str] = []
  425. required_inputs_primal: list[str] = []
  426. if fw_derivative.required_inputs_fw_grad is not None:
  427. required_inputs_fw_grad = list(fw_derivative.required_inputs_fw_grad)
  428. if fw_derivative.required_inputs_primal:
  429. required_inputs_primal = list(fw_derivative.required_inputs_primal)
  430. modified_formula = fw_derivative.formula
  431. # Foreach's result is TensorList
  432. if "result" in modified_formula:
  433. modified_formula = fw_derivative.formula.replace("result", "result[i]")
  434. for foreach_arg, ref_arg in zip(
  435. foreach_function.func.arguments.flat_non_out,
  436. ref_diff_info.func.func.arguments.flat_non_out,
  437. ):
  438. # Modify reference forward formula
  439. if (
  440. isinstance(foreach_arg.type, ListType)
  441. and not foreach_arg.type.is_tensor_like()
  442. ):
  443. # Assuming ScalarList
  444. modified_formula = modified_formula.replace(
  445. ref_arg.name, foreach_arg.name + "[i]"
  446. )
  447. elif foreach_arg.type.is_tensor_like():
  448. # Assuming TensorList / Tensor
  449. if not (
  450. isinstance(foreach_arg.type, ListType)
  451. or (
  452. foreach_arg.type == BaseType(BaseTy.Tensor)
  453. and str(foreach_function.func.name)
  454. in _foreach_with_tensor_overload
  455. )
  456. ):
  457. raise AssertionError(
  458. f"{foreach_function.func.name}, {foreach_arg.type}"
  459. )
  460. for suffix in ("_p", "_t"):
  461. curr_expr = ref_arg.name + suffix
  462. if curr_expr in modified_formula:
  463. new_expr = foreach_arg.name + suffix
  464. modified_formula = modified_formula.replace(curr_expr, new_expr)
  465. else:
  466. # Assuming Scalar
  467. if foreach_arg.name != ref_arg.name:
  468. modified_formula = modified_formula.replace(
  469. ref_arg.name, foreach_arg.name
  470. )
  471. # note(crcrpar): there should exist a cooler way...
  472. for i, name in enumerate(var_names):
  473. if name == ref_arg.name:
  474. var_names[i] = foreach_arg.name
  475. var_types[i] = foreach_arg.type
  476. for i, name in enumerate(required_inputs_fw_grad):
  477. if name == ref_arg.name:
  478. required_inputs_fw_grad[i] = foreach_arg.name
  479. for i, name in enumerate(required_inputs_primal):
  480. if name == ref_arg.name:
  481. required_inputs_primal[i] = foreach_arg.name
  482. forward_derivatives.append(
  483. ForwardDerivative(
  484. formula=modified_formula,
  485. var_names=tuple(var_names),
  486. var_types=tuple(var_types),
  487. required_inputs_fw_grad=tuple(required_inputs_fw_grad),
  488. required_inputs_primal=tuple(required_inputs_primal),
  489. required_original_self_value=fw_derivative.required_original_self_value,
  490. is_reusing_outplace_formula=fw_derivative.is_reusing_outplace_formula,
  491. )
  492. )
  493. return (
  494. DifferentiabilityInfo(
  495. name=foreach_function.func.name.name.base,
  496. func=foreach_function,
  497. op=f"Foreach{ref_diff_info.op}{foreach_function.func.name.overload_name}",
  498. derivatives=modified_derivative_formulas,
  499. forward_derivatives=forward_derivatives,
  500. all_saved_inputs=tuple(set(all_saved_inputs)),
  501. all_saved_outputs=tuple(set(all_saved_outputs)),
  502. available_named_gradients=(),
  503. used_named_gradients=set(),
  504. args_with_derivatives=args_with_derivatives,
  505. non_differentiable_arg_names=[],
  506. output_differentiability=None,
  507. output_differentiability_conditions=None,
  508. ),
  509. True,
  510. )
  511. def match_differentiability_info(
  512. native_functions: list[NativeFunction],
  513. differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]],
  514. ) -> list[NativeFunctionWithDifferentiabilityInfo]:
  515. """Sets the "derivative" key on declarations to matching autograd function
  516. In-place functions will use the out-of-place derivative definition if there
  517. is no in-place specific derivative.
  518. """
  519. functional_info_by_signature = {
  520. schema.signature(strip_default=True): info_dict
  521. for schema, info_dict in differentiability_infos.items()
  522. if schema.kind() == SchemaKind.functional
  523. }
  524. non_functional_info_by_signature = {
  525. schema.signature(strip_default=True): info_dict
  526. for schema, info_dict in differentiability_infos.items()
  527. if schema.kind() != SchemaKind.functional
  528. }
  529. def find_info(
  530. f: NativeFunction,
  531. ) -> tuple[dict[str, DifferentiabilityInfo] | None, bool]:
  532. # Don't bother matching info to generated out= variants
  533. if "generated" in f.tags and f.func.kind() == SchemaKind.out:
  534. return None, False
  535. # (1) Check for an exact match
  536. if f.func in differentiability_infos:
  537. return differentiability_infos[f.func], True
  538. # (2) If no exact match, check if the out-of-place variant
  539. # of this operator has a match.
  540. # i.e mul() for mul_() or mul_out()
  541. # note(crcrpar): Check foreach or not because in-place foreach functions use backward defined for the existing
  542. # native functions instead of the out-place counterparts.
  543. f_sig = f.func.signature(strip_default=True)
  544. if f_sig in functional_info_by_signature and not is_foreach_func(f):
  545. return functional_info_by_signature[f_sig], False
  546. # (3) Some operators have a derivative explicitly defined for the mutable
  547. # variant, but get a code-generated out-of-place variant which does *not*
  548. # come with a derivative formula.
  549. # For the generated out-of-place variant, use the mutable variant's formula
  550. # if it exists.
  551. if "generated" in f.tags and f_sig in non_functional_info_by_signature:
  552. info_dict = non_functional_info_by_signature[f_sig]
  553. # See https://github.com/pytorch/pytorch/pull/76320/files#r874816389
  554. if any(
  555. any("self" in str(input.nctype.name) for input in info.all_saved_inputs)
  556. for info in info_dict.values()
  557. ):
  558. raise AssertionError(
  559. f"Attempted to convert a derivative formula for a mutable operator "
  560. f'to be used automatically by its functional variant ("{str(f.func)}"). '
  561. "This is not currently supported (we'd need to fix up the formula in the codegen)."
  562. )
  563. return info_dict, False
  564. # (4) Generate derivative information of foreach functions if none is defined in `derivatives.yaml`
  565. if is_foreach_func(f):
  566. if f.func in differentiability_infos:
  567. raise AssertionError(
  568. f"Foreach function {f.func.name} already has differentiability info"
  569. )
  570. diff_info, is_generated = gen_foreach_derivativeinfo(
  571. f,
  572. functional_info_by_signature,
  573. non_functional_info_by_signature,
  574. )
  575. if diff_info is None:
  576. return None, False
  577. # TODO(crcrpar): Avoid hard coding "Default" ideally.
  578. diff_info_dict = {"Default": diff_info}
  579. if is_generated:
  580. differentiability_infos[f.func] = diff_info_dict
  581. functional_info_by_signature[f.func] = diff_info_dict
  582. return diff_info_dict, is_generated
  583. return None, False
  584. result: list[NativeFunctionWithDifferentiabilityInfo] = []
  585. for f in native_functions:
  586. info_dict, is_exact_match = find_info(f)
  587. # Currently, the '.strides()' to 'strides_or_error' replacement does not support
  588. # 'self' derivatives of an inplace function, so we must check for this case.
  589. if f.func.kind() == SchemaKind.inplace and (info_dict is not None):
  590. for info in info_dict.values():
  591. for derivative in info.derivatives:
  592. if "self" in derivative.var_names:
  593. for saved_input in derivative.saved_inputs:
  594. if "strides_or_error" in saved_input.expr:
  595. raise AssertionError(
  596. "Calling '.strides()' in the 'self' derivative formula of an "
  597. f"in-place function is not supported: {f.func}"
  598. )
  599. if not info_dict:
  600. result.append(
  601. NativeFunctionWithDifferentiabilityInfo(
  602. func=f, info=None, fw_derivatives=None
  603. )
  604. )
  605. continue
  606. fw_derivative_dict: dict[str, Sequence[ForwardDerivative]] = {}
  607. for key, info in info_dict.items():
  608. if not info.forward_derivatives:
  609. fw_derivative_dict[key] = []
  610. continue
  611. forward_derivatives = info.forward_derivatives
  612. # For functions that have a single def for out-of-place and inplace (like abs())
  613. if f.func.kind() == SchemaKind.inplace:
  614. # For inplace functions there is a little bit of work to do:
  615. # 1) Validate the formula and make sure the input that is modified in not used:
  616. # - If there is a formula for the inplace variant of the function (is_exact_match == True) then
  617. # we make sure that the original value of the input that is being modified inplace (self_p) is
  618. # not used in the formula. Note that the formula can use "original_self_p" here and that would
  619. # trigger a clone of the original input.
  620. # - If we are reusing the out of place formula (is_exact_match == False) then we replace every
  621. # occurrence of self_p and self_t by original_self_p and original_self_t. These will be
  622. # populated by cloned version of the original input (either the clone done by the backward AD
  623. # logic if self is also used in a backward formula or a special clone that we add).
  624. # 2) At this point, there cannot be a self_p in the formula.
  625. # 3) Change "result" into "self_p" as by design, in the inplace function codegen, the result is
  626. # simply called self (as it is modified inplace).
  627. # 4) Update the required primals data in case it used to contain "result" but should now contain
  628. # "self"
  629. # 5) If it is not an exact match, the user formula is not modifying the existing forward grad
  630. # inplace as it should. So add some code that makes sure that we do so if the forward grad
  631. # already exists.
  632. if len(info.forward_derivatives) != 1:
  633. raise AssertionError(
  634. "Only single output inplace should exist, "
  635. f"got {len(info.forward_derivatives)}"
  636. )
  637. fw_info = info.forward_derivatives[0]
  638. formula = fw_info.formula
  639. def replace_self_with_original_self(formula: str, postfix: str) -> str:
  640. def repl(m: re.Match[str]) -> str:
  641. return f"{m.group(1)}original_self{postfix}{m.group(2)}"
  642. return re.sub(IDENT_REGEX.format(f"self{postfix}"), repl, formula)
  643. if re.search(IDENT_REGEX.format("self_p"), formula):
  644. if is_exact_match:
  645. # For manually defined formulas, don't allow the original value to be used
  646. raise RuntimeError(
  647. f'The formula for "{f.func.name}" is using the original value of self '
  648. "that is being modified inplace. This would lead to wrong forward gradients. "
  649. 'Please use "result" in the formula only.'
  650. )
  651. else:
  652. # When the original formula is out of place, we save a clone of the primal
  653. # value to be able to access this value if needed
  654. # replace "self_p"/"self_t" from the formula by "original_self_p"/"original_self_t"
  655. formula = replace_self_with_original_self(formula, "_p")
  656. formula = replace_self_with_original_self(formula, "_t")
  657. # replace "result" from the formula by "self_p"
  658. def repl(m: re.Match[str]) -> str:
  659. return f"{m.group(1)}self_p{m.group(2)}"
  660. formula = re.sub(IDENT_REGEX.format("result"), repl, formula)
  661. required_primals = fw_info.required_inputs_primal
  662. if re.search(IDENT_REGEX.format("self_p"), formula):
  663. required_primals = (
  664. required_primals + ("self",) if required_primals else ("self",)
  665. )
  666. if not is_exact_match:
  667. # NOTE [In-place forward AD formula Optimization]
  668. #
  669. # This optimization transforms the formula to directly do inplace, i.e.
  670. # instead of self_t.copy_(self_t.op()) we do self_t.op_() when the following are met:
  671. #
  672. # 1) the formula satisfies the pattern: "self_t.op(*args)"
  673. # 2) "op" in (1) needs to be the same as the op the derivative is for
  674. #
  675. # (2) may seem too strict, but currently the only ops that satisfy (1) also satisfy (2)
  676. # If there is a need, we can relax (2) to allow any op that has an in-place variant
  677. is_single_method_on_self_t = False
  678. directly_do_inplace = False
  679. op_name: str | None = None
  680. between_parens: str | None = None
  681. match = re.fullmatch(r"self_t.([\w]*)\((.*)\)", formula)
  682. if match:
  683. op_name, between_parens = match.group(1), match.group(2)
  684. # We want to...
  685. # Match: self_t.op1(other_p.op2(arg))
  686. # Avoid: self_t.op1(args) + self_t.op2(args)
  687. # Avoid: self_t.op1(other_p.op2(arg)) + self_t.op2(args)
  688. def check_parens_nest_level_gt_zero(s: str) -> bool:
  689. level = 1
  690. for ch in s:
  691. if ch == ")":
  692. level -= 1
  693. if level == 0:
  694. return False
  695. if ch == "(":
  696. level += 1
  697. return True
  698. is_single_method_on_self_t = check_parens_nest_level_gt_zero(
  699. between_parens
  700. )
  701. directly_do_inplace = (
  702. is_single_method_on_self_t and op_name == info.name
  703. )
  704. if directly_do_inplace:
  705. if op_name is None:
  706. raise AssertionError("op_name must be non-None for inplace")
  707. if between_parens is None:
  708. raise AssertionError(
  709. "between_parens must be non-None for inplace"
  710. )
  711. formula = f"self_t_raw.defined() ? self_t_raw.{op_name}_({between_parens}) : {formula}"
  712. else:
  713. # Make sure that the forward grad is modified inplace when the original formula
  714. # is out of place
  715. formula = f"self_t_raw.defined() ? self_t_raw.copy_({formula}) : {formula}"
  716. required_original_self_value = bool(
  717. re.search(IDENT_REGEX.format("original_self_p"), formula)
  718. ) or bool(re.search(IDENT_REGEX.format("original_self_t"), formula))
  719. forward_derivatives = [
  720. ForwardDerivative(
  721. formula=formula,
  722. var_names=("self",),
  723. var_types=fw_info.var_types,
  724. required_inputs_fw_grad=fw_info.required_inputs_fw_grad,
  725. required_inputs_primal=required_primals,
  726. required_original_self_value=required_original_self_value,
  727. is_reusing_outplace_formula=not is_exact_match,
  728. ),
  729. ]
  730. fw_derivative_dict[key] = forward_derivatives
  731. result.append(
  732. NativeFunctionWithDifferentiabilityInfo(
  733. func=f, info=info_dict, fw_derivatives=fw_derivative_dict
  734. )
  735. )
  736. return result
  737. def is_differentiable(
  738. name: str, type: Type, info: DifferentiabilityInfo | None
  739. ) -> bool:
  740. return type.is_tensor_like() and (
  741. info is None or name not in info.non_differentiable_arg_names
  742. )
  743. def gen_differentiable_outputs(
  744. fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default"
  745. ) -> list[DifferentiableOutput]:
  746. f = fn.func
  747. info = fn.info[key] if fn.info else None
  748. outputs: list[DifferentiableOutput] = [
  749. DifferentiableOutput(
  750. name=name,
  751. type=ret.type,
  752. cpp_type=cpp.return_type(ret, symint=True).cpp_type(),
  753. )
  754. for name, ret in zip(cpp.return_names(f), f.func.returns)
  755. ]
  756. output_differentiability = info.output_differentiability if info else None
  757. if output_differentiability is not None:
  758. if len(output_differentiability) != len(outputs):
  759. raise RuntimeError(
  760. f"The length of output_differentiability ({len(output_differentiability)}), "
  761. f"does not match the number of outputs ({len(outputs)})."
  762. )
  763. differentiable_outputs: list[DifferentiableOutput] = []
  764. if False in output_differentiability and f.func.kind() == SchemaKind.inplace:
  765. raise RuntimeError(
  766. "output_differentiability=False for inplace operation (version_counter won't get updated)"
  767. )
  768. for differentiable, output in zip(output_differentiability, outputs):
  769. if differentiable:
  770. differentiable_outputs.append(output)
  771. return differentiable_outputs
  772. candidate_differentiable_outputs = list(
  773. filter(lambda r: is_differentiable(r.name, r.type, info), outputs)
  774. )
  775. if uses_single_grad(info):
  776. return candidate_differentiable_outputs[:1]
  777. else:
  778. return candidate_differentiable_outputs