| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222 |
- # mypy: allow-untyped-defs
- import itertools
- import operator
- from collections.abc import Sequence
- from typing import Any, TYPE_CHECKING
- import torch
- import torch.nn.functional as F
- from torch import _VF, Tensor
- from torch._C import _add_docstr
- from torch._jit_internal import _overload as overload, boolean_dispatch
- from torch._lowrank import pca_lowrank, svd_lowrank
- from torch.overrides import (
- handle_torch_function,
- has_torch_function,
- has_torch_function_unary,
- has_torch_function_variadic,
- )
- __all__ = [
- "atleast_1d",
- "atleast_2d",
- "atleast_3d",
- "align_tensors",
- "broadcast_shapes",
- "broadcast_tensors",
- "cartesian_prod",
- "block_diag",
- "cdist",
- "chain_matmul",
- "einsum",
- "istft",
- "lu",
- "norm",
- "meshgrid",
- "pca_lowrank",
- "split",
- "stft",
- "svd_lowrank",
- "tensordot",
- "unique",
- "unique_consecutive",
- "unravel_index",
- ]
- def broadcast_tensors(*tensors):
- r"""broadcast_tensors(*tensors) -> List of Tensors
- Broadcasts the given tensors according to :ref:`broadcasting-semantics`.
- Args:
- *tensors: any number of tensors of the same type
- .. warning::
- More than one element of a broadcasted tensor may refer to a single
- memory location. As a result, in-place operations (especially ones that
- are vectorized) may result in incorrect behavior. If you need to write
- to the tensors, please clone them first.
- Example::
- >>> x = torch.arange(3).view(1, 3)
- >>> y = torch.arange(2).view(2, 1)
- >>> a, b = torch.broadcast_tensors(x, y)
- >>> a.size()
- torch.Size([2, 3])
- >>> a
- tensor([[0, 1, 2],
- [0, 1, 2]])
- """
- # This wrapper exists to support variadic args.
- if has_torch_function(tensors):
- return handle_torch_function(broadcast_tensors, tensors, *tensors)
- return _VF.broadcast_tensors(tensors) # type: ignore[attr-defined]
- def broadcast_shapes(*shapes):
- r"""broadcast_shapes(*shapes) -> Size
- Similar to :func:`broadcast_tensors` but for shapes.
- This is equivalent to
- ``torch.broadcast_tensors(*map(torch.empty, shapes))[0].shape``
- but avoids the need create to intermediate tensors. This is useful for
- broadcasting tensors of common batch shape but different rightmost shape,
- e.g. to broadcast mean vectors with covariance matrices.
- Example::
- >>> torch.broadcast_shapes((2,), (3, 1), (1, 1, 1))
- torch.Size([1, 3, 2])
- Args:
- \*shapes (torch.Size): Shapes of tensors.
- Returns:
- shape (torch.Size): A shape compatible with all input shapes.
- Raises:
- RuntimeError: If shapes are incompatible.
- """
- # This wrapper exists to support variadic args.
- # TODO Move this to C++ once the jit has better support for torch.Size.
- if not torch.jit.is_tracing():
- result = torch._refs._broadcast_shapes(*shapes)
- if result is None:
- return torch.Size([])
- return torch.Size(result)
- else:
- # with implementation above, torch.jit.trace hardcodes the sizes which makes subsequent replays fail
- with torch.no_grad():
- scalar = torch.zeros((), device="cpu")
- tensors = [scalar.expand(shape) for shape in shapes]
- tensors = broadcast_tensors(*tensors)
- return tensors[0].shape
- def split(
- tensor: Tensor,
- split_size_or_sections: int | list[int],
- dim: int = 0,
- ) -> tuple[Tensor, ...]:
- r"""Splits the tensor into chunks. Each chunk is a view of the original tensor.
- If :attr:`split_size_or_sections` is an integer type, then :attr:`tensor` will
- be split into equally sized chunks (if possible). Last chunk will be smaller if
- the tensor size along the given dimension :attr:`dim` is not divisible by
- :attr:`split_size`.
- If :attr:`split_size_or_sections` is a list, then :attr:`tensor` will be split
- into ``len(split_size_or_sections)`` chunks with sizes in :attr:`dim` according
- to :attr:`split_size_or_sections`.
- Args:
- tensor (Tensor): tensor to split.
- split_size_or_sections (int) or (list(int)): size of a single chunk or
- list of sizes for each chunk
- dim (int): dimension along which to split the tensor.
- Example::
- >>> a = torch.arange(10).reshape(5, 2)
- >>> a
- tensor([[0, 1],
- [2, 3],
- [4, 5],
- [6, 7],
- [8, 9]])
- >>> torch.split(a, 2)
- (tensor([[0, 1],
- [2, 3]]),
- tensor([[4, 5],
- [6, 7]]),
- tensor([[8, 9]]))
- >>> torch.split(a, [1, 4])
- (tensor([[0, 1]]),
- tensor([[2, 3],
- [4, 5],
- [6, 7],
- [8, 9]]))
- """
- if has_torch_function_unary(tensor):
- return handle_torch_function(
- split, (tensor,), tensor, split_size_or_sections, dim=dim
- )
- # Overwriting reason:
- # This dispatches to two ATen functions depending on the type of
- # split_size_or_sections. The branching code is in _tensor.py, which we
- # call here.
- return tensor.split(split_size_or_sections, dim)
- def einsum(*args: Any) -> Tensor:
- r"""einsum(equation, *operands) -> Tensor
- Sums the product of the elements of the input :attr:`operands` along dimensions specified using a notation
- based on the Einstein summation convention.
- Einsum allows computing many common multi-dimensional linear algebraic array operations by representing them
- in a short-hand format based on the Einstein summation convention, given by :attr:`equation`. The details of
- this format are described below, but the general idea is to label every dimension of the input :attr:`operands`
- with some subscript and define which subscripts are part of the output. The output is then computed by summing
- the product of the elements of the :attr:`operands` along the dimensions whose subscripts are not part of the
- output. For example, matrix multiplication can be computed using einsum as `torch.einsum("ij,jk->ik", A, B)`.
- Here, j is the summation subscript and i and k the output subscripts (see section below for more details on why).
- Equation:
- The :attr:`equation` string specifies the subscripts (letters in `[a-zA-Z]`) for each dimension of
- the input :attr:`operands` in the same order as the dimensions, separating subscripts for each operand by a
- comma (','), e.g. `'ij,jk'` specify subscripts for two 2D operands. The dimensions labeled with the same subscript
- must be broadcastable, that is, their size must either match or be `1`. The exception is if a subscript is
- repeated for the same input operand, in which case the dimensions labeled with this subscript for this operand
- must match in size and the operand will be replaced by its diagonal along these dimensions. The subscripts that
- appear exactly once in the :attr:`equation` will be part of the output, sorted in increasing alphabetical order.
- The output is computed by multiplying the input :attr:`operands` element-wise, with their dimensions aligned based
- on the subscripts, and then summing out the dimensions whose subscripts are not part of the output.
- Optionally, the output subscripts can be explicitly defined by adding an arrow ('->') at the end of the equation
- followed by the subscripts for the output. For instance, the following equation computes the transpose of a
- matrix multiplication: 'ij,jk->ki'. The output subscripts must appear at least once for some input operand and
- at most once for the output.
- Ellipsis ('...') can be used in place of subscripts to broadcast the dimensions covered by the ellipsis.
- Each input operand may contain at most one ellipsis which will cover the dimensions not covered by subscripts,
- e.g. for an input operand with 5 dimensions, the ellipsis in the equation `'ab...c'` cover the third and fourth
- dimensions. The ellipsis does not need to cover the same number of dimensions across the :attr:`operands` but the
- 'shape' of the ellipsis (the size of the dimensions covered by them) must broadcast together. If the output is not
- explicitly defined with the arrow ('->') notation, the ellipsis will come first in the output (left-most dimensions),
- before the subscript labels that appear exactly once for the input operands. e.g. the following equation implements
- batch matrix multiplication `'...ij,...jk'`.
- A few final notes: the equation may contain whitespaces between the different elements (subscripts, ellipsis,
- arrow and comma) but something like `'. . .'` is not valid. An empty string `''` is valid for scalar operands.
- .. note::
- ``torch.einsum`` handles ellipsis ('...') differently from NumPy in that it allows dimensions
- covered by the ellipsis to be summed over, that is, ellipsis are not required to be part of the output.
- .. note::
- Please install opt-einsum (https://optimized-einsum.readthedocs.io/en/stable/) in order to enroll into a more
- performant einsum. You can install when installing torch like so: `pip install torch[opt-einsum]` or by itself
- with `pip install opt-einsum`.
- If opt-einsum is available, this function will automatically speed up computation and/or consume less memory
- by optimizing contraction order through our opt_einsum backend :mod:`torch.backends.opt_einsum` (The _ vs - is
- confusing, I know). This optimization occurs when there are at least three inputs, since the order does not matter
- otherwise. Note that finding `the` optimal path is an NP-hard problem, thus, opt-einsum relies on different
- heuristics to achieve near-optimal results. If opt-einsum is not available, the default order is to contract
- from left to right.
- To bypass this default behavior, add the following to disable opt_einsum and skip path calculation:
- ``torch.backends.opt_einsum.enabled = False``
- To specify which strategy you'd like for opt_einsum to compute the contraction path, add the following line:
- ``torch.backends.opt_einsum.strategy = 'auto'``. The default strategy is 'auto', and we also support 'greedy' and
- 'optimal'. Disclaimer that the runtime of 'optimal' is factorial in the number of inputs! See more details in
- the opt_einsum documentation (https://optimized-einsum.readthedocs.io/en/stable/path_finding.html).
- .. note::
- As of PyTorch 1.10 :func:`torch.einsum` also supports the sublist format (see examples below). In this format,
- subscripts for each operand are specified by sublists, list of integers in the range [0, 52). These sublists
- follow their operands, and an extra sublist can appear at the end of the input to specify the output's
- subscripts., e.g. `torch.einsum(op1, sublist1, op2, sublist2, ..., [subslist_out])`. Python's `Ellipsis` object
- may be provided in a sublist to enable broadcasting as described in the Equation section above.
- Args:
- equation (str): The subscripts for the Einstein summation.
- operands (List[Tensor]): The tensors to compute the Einstein summation of.
- Examples::
- >>> # xdoctest: +IGNORE_WANT("non-deterministic")
- >>> # trace
- >>> torch.einsum('ii', torch.randn(4, 4))
- tensor(-1.2104)
- >>> # xdoctest: +IGNORE_WANT("non-deterministic")
- >>> # diagonal
- >>> torch.einsum('ii->i', torch.randn(4, 4))
- tensor([-0.1034, 0.7952, -0.2433, 0.4545])
- >>> # xdoctest: +IGNORE_WANT("non-deterministic")
- >>> # outer product
- >>> x = torch.randn(5)
- >>> y = torch.randn(4)
- >>> torch.einsum('i,j->ij', x, y)
- tensor([[ 0.1156, -0.2897, -0.3918, 0.4963],
- [-0.3744, 0.9381, 1.2685, -1.6070],
- [ 0.7208, -1.8058, -2.4419, 3.0936],
- [ 0.1713, -0.4291, -0.5802, 0.7350],
- [ 0.5704, -1.4290, -1.9323, 2.4480]])
- >>> # xdoctest: +IGNORE_WANT("non-deterministic")
- >>> # batch matrix multiplication
- >>> As = torch.randn(3, 2, 5)
- >>> Bs = torch.randn(3, 5, 4)
- >>> torch.einsum('bij,bjk->bik', As, Bs)
- tensor([[[-1.0564, -1.5904, 3.2023, 3.1271],
- [-1.6706, -0.8097, -0.8025, -2.1183]],
- [[ 4.2239, 0.3107, -0.5756, -0.2354],
- [-1.4558, -0.3460, 1.5087, -0.8530]],
- [[ 2.8153, 1.8787, -4.3839, -1.2112],
- [ 0.3728, -2.1131, 0.0921, 0.8305]]])
- >>> # xdoctest: +IGNORE_WANT("non-deterministic")
- >>> # with sublist format and ellipsis
- >>> torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2])
- tensor([[[-1.0564, -1.5904, 3.2023, 3.1271],
- [-1.6706, -0.8097, -0.8025, -2.1183]],
- [[ 4.2239, 0.3107, -0.5756, -0.2354],
- [-1.4558, -0.3460, 1.5087, -0.8530]],
- [[ 2.8153, 1.8787, -4.3839, -1.2112],
- [ 0.3728, -2.1131, 0.0921, 0.8305]]])
- >>> # batch permute
- >>> A = torch.randn(2, 3, 4, 5)
- >>> torch.einsum('...ij->...ji', A).shape
- torch.Size([2, 3, 5, 4])
- >>> # equivalent to torch.nn.functional.bilinear
- >>> A = torch.randn(3, 5, 4)
- >>> l = torch.randn(2, 5)
- >>> r = torch.randn(2, 4)
- >>> torch.einsum('bn,anm,bm->ba', l, A, r)
- tensor([[-0.3430, -5.2405, 0.4494],
- [ 0.3311, 5.5201, -3.0356]])
- """
- import torch.backends.opt_einsum as opt_einsum
- # This wrapper exists to support variadic args.
- if len(args) < 2:
- raise ValueError(
- "einsum(): must specify the equation string and at least one operand, "
- "or at least one operand and its subscripts list"
- )
- equation = None
- operands = None
- if isinstance(args[0], torch.Tensor):
- # Convert the subscript list format which is an interleaving of operand and its subscripts
- # list with an optional output subscripts list at the end (see documentation for more details on this)
- # to the equation string format by creating the equation string from the subscripts list and grouping the
- # input operands into a tensorlist (List[Tensor]).
- def parse_subscript(n: int) -> str:
- if n == Ellipsis:
- return "..."
- if n >= 0 and n < 26:
- return chr(ord("A") + n)
- if n >= 26 and n < 52:
- return chr(ord("a") + n - 26)
- raise ValueError(
- "einsum(): subscript in subscript list is not within the valid range [0, 52)"
- )
- # Parse subscripts for input operands
- equation = ",".join("".join(parse_subscript(s) for s in l) for l in args[1::2])
- # Parse optional output subscripts (provided when the number of arguments is odd)
- if len(args) % 2 == 1:
- equation += "->" + "".join(parse_subscript(s) for s in args[-1])
- operands = args[:-1:2]
- else:
- operands = args[::2]
- else:
- equation = args[0]
- operands = args[1:]
- if has_torch_function(operands):
- return handle_torch_function(einsum, operands, equation, *operands)
- if len(operands) == 1 and isinstance(operands[0], (list, tuple)):
- # the old interface of passing the operands as one list argument
- _operands = operands[0]
- # recurse in case operands contains value that has torch function
- # in the original implementation this line is omitted
- return einsum(equation, *_operands)
- if len(operands) <= 2 or not opt_einsum.enabled:
- # the path for contracting 0 or 1 time(s) is already optimized
- # or the user has disabled using opt_einsum
- return _VF.einsum(equation, operands) # type: ignore[attr-defined]
- path = None
- if opt_einsum.is_available():
- _opt_einsum = opt_einsum.get_opt_einsum()
- tupled_path = _opt_einsum.contract_path(
- equation, *operands, optimize=opt_einsum.strategy
- )[0]
- # flatten path for dispatching to C++
- path = [*itertools.chain.from_iterable(tupled_path)]
- return _VF.einsum(equation, operands, path=path) # type: ignore[attr-defined]
- # This wrapper exists to support variadic args.
- if TYPE_CHECKING:
- # The JIT doesn't understand Union, so only add type annotation for mypy
- def meshgrid(
- *tensors: Tensor | list[Tensor], indexing: str | None = None
- ) -> tuple[Tensor, ...]:
- return _meshgrid(*tensors, indexing=indexing)
- else:
- def meshgrid(*tensors, indexing: str | None = None) -> tuple[Tensor, ...]:
- r"""Creates grids of coordinates specified by the 1D inputs in `attr`:tensors.
- This is helpful when you want to visualize data over some
- range of inputs. See below for a plotting example.
- Given :math:`N` 1D tensors :math:`T_0 \ldots T_{N-1}` as
- inputs with corresponding sizes :math:`S_0 \ldots S_{N-1}`,
- this creates :math:`N` N-dimensional tensors :math:`G_0 \ldots
- G_{N-1}`, each with shape :math:`(S_0, ..., S_{N-1})` where
- the output :math:`G_i` is constructed by expanding :math:`T_i`
- to the result shape.
- .. note::
- 0D inputs are treated equivalently to 1D inputs of a
- single element.
- .. warning::
- `torch.meshgrid(*tensors)` currently has the same behavior
- as calling `numpy.meshgrid(*arrays, indexing='ij')`.
- In the future `torch.meshgrid` will transition to
- `indexing='xy'` as the default.
- https://github.com/pytorch/pytorch/issues/50276 tracks
- this issue with the goal of migrating to NumPy's behavior.
- .. seealso::
- :func:`torch.cartesian_prod` has the same effect but it
- collects the data in a tensor of vectors.
- Args:
- tensors (list of Tensor): list of scalars or 1 dimensional tensors. Scalars will be
- treated as tensors of size :math:`(1,)` automatically
- indexing: (str, optional): the indexing mode, either "xy"
- or "ij", defaults to "ij". See warning for future changes.
- If "xy" is selected, the first dimension corresponds
- to the cardinality of the second input and the second
- dimension corresponds to the cardinality of the first
- input.
- If "ij" is selected, the dimensions are in the same
- order as the cardinality of the inputs.
- Returns:
- seq (sequence of Tensors): If the input has :math:`N`
- tensors of size :math:`S_0 \ldots S_{N-1}``, then the
- output will also have :math:`N` tensors, where each tensor
- is of shape :math:`(S_0, ..., S_{N-1})`.
- Example::
- >>> x = torch.tensor([1, 2, 3])
- >>> y = torch.tensor([4, 5, 6])
- Observe the element-wise pairings across the grid, (1, 4),
- (1, 5), ..., (3, 6). This is the same thing as the
- cartesian product.
- >>> grid_x, grid_y = torch.meshgrid(x, y, indexing='ij')
- >>> grid_x
- tensor([[1, 1, 1],
- [2, 2, 2],
- [3, 3, 3]])
- >>> grid_y
- tensor([[4, 5, 6],
- [4, 5, 6],
- [4, 5, 6]])
- This correspondence can be seen when these grids are
- stacked properly.
- >>> torch.equal(torch.cat(tuple(torch.dstack([grid_x, grid_y]))),
- ... torch.cartesian_prod(x, y))
- True
- `torch.meshgrid` is commonly used to produce a grid for
- plotting.
- >>> # xdoctest: +REQUIRES(module:matplotlib)
- >>> # xdoctest: +REQUIRES(env:DOCTEST_SHOW)
- >>> import matplotlib.pyplot as plt
- >>> xs = torch.linspace(-5, 5, steps=100)
- >>> ys = torch.linspace(-5, 5, steps=100)
- >>> x, y = torch.meshgrid(xs, ys, indexing='xy')
- >>> z = torch.sin(torch.sqrt(x * x + y * y))
- >>> ax = plt.axes(projection='3d')
- >>> ax.plot_surface(x.numpy(), y.numpy(), z.numpy())
- >>> plt.show()
- .. image:: ../_static/img/meshgrid.png
- :width: 512
- """
- return _meshgrid(*tensors, indexing=indexing)
- def _meshgrid(*tensors, indexing: str | None):
- if has_torch_function(tensors):
- return handle_torch_function(meshgrid, tensors, *tensors, indexing=indexing)
- if len(tensors) == 1 and isinstance(tensors[0], (list, tuple)):
- # the old interface of passing the operands as one list argument
- tensors = tensors[0] # type: ignore[assignment]
- # Continue allowing call of old method that takes no indexing
- # kwarg for forward compatibility reasons.
- #
- # Remove this two weeks after landing.
- kwargs = {} if indexing is None else {"indexing": indexing}
- return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
- def stft(
- input: Tensor,
- n_fft: int,
- hop_length: int | None = None,
- win_length: int | None = None,
- window: Tensor | None = None,
- center: bool = True,
- pad_mode: str = "reflect",
- normalized: bool = False,
- onesided: bool | None = None,
- return_complex: bool | None = None,
- align_to_window: bool | None = None,
- ) -> Tensor:
- r"""Short-time Fourier transform (STFT).
- .. warning::
- From version 1.8.0, :attr:`return_complex` must always be given
- explicitly for real inputs and `return_complex=False` has been
- deprecated. Strongly prefer `return_complex=True` as in a future
- pytorch release, this function will only return complex tensors.
- Note that :func:`torch.view_as_real` can be used to recover a real
- tensor with an extra last dimension for real and imaginary components.
- .. warning::
- From version 2.1, a warning will be provided if a :attr:`window` is
- not specified. In a future release, this attribute will be required.
- Not providing a window currently defaults to using a rectangular window,
- which may result in undesirable artifacts. Consider using tapered windows,
- such as :func:`torch.hann_window`.
- The STFT computes the Fourier transform of short overlapping windows of the
- input. This giving frequency components of the signal as they change over
- time. The interface of this function is modeled after (but *not* a drop-in
- replacement for) librosa_ stft function.
- .. _librosa: https://librosa.org/doc/latest/generated/librosa.stft.html
- Ignoring the optional batch dimension, this method computes the following
- expression:
- .. math::
- X[\omega, m] = \sum_{k = 0}^{\text{win\_length-1}}%
- \text{window}[k]\ \text{input}[m \times \text{hop\_length} + k]\ %
- \exp\left(- j \frac{2 \pi \cdot \omega k}{\text{n\_fft}}\right),
- where :math:`m` is the index of the sliding window, and :math:`\omega` is
- the frequency :math:`0 \leq \omega < \text{n\_fft}` for ``onesided=False``,
- or :math:`0 \leq \omega < \lfloor \text{n\_fft} / 2 \rfloor + 1` for ``onesided=True``.
- * :attr:`input` must be either a 1-D time sequence or a 2-D batch of time
- sequences.
- * If :attr:`hop_length` is ``None`` (default), it is treated as equal to
- ``floor(n_fft / 4)``.
- * If :attr:`win_length` is ``None`` (default), it is treated as equal to
- :attr:`n_fft`.
- * :attr:`window` can be a 1-D tensor of size :attr:`win_length`, e.g., from
- :meth:`torch.hann_window`. If :attr:`window` is ``None`` (default), it is
- treated as if having :math:`1` everywhere in the window. If
- :math:`\text{win\_length} < \text{n\_fft}`, :attr:`window` will be padded on
- both sides to length :attr:`n_fft` before being applied.
- * If :attr:`center` is ``True`` (default), :attr:`input` will be padded on
- both sides so that the :math:`t`-th frame is centered at time
- :math:`t \times \text{hop\_length}`. Otherwise, the :math:`t`-th frame
- begins at time :math:`t \times \text{hop\_length}`.
- * :attr:`pad_mode` determines the padding method used on :attr:`input` when
- :attr:`center` is ``True``. See :meth:`torch.nn.functional.pad` for
- all available options. Default is ``"reflect"``.
- * If :attr:`onesided` is ``True`` (default for real input), only values for
- :math:`\omega` in :math:`\left[0, 1, 2, \dots, \left\lfloor
- \frac{\text{n\_fft}}{2} \right\rfloor + 1\right]` are returned because
- the real-to-complex Fourier transform satisfies the conjugate symmetry,
- i.e., :math:`X[m, \omega] = X[m, \text{n\_fft} - \omega]^*`.
- Note if the input or window tensors are complex, then :attr:`onesided`
- output is not possible.
- * If :attr:`normalized` is ``True`` (default is ``False``), the function
- returns the normalized STFT results, i.e., multiplied by :math:`(\text{frame\_length})^{-0.5}`.
- * If :attr:`return_complex` is ``True`` (default if input is complex), the
- return is a ``input.dim() + 1`` dimensional complex tensor. If ``False``,
- the output is a ``input.dim() + 2`` dimensional real tensor where the last
- dimension represents the real and imaginary components.
- Returns either a complex tensor of size :math:`(* \times N \times T)` if
- :attr:`return_complex` is true, or a real tensor of size :math:`(* \times N
- \times T \times 2)`. Where :math:`*` is the optional batch size of
- :attr:`input`, :math:`N` is the number of frequencies where STFT is applied
- and :math:`T` is the total number of frames used.
- .. warning::
- This function changed signature at version 0.4.1. Calling with the
- previous signature may cause error or return incorrect result.
- Args:
- input (Tensor): the input tensor of shape `(B?, L)` where `B?` is an optional
- batch dimension
- n_fft (int): size of Fourier transform
- hop_length (int, optional): the distance between neighboring sliding window
- frames. Default: ``None`` (treated as equal to ``floor(n_fft / 4)``)
- win_length (int, optional): the size of window frame and STFT filter.
- Default: ``None`` (treated as equal to :attr:`n_fft`)
- window (Tensor, optional): the optional window function.
- Shape must be 1d and `<= n_fft`
- Default: ``None`` (treated as window of all :math:`1` s)
- center (bool, optional): whether to pad :attr:`input` on both sides so
- that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
- Default: ``True``
- pad_mode (str, optional): controls the padding method used when
- :attr:`center` is ``True``. Default: ``"reflect"``
- normalized (bool, optional): controls whether to return the normalized STFT results
- Default: ``False``
- onesided (bool, optional): controls whether to return half of results to
- avoid redundancy for real inputs.
- Default: ``True`` for real :attr:`input` and :attr:`window`, ``False`` otherwise.
- return_complex (bool, optional): whether to return a complex tensor, or
- a real tensor with an extra last dimension for the real and
- imaginary components.
- .. versionchanged:: 2.0
- ``return_complex`` is now a required argument for real inputs,
- as the default is being transitioned to ``True``.
- .. deprecated:: 2.0
- ``return_complex=False`` is deprecated, instead use ``return_complex=True``
- Note that calling :func:`torch.view_as_real` on the output will
- recover the deprecated output format.
- Returns:
- Tensor: A tensor containing the STFT result with shape `(B?, N, T, C?)` where
- - `B?` is an optional batch dimension from the input.
- - `N` is the number of frequency samples, `(n_fft // 2) + 1` for
- `onesided=True`, or otherwise `n_fft`.
- - `T` is the number of frames, `1 + L // hop_length`
- for `center=True`, or `1 + (L - n_fft) // hop_length` otherwise.
- - `C?` is an optional length-2 dimension of real and imaginary
- components, present when `return_complex=False`.
- """
- if has_torch_function_unary(input):
- return handle_torch_function(
- stft,
- (input,),
- input,
- n_fft,
- hop_length=hop_length,
- win_length=win_length,
- window=window,
- center=center,
- pad_mode=pad_mode,
- normalized=normalized,
- onesided=onesided,
- return_complex=return_complex,
- align_to_window=align_to_window,
- )
- if center and align_to_window is not None:
- raise RuntimeError(
- "stft align_to_window should only be set when center = false"
- )
- # NOTE: Do not edit. This code will be removed once the forward-compatibility
- # period is over for PR #73432
- if center:
- signal_dim = input.dim()
- extended_shape = [1] * (3 - signal_dim) + list(input.size())
- pad = int(n_fft // 2)
- input = F.pad(input.view(extended_shape), [pad, pad], pad_mode)
- input = input.view(input.shape[-signal_dim:])
- return _VF.stft( # type: ignore[attr-defined]
- input,
- n_fft,
- hop_length,
- win_length,
- window,
- normalized,
- onesided,
- return_complex,
- align_to_window,
- )
- istft = _add_docstr(
- torch.istft,
- "istft(input, n_fft, hop_length=None, win_length=None, window=None, center=True, "
- "normalized=False, onesided=None, length=None, return_complex=False) -> Tensor:\n"
- r"""
- Inverse short time Fourier Transform. This is expected to be the inverse of :func:`~torch.stft`.
- .. warning::
- From version 2.1, a warning will be provided if a :attr:`window` is
- not specified. In a future release, this attribute will be required.
- Please provide the same window used in the stft call.
- It has the same parameters (+ additional optional parameter of :attr:`length`) and it should return the
- least squares estimation of the original signal. The algorithm will check using the NOLA condition (
- nonzero overlap).
- Important consideration in the parameters :attr:`window` and :attr:`center` so that the envelope
- created by the summation of all the windows is never zero at certain point in time. Specifically,
- :math:`\sum_{t=-\infty}^{\infty} |w|^2[n-t\times hop\_length] \cancel{=} 0`.
- Since :func:`~torch.stft` discards elements at the end of the signal if they do not fit in a frame,
- ``istft`` may return a shorter signal than the original signal (can occur if :attr:`center` is False
- since the signal isn't padded). If `length` is given in the arguments and is longer than expected,
- ``istft`` will pad zeros to the end of the returned signal.
- If :attr:`center` is ``True``, then there will be padding e.g. ``'constant'``, ``'reflect'``, etc.
- Left padding can be trimmed off exactly because they can be calculated but right padding cannot be
- calculated without additional information.
- Example: Suppose the last window is:
- ``[17, 18, 0, 0, 0]`` vs ``[18, 0, 0, 0, 0]``
- The :attr:`n_fft`, :attr:`hop_length`, :attr:`win_length` are all the same which prevents the calculation
- of right padding. These additional values could be zeros or a reflection of the signal so providing
- :attr:`length` could be useful. If :attr:`length` is ``None`` then padding will be aggressively removed
- (some loss of signal).
- [1] D. W. Griffin and J. S. Lim, "Signal estimation from modified short-time Fourier transform,"
- IEEE Trans. ASSP, vol.32, no.2, pp.236-243, Apr. 1984.
- Args:
- input (Tensor): The input tensor. Expected to be in the format of :func:`~torch.stft`,
- output. That is a complex tensor of shape `(B?, N, T)` where
- - `B?` is an optional batch dimension
- - `N` is the number of frequency samples, `(n_fft // 2) + 1`
- for onesided input, or otherwise `n_fft`.
- - `T` is the number of frames, `1 + length // hop_length` for centered stft,
- or `1 + (length - n_fft) // hop_length` otherwise.
- .. versionchanged:: 2.0
- Real datatype inputs are no longer supported. Input must now have a
- complex datatype, as returned by ``stft(..., return_complex=True)``.
- n_fft (int): Size of Fourier transform
- hop_length (Optional[int]): The distance between neighboring sliding window frames.
- (Default: ``n_fft // 4``)
- win_length (Optional[int]): The size of window frame and STFT filter. (Default: ``n_fft``)
- window (Optional[torch.Tensor]): The optional window function.
- Shape must be 1d and `<= n_fft`
- (Default: ``torch.ones(win_length)``)
- center (bool): Whether :attr:`input` was padded on both sides so that the :math:`t`-th frame is
- centered at time :math:`t \times \text{hop\_length}`.
- (Default: ``True``)
- normalized (bool): Whether the STFT was normalized. (Default: ``False``)
- onesided (Optional[bool]): Whether the STFT was onesided.
- (Default: ``True`` if `n_fft != fft_size` in the input size)
- length (Optional[int]): The amount to trim the signal by (i.e. the
- original signal length). Defaults to `(T - 1) * hop_length` for
- centered stft, or `n_fft + (T - 1) * hop_length` otherwise, where `T`
- is the number of input frames.
- return_complex (Optional[bool]):
- Whether the output should be complex, or if the input should be
- assumed to derive from a real signal and window.
- Note that this is incompatible with ``onesided=True``.
- (Default: ``False``)
- Returns:
- Tensor: Least squares estimation of the original signal of shape `(B?, length)` where
- `B?` is an optional batch dimension from the input tensor.
- """,
- )
- if TYPE_CHECKING:
- # These _impl functions return a variable number of tensors as output with
- # __torch_function__; tuple unpacking is done already rather than being
- # done by the caller of the _impl function
- _unique_impl_out = Any
- else:
- _unique_impl_out = tuple[Tensor, Tensor, Tensor]
- def _unique_impl(
- input: Tensor,
- sorted: bool = True,
- return_inverse: bool = False,
- return_counts: bool = False,
- dim: int | None = None,
- ) -> _unique_impl_out:
- r"""unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None) -> tuple[Tensor, Tensor, Tensor]
- Returns the unique elements of the input tensor.
- .. note:: This function is different from :func:`torch.unique_consecutive` in the sense that
- this function also eliminates non-consecutive duplicate values.
- .. note:: Currently in the CUDA implementation and the CPU implementation,
- `torch.unique` always sort the tensor at the beginning regardless of the `sort` argument.
- Sorting could be slow, so if your input tensor is already sorted, it is recommended to use
- :func:`torch.unique_consecutive` which avoids the sorting.
- Args:
- input (Tensor): the input tensor
- sorted (bool): Whether to sort the unique elements in ascending order
- before returning as output.
- return_inverse (bool): Whether to also return the indices for where
- elements in the original input ended up in the returned unique list.
- return_counts (bool): Whether to also return the counts for each unique
- element.
- dim (int, optional): the dimension to operate upon. If ``None``, the
- unique of the flattened input is returned. Otherwise, each of the
- tensors indexed by the given dimension is treated as one of the
- elements to apply the unique operation upon. **Important:** when ``dim``
- is specified, the operation finds unique sub-tensors (e.g., unique rows
- or columns), not unique scalar values. This means individual values may
- appear multiple times in the output if they exist in different sub-tensors.
- See examples for more details. Default: ``None``
- Returns:
- (Tensor, Tensor (optional), Tensor (optional)): A tensor or a tuple of tensors containing
- - **output** (*Tensor*): the output list of unique scalar elements if :attr:`dim`
- is ``None``; otherwise, the unique sub-tensors along the specified dimension.
- Note that when :attr:`dim` is specified, scalar values may repeat in the output.
- - **inverse_indices** (*Tensor*): (optional) if
- :attr:`return_inverse` is True, there will be an additional
- returned tensor (same shape as input) representing the indices
- for where elements in the original input map to in the output;
- otherwise, this function will only return a single tensor.
- - **counts** (*Tensor*): (optional) if
- :attr:`return_counts` is True, there will be an additional
- returned tensor (same shape as output or output.size(dim),
- if dim was specified) representing the number of occurrences
- for each unique value or tensor.
- Example::
- >>> output = torch.unique(torch.tensor([1, 3, 2, 3], dtype=torch.long))
- >>> output
- tensor([1, 2, 3])
- >>> output, inverse_indices = torch.unique(
- ... torch.tensor([1, 3, 2, 3], dtype=torch.long), sorted=True, return_inverse=True)
- >>> output
- tensor([1, 2, 3])
- >>> inverse_indices
- tensor([0, 2, 1, 2])
- >>> output, inverse_indices = torch.unique(
- ... torch.tensor([[1, 3], [2, 3]], dtype=torch.long), sorted=True, return_inverse=True)
- >>> output
- tensor([1, 2, 3])
- >>> inverse_indices
- tensor([[0, 2],
- [1, 2]])
- >>> # When using dim, the operation finds unique sub-tensors, not unique values.
- >>> # Notice how values can repeat in the output:
- >>> x = torch.tensor([[1, 3, 2, 3], [1, 2, 1, 2]], dtype=torch.long)
- >>> torch.unique(x, dim=0) # unique rows
- tensor([[1, 2, 1, 2],
- [1, 3, 2, 3]])
- >>> # Both rows are kept because they're different from each other,
- >>> # even though values 1, 2, 3 appear multiple times in the output.
- >>> torch.unique(x, dim=1) # unique columns
- tensor([[1, 2, 3],
- [1, 1, 2]])
- >>> # The value 1 appears twice because we're comparing columns, not values.
- >>> # Compare with flattened (no dim):
- >>> torch.unique(x)
- tensor([1, 2, 3])
- >>> a = torch.tensor([
- ... [
- ... [1, 1, 0, 0],
- ... [1, 1, 0, 0],
- ... [0, 0, 1, 1],
- ... ],
- ... [
- ... [0, 0, 1, 1],
- ... [0, 0, 1, 1],
- ... [1, 1, 1, 1],
- ... ],
- ... [
- ... [1, 1, 0, 0],
- ... [1, 1, 0, 0],
- ... [0, 0, 1, 1],
- ... ],
- ... ])
- >>> # If we call `torch.unique(a, dim=0)`, each of the tensors `a[idx, :, :]`
- >>> # will be compared. We can see that `a[0, :, :]` and `a[2, :, :]` match
- >>> # each other, so one of them will be removed.
- >>> (a[0, :, :] == a[2, :, :]).all()
- tensor(True)
- >>> a_unique_dim0 = torch.unique(a, dim=0)
- >>> a_unique_dim0
- tensor([[[0, 0, 1, 1],
- [0, 0, 1, 1],
- [1, 1, 1, 1]],
- [[1, 1, 0, 0],
- [1, 1, 0, 0],
- [0, 0, 1, 1]]])
- >>> # Notice which sub-tensors from `a` match with the sub-tensors from
- >>> # `a_unique_dim0`:
- >>> (a_unique_dim0[0, :, :] == a[1, :, :]).all()
- tensor(True)
- >>> (a_unique_dim0[1, :, :] == a[0, :, :]).all()
- tensor(True)
- >>> # For `torch.unique(a, dim=1)`, each of the tensors `a[:, idx, :]` are
- >>> # compared. `a[:, 0, :]` and `a[:, 1, :]` match each other, so one of
- >>> # them will be removed.
- >>> (a[:, 0, :] == a[:, 1, :]).all()
- tensor(True)
- >>> torch.unique(a, dim=1)
- tensor([[[0, 0, 1, 1],
- [1, 1, 0, 0]],
- [[1, 1, 1, 1],
- [0, 0, 1, 1]],
- [[0, 0, 1, 1],
- [1, 1, 0, 0]]])
- >>> # For `torch.unique(a, dim=2)`, the tensors `a[:, :, idx]` are compared.
- >>> # `a[:, :, 0]` and `a[:, :, 1]` match each other. Also, `a[:, :, 2]` and
- >>> # `a[:, :, 3]` match each other as well. So in this case, two of the
- >>> # sub-tensors will be removed.
- >>> (a[:, :, 0] == a[:, :, 1]).all()
- tensor(True)
- >>> (a[:, :, 2] == a[:, :, 3]).all()
- tensor(True)
- >>> torch.unique(a, dim=2)
- tensor([[[0, 1],
- [0, 1],
- [1, 0]],
- [[1, 0],
- [1, 0],
- [1, 1]],
- [[0, 1],
- [0, 1],
- [1, 0]]])
- """
- if has_torch_function_unary(input):
- return handle_torch_function(
- unique,
- (input,),
- input,
- sorted=sorted,
- return_inverse=return_inverse,
- return_counts=return_counts,
- dim=dim,
- )
- if dim is not None:
- output, inverse_indices, counts = _VF.unique_dim(
- input,
- dim,
- sorted=sorted,
- return_inverse=return_inverse,
- return_counts=return_counts,
- )
- else:
- output, inverse_indices, counts = torch._unique2(
- input,
- sorted=sorted,
- return_inverse=return_inverse,
- return_counts=return_counts,
- )
- return output, inverse_indices, counts
- def _unique_consecutive_impl(
- input: Tensor,
- return_inverse: bool = False,
- return_counts: bool = False,
- dim: int | None = None,
- ) -> _unique_impl_out:
- r"""Eliminates all but the first element from every consecutive group of equivalent elements.
- .. note:: This function is different from :func:`torch.unique` in the sense that this function
- only eliminates consecutive duplicate values. This semantics is similar to `std::unique`
- in C++.
- Args:
- input (Tensor): the input tensor
- return_inverse (bool): Whether to also return the indices for where
- elements in the original input ended up in the returned unique list.
- return_counts (bool): Whether to also return the counts for each unique
- element.
- dim (int): the dimension to apply unique. If ``None``, the unique of the
- flattened input is returned. default: ``None``
- Returns:
- (Tensor, Tensor (optional), Tensor (optional)): A tensor or a tuple of tensors containing
- - **output** (*Tensor*): the output list of unique scalar elements.
- - **inverse_indices** (*Tensor*): (optional) if
- :attr:`return_inverse` is True, there will be an additional
- returned tensor (same shape as input) representing the indices
- for where elements in the original input map to in the output;
- otherwise, this function will only return a single tensor.
- - **counts** (*Tensor*): (optional) if
- :attr:`return_counts` is True, there will be an additional
- returned tensor (same shape as output or output.size(dim),
- if dim was specified) representing the number of occurrences
- for each unique value or tensor.
- Example::
- >>> x = torch.tensor([1, 1, 2, 2, 3, 1, 1, 2])
- >>> output = torch.unique_consecutive(x)
- >>> output
- tensor([1, 2, 3, 1, 2])
- >>> output, inverse_indices = torch.unique_consecutive(x, return_inverse=True)
- >>> output
- tensor([1, 2, 3, 1, 2])
- >>> inverse_indices
- tensor([0, 0, 1, 1, 2, 3, 3, 4])
- >>> output, counts = torch.unique_consecutive(x, return_counts=True)
- >>> output
- tensor([1, 2, 3, 1, 2])
- >>> counts
- tensor([2, 2, 1, 2, 1])
- """
- if has_torch_function_unary(input):
- return handle_torch_function(
- unique_consecutive,
- (input,),
- input,
- return_inverse=return_inverse,
- return_counts=return_counts,
- dim=dim,
- )
- output, inverse_indices, counts = _VF.unique_consecutive( # type: ignore[attr-defined]
- input, return_inverse=return_inverse, return_counts=return_counts, dim=dim
- )
- return output, inverse_indices, counts
- def _return_counts(
- input,
- sorted=True,
- return_inverse=False,
- return_counts=False,
- dim=None,
- ):
- # type: (Tensor, bool, bool, bool, Optional[int]) -> tuple[Tensor, Tensor]
- if has_torch_function_unary(input):
- return _unique_impl(input, sorted, return_inverse, return_counts, dim)
- output, _, counts = _unique_impl(input, sorted, return_inverse, return_counts, dim)
- return output, counts
- def _return_output(
- input,
- sorted=True,
- return_inverse=False,
- return_counts=False,
- dim=None,
- ):
- # type: (Tensor, bool, bool, bool, Optional[int]) -> Tensor
- if has_torch_function_unary(input):
- return _unique_impl(input, sorted, return_inverse, return_counts, dim)
- output, _, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim)
- return output
- def _return_inverse(
- input,
- sorted=True,
- return_inverse=False,
- return_counts=False,
- dim=None,
- ):
- # type: (Tensor, bool, bool, bool, Optional[int]) -> tuple[Tensor, Tensor]
- if has_torch_function_unary(input):
- return _unique_impl(input, sorted, return_inverse, return_counts, dim)
- output, inverse_indices, _ = _unique_impl(
- input, sorted, return_inverse, return_counts, dim
- )
- return output, inverse_indices
- _return_inverse_false = boolean_dispatch(
- arg_name="return_counts",
- arg_index=3,
- default=False,
- if_true=_return_counts,
- if_false=_return_output,
- module_name=__name__,
- func_name="unique",
- )
- _return_inverse_true = boolean_dispatch(
- arg_name="return_counts",
- arg_index=3,
- default=False,
- if_true=_unique_impl,
- if_false=_return_inverse,
- module_name=__name__,
- func_name="unique",
- )
- # The return type of unique depends on `return_inverse`, and `return_counts` so in order to
- # resolve the output type in TorchScript we need to statically know the value of both parameters
- unique = boolean_dispatch(
- arg_name="return_inverse",
- arg_index=2,
- default=False,
- if_true=_return_inverse_true,
- if_false=_return_inverse_false,
- module_name=__name__,
- func_name="unique",
- )
- unique.__doc__ = _unique_impl.__doc__
- def _consecutive_return_counts(
- input,
- return_inverse=False,
- return_counts=False,
- dim=None,
- ):
- # type: (Tensor, bool, bool, Optional[int]) -> tuple[Tensor, Tensor]
- if has_torch_function_unary(input):
- return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
- output, _, counts = _unique_consecutive_impl(
- input, return_inverse, return_counts, dim
- )
- return output, counts
- def _consecutive_return_output(
- input,
- return_inverse=False,
- return_counts=False,
- dim=None,
- ):
- # type: (Tensor, bool, bool, Optional[int]) -> Tensor
- if has_torch_function_unary(input):
- return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
- output, _, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
- return output
- def _consecutive_return_inverse(
- input,
- return_inverse=False,
- return_counts=False,
- dim=None,
- ):
- # type: (Tensor, bool, bool, Optional[int]) -> tuple[Tensor, Tensor]
- if has_torch_function_unary(input):
- return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
- output, inverse_indices, _ = _unique_consecutive_impl(
- input, return_inverse, return_counts, dim
- )
- return output, inverse_indices
- _consecutive_return_inverse_false = boolean_dispatch(
- arg_name="return_counts",
- arg_index=1,
- default=False,
- if_true=_consecutive_return_counts,
- if_false=_consecutive_return_output,
- module_name=__name__,
- func_name="unique_consecutive",
- )
- _consecutive_return_inverse_true = boolean_dispatch(
- arg_name="return_counts",
- arg_index=1,
- default=False,
- if_true=_unique_consecutive_impl,
- if_false=_consecutive_return_inverse,
- module_name=__name__,
- func_name="unique_consecutive",
- )
- # The return type of unique depends on `return_inverse`, and `return_counts` so in order to
- # resolve the output type in TorchScript we need to statically know the value of both parameters
- unique_consecutive = boolean_dispatch(
- arg_name="return_inverse",
- arg_index=2,
- default=False,
- if_true=_consecutive_return_inverse_true,
- if_false=_consecutive_return_inverse_false,
- module_name=__name__,
- func_name="unique_consecutive",
- )
- unique_consecutive.__doc__ = _unique_consecutive_impl.__doc__
- if TYPE_CHECKING:
- pass
- # There's no good way to use this type annotation without breaking JIT
- # overloads. So leave untyped for mypy for now.
- else:
- @overload
- def tensordot(
- a,
- b,
- dims: int = 2,
- out: torch.Tensor | None = None,
- ):
- pass
- @overload
- def tensordot( # noqa: F811
- a,
- b,
- dims: tuple[list[int], list[int]],
- out: torch.Tensor | None = None,
- ):
- pass
- @overload
- def tensordot( # noqa: F811
- a,
- b,
- dims: list[list[int]],
- out: torch.Tensor | None = None,
- ):
- pass
- @overload
- def tensordot( # noqa: F811
- a,
- b,
- dims: torch.Tensor,
- out: torch.Tensor | None = None,
- ):
- pass
- def tensordot( # noqa: F811
- a,
- b,
- dims=2,
- out: torch.Tensor | None = None,
- ):
- r"""Returns a contraction of a and b over multiple dimensions.
- :attr:`tensordot` implements a generalized matrix product.
- Args:
- a (Tensor): Left tensor to contract
- b (Tensor): Right tensor to contract
- dims (int or Tuple[List[int], List[int]] or List[List[int]] containing two lists or Tensor): number of dimensions to
- contract or explicit lists of dimensions for :attr:`a` and
- :attr:`b` respectively
- When called with a non-negative integer argument :attr:`dims` = :math:`d`, and
- the number of dimensions of :attr:`a` and :attr:`b` is :math:`m` and :math:`n`,
- respectively, :func:`~torch.tensordot` computes the tensor :math:`r` of shape
- ``a.shape[:-dims] + b.shape[dims:]`` given by:
- .. math::
- r_{i_1,...,i_{m-d}, j_1,...,j_{n-d}}
- = \sum_{k_1,...,k_d} a_{i_1,...,i_{m-d},k_1,...,k_d} \times b_{k_1,...,k_d, j_1,...,j_{n-d}}.
- When called with :attr:`dims` of the list form, the given dimensions will be contracted
- in place of the last :math:`d` of :attr:`a` and the first :math:`d` of :math:`b`. The sizes
- in these dimensions must match, but :func:`~torch.tensordot` will deal with broadcasted
- dimensions.
- Examples::
- >>> a = torch.arange(60.).reshape(3, 4, 5)
- >>> b = torch.arange(24.).reshape(4, 3, 2)
- >>> torch.tensordot(a, b, dims=([1, 0], [0, 1]))
- tensor([[4400., 4730.],
- [4532., 4874.],
- [4664., 5018.],
- [4796., 5162.],
- [4928., 5306.]])
- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
- >>> a = torch.randn(3, 4, 5, device='cuda')
- >>> b = torch.randn(4, 5, 6, device='cuda')
- >>> c = torch.tensordot(a, b, dims=2).cpu()
- tensor([[ 8.3504, -2.5436, 6.2922, 2.7556, -1.0732, 3.2741],
- [ 3.3161, 0.0704, 5.0187, -0.4079, -4.3126, 4.8744],
- [ 0.8223, 3.9445, 3.2168, -0.2400, 3.4117, 1.7780]])
- >>> a = torch.randn(3, 5, 4, 6)
- >>> b = torch.randn(6, 4, 5, 3)
- >>> torch.tensordot(a, b, dims=([2, 1, 3], [1, 2, 0]))
- tensor([[ 7.7193, -2.4867, -10.3204],
- [ 1.5513, -14.4737, -6.5113],
- [ -0.2850, 4.2573, -3.5997]])
- """
- if has_torch_function_variadic(a, b):
- return handle_torch_function(tensordot, (a, b), a, b, dims=dims, out=out)
- if not isinstance(dims, (tuple, list, torch.Tensor, int, torch.SymInt)):
- raise RuntimeError(
- "tensordot expects dims to be int or "
- + "tuple[list[int], list[int]] or "
- + "list[list[int]] containing two lists, but got "
- + f"dims={dims}"
- )
- dims_a: list[int] = []
- dims_b: list[int] = []
- if isinstance(dims, (tuple, list)):
- dims_a, dims_b = dims
- if isinstance(dims, torch.Tensor):
- num_elements = dims.numel()
- if num_elements > 1:
- if dims.size()[0] != 2:
- raise AssertionError(
- f"dims tensor must have size 2 in first dimension, got {dims.size()[0]}"
- )
- dims_a = torch.jit.annotate(list[int], dims[0].tolist())
- dims_b = torch.jit.annotate(list[int], dims[1].tolist())
- else:
- dims_val = int(dims.item())
- if dims_val < 0:
- raise RuntimeError(f"tensordot expects dims >= 0, but got dims={dims}")
- dims_a = list(range(-dims_val, 0))
- dims_b = list(range(dims_val))
- if isinstance(dims, (int, torch.SymInt)):
- if dims < 0:
- raise RuntimeError(f"tensordot expects dims >= 0, but got dims={dims}")
- if dims > min(a.dim(), b.dim()):
- raise RuntimeError(
- f"tensordot expects dims < ndim_a or ndim_b, but got dims={dims}"
- )
- dims_a = list(range(-dims, 0))
- dims_b = list(range(dims))
- if out is None:
- return _VF.tensordot(a, b, dims_a, dims_b) # type: ignore[attr-defined]
- else:
- return _VF.tensordot(a, b, dims_a, dims_b, out=out) # type: ignore[attr-defined]
- def cartesian_prod(*tensors: Tensor) -> Tensor:
- """Do cartesian product of the given sequence of tensors. The behavior is similar to
- python's `itertools.product`.
- Args:
- *tensors: any number of 1 dimensional tensors.
- Returns:
- Tensor: A tensor equivalent to converting all the input tensors into lists,
- do `itertools.product` on these lists, and finally convert the resulting list
- into tensor.
- Example::
- >>> import itertools
- >>> a = [1, 2, 3]
- >>> b = [4, 5]
- >>> list(itertools.product(a, b))
- [(1, 4), (1, 5), (2, 4), (2, 5), (3, 4), (3, 5)]
- >>> tensor_a = torch.tensor(a)
- >>> tensor_b = torch.tensor(b)
- >>> torch.cartesian_prod(tensor_a, tensor_b)
- tensor([[1, 4],
- [1, 5],
- [2, 4],
- [2, 5],
- [3, 4],
- [3, 5]])
- """
- # This wrapper exists to support variadic args.
- if has_torch_function(tensors):
- return handle_torch_function(cartesian_prod, tensors, *tensors)
- return _VF.cartesian_prod(tensors) # type: ignore[attr-defined]
- def block_diag(*tensors):
- """Create a block diagonal matrix from provided tensors.
- Args:
- *tensors: One or more tensors with 0, 1, or 2 dimensions.
- Returns:
- Tensor: A 2 dimensional tensor with all the input tensors arranged in
- order such that their upper left and lower right corners are
- diagonally adjacent. All other elements are set to 0.
- Example::
- >>> import torch
- >>> A = torch.tensor([[0, 1], [1, 0]])
- >>> B = torch.tensor([[3, 4, 5], [6, 7, 8]])
- >>> C = torch.tensor(7)
- >>> D = torch.tensor([1, 2, 3])
- >>> E = torch.tensor([[4], [5], [6]])
- >>> torch.block_diag(A, B, C, D, E)
- tensor([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- [0, 0, 3, 4, 5, 0, 0, 0, 0, 0],
- [0, 0, 6, 7, 8, 0, 0, 0, 0, 0],
- [0, 0, 0, 0, 0, 7, 0, 0, 0, 0],
- [0, 0, 0, 0, 0, 0, 1, 2, 3, 0],
- [0, 0, 0, 0, 0, 0, 0, 0, 0, 4],
- [0, 0, 0, 0, 0, 0, 0, 0, 0, 5],
- [0, 0, 0, 0, 0, 0, 0, 0, 0, 6]])
- """
- # This wrapper exists to support variadic args.
- if has_torch_function(tensors):
- return handle_torch_function(block_diag, tensors, *tensors)
- return torch._C._VariableFunctions.block_diag(tensors) # type: ignore[attr-defined]
- def cdist(x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary"):
- # type: (Tensor, Tensor, float, str) -> (Tensor)
- r"""Computes batched the p-norm distance between each pair of the two collections of row vectors.
- Args:
- x1 (Tensor): input tensor where the last two dimensions represent the points and the feature dimension respectively.
- The shape can be :math:`D_1 \times D_2 \times \cdots \times D_n \times P \times M`,
- where :math:`P` is the number of points and :math:`M` is the feature dimension.
- x2 (Tensor): input tensor where the last two dimensions also represent the points and the feature dimension respectively.
- The shape can be :math:`D_1' \times D_2' \times \cdots \times D_m' \times R \times M`,
- where :math:`R` is the number of points and :math:`M` is the feature dimension,
- which should match the feature dimension of `x1`.
- p: p value for the p-norm distance to calculate between each vector pair
- :math:`\in [0, \infty]`.
- compute_mode:
- 'use_mm_for_euclid_dist_if_necessary' - will use matrix multiplication approach to calculate
- euclidean distance (p = 2) if P > 25 or R > 25
- 'use_mm_for_euclid_dist' - will always use matrix multiplication approach to calculate
- euclidean distance (p = 2)
- 'donot_use_mm_for_euclid_dist' - will never use matrix multiplication approach to calculate
- euclidean distance (p = 2)
- Default: use_mm_for_euclid_dist_if_necessary.
- If x1 has shape :math:`B \times P \times M` and x2 has shape :math:`B \times R \times M` then the
- output will have shape :math:`B \times P \times R`.
- This function is equivalent to `scipy.spatial.distance.cdist(input,'minkowski', p=p)`
- if :math:`p \in (0, \infty)`. When :math:`p = 0` it is equivalent to
- `scipy.spatial.distance.cdist(input, 'hamming') * M`. When :math:`p = \infty`, the closest
- scipy function is `scipy.spatial.distance.cdist(xn, lambda x, y: np.abs(x - y).max())`.
- Example:
- >>> a = torch.tensor([[0.9041, 0.0196], [-0.3108, -2.4423], [-0.4821, 1.059]])
- >>> a
- tensor([[ 0.9041, 0.0196],
- [-0.3108, -2.4423],
- [-0.4821, 1.0590]])
- >>> b = torch.tensor([[-2.1763, -0.4713], [-0.6986, 1.3702]])
- >>> b
- tensor([[-2.1763, -0.4713],
- [-0.6986, 1.3702]])
- >>> torch.cdist(a, b, p=2)
- tensor([[3.1193, 2.0959],
- [2.7138, 3.8322],
- [2.2830, 0.3791]])
- """
- if has_torch_function_variadic(x1, x2):
- return handle_torch_function(
- cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode
- )
- if compute_mode == "use_mm_for_euclid_dist_if_necessary":
- return _VF.cdist(x1, x2, p, None) # type: ignore[attr-defined]
- elif compute_mode == "use_mm_for_euclid_dist":
- return _VF.cdist(x1, x2, p, 1) # type: ignore[attr-defined]
- elif compute_mode == "donot_use_mm_for_euclid_dist":
- return _VF.cdist(x1, x2, p, 2) # type: ignore[attr-defined]
- else:
- raise ValueError(f"{compute_mode} is not a valid value for compute_mode")
- def atleast_1d(*tensors):
- r"""
- Returns a 1-dimensional view of each input tensor with zero dimensions.
- Input tensors with one or more dimensions are returned as-is.
- Args:
- input (Tensor or sequence of Tensors): tensor(s) to be converted to at least 1-dimensional.
- Returns:
- output (Tensor or tuple of Tensors)
- Example::
- >>> x = torch.arange(2)
- >>> x
- tensor([0, 1])
- >>> torch.atleast_1d(x)
- tensor([0, 1])
- >>> x = torch.tensor(1.)
- >>> x
- tensor(1.)
- >>> torch.atleast_1d(x)
- tensor([1.])
- >>> x = torch.tensor(0.5)
- >>> y = torch.tensor(1.)
- >>> torch.atleast_1d((x, y))
- (tensor([0.5000]), tensor([1.]))
- >>> torch.atleast_1d()
- ()
- """
- # This wrapper exists to support variadic args.
- if has_torch_function(tensors):
- return handle_torch_function(atleast_1d, tensors, *tensors)
- if len(tensors) == 1:
- tensors = tensors[0]
- return _VF.atleast_1d(tensors) # type: ignore[attr-defined]
- def atleast_2d(*tensors):
- r"""
- Returns a 2-dimensional view of each input tensor with zero dimensions.
- Input tensors with two or more dimensions are returned as-is.
- Args:
- input (Tensor or sequence of Tensors): tensor(s) to be converted to at least 2-dimensional.
- Returns:
- output (Tensor or tuple of Tensors)
- Example::
- >>> x = torch.tensor(1.)
- >>> x
- tensor(1.)
- >>> torch.atleast_2d(x)
- tensor([[1.]])
- >>> x = torch.arange(4).view(2, 2)
- >>> x
- tensor([[0, 1],
- [2, 3]])
- >>> torch.atleast_2d(x)
- tensor([[0, 1],
- [2, 3]])
- >>> x = torch.tensor(0.5)
- >>> y = torch.tensor(1.)
- >>> torch.atleast_2d((x, y))
- (tensor([[0.5000]]), tensor([[1.]]))
- >>> torch.atleast_2d()
- ()
- """
- # This wrapper exists to support variadic args.
- if has_torch_function(tensors):
- return handle_torch_function(atleast_2d, tensors, *tensors)
- if len(tensors) == 1:
- tensors = tensors[0]
- return _VF.atleast_2d(tensors) # type: ignore[attr-defined]
- def atleast_3d(*tensors):
- r"""
- Returns a 3-dimensional view of each input tensor with zero dimensions.
- Input tensors with three or more dimensions are returned as-is.
- Args:
- input (Tensor or sequence of Tensors): tensor(s) to be converted to at least 3-dimensional.
- Returns:
- output (Tensor or tuple of Tensors)
- Example:
- >>> x = torch.tensor(0.5)
- >>> x
- tensor(0.5000)
- >>> torch.atleast_3d(x)
- tensor([[[0.5000]]])
- >>> y = torch.arange(4).view(2, 2)
- >>> y
- tensor([[0, 1],
- [2, 3]])
- >>> torch.atleast_3d(y)
- tensor([[[0],
- [1]],
- <BLANKLINE>
- [[2],
- [3]]])
- >>> x = torch.tensor(1).view(1, 1, 1)
- >>> x
- tensor([[[1]]])
- >>> torch.atleast_3d(x)
- tensor([[[1]]])
- >>> x = torch.tensor(0.5)
- >>> y = torch.tensor(1.0)
- >>> torch.atleast_3d((x, y))
- (tensor([[[0.5000]]]), tensor([[[1.]]]))
- >>> torch.atleast_3d()
- ()
- """
- # This wrapper exists to support variadic args.
- if has_torch_function(tensors):
- return handle_torch_function(atleast_3d, tensors, *tensors)
- if len(tensors) == 1:
- tensors = tensors[0]
- return _VF.atleast_3d(tensors) # type: ignore[attr-defined]
- if TYPE_CHECKING:
- pass
- # There's no good way to use this type annotation; cannot rename norm() to
- # _norm_impl() in a way that doesn't break JIT overloads. So leave untyped
- # for mypy for now.
- # def norm(input: Tensor,
- # p: Optional[Union[str, Number]] = "fro",
- # dim: Optional[Union[int, List[int]]] = None,
- # keepdim: bool = False,
- # out: Optional[Tensor] = None,
- # dtype: _dtype = None) -> Tensor:
- # return _norm_impl(input, p, dim, keepdim, out, dtype)
- else:
- # TODO: type dim as BroadcastingList when
- # https://github.com/pytorch/pytorch/issues/33782 is fixed
- @overload
- def norm(
- input,
- p="fro",
- dim=None,
- keepdim=False,
- out=None,
- dtype=None,
- ):
- # type: (Tensor, str, Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor
- pass
- @overload
- def norm( # noqa: F811
- input,
- p="fro",
- dim=None,
- keepdim=False,
- out=None,
- dtype=None,
- ):
- # type: (Tensor, Optional[number], Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor
- pass
- @overload
- def norm( # noqa: F811
- input,
- p="fro",
- dim=None,
- keepdim=False,
- out=None,
- dtype=None,
- ):
- # type: (Tensor, Optional[number], Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor
- pass
- @overload
- def norm( # noqa: F811
- input,
- p="fro",
- dim=None,
- keepdim=False,
- out=None,
- dtype=None,
- ):
- # type: (Tensor, str, Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor
- pass
- def norm( # noqa: F811
- input,
- p: float | str | None = "fro",
- dim=None,
- keepdim=False,
- out=None,
- dtype=None,
- ):
- r"""Returns the matrix norm or vector norm of a given tensor.
- .. warning::
- torch.norm is deprecated and may be removed in a future PyTorch release.
- Its documentation and behavior may be incorrect, and it is no longer
- actively maintained.
- Use :func:`torch.linalg.vector_norm` when computing vector norms and
- :func:`torch.linalg.matrix_norm` when computing matrix norms.
- For a function with a similar behavior as this one see :func:`torch.linalg.norm`.
- Note, however, the signature for these functions is slightly different than the
- signature for ``torch.norm``.
- Args:
- input (Tensor): The input tensor. Its data type must be either a floating
- point or complex type. For complex inputs, the norm is calculated using the
- absolute value of each element. If the input is complex and neither
- :attr:`dtype` nor :attr:`out` is specified, the result's data type will
- be the corresponding floating point type (e.g. float if :attr:`input` is
- complexfloat).
- p (int, float, inf, -inf, 'fro', 'nuc', optional): the order of norm. Default: ``'fro'``
- The following norms can be calculated:
- ====== ============== ==========================
- ord matrix norm vector norm
- ====== ============== ==========================
- 'fro' Frobenius norm --
- 'nuc' nuclear norm --
- Number -- sum(abs(x)**ord)**(1./ord)
- ====== ============== ==========================
- The vector norm can be calculated across any number of dimensions.
- The corresponding dimensions of :attr:`input` are flattened into
- one dimension, and the norm is calculated on the flattened
- dimension.
- Frobenius norm produces the same result as ``p=2`` in all cases
- except when :attr:`dim` is a list of three or more dims, in which
- case Frobenius norm throws an error.
- Nuclear norm can only be calculated across exactly two dimensions.
- dim (int, tuple of ints, list of ints, optional):
- Specifies which dimension or dimensions of :attr:`input` to
- calculate the norm across. If :attr:`dim` is ``None``, the norm will
- be calculated across all dimensions of :attr:`input`. If the norm
- type indicated by :attr:`p` does not support the specified number of
- dimensions, an error will occur.
- keepdim (bool, optional): whether the output tensors have :attr:`dim`
- retained or not. Ignored if :attr:`dim` = ``None`` and
- :attr:`out` = ``None``. Default: ``False``
- out (Tensor, optional): the output tensor. Ignored if
- :attr:`dim` = ``None`` and :attr:`out` = ``None``.
- dtype (:class:`torch.dtype`, optional): the desired data type of
- returned tensor. If specified, the input tensor is casted to
- :attr:`dtype` while performing the operation. Default: None.
- .. note::
- Even though ``p='fro'`` supports any number of dimensions, the true
- mathematical definition of Frobenius norm only applies to tensors with
- exactly two dimensions. :func:`torch.linalg.matrix_norm` with ``ord='fro'``
- aligns with the mathematical definition, since it can only be applied across
- exactly two dimensions.
- Example::
- >>> import torch
- >>> a = torch.arange(9, dtype= torch.float) - 4
- >>> b = a.reshape((3, 3))
- >>> torch.norm(a)
- tensor(7.7460)
- >>> torch.norm(b)
- tensor(7.7460)
- >>> torch.norm(a, float('inf'))
- tensor(4.)
- >>> torch.norm(b, float('inf'))
- tensor(4.)
- >>> c = torch.tensor([[ 1, 2, 3], [-1, 1, 4]] , dtype=torch.float)
- >>> torch.norm(c, dim=0)
- tensor([1.4142, 2.2361, 5.0000])
- >>> torch.norm(c, dim=1)
- tensor([3.7417, 4.2426])
- >>> torch.norm(c, p=1, dim=1)
- tensor([6., 6.])
- >>> d = torch.arange(8, dtype=torch.float).reshape(2, 2, 2)
- >>> torch.norm(d, dim=(1, 2))
- tensor([ 3.7417, 11.2250])
- >>> torch.norm(d[0, :, :]), torch.norm(d[1, :, :])
- (tensor(3.7417), tensor(11.2250))
- """
- if has_torch_function_unary(input):
- return handle_torch_function(
- norm, (input,), input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype
- )
- # NB. All the repeated code and weird python is to please TorchScript.
- # For a more compact implementation see the relevant function in `_refs/__init__.py`
- # We don't do this for MPS or sparse tensors
- if input.layout == torch.strided and input.device.type in (
- "cpu",
- "cuda",
- "xpu",
- "meta",
- torch.utils.backend_registration._privateuse1_backend_name,
- ):
- if dim is not None:
- if isinstance(dim, (int, torch.SymInt)):
- _dim = [dim]
- else:
- _dim = dim
- else:
- _dim = None # type: ignore[assignment]
- if isinstance(p, str):
- if p == "fro" and (
- dim is None
- or isinstance(dim, (int, torch.SymInt))
- or len(dim) <= 2 # pyrefly: ignore # bad-argument-type
- ):
- if out is None:
- return torch.linalg.vector_norm(
- input, 2, _dim, keepdim, dtype=dtype
- )
- else:
- return torch.linalg.vector_norm(
- input, 2, _dim, keepdim, dtype=dtype, out=out
- )
- # Here we either call the nuclear norm, or we call matrix_norm with some arguments
- # that will throw an error
- if _dim is None:
- _dim = list(range(input.ndim))
- if out is None:
- return torch.linalg.matrix_norm(input, p, _dim, keepdim, dtype=dtype)
- else:
- return torch.linalg.matrix_norm(
- input, p, _dim, keepdim, dtype=dtype, out=out
- )
- else:
- # NB. p should be Union[str, number], not Optional!
- _p = 2.0 if p is None else p
- if out is None:
- return torch.linalg.vector_norm(input, _p, _dim, keepdim, dtype=dtype)
- else:
- return torch.linalg.vector_norm(
- input, _p, _dim, keepdim, dtype=dtype, out=out
- )
- ndim = input.dim()
- # catch default case
- if dim is None and out is None and dtype is None and p is not None:
- if isinstance(p, str):
- if p == "fro":
- return _VF.frobenius_norm(input, dim=(), keepdim=keepdim)
- if not isinstance(p, str):
- _dim = list(range(ndim))
- return _VF.norm(input, p, dim=_dim, keepdim=keepdim) # type: ignore[attr-defined]
- # TODO: when https://github.com/pytorch/pytorch/issues/33782 is fixed
- # remove the overloads where dim is an int and replace with BroadcastingList1
- # and remove next four lines, replace _dim with dim
- if dim is not None:
- if isinstance(dim, (int, torch.SymInt)):
- _dim = [dim]
- else:
- _dim = dim
- else:
- _dim = None # type: ignore[assignment]
- if isinstance(p, str):
- if p == "fro":
- if dtype is not None:
- raise ValueError("dtype argument is not supported in frobenius norm")
- if _dim is None:
- _dim = list(range(ndim))
- if out is None:
- return _VF.frobenius_norm(input, _dim, keepdim=keepdim) # type: ignore[arg-type]
- else:
- return _VF.frobenius_norm(input, _dim, keepdim=keepdim, out=out) # type: ignore[arg-type]
- elif p == "nuc":
- if dtype is not None:
- raise ValueError("dtype argument is not supported in nuclear norm")
- if _dim is None:
- if out is None:
- return _VF.nuclear_norm(input, keepdim=keepdim) # type: ignore[arg-type]
- else:
- return _VF.nuclear_norm(input, keepdim=keepdim, out=out) # type: ignore[arg-type]
- else:
- if out is None:
- return _VF.nuclear_norm(input, _dim, keepdim=keepdim) # type: ignore[arg-type]
- else:
- return _VF.nuclear_norm(input, _dim, keepdim=keepdim, out=out) # type: ignore[arg-type]
- raise RuntimeError(f"only valid string values are 'fro' and 'nuc', found {p}")
- else:
- if _dim is None:
- _dim = list(range(ndim))
- if out is None:
- if dtype is None:
- return _VF.norm(input, p, _dim, keepdim=keepdim) # type: ignore[attr-defined]
- else:
- return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype) # type: ignore[attr-defined]
- else:
- if dtype is None:
- return _VF.norm(input, p, _dim, keepdim=keepdim, out=out) # type: ignore[attr-defined]
- else:
- return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype, out=out) # type: ignore[attr-defined]
- def unravel_index(
- indices: Tensor,
- shape: int | Sequence[int] | torch.Size,
- ) -> tuple[Tensor, ...]:
- r"""Converts a tensor of flat indices into a tuple of coordinate tensors that
- index into an arbitrary tensor of the specified shape.
- Args:
- indices (Tensor): An integer tensor containing indices into the
- flattened version of an arbitrary tensor of shape :attr:`shape`.
- All elements must be in the range ``[0, prod(shape) - 1]``.
- shape (int, sequence of ints, or torch.Size): The shape of the arbitrary
- tensor. All elements must be non-negative.
- Returns:
- tuple of Tensors: Each ``i``-th tensor in the output corresponds with
- dimension ``i`` of :attr:`shape`. Each tensor has the same shape as
- ``indices`` and contains one index into dimension ``i`` for each of the
- flat indices given by ``indices``.
- Example::
- >>> import torch
- >>> torch.unravel_index(torch.tensor(4), (3, 2))
- (tensor(2),
- tensor(0))
- >>> torch.unravel_index(torch.tensor([4, 1]), (3, 2))
- (tensor([2, 0]),
- tensor([0, 1]))
- >>> torch.unravel_index(torch.tensor([0, 1, 2, 3, 4, 5]), (3, 2))
- (tensor([0, 0, 1, 1, 2, 2]),
- tensor([0, 1, 0, 1, 0, 1]))
- >>> torch.unravel_index(torch.tensor([1234, 5678]), (10, 10, 10, 10))
- (tensor([1, 5]),
- tensor([2, 6]),
- tensor([3, 7]),
- tensor([4, 8]))
- >>> torch.unravel_index(torch.tensor([[1234], [5678]]), (10, 10, 10, 10))
- (tensor([[1], [5]]),
- tensor([[2], [6]]),
- tensor([[3], [7]]),
- tensor([[4], [8]]))
- >>> torch.unravel_index(torch.tensor([[1234], [5678]]), (100, 100))
- (tensor([[12], [56]]),
- tensor([[34], [78]]))
- """
- if has_torch_function_unary(indices):
- return handle_torch_function(unravel_index, (indices,), indices, shape=shape)
- res_tensor = _unravel_index(indices, shape)
- return res_tensor.unbind(-1)
- def _unravel_index(indices: Tensor, shape: int | Sequence[int]) -> Tensor:
- torch._check_type(
- not indices.is_complex()
- and not indices.is_floating_point()
- and indices.dtype != torch.bool,
- lambda: f"expected 'indices' to be integer dtype, but got {indices.dtype}",
- )
- torch._check_type(
- isinstance(shape, (int, torch.SymInt, Sequence)),
- lambda: f"expected 'shape' to be int or sequence of ints, but got {type(shape)}",
- )
- if isinstance(shape, (int, torch.SymInt)):
- shape = torch.Size([shape]) # pyrefly: ignore [bad-argument-type]
- else:
- for dim in shape:
- torch._check_type(
- isinstance(dim, (int, torch.SymInt)),
- lambda: f"expected 'shape' sequence to only contain ints, but got {type(dim)}",
- )
- shape = torch.Size(shape)
- torch._check_value(
- all(dim >= 0 for dim in shape),
- lambda: f"'shape' cannot have negative values, but got {tuple(shape)}",
- )
- coefs = list(
- reversed(
- list(
- itertools.accumulate(
- reversed(shape[1:] + torch.Size([1])), func=operator.mul
- )
- )
- )
- )
- return indices.unsqueeze(-1).floor_divide(
- torch.tensor(coefs, device=indices.device, dtype=torch.int64)
- ) % torch.tensor(shape, device=indices.device, dtype=torch.int64)
- def chain_matmul(*matrices, out=None):
- r"""Returns the matrix product of the :math:`N` 2-D tensors. This product is efficiently computed
- using the matrix chain order algorithm which selects the order in which incurs the lowest cost in terms
- of arithmetic operations (`[CLRS]`_). Note that since this is a function to compute the product, :math:`N`
- needs to be greater than or equal to 2; if equal to 2 then a trivial matrix-matrix product is returned.
- If :math:`N` is 1, then this is a no-op - the original matrix is returned as is.
- .. warning::
- :func:`torch.chain_matmul` is deprecated and will be removed in a future PyTorch release.
- Use :func:`torch.linalg.multi_dot` instead, which accepts a list of two or more tensors
- rather than multiple arguments.
- Args:
- matrices (Tensors...): a sequence of 2 or more 2-D tensors whose product is to be determined.
- out (Tensor, optional): the output tensor. Ignored if :attr:`out` = ``None``.
- Returns:
- Tensor: if the :math:`i^{th}` tensor was of dimensions :math:`p_{i} \times p_{i + 1}`, then the product
- would be of dimensions :math:`p_{1} \times p_{N + 1}`.
- Example::
- >>> # xdoctest: +SKIP
- >>> # xdoctest: +IGNORE_WANT("non-deterministic")
- >>> a = torch.randn(3, 4)
- >>> b = torch.randn(4, 5)
- >>> c = torch.randn(5, 6)
- >>> d = torch.randn(6, 7)
- >>> # will raise a deprecation warning
- >>> torch.chain_matmul(a, b, c, d)
- tensor([[ -2.3375, -3.9790, -4.1119, -6.6577, 9.5609, -11.5095, -3.2614],
- [ 21.4038, 3.3378, -8.4982, -5.2457, -10.2561, -2.4684, 2.7163],
- [ -0.9647, -5.8917, -2.3213, -5.2284, 12.8615, -12.2816, -2.5095]])
- .. _`[CLRS]`: https://mitpress.mit.edu/books/introduction-algorithms-third-edition
- """
- # This wrapper exists to support variadic args.
- if has_torch_function(matrices):
- return handle_torch_function(chain_matmul, matrices, *matrices)
- if out is None:
- return _VF.chain_matmul(matrices) # type: ignore[attr-defined]
- else:
- return _VF.chain_matmul(matrices, out=out) # type: ignore[attr-defined]
- def _lu_impl(A, pivot=True, get_infos=False, out=None):
- # type: (Tensor, bool, bool, Any) -> tuple[Tensor, Tensor, Tensor]
- r"""Computes the LU factorization of a matrix or batches of matrices
- :attr:`A`. Returns a tuple containing the LU factorization and
- pivots of :attr:`A`. Pivoting is done if :attr:`pivot` is set to
- ``True``.
- .. warning::
- :func:`torch.lu` is deprecated in favor of :func:`torch.linalg.lu_factor`
- and :func:`torch.linalg.lu_factor_ex`. :func:`torch.lu` will be removed in a
- future PyTorch release.
- ``LU, pivots, info = torch.lu(A, compute_pivots)`` should be replaced with
- .. code:: python
- LU, pivots = torch.linalg.lu_factor(A, compute_pivots)
- ``LU, pivots, info = torch.lu(A, compute_pivots, get_infos=True)`` should be replaced with
- .. code:: python
- LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots)
- .. note::
- * The returned permutation matrix for every matrix in the batch is
- represented by a 1-indexed vector of size ``min(A.shape[-2], A.shape[-1])``.
- ``pivots[i] == j`` represents that in the ``i``-th step of the algorithm,
- the ``i``-th row was permuted with the ``j-1``-th row.
- * LU factorization with :attr:`pivot` = ``False`` is not available
- for CPU, and attempting to do so will throw an error. However,
- LU factorization with :attr:`pivot` = ``False`` is available for
- CUDA.
- * This function does not check if the factorization was successful
- or not if :attr:`get_infos` is ``True`` since the status of the
- factorization is present in the third element of the return tuple.
- * In the case of batches of square matrices with size less or equal
- to 32 on a CUDA device, the LU factorization is repeated for
- singular matrices due to the bug in the MAGMA library
- (see magma issue 13).
- * ``L``, ``U``, and ``P`` can be derived using :func:`torch.lu_unpack`.
- .. warning::
- The gradients of this function will only be finite when :attr:`A` is full rank.
- This is because the LU decomposition is just differentiable at full rank matrices.
- Furthermore, if :attr:`A` is close to not being full rank,
- the gradient will be numerically unstable as it depends on the computation of :math:`L^{-1}` and :math:`U^{-1}`.
- Args:
- A (Tensor): the tensor to factor of size :math:`(*, m, n)`
- pivot (bool, optional): Whether to compute the LU decomposition with partial pivoting, or the regular LU
- decomposition. :attr:`pivot`\ `= False` not supported on CPU. Default: `True`.
- get_infos (bool, optional): if set to ``True``, returns an info IntTensor.
- Default: ``False``
- out (tuple, optional): optional output tuple. If :attr:`get_infos` is ``True``,
- then the elements in the tuple are Tensor, IntTensor,
- and IntTensor. If :attr:`get_infos` is ``False``, then the
- elements in the tuple are Tensor, IntTensor. Default: ``None``
- Returns:
- (Tensor, IntTensor, IntTensor (optional)): A tuple of tensors containing
- - **factorization** (*Tensor*): the factorization of size :math:`(*, m, n)`
- - **pivots** (*IntTensor*): the pivots of size :math:`(*, \text{min}(m, n))`.
- ``pivots`` stores all the intermediate transpositions of rows.
- The final permutation ``perm`` could be reconstructed by
- applying ``swap(perm[i], perm[pivots[i] - 1])`` for ``i = 0, ..., pivots.size(-1) - 1``,
- where ``perm`` is initially the identity permutation of :math:`m` elements
- (essentially this is what :func:`torch.lu_unpack` is doing).
- - **infos** (*IntTensor*, *optional*): if :attr:`get_infos` is ``True``, this is a tensor of
- size :math:`(*)` where non-zero values indicate whether factorization for the matrix or
- each minibatch has succeeded or failed
- Example::
- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
- >>> # xdoctest: +IGNORE_WANT("non-deterministic")
- >>> A = torch.randn(2, 3, 3)
- >>> A_LU, pivots = torch.lu(A)
- >>> A_LU
- tensor([[[ 1.3506, 2.5558, -0.0816],
- [ 0.1684, 1.1551, 0.1940],
- [ 0.1193, 0.6189, -0.5497]],
- [[ 0.4526, 1.2526, -0.3285],
- [-0.7988, 0.7175, -0.9701],
- [ 0.2634, -0.9255, -0.3459]]])
- >>> pivots
- tensor([[ 3, 3, 3],
- [ 3, 3, 3]], dtype=torch.int32)
- >>> A_LU, pivots, info = torch.lu(A, get_infos=True)
- >>> if info.nonzero().size(0) == 0:
- ... print('LU factorization succeeded for all samples!')
- LU factorization succeeded for all samples!
- """
- # If get_infos is True, then we don't need to check for errors and vice versa
- return torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos))
- if TYPE_CHECKING:
- _ListOrSeq = Sequence[Tensor]
- else:
- _ListOrSeq = list[Tensor]
- def _check_list_size(out_len: int, get_infos: bool, out: _ListOrSeq) -> None:
- get_infos_int = 1 if get_infos else 0
- if out_len - get_infos_int != 2:
- raise TypeError(
- f"expected tuple of {2 + int(get_infos)} elements but got {out_len}"
- )
- if not isinstance(out, (tuple, list)):
- raise TypeError(
- f"argument 'out' must be tuple of Tensors, not {type(out).__name__}"
- )
- def _lu_with_infos(A, pivot=True, get_infos=False, out=None):
- # type: (Tensor, bool, bool, Optional[tuple[Tensor, Tensor, Tensor]]) -> tuple[Tensor, Tensor, Tensor]
- if has_torch_function_unary(A):
- return handle_torch_function(
- lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out
- )
- result = _lu_impl(A, pivot, get_infos, out)
- if out is not None:
- _check_list_size(len(out), get_infos, out)
- for i in range(len(out)):
- out[i].resize_as_(result[i]).copy_(result[i])
- return out
- else:
- return result # A_LU, pivots, infos
- def _lu_no_infos(A, pivot=True, get_infos=False, out=None):
- # type: (Tensor, bool, bool, Optional[tuple[Tensor, Tensor]]) -> tuple[Tensor, Tensor]
- # need to check for torch_function here so that we exit if
- if has_torch_function_unary(A):
- return handle_torch_function(
- lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out
- )
- result = _lu_impl(A, pivot, get_infos, out)
- if out is not None:
- _check_list_size(len(out), get_infos, out)
- for i in range(len(out)):
- out[i].resize_as_(result[i]).copy_(result[i])
- return out
- else:
- return result[0], result[1] # A_LU, pivots
- # The return type of lu depends on `get_infos`, so in order to resolve the output type
- # of lu in TorchScript we need to statically know the value of `get_infos`
- lu = boolean_dispatch(
- arg_name="get_infos",
- arg_index=2,
- default=False,
- if_true=_lu_with_infos,
- if_false=_lu_no_infos,
- module_name=__name__,
- func_name="lu",
- )
- lu.__doc__ = _lu_impl.__doc__
- def align_tensors(*tensors):
- raise RuntimeError("`align_tensors` not yet implemented.")
|