| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212 |
- # mypy: allow-untyped-defs
- import torch
- from torch._vmap_internals import _vmap
- from . import forward_ad as fwAD
- __all__ = ["vjp", "jvp", "jacobian", "hessian", "hvp", "vhp"]
- # Utility functions
- def _as_tuple_nocheck(x):
- if isinstance(x, tuple):
- return x
- elif isinstance(x, list):
- return tuple(x)
- else:
- return (x,)
- def _as_tuple(inp, arg_name=None, fn_name=None):
- # Ensures that inp is a tuple of Tensors
- # Returns whether or not the original inp was a tuple and the tupled version of the input
- if arg_name is None and fn_name is None:
- return _as_tuple_nocheck(inp)
- is_inp_tuple = True
- if not isinstance(inp, tuple):
- inp = (inp,)
- is_inp_tuple = False
- for i, el in enumerate(inp):
- if not isinstance(el, torch.Tensor):
- if is_inp_tuple:
- raise TypeError(
- f"The {arg_name} given to {fn_name} must be either a Tensor or a tuple of Tensors but the"
- f" value at index {i} has type {type(el)}."
- )
- else:
- raise TypeError(
- f"The {arg_name} given to {fn_name} must be either a Tensor or a tuple of Tensors but the"
- f" given {arg_name} has type {type(el)}."
- )
- return is_inp_tuple, inp
- def _tuple_postprocess(res, to_unpack):
- # Unpacks a potentially nested tuple of Tensors
- # to_unpack should be a single boolean or a tuple of two booleans.
- # It is used to:
- # - invert _as_tuple when res should match the inp given to _as_tuple
- # - optionally remove nesting of two tuples created by multiple calls to _as_tuple
- if isinstance(to_unpack, tuple):
- if len(to_unpack) != 2:
- raise AssertionError("Expected to_unpack tuple to have exactly 2 elements")
- if not to_unpack[1]:
- res = tuple(el[0] for el in res)
- if not to_unpack[0]:
- res = res[0]
- else:
- if not to_unpack:
- res = res[0]
- return res
- def _grad_preprocess(inputs, create_graph, need_graph):
- # Preprocess the inputs to make sure they require gradient
- # inputs is a tuple of Tensors to preprocess
- # create_graph specifies if the user wants gradients to flow back to the Tensors in inputs
- # need_graph specifies if we internally want gradients to flow back to the Tensors in res
- # Note that we *always* create a new Tensor object to be able to see the difference between
- # inputs given as arguments and the same Tensors automatically captured by the user function.
- # Check this issue for more details on how that can happen: https://github.com/pytorch/pytorch/issues/32576
- res = []
- for inp in inputs:
- if create_graph and inp.requires_grad:
- # Create at least a new Tensor object in a differentiable way
- if not inp.is_sparse:
- # Use .view_as() to get a shallow copy
- res.append(inp.view_as(inp))
- else:
- # We cannot use view for sparse Tensors so we clone
- res.append(inp.clone())
- else:
- res.append(inp.detach().requires_grad_(need_graph))
- return tuple(res)
- def _grad_postprocess(inputs, create_graph):
- # Postprocess the generated Tensors to avoid returning Tensors with history when the user did not
- # request it.
- if isinstance(inputs[0], torch.Tensor):
- if not create_graph:
- return tuple(inp.detach() for inp in inputs)
- else:
- return inputs
- else:
- return tuple(_grad_postprocess(inp, create_graph) for inp in inputs)
- def _validate_v(v, other, is_other_tuple):
- # This assumes that other is the correct shape, and v should match
- # Both are assumed to be tuples of Tensors
- if len(other) != len(v):
- if is_other_tuple:
- raise RuntimeError(
- f"v is a tuple of invalid length: should be {len(other)} but got {len(v)}."
- )
- else:
- raise RuntimeError("The given v should contain a single Tensor.")
- for idx, (el_v, el_other) in enumerate(zip(v, other)):
- if el_v.size() != el_other.size():
- prepend = ""
- if is_other_tuple:
- prepend = f"Entry {idx} in "
- raise RuntimeError(
- f"{prepend}v has invalid size: should be {el_other.size()} but got {el_v.size()}."
- )
- def _check_requires_grad(inputs, input_type, strict):
- # Used to make all the necessary checks to raise nice errors in strict mode.
- if not strict:
- return
- if input_type not in ["outputs", "grad_inputs", "jacobian", "hessian"]:
- raise RuntimeError("Invalid input_type to _check_requires_grad")
- for i, inp in enumerate(inputs):
- if inp is None:
- # This can only be reached for grad_inputs.
- raise RuntimeError(
- f"The output of the user-provided function is independent of input {i}."
- " This is not allowed in strict mode."
- )
- if not inp.requires_grad:
- if input_type == "hessian":
- raise RuntimeError(
- f"The hessian of the user-provided function with respect to input {i}"
- " is independent of the input. This is not allowed in strict mode."
- " You should ensure that your function is thrice differentiable and that"
- " the hessian depends on the inputs."
- )
- elif input_type == "jacobian":
- raise RuntimeError(
- "While computing the hessian, found that the jacobian of the user-provided"
- f" function with respect to input {i} is independent of the input. This is not"
- " allowed in strict mode. You should ensure that your function is twice"
- " differentiable and that the jacobian depends on the inputs (this would be"
- " violated by a linear function for example)."
- )
- elif input_type == "grad_inputs":
- raise RuntimeError(
- f"The gradient with respect to input {i} is independent of the inputs of the"
- " user-provided function. This is not allowed in strict mode."
- )
- else:
- raise RuntimeError(
- f"Output {i} of the user-provided function does not require gradients."
- " The outputs must be computed in a differentiable manner from the input"
- " when running in strict mode."
- )
- def _autograd_grad(
- outputs,
- inputs,
- grad_outputs=None,
- create_graph=False,
- retain_graph=None,
- is_grads_batched=False,
- ):
- # Version of autograd.grad that accepts `None` in outputs and do not compute gradients for them.
- # This has the extra constraint that inputs has to be a tuple
- if not isinstance(outputs, tuple):
- raise AssertionError("Expected outputs to be a tuple")
- if grad_outputs is None:
- grad_outputs = (None,) * len(outputs)
- if not isinstance(grad_outputs, tuple):
- raise AssertionError("Expected grad_outputs to be a tuple")
- if len(outputs) != len(grad_outputs):
- raise AssertionError(
- f"Expected outputs and grad_outputs to have the same length, "
- f"but got {len(outputs)} and {len(grad_outputs)}"
- )
- new_outputs: tuple[torch.Tensor, ...] = ()
- new_grad_outputs: tuple[torch.Tensor, ...] = ()
- for out, grad_out in zip(outputs, grad_outputs):
- if out is not None and out.requires_grad:
- new_outputs += (out,)
- # pyrefly: ignore [bad-assignment]
- new_grad_outputs += (grad_out,)
- if len(new_outputs) == 0:
- # No differentiable output, we don't need to call the autograd engine
- return (None,) * len(inputs)
- else:
- return torch.autograd.grad(
- new_outputs,
- inputs,
- new_grad_outputs,
- allow_unused=True,
- create_graph=create_graph,
- retain_graph=retain_graph,
- is_grads_batched=is_grads_batched,
- )
- def _fill_in_zeros(grads, refs, strict, create_graph, stage):
- # Used to detect None in the grads and depending on the flags, either replace them
- # with Tensors full of 0s of the appropriate size based on the refs or raise an error.
- # strict and create graph allow us to detect when it is appropriate to raise an error
- # stage gives us information of which backward call we consider to give good error message
- if stage not in ["back", "back_trick", "double_back", "double_back_trick"]:
- raise RuntimeError(f"Invalid stage argument '{stage}' to _fill_in_zeros")
- res: tuple[torch.Tensor, ...] = ()
- for i, grads_i in enumerate(grads):
- if grads_i is None:
- if strict:
- if stage == "back":
- raise RuntimeError(
- "The output of the user-provided function is independent of "
- f"input {i}. This is not allowed in strict mode."
- )
- elif stage == "back_trick":
- raise RuntimeError(
- f"The gradient with respect to the input is independent of entry {i}"
- " in the grad_outputs when using the double backward trick to compute"
- " forward mode gradients. This is not allowed in strict mode."
- )
- elif stage == "double_back":
- raise RuntimeError(
- "The jacobian of the user-provided function is independent of "
- f"input {i}. This is not allowed in strict mode."
- )
- else:
- raise RuntimeError(
- "The hessian of the user-provided function is independent of "
- f"entry {i} in the grad_jacobian. This is not allowed in strict "
- "mode as it prevents from using the double backward trick to "
- "replace forward mode AD."
- )
- grads_i = torch.zeros_like(refs[i])
- else:
- if strict and create_graph and not grads_i.requires_grad:
- if "double" not in stage:
- raise RuntimeError(
- "The jacobian of the user-provided function is independent of "
- f"input {i}. This is not allowed in strict mode when create_graph=True."
- )
- else:
- raise RuntimeError(
- "The hessian of the user-provided function is independent of "
- f"input {i}. This is not allowed in strict mode when create_graph=True."
- )
- res += (grads_i,)
- return res
- # Public API
- def vjp(func, inputs, v=None, create_graph=False, strict=False):
- r"""Compute the dot product between a vector ``v`` and the Jacobian of the given function at the point given by the inputs.
- Args:
- func (function): a Python function that takes Tensor inputs and returns
- a tuple of Tensors or a Tensor.
- inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
- v (tuple of Tensors or Tensor): The vector for which the vector
- Jacobian product is computed. Must be the same size as the output
- of ``func``. This argument is optional when the output of ``func``
- contains a single element and (if it is not provided) will be set
- as a Tensor containing a single ``1``.
- create_graph (bool, optional): If ``True``, both the output and result
- will be computed in a differentiable way. Note that when ``strict``
- is ``False``, the result can not require gradients or be
- disconnected from the inputs. Defaults to ``False``.
- strict (bool, optional): If ``True``, an error will be raised when we
- detect that there exists an input such that all the outputs are
- independent of it. If ``False``, we return a Tensor of zeros as the
- vjp for said inputs, which is the expected mathematical value.
- Defaults to ``False``.
- Returns:
- output (tuple): tuple with:
- func_output (tuple of Tensors or Tensor): output of ``func(inputs)``
- vjp (tuple of Tensors or Tensor): result of the dot product with
- the same shape as the inputs.
- Example:
- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
- >>> def exp_reducer(x):
- ... return x.exp().sum(dim=1)
- >>> inputs = torch.rand(4, 4)
- >>> v = torch.ones(4)
- >>> # xdoctest: +IGNORE_WANT("non-deterministic")
- >>> vjp(exp_reducer, inputs, v)
- (tensor([5.7817, 7.2458, 5.7830, 6.7782]),
- tensor([[1.4458, 1.3962, 1.3042, 1.6354],
- [2.1288, 1.0652, 1.5483, 2.5035],
- [2.2046, 1.1292, 1.1432, 1.3059],
- [1.3225, 1.6652, 1.7753, 2.0152]]))
- >>> vjp(exp_reducer, inputs, v, create_graph=True)
- (tensor([5.7817, 7.2458, 5.7830, 6.7782], grad_fn=<SumBackward1>),
- tensor([[1.4458, 1.3962, 1.3042, 1.6354],
- [2.1288, 1.0652, 1.5483, 2.5035],
- [2.2046, 1.1292, 1.1432, 1.3059],
- [1.3225, 1.6652, 1.7753, 2.0152]], grad_fn=<MulBackward0>))
- >>> def adder(x, y):
- ... return 2 * x + 3 * y
- >>> inputs = (torch.rand(2), torch.rand(2))
- >>> v = torch.ones(2)
- >>> vjp(adder, inputs, v)
- (tensor([2.4225, 2.3340]),
- (tensor([2., 2.]), tensor([3., 3.])))
- """
- with torch.enable_grad():
- is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "vjp")
- inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
- outputs = func(*inputs)
- is_outputs_tuple, outputs = _as_tuple(
- outputs, "outputs of the user-provided function", "vjp"
- )
- _check_requires_grad(outputs, "outputs", strict=strict)
- if v is not None:
- _, v = _as_tuple(v, "v", "vjp")
- v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
- _validate_v(v, outputs, is_outputs_tuple)
- else:
- if len(outputs) != 1 or outputs[0].nelement() != 1:
- raise RuntimeError(
- "The vector v can only be None if the "
- "user-provided function returns "
- "a single Tensor with a single element."
- )
- enable_grad = True if create_graph else torch.is_grad_enabled()
- with torch.set_grad_enabled(enable_grad):
- grad_res = _autograd_grad(outputs, inputs, v, create_graph=create_graph)
- vjp = _fill_in_zeros(grad_res, inputs, strict, create_graph, "back")
- # Cleanup objects and return them to the user
- outputs = _grad_postprocess(outputs, create_graph)
- vjp = _grad_postprocess(vjp, create_graph)
- return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(
- vjp, is_inputs_tuple
- )
- def jvp(func, inputs, v=None, create_graph=False, strict=False):
- r"""Compute the dot product between the Jacobian of the given function at the point given by the inputs and a vector ``v``.
- Args:
- func (function): a Python function that takes Tensor inputs and returns
- a tuple of Tensors or a Tensor.
- inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
- v (tuple of Tensors or Tensor): The vector for which the Jacobian
- vector product is computed. Must be the same size as the input of
- ``func``. This argument is optional when the input to ``func``
- contains a single element and (if it is not provided) will be set
- as a Tensor containing a single ``1``.
- create_graph (bool, optional): If ``True``, both the output and result
- will be computed in a differentiable way. Note that when ``strict``
- is ``False``, the result can not require gradients or be
- disconnected from the inputs. Defaults to ``False``.
- strict (bool, optional): If ``True``, an error will be raised when we
- detect that there exists an input such that all the outputs are
- independent of it. If ``False``, we return a Tensor of zeros as the
- jvp for said inputs, which is the expected mathematical value.
- Defaults to ``False``.
- Returns:
- output (tuple): tuple with:
- func_output (tuple of Tensors or Tensor): output of ``func(inputs)``
- jvp (tuple of Tensors or Tensor): result of the dot product with
- the same shape as the output.
- Note:
- ``autograd.functional.jvp`` computes the jvp by using the backward of
- the backward (sometimes called the double backwards trick). This is not
- the most performant way of computing the jvp. Please consider using
- :func:`torch.func.jvp` or the
- :ref:`low-level forward-mode AD API <forward-mode-ad>` instead.
- Example:
- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
- >>> def exp_reducer(x):
- ... return x.exp().sum(dim=1)
- >>> inputs = torch.rand(4, 4)
- >>> v = torch.ones(4, 4)
- >>> # xdoctest: +IGNORE_WANT("non-deterministic")
- >>> jvp(exp_reducer, inputs, v)
- (tensor([6.3090, 4.6742, 7.9114, 8.2106]),
- tensor([6.3090, 4.6742, 7.9114, 8.2106]))
- >>> jvp(exp_reducer, inputs, v, create_graph=True)
- (tensor([6.3090, 4.6742, 7.9114, 8.2106], grad_fn=<SumBackward1>),
- tensor([6.3090, 4.6742, 7.9114, 8.2106], grad_fn=<SqueezeBackward1>))
- >>> def adder(x, y):
- ... return 2 * x + 3 * y
- >>> inputs = (torch.rand(2), torch.rand(2))
- >>> v = (torch.ones(2), torch.ones(2))
- >>> jvp(adder, inputs, v)
- (tensor([2.2399, 2.5005]),
- tensor([5., 5.]))
- """
- with torch.enable_grad():
- is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jvp")
- inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
- if v is not None:
- _, v = _as_tuple(v, "v", "jvp")
- v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
- _validate_v(v, inputs, is_inputs_tuple)
- else:
- if len(inputs) != 1 or inputs[0].nelement() != 1:
- raise RuntimeError(
- "The vector v can only be None if the input to "
- "the user-provided function is a single Tensor "
- "with a single element."
- )
- outputs = func(*inputs)
- is_outputs_tuple, outputs = _as_tuple(
- outputs, "outputs of the user-provided function", "jvp"
- )
- _check_requires_grad(outputs, "outputs", strict=strict)
- # The backward is linear so the value of grad_outputs is not important as
- # it won't appear in the double backward graph. We only need to ensure that
- # it does not contain inf or nan.
- grad_outputs = tuple(
- torch.zeros_like(out, requires_grad=True) for out in outputs
- )
- grad_inputs = _autograd_grad(outputs, inputs, grad_outputs, create_graph=True)
- _check_requires_grad(grad_inputs, "grad_inputs", strict=strict)
- if create_graph:
- with torch.enable_grad():
- grad_res = _autograd_grad(
- grad_inputs, grad_outputs, v, create_graph=create_graph
- )
- jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick")
- else:
- grad_res = _autograd_grad(
- grad_inputs, grad_outputs, v, create_graph=create_graph
- )
- jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick")
- # Cleanup objects and return them to the user
- outputs = _grad_postprocess(outputs, create_graph)
- jvp = _grad_postprocess(jvp, create_graph)
- return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(
- jvp, is_outputs_tuple
- )
- def _construct_standard_basis_for(
- tensors: tuple[torch.Tensor, ...], tensor_numels: tuple[int, ...]
- ) -> tuple[torch.Tensor, ...]:
- # This function:
- # - constructs a N=sum(tensor_numels) standard basis. i.e. an NxN identity matrix.
- # - Splits the identity matrix into chunks with each chunk size determined by `tensor_numels`.
- # - Each chunk corresponds to one tensor. The chunk has the same dtype and
- # device as the tensor
- #
- # For example, with tensor_numels = [1, 2, 1], this function returns:
- # ( tensor([[1], tensor([[0, 0], tensor([[0],
- # [0], [1, 0], [0],
- # [0], [0, 1], [0],
- # [0]]) , [0, 0]]) , [1]]) )
- #
- # Precondition: tensor_numels == tuple(tensor.numel() for tensor in tensors)
- # Precondition: tensors always has at least one element.
- #
- # See NOTE: [Computing jacobian with vmap and grad for multiple tensors]
- # for context behind this function. All the pre-conditions are guarded for
- # in torch.autograd.functional.jacobian.
- if len(tensors) != len(tensor_numels):
- raise AssertionError(
- f"Expected tensors and tensor_numels to have the same length, "
- f"but got {len(tensors)} and {len(tensor_numels)}"
- )
- if len(tensors) == 0:
- raise AssertionError("Expected at least one tensor")
- total_numel = sum(tensor_numels)
- chunks = tuple(
- tensor.new_zeros(total_numel, tensor_numel)
- for tensor, tensor_numel in zip(tensors, tensor_numels)
- )
- diag_start_idx = 0
- for chunk, numel in zip(chunks, tensor_numels):
- chunk.diagonal(diag_start_idx).fill_(1)
- diag_start_idx -= numel
- return chunks
- def _jacfwd(func, inputs, strict=False, vectorize=False):
- if strict:
- raise RuntimeError(
- "torch.autograd.functional.jacobian: `strict=True` "
- 'and `strategy="forward-mode"` are not supported together (yet). '
- "Please either set `strict=False` or "
- '`strategy="reverse-mode"`.'
- )
- is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jacobian")
- output_info = []
- if vectorize:
- # See NOTE: [Computing jacobian with vmap and grad for multiple outputs]
- input_numels = tuple(input.numel() for input in inputs)
- # Step 1: Prepare tangents
- tangents = _construct_standard_basis_for(inputs, input_numels)
- # Step 2: Compute vmap over computation with dual tensors
- def jvp(tangents):
- with fwAD.dual_level():
- dual_inputs = tuple(
- fwAD.make_dual(input, tangent.view_as(input))
- for input, tangent in zip(inputs, tangents)
- )
- _is_outputs_tuple, dual_outputs = _as_tuple(
- func(*dual_inputs), "outputs"
- )
- output_info.append(_is_outputs_tuple)
- jv = []
- primal_outs = []
- for dual_out in dual_outputs:
- primal, tangent = fwAD.unpack_dual(dual_out)
- primal_outs.append(primal)
- if tangent is not None:
- jv.append(tangent)
- else:
- jv.append(torch.zeros_like(primal))
- output_info.append(primal_outs)
- return tuple(jv)
- outputs_before_split = _vmap(jvp)(tangents)
- is_outputs_tuple, outputs = output_info
- # Step 3: for each of the output tangents, split along dim 0
- jacobian_input_output = []
- for jac_output_i, output_i in zip(outputs_before_split, outputs):
- jacobian_output_i_output = []
- for jac, input_j in zip(jac_output_i.split(input_numels, dim=0), inputs):
- # We need to transpose the Jacobian because in forward AD, the
- # batch dimension represents that of the inputs
- jacobian_input_i_output_j = jac.permute(*range(1, jac.ndim), 0).reshape(
- (*output_i.shape, *input_j.shape)
- ) # noqa: C409
- jacobian_output_i_output.append(jacobian_input_i_output_j)
- jacobian_input_output.append(jacobian_output_i_output)
- # Omit [Step 4] because everything is already transposed w/ forward AD
- return _tuple_postprocess(
- jacobian_input_output, (is_outputs_tuple, is_inputs_tuple)
- )
- else:
- raise NotImplementedError(
- "Computing Jacobian using forward-AD or forward-over-reverse Hessian is"
- "only implemented for `vectorize=True`."
- )
- def jacobian(
- func,
- inputs,
- create_graph=False,
- strict=False,
- vectorize=False,
- strategy="reverse-mode",
- ):
- r"""Compute the Jacobian of a given function.
- Args:
- func (function): a Python function that takes Tensor inputs and returns
- a tuple of Tensors or a Tensor.
- inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
- create_graph (bool, optional): If ``True``, the Jacobian will be
- computed in a differentiable manner. Note that when ``strict`` is
- ``False``, the result can not require gradients or be disconnected
- from the inputs. Defaults to ``False``.
- strict (bool, optional): If ``True``, an error will be raised when we
- detect that there exists an input such that all the outputs are
- independent of it. If ``False``, we return a Tensor of zeros as the
- jacobian for said inputs, which is the expected mathematical value.
- Defaults to ``False``.
- vectorize (bool, optional): This feature is experimental.
- Please consider using :func:`torch.func.jacrev` or
- :func:`torch.func.jacfwd` instead if you are looking for something
- less experimental and more performant.
- When computing the jacobian, usually we invoke
- ``autograd.grad`` once per row of the jacobian. If this flag is
- ``True``, we perform only a single ``autograd.grad`` call with
- ``batched_grad=True`` which uses the vmap prototype feature.
- Though this should lead to performance improvements in many cases,
- because this feature is still experimental, there may be performance
- cliffs. See :func:`torch.autograd.grad`'s ``batched_grad`` parameter for
- more information.
- strategy (str, optional): Set to ``"forward-mode"`` or ``"reverse-mode"`` to
- determine whether the Jacobian will be computed with forward or reverse
- mode AD. Currently, ``"forward-mode"`` requires ``vectorized=True``.
- Defaults to ``"reverse-mode"``. If ``func`` has more outputs than
- inputs, ``"forward-mode"`` tends to be more performant. Otherwise,
- prefer to use ``"reverse-mode"``.
- Returns:
- Jacobian (Tensor or nested tuple of Tensors): if there is a single
- input and output, this will be a single Tensor containing the
- Jacobian for the linearized inputs and output. If one of the two is
- a tuple, then the Jacobian will be a tuple of Tensors. If both of
- them are tuples, then the Jacobian will be a tuple of tuple of
- Tensors where ``Jacobian[i][j]`` will contain the Jacobian of the
- ``i``\th output and ``j``\th input and will have as size the
- concatenation of the sizes of the corresponding output and the
- corresponding input and will have same dtype and device as the
- corresponding input. If strategy is ``forward-mode``, the dtype will be
- that of the output; otherwise, the input.
- Example:
- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
- >>> def exp_reducer(x):
- ... return x.exp().sum(dim=1)
- >>> inputs = torch.rand(2, 2)
- >>> # xdoctest: +IGNORE_WANT("non-deterministic")
- >>> jacobian(exp_reducer, inputs)
- tensor([[[1.4917, 2.4352],
- [0.0000, 0.0000]],
- [[0.0000, 0.0000],
- [2.4369, 2.3799]]])
- >>> jacobian(exp_reducer, inputs, create_graph=True)
- tensor([[[1.4917, 2.4352],
- [0.0000, 0.0000]],
- [[0.0000, 0.0000],
- [2.4369, 2.3799]]], grad_fn=<ViewBackward>)
- >>> def exp_adder(x, y):
- ... return 2 * x.exp() + 3 * y
- >>> inputs = (torch.rand(2), torch.rand(2))
- >>> jacobian(exp_adder, inputs)
- (tensor([[2.8052, 0.0000],
- [0.0000, 3.3963]]),
- tensor([[3., 0.],
- [0., 3.]]))
- >>> def linear_model(x):
- ... W = torch.tensor([[2.0, -1.0], [0.0, 1.0]])
- ... b = torch.tensor([1.0, 0.5])
- ... return x @ W.T + b
- >>> x = torch.randn(4, 2, requires_grad=True)
- >>> jac = jacobian(linear_model, x, vectorize=True)
- >>> jac.shape
- torch.Size([4, 2, 4, 2])
- """
- if strategy not in ("forward-mode", "reverse-mode"):
- raise AssertionError(
- 'Expected strategy to be either "forward-mode" or "reverse-mode". Hint: If your '
- 'function has more outputs than inputs, "forward-mode" tends to be more performant. '
- 'Otherwise, prefer to use "reverse-mode".'
- )
- if strategy == "forward-mode":
- if create_graph:
- raise NotImplementedError(
- "torch.autograd.functional.jacobian: `create_graph=True` "
- 'and `strategy="forward-mode"` are not supported together (yet). '
- "Please either set `create_graph=False` or "
- '`strategy="reverse-mode"`.'
- )
- return _jacfwd(func, inputs, strict, vectorize)
- with torch.enable_grad():
- is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jacobian")
- inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
- outputs = func(*inputs)
- is_outputs_tuple, outputs = _as_tuple(
- outputs, "outputs of the user-provided function", "jacobian"
- )
- _check_requires_grad(outputs, "outputs", strict=strict)
- if vectorize:
- if strict:
- raise RuntimeError(
- "torch.autograd.functional.jacobian: `strict=True` "
- "and `vectorized=True` are not supported together. "
- "Please either set `strict=False` or "
- "`vectorize=False`."
- )
- # NOTE: [Computing jacobian with vmap and grad for multiple outputs]
- #
- # Let's consider f(x) = (x**2, x.sum()) and let x = torch.randn(3).
- # It turns out we can compute the jacobian of this function with a single
- # call to autograd.grad by using vmap over the correct grad_outputs.
- #
- # Firstly, one way to compute the jacobian is to stack x**2 and x.sum()
- # into a 4D vector. E.g., use g(x) = torch.stack([x**2, x.sum()])
- #
- # To get the first row of the jacobian, we call
- # >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([1, 0, 0, 0]))
- # To get the 2nd row of the jacobian, we call
- # >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([0, 1, 0, 0]))
- # and so on.
- #
- # Using vmap, we can vectorize all 4 of these computations into one by
- # passing the standard basis for R^4 as the grad_output.
- # vmap(partial(autograd.grad, g(x), x))(torch.eye(4)).
- #
- # Now, how do we compute the jacobian *without stacking the output*?
- # We can just split the standard basis across the outputs. So to
- # compute the jacobian of f(x), we'd use
- # >>> autograd.grad(f(x), x, grad_outputs=_construct_standard_basis_for(...))
- # The grad_outputs looks like the following:
- # ( torch.tensor([[1, 0, 0],
- # [0, 1, 0],
- # [0, 0, 1],
- # [0, 0, 0]]),
- # torch.tensor([[0],
- # [0],
- # [0],
- # [1]]) )
- #
- # But we're not done yet!
- # >>> vmap(partial(autograd.grad(f(x), x, grad_outputs=...)))
- # returns a Tensor of shape [4, 3]. We have to remember to split the
- # jacobian of shape [4, 3] into two:
- # - one of shape [3, 3] for the first output
- # - one of shape [ 3] for the second output
- # Step 1: Construct grad_outputs by splitting the standard basis
- output_numels = tuple(output.numel() for output in outputs)
- grad_outputs = _construct_standard_basis_for(outputs, output_numels)
- flat_outputs = tuple(output.reshape(-1) for output in outputs)
- # Step 2: Call vmap + autograd.grad
- def vjp(grad_output):
- vj = list(
- _autograd_grad(
- flat_outputs,
- inputs,
- grad_output,
- create_graph=create_graph,
- is_grads_batched=True,
- )
- )
- for el_idx, vj_el in enumerate(vj):
- if vj_el is not None:
- continue
- vj[el_idx] = torch.zeros_like(inputs[el_idx]).expand(
- (sum(output_numels),) + inputs[el_idx].shape
- )
- return tuple(vj)
- jacobians_of_flat_output = vjp(grad_outputs)
- # Step 3: The returned jacobian is one big tensor per input. In this step,
- # we split each Tensor by output.
- jacobian_input_output = []
- for jac_input_i, input_i in zip(jacobians_of_flat_output, inputs):
- jacobian_input_i_output = []
- for jac, output_j in zip(
- jac_input_i.split(output_numels, dim=0), outputs
- ):
- jacobian_input_i_output_j = jac.view(output_j.shape + input_i.shape)
- jacobian_input_i_output.append(jacobian_input_i_output_j)
- jacobian_input_output.append(jacobian_input_i_output)
- # Step 4: Right now, `jacobian` is a List[List[Tensor]].
- # The outer List corresponds to the number of inputs,
- # the inner List corresponds to the number of outputs.
- # We need to exchange the order of these and convert to tuples
- # before returning.
- jacobian_output_input = tuple(zip(*jacobian_input_output))
- jacobian_output_input = _grad_postprocess(
- jacobian_output_input, create_graph
- )
- return _tuple_postprocess(
- jacobian_output_input, (is_outputs_tuple, is_inputs_tuple)
- )
- jacobian: tuple[torch.Tensor, ...] = ()
- for i, out in enumerate(outputs):
- # mypy complains that expression and variable have different types due to the empty list
- jac_i: tuple[list[torch.Tensor]] = tuple([] for _ in range(len(inputs))) # type: ignore[assignment]
- for j in range(out.nelement()):
- vj = _autograd_grad(
- (out.reshape(-1)[j],),
- inputs,
- retain_graph=True,
- create_graph=create_graph,
- )
- for el_idx, (jac_i_el, vj_el, inp_el) in enumerate(
- zip(jac_i, vj, inputs)
- ):
- if vj_el is not None:
- if strict and create_graph and not vj_el.requires_grad:
- msg = (
- "The jacobian of the user-provided function is "
- f"independent of input {i}. This is not allowed in "
- "strict mode when create_graph=True."
- )
- raise RuntimeError(msg)
- jac_i_el.append(vj_el)
- else:
- if strict:
- msg = (
- f"Output {i} of the user-provided function is "
- f"independent of input {el_idx}. This is not allowed in "
- "strict mode."
- )
- raise RuntimeError(msg)
- jac_i_el.append(torch.zeros_like(inp_el))
- # pyrefly: ignore [bad-assignment]
- jacobian += (
- tuple(
- torch.stack(jac_i_el, dim=0).view(
- out.size() + inputs[el_idx].size() # type: ignore[operator]
- )
- for (el_idx, jac_i_el) in enumerate(jac_i)
- ),
- )
- jacobian = _grad_postprocess(jacobian, create_graph)
- return _tuple_postprocess(jacobian, (is_outputs_tuple, is_inputs_tuple))
- def hessian(
- func,
- inputs,
- create_graph=False,
- strict=False,
- vectorize=False,
- outer_jacobian_strategy="reverse-mode",
- ):
- r"""Compute the Hessian of a given scalar function.
- Args:
- func (function): a Python function that takes Tensor inputs and returns
- a Tensor with a single element.
- inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
- create_graph (bool, optional): If ``True``, the Hessian will be computed in
- a differentiable manner. Note that when ``strict`` is ``False``, the result can not
- require gradients or be disconnected from the inputs.
- Defaults to ``False``.
- strict (bool, optional): If ``True``, an error will be raised when we detect that there exists an input
- such that all the outputs are independent of it. If ``False``, we return a Tensor of zeros as the
- hessian for said inputs, which is the expected mathematical value.
- Defaults to ``False``.
- vectorize (bool, optional): This feature is experimental.
- Please consider using :func:`torch.func.hessian`
- instead if you are looking for something less experimental and more performant.
- When computing the hessian, usually we invoke
- ``autograd.grad`` once per row of the hessian. If this flag is
- ``True``, we use the vmap prototype feature as the backend to
- vectorize calls to ``autograd.grad`` so we only invoke it once
- instead of once per row. This should lead to performance
- improvements in many use cases, however, due to this feature
- being incomplete, there may be performance cliffs. Please
- use `torch._C._debug_only_display_vmap_fallback_warnings(True)`
- to show any performance warnings and file us issues if
- warnings exist for your use case. Defaults to ``False``.
- outer_jacobian_strategy (str, optional): The Hessian is computed by
- computing the Jacobian of a Jacobian. The inner Jacobian is always
- computed in reverse-mode AD. Setting strategy to ``"forward-mode"``
- or ``"reverse-mode"`` determines whether the outer Jacobian will be
- computed with forward or reverse mode AD. Currently, computing the outer
- Jacobian in ``"forward-mode"`` requires ``vectorized=True``. Defaults
- to ``"reverse-mode"``.
- Returns:
- Hessian (Tensor or a tuple of tuple of Tensors): if there is a single input,
- this will be a single Tensor containing the Hessian for the input.
- If it is a tuple, then the Hessian will be a tuple of tuples where
- ``Hessian[i][j]`` will contain the Hessian of the ``i``\th input
- and ``j``\th input with size the sum of the size of the ``i``\th input plus
- the size of the ``j``\th input. ``Hessian[i][j]`` will have the same
- dtype and device as the corresponding ``i``\th input.
- Example:
- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
- >>> def pow_reducer(x):
- ... return x.pow(3).sum()
- >>> inputs = torch.rand(2, 2)
- >>> # xdoctest: +IGNORE_WANT("non-deterministic")
- >>> hessian(pow_reducer, inputs)
- tensor([[[[5.2265, 0.0000],
- [0.0000, 0.0000]],
- [[0.0000, 4.8221],
- [0.0000, 0.0000]]],
- [[[0.0000, 0.0000],
- [1.9456, 0.0000]],
- [[0.0000, 0.0000],
- [0.0000, 3.2550]]]])
- >>> hessian(pow_reducer, inputs, create_graph=True)
- tensor([[[[5.2265, 0.0000],
- [0.0000, 0.0000]],
- [[0.0000, 4.8221],
- [0.0000, 0.0000]]],
- [[[0.0000, 0.0000],
- [1.9456, 0.0000]],
- [[0.0000, 0.0000],
- [0.0000, 3.2550]]]], grad_fn=<ViewBackward>)
- >>> def pow_adder_reducer(x, y):
- ... return (2 * x.pow(2) + 3 * y.pow(2)).sum()
- >>> inputs = (torch.rand(2), torch.rand(2))
- >>> hessian(pow_adder_reducer, inputs)
- ((tensor([[4., 0.],
- [0., 4.]]),
- tensor([[0., 0.],
- [0., 0.]])),
- (tensor([[0., 0.],
- [0., 0.]]),
- tensor([[6., 0.],
- [0., 6.]])))
- """
- is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "hessian")
- if outer_jacobian_strategy not in (
- "forward-mode",
- "reverse-mode",
- ):
- raise AssertionError(
- 'Expected strategy to be either "forward-mode" or "reverse-mode".'
- )
- def ensure_single_output_function(*inp):
- out = func(*inp)
- is_out_tuple, t_out = _as_tuple(
- out, "outputs of the user-provided function", "hessian"
- )
- _check_requires_grad(t_out, "outputs", strict=strict)
- if is_out_tuple or not isinstance(out, torch.Tensor):
- raise RuntimeError(
- "The function given to hessian should return a single Tensor"
- )
- if out.nelement() != 1:
- raise RuntimeError(
- "The Tensor returned by the function given to hessian should contain a single element"
- )
- return out.squeeze()
- def jac_func(*inp):
- if outer_jacobian_strategy == "forward-mode":
- # _grad_preprocess requires create_graph=True and input to require_grad
- # or else the input will be detached
- inp = tuple(t.requires_grad_(True) for t in inp)
- jac = jacobian(ensure_single_output_function, inp, create_graph=True)
- _check_requires_grad(jac, "jacobian", strict=strict)
- return jac
- res = jacobian(
- jac_func,
- inputs,
- create_graph=create_graph,
- strict=strict,
- vectorize=vectorize,
- strategy=outer_jacobian_strategy,
- )
- return _tuple_postprocess(res, (is_inputs_tuple, is_inputs_tuple))
- def vhp(func, inputs, v=None, create_graph=False, strict=False):
- r"""Compute the dot product between vector ``v`` and Hessian of a given scalar function at a specified point.
- Args:
- func (function): a Python function that takes Tensor inputs and returns
- a Tensor with a single element.
- inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
- v (tuple of Tensors or Tensor): The vector for which the vector Hessian
- product is computed. Must be the same size as the input of
- ``func``. This argument is optional when ``func``'s input contains
- a single element and (if it is not provided) will be set as a
- Tensor containing a single ``1``.
- create_graph (bool, optional): If ``True``, both the output and result
- will be computed in a differentiable way. Note that when ``strict``
- is ``False``, the result can not require gradients or be
- disconnected from the inputs.
- Defaults to ``False``.
- strict (bool, optional): If ``True``, an error will be raised when we
- detect that there exists an input such that all the outputs are
- independent of it. If ``False``, we return a Tensor of zeros as the
- vhp for said inputs, which is the expected mathematical value.
- Defaults to ``False``.
- Returns:
- output (tuple): tuple with:
- func_output (tuple of Tensors or Tensor): output of ``func(inputs)``
- vhp (tuple of Tensors or Tensor): result of the dot product with the
- same shape as the inputs.
- Example:
- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
- >>> def pow_reducer(x):
- ... return x.pow(3).sum()
- >>> inputs = torch.rand(2, 2)
- >>> v = torch.ones(2, 2)
- >>> # xdoctest: +IGNORE_WANT("non-deterministic")
- >>> vhp(pow_reducer, inputs, v)
- (tensor(0.5591),
- tensor([[1.0689, 1.2431],
- [3.0989, 4.4456]]))
- >>> vhp(pow_reducer, inputs, v, create_graph=True)
- (tensor(0.5591, grad_fn=<SumBackward0>),
- tensor([[1.0689, 1.2431],
- [3.0989, 4.4456]], grad_fn=<MulBackward0>))
- >>> def pow_adder_reducer(x, y):
- ... return (2 * x.pow(2) + 3 * y.pow(2)).sum()
- >>> inputs = (torch.rand(2), torch.rand(2))
- >>> v = (torch.zeros(2), torch.ones(2))
- >>> vhp(pow_adder_reducer, inputs, v)
- (tensor(4.8053),
- (tensor([0., 0.]),
- tensor([6., 6.])))
- """
- with torch.enable_grad():
- is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "vhp")
- inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
- if v is not None:
- _, v = _as_tuple(v, "v", "vhp")
- v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
- _validate_v(v, inputs, is_inputs_tuple)
- else:
- if len(inputs) != 1 or inputs[0].nelement() != 1:
- raise RuntimeError(
- "The vector v can only be None if the input to the user-provided function "
- "is a single Tensor with a single element."
- )
- outputs = func(*inputs)
- is_outputs_tuple, outputs = _as_tuple(
- outputs, "outputs of the user-provided function", "vhp"
- )
- _check_requires_grad(outputs, "outputs", strict=strict)
- if is_outputs_tuple or not isinstance(outputs[0], torch.Tensor):
- raise RuntimeError(
- "The function given to vhp should return a single Tensor"
- )
- if outputs[0].nelement() != 1:
- raise RuntimeError(
- "The Tensor returned by the function given to vhp should contain a single element"
- )
- jac = _autograd_grad(outputs, inputs, create_graph=True)
- _check_requires_grad(jac, "jacobian", strict=strict)
- enable_grad = True if create_graph else torch.is_grad_enabled()
- with torch.set_grad_enabled(enable_grad):
- grad_res = _autograd_grad(jac, inputs, v, create_graph=create_graph)
- vhp = _fill_in_zeros(grad_res, inputs, strict, create_graph, "double_back")
- outputs = _grad_postprocess(outputs, create_graph)
- vhp = _grad_postprocess(vhp, create_graph)
- return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(
- vhp, is_inputs_tuple
- )
- def hvp(func, inputs, v=None, create_graph=False, strict=False):
- r"""Compute the dot product between the scalar function's Hessian and a vector ``v`` at a specified point.
- Args:
- func (function): a Python function that takes Tensor inputs and returns
- a Tensor with a single element.
- inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
- v (tuple of Tensors or Tensor): The vector for which the Hessian vector
- product is computed. Must be the same size as the input of
- ``func``. This argument is optional when ``func``'s input contains
- a single element and (if it is not provided) will be set as a
- Tensor containing a single ``1``.
- create_graph (bool, optional): If ``True``, both the output and result will be
- computed in a differentiable way. Note that when ``strict`` is
- ``False``, the result can not require gradients or be disconnected
- from the inputs. Defaults to ``False``.
- strict (bool, optional): If ``True``, an error will be raised when we
- detect that there exists an input such that all the outputs are
- independent of it. If ``False``, we return a Tensor of zeros as the
- hvp for said inputs, which is the expected mathematical value.
- Defaults to ``False``.
- Returns:
- output (tuple): tuple with:
- func_output (tuple of Tensors or Tensor): output of ``func(inputs)``
- hvp (tuple of Tensors or Tensor): result of the dot product with
- the same shape as the inputs.
- Example:
- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
- >>> def pow_reducer(x):
- ... return x.pow(3).sum()
- >>> inputs = torch.rand(2, 2)
- >>> v = torch.ones(2, 2)
- >>> # xdoctest: +IGNORE_WANT("non-deterministic")
- >>> hvp(pow_reducer, inputs, v)
- (tensor(0.1448),
- tensor([[2.0239, 1.6456],
- [2.4988, 1.4310]]))
- >>> hvp(pow_reducer, inputs, v, create_graph=True)
- (tensor(0.1448, grad_fn=<SumBackward0>),
- tensor([[2.0239, 1.6456],
- [2.4988, 1.4310]], grad_fn=<MulBackward0>))
- >>> def pow_adder_reducer(x, y):
- ... return (2 * x.pow(2) + 3 * y.pow(2)).sum()
- >>> inputs = (torch.rand(2), torch.rand(2))
- >>> v = (torch.zeros(2), torch.ones(2))
- >>> hvp(pow_adder_reducer, inputs, v)
- (tensor(2.3030),
- (tensor([0., 0.]),
- tensor([6., 6.])))
- Note:
- This function is significantly slower than `vhp` due to backward mode AD constraints.
- If your functions is twice continuously differentiable, then hvp = vhp.t(). So if you
- know that your function satisfies this condition, you should use vhp instead that is
- much faster with the current implementation.
- """
- with torch.enable_grad():
- is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "hvp")
- inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
- if v is not None:
- _, v = _as_tuple(v, "v", "hvp")
- v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
- _validate_v(v, inputs, is_inputs_tuple)
- else:
- if len(inputs) != 1 or inputs[0].nelement() != 1:
- raise RuntimeError(
- "The vector v can only be None if the input to the user-provided function "
- "is a single Tensor with a single element."
- )
- outputs = func(*inputs)
- is_outputs_tuple, outputs = _as_tuple(
- outputs, "outputs of the user-provided function", "hvp"
- )
- _check_requires_grad(outputs, "outputs", strict=strict)
- if is_outputs_tuple or not isinstance(outputs[0], torch.Tensor):
- raise RuntimeError(
- "The function given to hvp should return a single Tensor"
- )
- if outputs[0].nelement() != 1:
- raise RuntimeError(
- "The Tensor returned by the function given to hvp should contain a single element"
- )
- jac = _autograd_grad(outputs, inputs, create_graph=True)
- _check_requires_grad(jac, "jacobian", strict=strict)
- grad_jac = tuple(torch.zeros_like(inp, requires_grad=True) for inp in inputs)
- double_back = _autograd_grad(jac, inputs, grad_jac, create_graph=True)
- _check_requires_grad(jac, "hessian", strict=strict)
- enable_grad = True if create_graph else torch.is_grad_enabled()
- with torch.set_grad_enabled(enable_grad):
- grad_res = _autograd_grad(double_back, grad_jac, v, create_graph=create_graph)
- hvp = _fill_in_zeros(
- grad_res, inputs, strict, create_graph, "double_back_trick"
- )
- outputs = _grad_postprocess(outputs, create_graph)
- hvp = _grad_postprocess(hvp, create_graph)
- return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(
- hvp, is_inputs_tuple
- )
|