load_derivatives.py 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044
  1. # Parses derivatives.yaml into autograd functions
  2. #
  3. # Each autograd function is represented by `DifferentiabilityInfo` containing
  4. # a list of `Derivative`. See `torchgen.api.autograd` for the data models.
  5. from __future__ import annotations
  6. import re
  7. from collections import Counter, defaultdict
  8. from typing import Any, TYPE_CHECKING
  9. import yaml
  10. from torchgen.api import cpp
  11. from torchgen.api.autograd import (
  12. Derivative,
  13. DifferentiabilityInfo,
  14. ForwardDerivative,
  15. SavedAttribute,
  16. )
  17. from torchgen.api.types import (
  18. BaseCType,
  19. Binding,
  20. boolT,
  21. CppSignatureGroup,
  22. layoutT,
  23. longT,
  24. NamedCType,
  25. OptionalCType,
  26. scalarTypeT,
  27. SpecialArgName,
  28. stringT,
  29. symIntArrayRefT,
  30. SymIntT,
  31. tensorGeometryT,
  32. tensorOptionsT,
  33. typeAndSizeT,
  34. VectorCType,
  35. )
  36. from torchgen.context import with_native_function
  37. from torchgen.gen import get_grouped_by_view_native_functions, parse_native_yaml
  38. from torchgen.model import (
  39. AUTOGRAD_KEYS,
  40. FunctionSchema,
  41. NativeFunction,
  42. NativeFunctionsViewGroup,
  43. OperatorName,
  44. SchemaKind,
  45. Type,
  46. Variant,
  47. )
  48. from torchgen.utils import concatMap, IDENT_REGEX, split_name_params
  49. from torchgen.yaml_utils import YamlLoader
  50. if TYPE_CHECKING:
  51. from collections.abc import Sequence
  52. DerivativeRet = tuple[dict[FunctionSchema, dict[str, DifferentiabilityInfo]], set[str]]
  53. _GLOBAL_LOAD_DERIVATIVE_CACHE: dict[tuple[str, str], DerivativeRet] = {}
  54. _VALID_AUTOGRAD_KEYS = set(AUTOGRAD_KEYS)
  55. # This function directly adds per-dispatchkey derivative entries for {view}_copy variants of each view op.
  56. # Since every {view} and {view}_copy op shares the same derivative formula,
  57. # we generate them here instead of duplicating them in the yaml.
  58. # See Note [Codegen'd {view}_copy Operators]
  59. def add_view_copy_derivatives(
  60. infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]],
  61. view_groups: list[NativeFunctionsViewGroup],
  62. ) -> None:
  63. # Get the map from each view op's name to its corresponding view group
  64. view_name_to_group: dict[OperatorName, NativeFunctionsViewGroup] = {
  65. g.view.func.name: g for g in view_groups
  66. }
  67. view_infos = {}
  68. for info_dispatch_dict in infos.values():
  69. # maybe_view_group only needs to be calculated once per info_dispatch_dict
  70. maybe_view_group = None
  71. view_copy_differentiability_infos = {}
  72. for dispatch_key, info in info_dispatch_dict.items():
  73. maybe_view_group = view_name_to_group.get(info.func.func.name, None)
  74. if maybe_view_group is not None and maybe_view_group.view_copy is not None:
  75. view_copy_info = info.create_view_copy_from_view_derivative(
  76. maybe_view_group
  77. )
  78. if view_copy_info is not None:
  79. fn_schema = view_copy_info.func.func
  80. view_copy_differentiability_infos[dispatch_key] = view_copy_info
  81. else:
  82. break
  83. # prefer manually-defined derivatives if any
  84. # pyrefly: ignore [unbound-name]
  85. if len(view_copy_differentiability_infos) > 0 and fn_schema not in infos:
  86. # pyrefly: ignore [unbound-name]
  87. if fn_schema is None:
  88. raise AssertionError("Expected fn_schema to be non-None")
  89. # pyrefly: ignore [unbound-name]
  90. view_infos[fn_schema] = view_copy_differentiability_infos
  91. infos.update(view_infos)
  92. def load_derivatives(
  93. derivatives_yaml_path: str, native_yaml_path: str, tags_yaml_path: str
  94. ) -> DerivativeRet:
  95. # Do some caching as this is a deterministic function
  96. global _GLOBAL_LOAD_DERIVATIVE_CACHE
  97. key = (derivatives_yaml_path, native_yaml_path)
  98. if key not in _GLOBAL_LOAD_DERIVATIVE_CACHE:
  99. with open(derivatives_yaml_path) as f:
  100. definitions = yaml.load(f, Loader=YamlLoader)
  101. funcs = parse_native_yaml(native_yaml_path, tags_yaml_path).native_functions
  102. # From the parsed native functions, separate out the (generated) view_copy functions,
  103. # so we can generate derivatives for them separately.
  104. native_functions_with_view_groups = get_grouped_by_view_native_functions(funcs)
  105. native_functions = concatMap(
  106. lambda g: [g]
  107. if isinstance(g, NativeFunction)
  108. else list(g.functions(include_copy=True)),
  109. native_functions_with_view_groups,
  110. )
  111. view_groups = [
  112. g
  113. for g in native_functions_with_view_groups
  114. if isinstance(g, NativeFunctionsViewGroup)
  115. ]
  116. # What's the difference between function schema v.s. signature?
  117. # function schema is the complete declaration including mutability annotation / default value and etc.
  118. # signature is the canonical schema for a group of functions (in-place/out/functional variants)
  119. # that are semantically related.
  120. functions_by_signature: dict[FunctionSchema, list[NativeFunction]] = (
  121. defaultdict(list)
  122. )
  123. functions_by_schema: dict[str, NativeFunction] = {}
  124. for function in native_functions:
  125. functions_by_signature[function.func.signature()].append(function)
  126. if str(function.func) in functions_by_schema:
  127. raise AssertionError(f"Duplicate function schema: {str(function.func)}")
  128. functions_by_schema[str(function.func)] = function
  129. # Keep track of how many of which ops we've seen so we can
  130. # disambiguate them with a numeric suffix.
  131. op_counter = Counter[str]()
  132. # infos is a dict that maps FunctionSchema -> a dict of per dispatch key DifferentiabilityInfos
  133. # this is useful because in tools/autograd/gen_autograd.py:match_differentiability_info
  134. # we ultimately need to categorize the DifferentiabilityInfos by FunctionSchema
  135. infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]] = {}
  136. used_dispatch_keys: set[str] = set()
  137. for defn_dict in definitions:
  138. # Ensure that the old derivatives.yaml schema with no dispatch key can be loaded.
  139. if "dispatch" not in defn_dict:
  140. specification = defn_dict.pop("name")
  141. output_differentiability = defn_dict.pop(
  142. "output_differentiability", None
  143. )
  144. defn_dict = {"name": specification, "dispatch": {"Default": defn_dict}}
  145. if output_differentiability:
  146. defn_dict["output_differentiability"] = output_differentiability
  147. name, per_dispatch_diffinfos = create_differentiability_info(
  148. defn_dict,
  149. functions_by_signature,
  150. functions_by_schema,
  151. op_counter,
  152. used_dispatch_keys,
  153. )
  154. infos[name] = per_dispatch_diffinfos
  155. add_view_copy_derivatives(infos, view_groups)
  156. # cache both loaded infos as well a a set of all the dispatch_keys/aliases
  157. # that appear in derivatives.yaml. used_dispatch_keys is useful for generating
  158. # VariableType.cpp where we need a TORCH_LIBRARY_IMPL for every autograd dispatch key used
  159. _GLOBAL_LOAD_DERIVATIVE_CACHE[key] = infos, used_dispatch_keys
  160. return _GLOBAL_LOAD_DERIVATIVE_CACHE[key]
  161. # TODO: Why is this going through CppSignatureGroup, that doesn't make sense...
  162. @with_native_function
  163. def cpp_arguments(f: NativeFunction) -> Sequence[Binding]:
  164. sigs = CppSignatureGroup.from_native_function(f, method=False)
  165. if sigs.symint_signature is not None:
  166. return sigs.symint_signature.arguments()
  167. else:
  168. return sigs.signature.arguments()
  169. def create_derivative(
  170. f: NativeFunction,
  171. formula: str,
  172. var_names: tuple[str, ...],
  173. available_named_gradients: Sequence[str],
  174. ) -> Derivative:
  175. original_formula = formula
  176. arguments: list[NamedCType] = [
  177. a.nctype.remove_const_ref() for a in cpp_arguments(f)
  178. ]
  179. return_names = tuple(n if n != "self" else "result" for n in cpp.return_names(f))
  180. return_types = tuple(
  181. cpp.return_type(r, symint=True).remove_const_ref() for r in f.func.returns
  182. )
  183. named_returns = [
  184. NamedCType(name, type) for name, type in zip(return_names, return_types)
  185. ]
  186. formula, saved_inputs = saved_variables(formula, arguments, var_names)
  187. formula, saved_outputs = saved_variables(formula, named_returns, var_names)
  188. used_named_gradients = {
  189. name
  190. for name in available_named_gradients
  191. if re.search(IDENT_REGEX.format(name), formula)
  192. }
  193. # Check that the referenced derivatives in the formula are in bounds
  194. for i in used_gradient_indices(formula):
  195. if i >= len(f.func.returns):
  196. raise RuntimeError(
  197. f"Out of bounds grads access: derivative formula for {cpp.name(f.func)} "
  198. f"used grads[{i}], but the forward only returns {len(f.func.returns)} outputs."
  199. )
  200. return Derivative(
  201. formula=formula,
  202. original_formula=original_formula,
  203. var_names=var_names,
  204. saved_inputs=saved_inputs,
  205. saved_outputs=saved_outputs,
  206. named_gradients=used_named_gradients,
  207. )
  208. def create_forward_derivative(
  209. f: NativeFunction, formula: str, names: tuple[str, ...]
  210. ) -> ForwardDerivative:
  211. var_names = names
  212. var_types: tuple[Type, ...] | None = None
  213. for r in f.func.returns:
  214. if r.name in var_names:
  215. if var_types is None:
  216. var_types = ()
  217. var_types = var_types + (r.type,)
  218. # Handle default return names
  219. if var_types is None:
  220. if var_names == ("result",):
  221. if len(f.func.returns) != 1:
  222. raise AssertionError(
  223. f"Expected 1 return for 'result', got {len(f.func.returns)}"
  224. )
  225. var_types = (f.func.returns[0].type,)
  226. else:
  227. for var_name in var_names:
  228. res = re.findall(r"^result(\d+)$", var_name)
  229. if len(res) == 1:
  230. if var_types is None:
  231. var_types = ()
  232. arg_idx = int(res[0])
  233. var_types = var_types + (f.func.returns[arg_idx].type,)
  234. if var_types is None:
  235. raise AssertionError("No matching output for forward derivative definition")
  236. return ForwardDerivative(
  237. formula=formula,
  238. var_names=var_names,
  239. var_types=var_types,
  240. required_inputs_fw_grad=None,
  241. required_inputs_primal=None,
  242. required_original_self_value=False,
  243. is_reusing_outplace_formula=False,
  244. )
  245. def postprocess_forward_derivatives(
  246. f: NativeFunction,
  247. defn_name: str,
  248. all_arg_names: list[str],
  249. derivatives: list[Derivative],
  250. forward_derivatives: list[ForwardDerivative],
  251. args_with_derivatives: Sequence[Binding],
  252. ) -> list[ForwardDerivative]:
  253. def find_required_inputs(formula: str, postfix: str) -> tuple[str, ...]:
  254. is_foreach = f.func.name.name.base.startswith("_foreach_")
  255. required_inputs = set()
  256. for arg in args_with_derivatives:
  257. if (
  258. arg.type in ("at::TensorList", "const at::ITensorListRef &")
  259. and not is_foreach
  260. ):
  261. # The functions taking TensorList handle everything internally
  262. continue
  263. arg_name = arg.name
  264. found = re.search(IDENT_REGEX.format(arg_name), formula)
  265. if found:
  266. raise RuntimeError(
  267. f"The forward formula for {defn_name} is using the base name of the {arg_name} "
  268. f"argument which is ambiguous. You should use {arg_name}_p to access the primal "
  269. f"value and {arg_name}_t to access the tangent."
  270. )
  271. found = re.search(IDENT_REGEX.format(arg_name + postfix), formula)
  272. if found:
  273. required_inputs.add(arg_name)
  274. return tuple(required_inputs)
  275. updated_derivatives: list[ForwardDerivative] = []
  276. for defn in forward_derivatives:
  277. formula = defn.formula
  278. required_inputs_tangent = find_required_inputs(formula, "_t")
  279. if formula == "auto_element_wise":
  280. if f.func.kind() == SchemaKind.inplace:
  281. raise AssertionError(
  282. f"Cannot use auto_element_wise with {f.func.name} because it is an in-place variant"
  283. )
  284. if (
  285. (not len(args_with_derivatives) == 1)
  286. or len(forward_derivatives) > 1
  287. or len(forward_derivatives[0].var_names) > 1
  288. ):
  289. raise RuntimeError(
  290. f"Derivative definition of {defn_name} in derivatives.yaml defines the "
  291. "forward definition of gradient as element_wise but this only "
  292. "works for functions with a single differentiable input and a "
  293. "single differentiable output."
  294. )
  295. if not len(derivatives) == 1:
  296. raise RuntimeError(
  297. f"Derivative definition of {defn_name} in derivatives.yaml defines the "
  298. "forward definition of gradient as element_wise but it does not "
  299. "defines the gradient formula for its argument which is required."
  300. )
  301. # This transformation is based on the observation that for element-wise functions, the Jacobian
  302. # matrix is diagonal and thus doing J * v is the same as (v^T J)^T (in practice, we ignore the transpositions)
  303. # For the complex case, we use hermitian transpose and get (v.conj() J).conj()
  304. # So here we are going to reuse the backward formula and replace two things:
  305. # 1) all occurrences of "grad" with "foo_t.conj()", where foo is the name of the unique differentiable input.
  306. # 2) all usage of an original input "foo" with its primal value "foo_p".
  307. # 3) conjugate the final result
  308. # For example, for abs, the backward formula is:
  309. # grad * self.sgn()
  310. # And this function generates a forward formula that is:
  311. # (self_t.conj() * self_p.sgn()).conj()
  312. backward_formula = derivatives[0].original_formula
  313. input_name = args_with_derivatives[0].name
  314. # Do replacement 1) of the grad
  315. def repl(m: Any) -> str:
  316. return f"{m.group(1)}{input_name}_t.conj(){m.group(2)}"
  317. fw_formula = re.sub(IDENT_REGEX.format("grad"), repl, backward_formula)
  318. # Do replacement 2) of the input variables
  319. for arg in args_with_derivatives:
  320. arg_name = arg.name
  321. def repl(m: Any) -> str:
  322. return f"{m.group(1)}{arg_name}_p{m.group(2)}"
  323. fw_formula = re.sub(IDENT_REGEX.format(arg_name), repl, fw_formula)
  324. # Do the final conjugate 3)
  325. fw_formula = f"({fw_formula}).conj()"
  326. # Since there is a single differentiable inputs and we necessarily need its tangent we can
  327. # simply require all differentiable input's tangent.
  328. required_inputs_tangent = tuple(all_arg_names)
  329. formula = fw_formula
  330. elif formula == "auto_linear":
  331. if (
  332. len(forward_derivatives) > 1
  333. or len(forward_derivatives[0].var_names) > 1
  334. ):
  335. raise RuntimeError(
  336. f"Derivative definition of {defn_name} in derivatives.yaml defines the "
  337. "forward definition of gradient as linear but this only works "
  338. "for functions with a single differentiable output."
  339. )
  340. # This transformation is based on the observation that linear functions can be written as:
  341. # y = f(x) = A * x
  342. # For some matrix A and the Jacobian of the function f is also A.
  343. # So doing J * v = A * v = f(v).
  344. # Hence to do the jvp, we simply need to evaluate the function at the point v instead of x.
  345. # We do this by calling the forward again by replacing any occurrence of the differentiable
  346. # input "foo" by it's tangent "foo_t".
  347. # Note that multiple inputs are not a problem as long as the function is truly linear wrt to
  348. # the vector where all the differentiable inputs are stacked.
  349. diff_arg_names = [arg.name for arg in args_with_derivatives]
  350. if len(diff_arg_names) == 0:
  351. raise AssertionError("Expected at least one differentiable argument")
  352. # Do replacement of input variables
  353. new_args = []
  354. for arg_name in all_arg_names:
  355. if arg_name in diff_arg_names:
  356. arg_name = arg_name + "_t"
  357. # pyrefly: ignore [bad-argument-type]
  358. new_args.append(arg_name)
  359. # TODO we are trolling
  360. if f.func.has_symint():
  361. defn_name += "_symint"
  362. # Call into the forward again. We need two cases here to handle both Tensor methods and at:: functions.
  363. if Variant.function in f.variants:
  364. fw_formula = f"at::{defn_name}({', '.join(new_args)})"
  365. else:
  366. if Variant.method not in f.variants:
  367. raise AssertionError(
  368. f"Expected Variant.method in variants for {f.func.name}"
  369. )
  370. fw_formula = f"{new_args[0]}.{defn_name}({', '.join(new_args[1:])})"
  371. # All of the input tangents are always used so all of them are required here.
  372. required_inputs_tangent = tuple(diff_arg_names)
  373. formula = fw_formula
  374. # At this point, the formula is final and is not modified anymore.
  375. # During forward formula, we use the primal instead of the input Tensors.
  376. # This call inspects the formula to find for which input's primal are used.
  377. required_inputs_primal = find_required_inputs(formula, "_p")
  378. updated_derivatives.append(
  379. ForwardDerivative(
  380. formula=formula,
  381. var_names=defn.var_names,
  382. var_types=defn.var_types,
  383. required_inputs_fw_grad=required_inputs_tangent,
  384. required_inputs_primal=required_inputs_primal,
  385. required_original_self_value=False,
  386. is_reusing_outplace_formula=False,
  387. )
  388. )
  389. return updated_derivatives
  390. def is_forward_derivative_definition(
  391. all_arg_names: list[str], names: tuple[str, ...]
  392. ) -> bool:
  393. for name in names:
  394. return name not in all_arg_names
  395. raise RuntimeError("Expected `names` to be non-empty")
  396. def create_differentiability_info(
  397. defn_dict: dict[Any, Any],
  398. functions_by_signature: dict[FunctionSchema, list[NativeFunction]],
  399. functions_by_schema: dict[str, NativeFunction],
  400. op_counter: Counter[str],
  401. used_dispatch_keys: set[str],
  402. ) -> tuple[FunctionSchema, dict[str, DifferentiabilityInfo]]:
  403. """Processes a single entry `defn` in derivatives.yaml"""
  404. def canonical_function(
  405. functions: Sequence[NativeFunction], name: str
  406. ) -> NativeFunction:
  407. for f in functions:
  408. if (
  409. not f.func.is_functional_fn()
  410. and not f.func.is_out_fn()
  411. and name == str(f.func.name.name)
  412. ):
  413. return f
  414. # some functions only have in-place variants
  415. if name + "_" != cpp.name(functions[0].func):
  416. raise AssertionError(
  417. f"Expected inplace function name '{name}_', got '{cpp.name(functions[0].func)}'"
  418. )
  419. return functions[0]
  420. def split_names(raw_names: str) -> tuple[str, ...]:
  421. """Given "foo, bar", return ["foo", "bar"]."""
  422. return tuple(x.strip() for x in raw_names.split(","))
  423. def check_grad_usage(defn_name: str, derivatives: Sequence[Derivative]) -> None:
  424. """
  425. Check for some subtle mistakes one might make when writing derivatives.
  426. These mistakes will compile, but will be latent until a function is
  427. used with double backwards.
  428. """
  429. uses_grad = False # true if any derivative uses "grad"
  430. num_grads_uses = 0 # count of uses of "grads" or "grads[INDEX]"
  431. uses_named_grads = False # true if any derivative uses "grad_{name}"
  432. used_grads_indices: list[int] = [] # which indices of grads are used
  433. for d in derivatives:
  434. formula = d.formula
  435. uses_grad = uses_grad or bool(
  436. re.findall(IDENT_REGEX.format("grad"), formula)
  437. )
  438. num_grads_uses += len(re.findall(IDENT_REGEX.format("grads"), formula))
  439. uses_named_grads = uses_named_grads or bool(d.named_gradients)
  440. used_grads_indices.extend(used_gradient_indices(formula))
  441. # This is a basic sanity check: the number of places we see
  442. # "grads" should be no fewer than the number of indices we see
  443. # inside "grads". They may not be equal because we may use
  444. # "grads" without an index.
  445. if num_grads_uses < len(used_grads_indices):
  446. raise AssertionError(
  447. f"num_grads_uses ({num_grads_uses}) < len(used_grads_indices) ({len(used_grads_indices)})"
  448. )
  449. # Thus if the number is equal, every use of grads is also
  450. # indexed.
  451. only_used_grads_indices = num_grads_uses == len(used_grads_indices)
  452. if uses_grad and num_grads_uses > 0:
  453. raise RuntimeError(
  454. f"Derivative definition of {defn_name} in derivatives.yaml illegally "
  455. "mixes use of 'grad' and 'grads'. Consider replacing "
  456. "occurrences of 'grad' with 'grads[0]'"
  457. )
  458. if only_used_grads_indices and set(used_grads_indices) == {0}:
  459. raise RuntimeError(
  460. f"Derivative definition of {defn_name} in derivatives.yaml solely "
  461. "refers to 'grads[0]'. If the first output is indeed the "
  462. "only differentiable output, replace 'grads[0]' with 'grad'; "
  463. "otherwise, there is a likely error in your derivatives "
  464. "declaration."
  465. )
  466. if uses_named_grads and (uses_grad or num_grads_uses > 0):
  467. raise RuntimeError(
  468. f"Derivative definition of {defn_name} in derivatives.yaml illegally "
  469. 'mixes use of "grad_RETURN_NAME" and "grad" or "grads[x]". Use '
  470. "only one method for identifying gradients."
  471. )
  472. @with_native_function
  473. def set_up_derivatives(
  474. f: NativeFunction,
  475. ) -> tuple[
  476. Sequence[Derivative],
  477. Sequence[ForwardDerivative],
  478. Sequence[Binding],
  479. Sequence[str],
  480. Sequence[str],
  481. ]:
  482. # Set up the derivative information
  483. derivatives: list[Derivative] = []
  484. forward_derivatives: list[ForwardDerivative] = []
  485. non_differentiable_arg_names: list[str] = []
  486. args_with_derivatives_set: set[str] = set()
  487. all_arg_names = [a.name for a in cpp_arguments(f)]
  488. all_ret_names = [
  489. r.name for r in f.func.returns
  490. ] # only used for the assert below
  491. # output_differentiability is captured from the enclosed
  492. # scope. Don't modify it.
  493. #
  494. # If it is not present, then no output is explicitly
  495. # undifferentiable.
  496. #
  497. # It may be present and shorter than the length of return
  498. # values. If that's the case, any return value that does not
  499. # have a corresponding entry is considered not differentiable.
  500. differentiability = output_differentiability or [True] * len(f.func.returns)
  501. # A return is available as a named gradient ...
  502. available_named_gradients = [
  503. f"grad_{ret.name}"
  504. for ret, differentiable in zip(f.func.returns, differentiability)
  505. # if it has not been explicitly made undifferentiable
  506. if differentiable
  507. # and if it has a name
  508. and ret.name is not None
  509. # and if its type is differentiable
  510. and ret.type.is_tensor_like()
  511. ]
  512. for raw_names in sorted(defn.keys()):
  513. formula = defn[raw_names]
  514. names = split_names(raw_names)
  515. for name in names:
  516. if name in all_arg_names and name in all_ret_names:
  517. raise AssertionError(
  518. f"While processing the derivative formula for '{f.func.name}' wrt '{name}', "
  519. f"expected '{name}' to not be both an input arg and named return."
  520. )
  521. if is_forward_derivative_definition(all_arg_names, names):
  522. forward_derivatives.append(create_forward_derivative(f, formula, names))
  523. else:
  524. if formula.lower().strip() == "non_differentiable":
  525. non_differentiable_arg_names += names
  526. else:
  527. derivative = create_derivative(
  528. f, formula, names, available_named_gradients
  529. )
  530. derivatives.append(derivative)
  531. args_with_derivatives_set |= set(names)
  532. overlap = args_with_derivatives_set.intersection(non_differentiable_arg_names)
  533. if overlap:
  534. raise RuntimeError(
  535. f"derivatives definition for {defn} have overlapped non_differentiable "
  536. f"and differentiable variables: {overlap}"
  537. )
  538. # Next, let us determine the list of inputs in order.
  539. # TODO: do we need eagerly calculate and save it here? Can it be derived
  540. # from NativeFunction and `derivatives` on callsites instead?
  541. args_with_derivatives = [
  542. a for a in cpp_arguments(f) if a.name in args_with_derivatives_set
  543. ]
  544. # Postprocess forward derivatives definitions now that we know the differentiable arguments
  545. forward_derivatives = postprocess_forward_derivatives(
  546. f,
  547. defn_name,
  548. all_arg_names,
  549. derivatives,
  550. forward_derivatives,
  551. args_with_derivatives,
  552. )
  553. # Test to see if the use of 'grads' makes sense.
  554. check_grad_usage(defn_name, derivatives)
  555. return (
  556. derivatives,
  557. forward_derivatives,
  558. args_with_derivatives,
  559. non_differentiable_arg_names,
  560. available_named_gradients,
  561. )
  562. # NB: Removes 'name' from defn dictionary
  563. specification = defn_dict.pop("name")
  564. defn_name, _ = split_name_params(specification)
  565. # NB: Removes 'output_differentiability' from defn dictionary
  566. # `None` means all differentiable.
  567. output_differentiability = defn_dict.pop("output_differentiability", None)
  568. output_differentiability_conditions = None
  569. if output_differentiability and any(
  570. isinstance(diff, str) for diff in output_differentiability
  571. ):
  572. if len(output_differentiability) != 1:
  573. raise RuntimeError(
  574. f"Not supported: for {specification},"
  575. f"output_differentiability must either be "
  576. f"list[bool] or a list[str] where each str is a "
  577. f"condition. In the case where it is a condition, "
  578. f"we only support single-output functions. "
  579. f"Please file us an issue. "
  580. )
  581. output_differentiability_conditions = output_differentiability
  582. output_differentiability = [True]
  583. schema_function = functions_by_schema.get(specification)
  584. if not schema_function:
  585. avail = "\n".join(
  586. k for k, v in functions_by_schema.items() if cpp.name(v.func) == defn_name
  587. )
  588. raise RuntimeError(
  589. f"could not find ATen function for schema: {specification} "
  590. f". Available signatures:\n{avail}"
  591. )
  592. # now map this to the legacy schema; this isn't technically necessary, but we'd need some logic here
  593. # to map in-place schemas to the out-of-place variants.
  594. # TODO: maybe the logic to handle the legacy schema is no longer necessary?
  595. signature = schema_function.func.signature()
  596. functions = functions_by_signature[signature]
  597. if len(functions) == 0:
  598. avail = "\n".join(
  599. str(k)
  600. for k, v in functions_by_signature.items()
  601. if cpp.name(k) == defn_name
  602. )
  603. raise RuntimeError(
  604. f"could not find ATen function for legacy signature: {signature} "
  605. f"corresponding to schema {specification}. Please report a bug to PyTorch. "
  606. f"Available signatures:\n{avail}"
  607. )
  608. canonical = canonical_function(functions, defn_name)
  609. if "grad_input_mask" in (a.name for a in cpp_arguments(canonical)):
  610. raise RuntimeError(
  611. f"Schema for {defn_name} has an argument named grad_input_mask, "
  612. "but this name would be shadowed by our codegen. "
  613. "Please use a different name in native_functions.yaml."
  614. )
  615. if "result" in (a.name for a in cpp_arguments(canonical)):
  616. raise RuntimeError(
  617. f"Schema for {defn_name} has an argument named result, "
  618. "but this is only allowed for outputs."
  619. "Please use a different name in native_functions.yaml."
  620. )
  621. diffinfo_dict = {}
  622. for key, defn in defn_dict["dispatch"].items():
  623. if key != "Default" and key not in _VALID_AUTOGRAD_KEYS:
  624. raise RuntimeError(
  625. f"Invalid dispatch key {key} in derivatives.yaml for {specification},"
  626. f" expected key to be one of {_VALID_AUTOGRAD_KEYS}"
  627. )
  628. if key not in used_dispatch_keys:
  629. used_dispatch_keys.add(key)
  630. (
  631. derivatives,
  632. forward_derivatives,
  633. args_with_derivatives,
  634. non_differentiable_arg_names,
  635. available_named_gradients,
  636. ) = set_up_derivatives(canonical)
  637. used_named_gradients: set[str] = set()
  638. for d in derivatives:
  639. used_named_gradients |= d.named_gradients
  640. # only assign an op name if we are actually going to calculate a derivative
  641. op = None
  642. if args_with_derivatives:
  643. op_prefix = _create_op_prefix(defn_name)
  644. if key != "Default":
  645. op_prefix = op_prefix + key
  646. op = f"{op_prefix}{op_counter[op_prefix]}"
  647. op_counter[op_prefix] += 1
  648. diffinfo_dict[key] = DifferentiabilityInfo(
  649. name=defn_name,
  650. func=canonical,
  651. op=op,
  652. derivatives=derivatives,
  653. forward_derivatives=forward_derivatives,
  654. all_saved_inputs=dedup_vars(
  655. [v for d in derivatives for v in d.saved_inputs]
  656. ),
  657. all_saved_outputs=dedup_vars(
  658. [v for d in derivatives for v in d.saved_outputs]
  659. ),
  660. available_named_gradients=available_named_gradients,
  661. used_named_gradients=used_named_gradients,
  662. args_with_derivatives=args_with_derivatives,
  663. non_differentiable_arg_names=non_differentiable_arg_names,
  664. output_differentiability=output_differentiability,
  665. output_differentiability_conditions=output_differentiability_conditions,
  666. )
  667. return canonical.func, diffinfo_dict
  668. GRAD_INDEX_REGEX = r"(?:^|\W)grads\[(\d+)\]"
  669. def used_gradient_indices(formula: str) -> list[int]:
  670. """Determine a list of gradient indices (the i in grads[i]) that
  671. are used by the formula.
  672. >>> used_gradient_indices("foo(grads[0], grads[1])")
  673. [0, 1]
  674. """
  675. return [int(i) for i in re.findall(GRAD_INDEX_REGEX, formula)]
  676. def saved_variables(
  677. formula: str,
  678. nctypes: list[NamedCType],
  679. var_names: tuple[str, ...],
  680. ) -> tuple[str, tuple[SavedAttribute, ...]]:
  681. def stride_expr(name: str) -> str:
  682. if var_names != (name,):
  683. raise AssertionError(
  684. 'Replacement for ".strides()" is currently only supported for single derivatives of the same tensor '
  685. 'that ".strides()" is being called on.'
  686. )
  687. return f'strides_or_error({name}, "{name}")'
  688. REPLACEMENTS: list[tuple[str, dict[str, Any]]] = [
  689. # replace self.sym_sizes() with self_sym_sizes
  690. (
  691. r"{}.sym_sizes\(\)",
  692. {
  693. "suffix": "_sym_sizes",
  694. "nctype": lambda name: NamedCType(name, BaseCType(symIntArrayRefT)),
  695. },
  696. ),
  697. # replace self->sym_sizes() with self_sym_sizes_opt
  698. (
  699. r"{}->sym_sizes\(\)",
  700. {
  701. "suffix": "_sym_sizes_opt",
  702. "nctype": lambda name: NamedCType(
  703. name, OptionalCType(BaseCType(symIntArrayRefT))
  704. ),
  705. "expr": lambda name: f"{name}.has_value() ? std::optional<c10::SymIntArrayRef>({name}->sym_sizes()) : std::nullopt",
  706. },
  707. ),
  708. # replace self.sym_blocksize() with self_sym_blocksize_opt
  709. (
  710. r"{}.sym_blocksize\(\)",
  711. {
  712. "suffix": "_self_sym_blocksize_opt",
  713. "nctype": lambda name: NamedCType(
  714. name, OptionalCType(BaseCType(symIntArrayRefT))
  715. ),
  716. "expr": lambda name: f"at::sparse_csr::getSymIntBlockSize({name})",
  717. },
  718. ),
  719. # replace self.options() with self_options
  720. (
  721. r"{}.options\(\)",
  722. {
  723. "suffix": "_options",
  724. "nctype": lambda name: NamedCType(name, BaseCType(tensorOptionsT)),
  725. },
  726. ),
  727. # replace zeros_like(self) with self_info
  728. (
  729. r"zeros_like\({}\)",
  730. {
  731. "suffix": "_info",
  732. "nctype": lambda name: NamedCType(name, BaseCType(typeAndSizeT)),
  733. "expr": lambda name: name, # at save-time
  734. "res": lambda name: name + "_info.zeros()", # at eval-time
  735. },
  736. ),
  737. # replace self.sym_size(2) with self_sym_size_2
  738. (
  739. r"{}.sym_size\((-?\w+)\)",
  740. {
  741. "suffix": lambda m: f"_sym_argsize_{m.groups()[0].replace('-', 'minus_')}",
  742. "nctype": lambda name: NamedCType(name, BaseCType(SymIntT)),
  743. },
  744. ),
  745. # replace self.numel() with self_numel
  746. (
  747. r"{}.numel\(\)",
  748. {
  749. "suffix": "_numel",
  750. "nctype": lambda name: NamedCType(name, BaseCType(longT)),
  751. },
  752. ),
  753. # replace self.sym_numel() with self_sym_numel
  754. (
  755. r"{}.sym_numel\(\)",
  756. {
  757. "suffix": "_sym_numel",
  758. "nctype": lambda name: NamedCType(name, BaseCType(SymIntT)),
  759. },
  760. ),
  761. # replace to_args_sizes(self) with self_args_sizes
  762. (
  763. r"to_args_sizes\({}\)",
  764. {
  765. "suffix": "_args_sizes",
  766. "nctype": lambda name: NamedCType(
  767. name, VectorCType(VectorCType(BaseCType(longT)))
  768. ),
  769. },
  770. ),
  771. # replace to_args_sizes_symint(self) with self_args_sizes
  772. (
  773. r"to_args_sizes_symint\({}\)",
  774. {
  775. "suffix": "_args_sizes_symint",
  776. "nctype": lambda name: NamedCType(
  777. name, VectorCType(VectorCType(BaseCType(SymIntT)))
  778. ),
  779. },
  780. ),
  781. # replace to_args_scalartypes(self) with self_args_scalartypes
  782. (
  783. r"to_args_scalartypes\({}\)",
  784. {
  785. "suffix": "_args_scalartypes",
  786. "nctype": lambda name: NamedCType(
  787. name, VectorCType(BaseCType(scalarTypeT))
  788. ),
  789. },
  790. ),
  791. # replace TensorGeometry(self) with self_geometry
  792. (
  793. r"TensorGeometry\({}\)",
  794. {
  795. "suffix": "_geometry",
  796. "nctype": lambda name: NamedCType(name, BaseCType(tensorGeometryT)),
  797. },
  798. ),
  799. (
  800. r"{}.scalar_type\(\)",
  801. {
  802. "suffix": "_scalar_type",
  803. "nctype": lambda name: NamedCType(name, BaseCType(scalarTypeT)),
  804. },
  805. ),
  806. # replace self.dim() with self_dim
  807. (
  808. r"{}.dim\(\)",
  809. {
  810. "suffix": "_dim",
  811. "nctype": lambda name: NamedCType(name, BaseCType(longT)),
  812. },
  813. ),
  814. # replace self.sym_strides() with self_sym_strides
  815. (
  816. r"{}.sym_strides\(\)",
  817. {
  818. "suffix": "_sym_strides",
  819. "nctype": lambda name: NamedCType(name, BaseCType(symIntArrayRefT)),
  820. "expr": stride_expr,
  821. },
  822. ),
  823. # replace self.layout() with self_layout
  824. (
  825. r"{}.layout\(\)",
  826. {
  827. "suffix": "_layout",
  828. "nctype": lambda name: NamedCType(name, BaseCType(layoutT)),
  829. },
  830. ),
  831. # replace self.is_conj() with self_conjugate
  832. (
  833. r"{}.is_conj\(\)",
  834. {
  835. "suffix": "_conjugate",
  836. "nctype": lambda name: NamedCType(name, BaseCType(boolT)),
  837. },
  838. ),
  839. ]
  840. # find which arguments need to be saved
  841. saved: list[SavedAttribute] = []
  842. if ".sizes()" in formula or "->sizes()" in formula:
  843. raise RuntimeError(
  844. ".sizes() is not supported in derivative formulas. Instead, please use the SymInt version,"
  845. + f".sym_sizes(), which returned a c10::SymIntArrayRef. formula={formula}"
  846. )
  847. if re.search(r"\.size\([-]?\d+\)", formula) or re.search(
  848. r"->size\([-]?\d+\)", formula
  849. ):
  850. raise RuntimeError(
  851. ".size(int) is not supported in derivative formulas. Instead, please use the SymInt version,"
  852. + f".sym_size(int), which returned a c10::SymIntArrayRef. formula={formula}"
  853. )
  854. if ".strides()" in formula or "->strides()" in formula:
  855. raise RuntimeError(
  856. ".strides() is not supported in derivative formulas. Instead, please use the SymInt version,"
  857. + f".sym_strides(), which returned a c10::SymIntArrayRef. formula={formula}"
  858. )
  859. for nctype in nctypes:
  860. # pyrefly: ignore [bad-assignment]
  861. name = (
  862. nctype.name.name if isinstance(nctype.name, SpecialArgName) else nctype.name
  863. )
  864. # First search the formula for expressions which can be evaluated
  865. # when the autograd Function is created to avoid saving variables
  866. for regex, info in REPLACEMENTS:
  867. def repl(m: re.Match[str]) -> str:
  868. suffix: str = (
  869. # pyrefly: ignore [bad-assignment]
  870. info["suffix"](m) if callable(info["suffix"]) else info["suffix"]
  871. )
  872. expr: str = info["expr"](name) if "expr" in info else m.group(0)
  873. saved.append(
  874. SavedAttribute(
  875. nctype=info["nctype"](name + suffix),
  876. expr=expr,
  877. )
  878. )
  879. if "res" in info:
  880. replacement: str = info["res"](name)
  881. return replacement
  882. return name + suffix
  883. formula = re.sub(regex.format(name), repl, formula)
  884. # std::optional<std::string> types stored in Backward nodes must be
  885. # converted to std::optional<std::string_view> before being passed into
  886. # the backward function
  887. if nctype.type == OptionalCType(BaseCType(stringT)):
  888. formula = re.sub(
  889. rf"\b{name}\b",
  890. f"{name}.has_value() ? std::optional<std::string_view>({name}.value()) : std::nullopt",
  891. formula,
  892. )
  893. # Find any variables which remain in the formula and save them
  894. if re.search(IDENT_REGEX.format(name), formula):
  895. saved.append(
  896. SavedAttribute(
  897. nctype=nctype,
  898. expr=name,
  899. )
  900. )
  901. return formula, tuple(saved)
  902. def _create_op_prefix(name: str) -> str:
  903. r"""Takes a native function name converts to an op prefix name.
  904. Note that the "name" parameter must be the native function name
  905. without the optional variant suffix, so "add" instead of
  906. "add.out".
  907. OP names correspond to classes, hence the change to title case.
  908. Example::
  909. >>> _create_op_prefix("add")
  910. 'AddBackward'
  911. """
  912. camel_case = "".join([p.title() for p in name.split("_")])
  913. return (camel_case + "Backward").replace("ForwardBackward", "Backward")
  914. def dedup_vars(vars: Sequence[SavedAttribute]) -> Sequence[SavedAttribute]:
  915. seen: set[str] = set()
  916. saved: list[SavedAttribute] = []
  917. for var in vars:
  918. name = (
  919. var.nctype.name.name
  920. if isinstance(var.nctype.name, SpecialArgName)
  921. else var.nctype.name
  922. )
  923. if name in seen:
  924. continue
  925. seen.add(name)
  926. saved.append(var)
  927. return saved