__init__.py 54 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442
  1. import torch
  2. from torch._C import _add_docstr, _fft # type: ignore[attr-defined]
  3. from torch._torch_docs import common_args, factory_common_args
  4. __all__ = [
  5. "fft",
  6. "ifft",
  7. "fft2",
  8. "ifft2",
  9. "fftn",
  10. "ifftn",
  11. "rfft",
  12. "irfft",
  13. "rfft2",
  14. "irfft2",
  15. "rfftn",
  16. "irfftn",
  17. "hfft",
  18. "ihfft",
  19. "fftfreq",
  20. "rfftfreq",
  21. "fftshift",
  22. "ifftshift",
  23. "Tensor",
  24. ]
  25. Tensor = torch.Tensor
  26. # Note: This not only adds the doc strings for the spectral ops, but
  27. # connects the torch.fft Python namespace to the torch._C._fft builtins.
  28. fft = _add_docstr(
  29. _fft.fft_fft,
  30. r"""
  31. fft(input, n=None, dim=-1, norm=None, *, out=None) -> Tensor
  32. Computes the one dimensional discrete Fourier transform of :attr:`input`.
  33. Note:
  34. The Fourier domain representation of any real signal satisfies the
  35. Hermitian property: `X[i] = conj(X[-i])`. This function always returns both
  36. the positive and negative frequency terms even though, for real inputs, the
  37. negative frequencies are redundant. :func:`~torch.fft.rfft` returns the
  38. more compact one-sided representation where only the positive frequencies
  39. are returned.
  40. Note:
  41. Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
  42. However it only supports powers of 2 signal length in every transformed dimension.
  43. Args:
  44. input (Tensor): the input tensor
  45. n (int, optional): Signal length. If given, the input will either be zero-padded
  46. or trimmed to this length before computing the FFT.
  47. dim (int, optional): The dimension along which to take the one dimensional FFT.
  48. norm (str, optional): Normalization mode. For the forward transform
  49. (:func:`~torch.fft.fft`), these correspond to:
  50. * ``"forward"`` - normalize by ``1/n``
  51. * ``"backward"`` - no normalization
  52. * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the FFT orthonormal)
  53. Calling the backward transform (:func:`~torch.fft.ifft`) with the same
  54. normalization mode will apply an overall normalization of ``1/n`` between
  55. the two transforms. This is required to make :func:`~torch.fft.ifft`
  56. the exact inverse.
  57. Default is ``"backward"`` (no normalization).
  58. Keyword args:
  59. {out}
  60. Example:
  61. >>> t = torch.arange(4)
  62. >>> t
  63. tensor([0, 1, 2, 3])
  64. >>> torch.fft.fft(t)
  65. tensor([ 6.+0.j, -2.+2.j, -2.+0.j, -2.-2.j])
  66. >>> t = torch.tensor([0.+1.j, 2.+3.j, 4.+5.j, 6.+7.j])
  67. >>> torch.fft.fft(t)
  68. tensor([12.+16.j, -8.+0.j, -4.-4.j, 0.-8.j])
  69. """.format(**common_args),
  70. )
  71. ifft = _add_docstr(
  72. _fft.fft_ifft,
  73. r"""
  74. ifft(input, n=None, dim=-1, norm=None, *, out=None) -> Tensor
  75. Computes the one dimensional inverse discrete Fourier transform of :attr:`input`.
  76. Note:
  77. Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
  78. However it only supports powers of 2 signal length in every transformed dimension.
  79. Args:
  80. input (Tensor): the input tensor
  81. n (int, optional): Signal length. If given, the input will either be zero-padded
  82. or trimmed to this length before computing the IFFT.
  83. dim (int, optional): The dimension along which to take the one dimensional IFFT.
  84. norm (str, optional): Normalization mode. For the backward transform
  85. (:func:`~torch.fft.ifft`), these correspond to:
  86. * ``"forward"`` - no normalization
  87. * ``"backward"`` - normalize by ``1/n``
  88. * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the IFFT orthonormal)
  89. Calling the forward transform (:func:`~torch.fft.fft`) with the same
  90. normalization mode will apply an overall normalization of ``1/n`` between
  91. the two transforms. This is required to make :func:`~torch.fft.ifft`
  92. the exact inverse.
  93. Default is ``"backward"`` (normalize by ``1/n``).
  94. Keyword args:
  95. {out}
  96. Example:
  97. >>> t = torch.tensor([ 6.+0.j, -2.+2.j, -2.+0.j, -2.-2.j])
  98. >>> torch.fft.ifft(t)
  99. tensor([0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j])
  100. """.format(**common_args),
  101. )
  102. fft2 = _add_docstr(
  103. _fft.fft_fft2,
  104. r"""
  105. fft2(input, s=None, dim=(-2, -1), norm=None, *, out=None) -> Tensor
  106. Computes the 2 dimensional discrete Fourier transform of :attr:`input`.
  107. Equivalent to :func:`~torch.fft.fftn` but FFTs only the last two dimensions by default.
  108. Note:
  109. The Fourier domain representation of any real signal satisfies the
  110. Hermitian property: ``X[i, j] = conj(X[-i, -j])``. This
  111. function always returns all positive and negative frequency terms even
  112. though, for real inputs, half of these values are redundant.
  113. :func:`~torch.fft.rfft2` returns the more compact one-sided representation
  114. where only the positive frequencies of the last dimension are returned.
  115. Note:
  116. Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
  117. However it only supports powers of 2 signal length in every transformed dimensions.
  118. Args:
  119. input (Tensor): the input tensor
  120. s (Tuple[int], optional): Signal size in the transformed dimensions.
  121. If given, each dimension ``dim[i]`` will either be zero-padded or
  122. trimmed to the length ``s[i]`` before computing the FFT.
  123. If a length ``-1`` is specified, no padding is done in that dimension.
  124. Default: ``s = [input.size(d) for d in dim]``
  125. dim (Tuple[int], optional): Dimensions to be transformed.
  126. Default: last two dimensions.
  127. norm (str, optional): Normalization mode. For the forward transform
  128. (:func:`~torch.fft.fft2`), these correspond to:
  129. * ``"forward"`` - normalize by ``1/n``
  130. * ``"backward"`` - no normalization
  131. * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the FFT orthonormal)
  132. Where ``n = prod(s)`` is the logical FFT size.
  133. Calling the backward transform (:func:`~torch.fft.ifft2`) with the same
  134. normalization mode will apply an overall normalization of ``1/n``
  135. between the two transforms. This is required to make
  136. :func:`~torch.fft.ifft2` the exact inverse.
  137. Default is ``"backward"`` (no normalization).
  138. Keyword args:
  139. {out}
  140. Example:
  141. >>> x = torch.rand(10, 10, dtype=torch.complex64)
  142. >>> fft2 = torch.fft.fft2(x)
  143. The discrete Fourier transform is separable, so :func:`~torch.fft.fft2`
  144. here is equivalent to two one-dimensional :func:`~torch.fft.fft` calls:
  145. >>> two_ffts = torch.fft.fft(torch.fft.fft(x, dim=0), dim=1)
  146. >>> torch.testing.assert_close(fft2, two_ffts, check_stride=False)
  147. """.format(**common_args),
  148. )
  149. ifft2 = _add_docstr(
  150. _fft.fft_ifft2,
  151. r"""
  152. ifft2(input, s=None, dim=(-2, -1), norm=None, *, out=None) -> Tensor
  153. Computes the 2 dimensional inverse discrete Fourier transform of :attr:`input`.
  154. Equivalent to :func:`~torch.fft.ifftn` but IFFTs only the last two dimensions by default.
  155. Note:
  156. Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
  157. However it only supports powers of 2 signal length in every transformed dimensions.
  158. Args:
  159. input (Tensor): the input tensor
  160. s (Tuple[int], optional): Signal size in the transformed dimensions.
  161. If given, each dimension ``dim[i]`` will either be zero-padded or
  162. trimmed to the length ``s[i]`` before computing the IFFT.
  163. If a length ``-1`` is specified, no padding is done in that dimension.
  164. Default: ``s = [input.size(d) for d in dim]``
  165. dim (Tuple[int], optional): Dimensions to be transformed.
  166. Default: last two dimensions.
  167. norm (str, optional): Normalization mode. For the backward transform
  168. (:func:`~torch.fft.ifft2`), these correspond to:
  169. * ``"forward"`` - no normalization
  170. * ``"backward"`` - normalize by ``1/n``
  171. * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the IFFT orthonormal)
  172. Where ``n = prod(s)`` is the logical IFFT size.
  173. Calling the forward transform (:func:`~torch.fft.fft2`) with the same
  174. normalization mode will apply an overall normalization of ``1/n`` between
  175. the two transforms. This is required to make :func:`~torch.fft.ifft2`
  176. the exact inverse.
  177. Default is ``"backward"`` (normalize by ``1/n``).
  178. Keyword args:
  179. {out}
  180. Example:
  181. >>> x = torch.rand(10, 10, dtype=torch.complex64)
  182. >>> ifft2 = torch.fft.ifft2(x)
  183. The discrete Fourier transform is separable, so :func:`~torch.fft.ifft2`
  184. here is equivalent to two one-dimensional :func:`~torch.fft.ifft` calls:
  185. >>> two_iffts = torch.fft.ifft(torch.fft.ifft(x, dim=0), dim=1)
  186. >>> torch.testing.assert_close(ifft2, two_iffts, check_stride=False)
  187. """.format(**common_args),
  188. )
  189. fftn = _add_docstr(
  190. _fft.fft_fftn,
  191. r"""
  192. fftn(input, s=None, dim=None, norm=None, *, out=None) -> Tensor
  193. Computes the N dimensional discrete Fourier transform of :attr:`input`.
  194. Note:
  195. The Fourier domain representation of any real signal satisfies the
  196. Hermitian property: ``X[i_1, ..., i_n] = conj(X[-i_1, ..., -i_n])``. This
  197. function always returns all positive and negative frequency terms even
  198. though, for real inputs, half of these values are redundant.
  199. :func:`~torch.fft.rfftn` returns the more compact one-sided representation
  200. where only the positive frequencies of the last dimension are returned.
  201. Note:
  202. Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
  203. However it only supports powers of 2 signal length in every transformed dimensions.
  204. Args:
  205. input (Tensor): the input tensor
  206. s (Tuple[int], optional): Signal size in the transformed dimensions.
  207. If given, each dimension ``dim[i]`` will either be zero-padded or
  208. trimmed to the length ``s[i]`` before computing the FFT.
  209. If a length ``-1`` is specified, no padding is done in that dimension.
  210. Default: ``s = [input.size(d) for d in dim]``
  211. dim (Tuple[int], optional): Dimensions to be transformed.
  212. Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given.
  213. norm (str, optional): Normalization mode. For the forward transform
  214. (:func:`~torch.fft.fftn`), these correspond to:
  215. * ``"forward"`` - normalize by ``1/n``
  216. * ``"backward"`` - no normalization
  217. * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the FFT orthonormal)
  218. Where ``n = prod(s)`` is the logical FFT size.
  219. Calling the backward transform (:func:`~torch.fft.ifftn`) with the same
  220. normalization mode will apply an overall normalization of ``1/n``
  221. between the two transforms. This is required to make
  222. :func:`~torch.fft.ifftn` the exact inverse.
  223. Default is ``"backward"`` (no normalization).
  224. Keyword args:
  225. {out}
  226. Example:
  227. >>> x = torch.rand(10, 10, dtype=torch.complex64)
  228. >>> fftn = torch.fft.fftn(x)
  229. The discrete Fourier transform is separable, so :func:`~torch.fft.fftn`
  230. here is equivalent to two one-dimensional :func:`~torch.fft.fft` calls:
  231. >>> two_ffts = torch.fft.fft(torch.fft.fft(x, dim=0), dim=1)
  232. >>> torch.testing.assert_close(fftn, two_ffts, check_stride=False)
  233. """.format(**common_args),
  234. )
  235. ifftn = _add_docstr(
  236. _fft.fft_ifftn,
  237. r"""
  238. ifftn(input, s=None, dim=None, norm=None, *, out=None) -> Tensor
  239. Computes the N dimensional inverse discrete Fourier transform of :attr:`input`.
  240. Note:
  241. Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
  242. However it only supports powers of 2 signal length in every transformed dimensions.
  243. Args:
  244. input (Tensor): the input tensor
  245. s (Tuple[int], optional): Signal size in the transformed dimensions.
  246. If given, each dimension ``dim[i]`` will either be zero-padded or
  247. trimmed to the length ``s[i]`` before computing the IFFT.
  248. If a length ``-1`` is specified, no padding is done in that dimension.
  249. Default: ``s = [input.size(d) for d in dim]``
  250. dim (Tuple[int], optional): Dimensions to be transformed.
  251. Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given.
  252. norm (str, optional): Normalization mode. For the backward transform
  253. (:func:`~torch.fft.ifftn`), these correspond to:
  254. * ``"forward"`` - no normalization
  255. * ``"backward"`` - normalize by ``1/n``
  256. * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the IFFT orthonormal)
  257. Where ``n = prod(s)`` is the logical IFFT size.
  258. Calling the forward transform (:func:`~torch.fft.fftn`) with the same
  259. normalization mode will apply an overall normalization of ``1/n`` between
  260. the two transforms. This is required to make :func:`~torch.fft.ifftn`
  261. the exact inverse.
  262. Default is ``"backward"`` (normalize by ``1/n``).
  263. Keyword args:
  264. {out}
  265. Example:
  266. >>> x = torch.rand(10, 10, dtype=torch.complex64)
  267. >>> ifftn = torch.fft.ifftn(x)
  268. The discrete Fourier transform is separable, so :func:`~torch.fft.ifftn`
  269. here is equivalent to two one-dimensional :func:`~torch.fft.ifft` calls:
  270. >>> two_iffts = torch.fft.ifft(torch.fft.ifft(x, dim=0), dim=1)
  271. >>> torch.testing.assert_close(ifftn, two_iffts, check_stride=False)
  272. """.format(**common_args),
  273. )
  274. rfft = _add_docstr(
  275. _fft.fft_rfft,
  276. r"""
  277. rfft(input, n=None, dim=-1, norm=None, *, out=None) -> Tensor
  278. Computes the one dimensional Fourier transform of real-valued :attr:`input`.
  279. The FFT of a real signal is Hermitian-symmetric, ``X[i] = conj(X[-i])`` so
  280. the output contains only the positive frequencies below the Nyquist frequency.
  281. To compute the full output, use :func:`~torch.fft.fft`
  282. Note:
  283. Supports torch.half on CUDA with GPU Architecture SM53 or greater.
  284. However it only supports powers of 2 signal length in every transformed dimension.
  285. Args:
  286. input (Tensor): the real input tensor
  287. n (int, optional): Signal length. If given, the input will either be zero-padded
  288. or trimmed to this length before computing the real FFT.
  289. dim (int, optional): The dimension along which to take the one dimensional real FFT.
  290. norm (str, optional): Normalization mode. For the forward transform
  291. (:func:`~torch.fft.rfft`), these correspond to:
  292. * ``"forward"`` - normalize by ``1/n``
  293. * ``"backward"`` - no normalization
  294. * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the FFT orthonormal)
  295. Calling the backward transform (:func:`~torch.fft.irfft`) with the same
  296. normalization mode will apply an overall normalization of ``1/n`` between
  297. the two transforms. This is required to make :func:`~torch.fft.irfft`
  298. the exact inverse.
  299. Default is ``"backward"`` (no normalization).
  300. Keyword args:
  301. {out}
  302. Example:
  303. >>> t = torch.arange(4)
  304. >>> t
  305. tensor([0, 1, 2, 3])
  306. >>> torch.fft.rfft(t)
  307. tensor([ 6.+0.j, -2.+2.j, -2.+0.j])
  308. Compare against the full output from :func:`~torch.fft.fft`:
  309. >>> torch.fft.fft(t)
  310. tensor([ 6.+0.j, -2.+2.j, -2.+0.j, -2.-2.j])
  311. Notice that the symmetric element ``T[-1] == T[1].conj()`` is omitted.
  312. At the Nyquist frequency ``T[-2] == T[2]`` is it's own symmetric pair,
  313. and therefore must always be real-valued.
  314. """.format(**common_args),
  315. )
  316. irfft = _add_docstr(
  317. _fft.fft_irfft,
  318. r"""
  319. irfft(input, n=None, dim=-1, norm=None, *, out=None) -> Tensor
  320. Computes the inverse of :func:`~torch.fft.rfft`.
  321. :attr:`input` is interpreted as a one-sided Hermitian signal in the Fourier
  322. domain, as produced by :func:`~torch.fft.rfft`. By the Hermitian property, the
  323. output will be real-valued.
  324. Note:
  325. Some input frequencies must be real-valued to satisfy the Hermitian
  326. property. In these cases the imaginary component will be ignored.
  327. For example, any imaginary component in the zero-frequency term cannot
  328. be represented in a real output and so will always be ignored.
  329. Note:
  330. The correct interpretation of the Hermitian input depends on the length of
  331. the original data, as given by :attr:`n`. This is because each input shape
  332. could correspond to either an odd or even length signal. By default, the
  333. signal is assumed to be even length and odd signals will not round-trip
  334. properly. So, it is recommended to always pass the signal length :attr:`n`.
  335. Note:
  336. Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
  337. However it only supports powers of 2 signal length in every transformed dimension.
  338. With default arguments, size of the transformed dimension should be (2^n + 1) as argument
  339. `n` defaults to even output size = 2 * (transformed_dim_size - 1)
  340. Args:
  341. input (Tensor): the input tensor representing a half-Hermitian signal
  342. n (int, optional): Output signal length. This determines the length of the
  343. output signal. If given, the input will either be zero-padded or trimmed to this
  344. length before computing the real IFFT.
  345. Defaults to even output: ``n=2*(input.size(dim) - 1)``.
  346. dim (int, optional): The dimension along which to take the one dimensional real IFFT.
  347. norm (str, optional): Normalization mode. For the backward transform
  348. (:func:`~torch.fft.irfft`), these correspond to:
  349. * ``"forward"`` - no normalization
  350. * ``"backward"`` - normalize by ``1/n``
  351. * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the real IFFT orthonormal)
  352. Calling the forward transform (:func:`~torch.fft.rfft`) with the same
  353. normalization mode will apply an overall normalization of ``1/n`` between
  354. the two transforms. This is required to make :func:`~torch.fft.irfft`
  355. the exact inverse.
  356. Default is ``"backward"`` (normalize by ``1/n``).
  357. Keyword args:
  358. {out}
  359. Example:
  360. >>> t = torch.linspace(0, 1, 5)
  361. >>> t
  362. tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000])
  363. >>> T = torch.fft.rfft(t)
  364. >>> T
  365. tensor([ 2.5000+0.0000j, -0.6250+0.8602j, -0.6250+0.2031j])
  366. Without specifying the output length to :func:`~torch.fft.irfft`, the output
  367. will not round-trip properly because the input is odd-length:
  368. >>> torch.fft.irfft(T)
  369. tensor([0.1562, 0.3511, 0.7812, 1.2114])
  370. So, it is recommended to always pass the signal length :attr:`n`:
  371. >>> roundtrip = torch.fft.irfft(T, t.numel())
  372. >>> torch.testing.assert_close(roundtrip, t, check_stride=False)
  373. """.format(**common_args),
  374. )
  375. rfft2 = _add_docstr(
  376. _fft.fft_rfft2,
  377. r"""
  378. rfft2(input, s=None, dim=(-2, -1), norm=None, *, out=None) -> Tensor
  379. Computes the 2-dimensional discrete Fourier transform of real :attr:`input`.
  380. Equivalent to :func:`~torch.fft.rfftn` but FFTs only the last two dimensions by default.
  381. The FFT of a real signal is Hermitian-symmetric, ``X[i, j] = conj(X[-i, -j])``,
  382. so the full :func:`~torch.fft.fft2` output contains redundant information.
  383. :func:`~torch.fft.rfft2` instead omits the negative frequencies in the last
  384. dimension.
  385. Note:
  386. Supports torch.half on CUDA with GPU Architecture SM53 or greater.
  387. However it only supports powers of 2 signal length in every transformed dimensions.
  388. Args:
  389. input (Tensor): the input tensor
  390. s (Tuple[int], optional): Signal size in the transformed dimensions.
  391. If given, each dimension ``dim[i]`` will either be zero-padded or
  392. trimmed to the length ``s[i]`` before computing the real FFT.
  393. If a length ``-1`` is specified, no padding is done in that dimension.
  394. Default: ``s = [input.size(d) for d in dim]``
  395. dim (Tuple[int], optional): Dimensions to be transformed.
  396. Default: last two dimensions.
  397. norm (str, optional): Normalization mode. For the forward transform
  398. (:func:`~torch.fft.rfft2`), these correspond to:
  399. * ``"forward"`` - normalize by ``1/n``
  400. * ``"backward"`` - no normalization
  401. * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the real FFT orthonormal)
  402. Where ``n = prod(s)`` is the logical FFT size.
  403. Calling the backward transform (:func:`~torch.fft.irfft2`) with the same
  404. normalization mode will apply an overall normalization of ``1/n`` between
  405. the two transforms. This is required to make :func:`~torch.fft.irfft2`
  406. the exact inverse.
  407. Default is ``"backward"`` (no normalization).
  408. Keyword args:
  409. {out}
  410. Example:
  411. >>> t = torch.rand(10, 10)
  412. >>> rfft2 = torch.fft.rfft2(t)
  413. >>> rfft2.size()
  414. torch.Size([10, 6])
  415. Compared against the full output from :func:`~torch.fft.fft2`, we have all
  416. elements up to the Nyquist frequency.
  417. >>> fft2 = torch.fft.fft2(t)
  418. >>> torch.testing.assert_close(fft2[..., :6], rfft2, check_stride=False)
  419. The discrete Fourier transform is separable, so :func:`~torch.fft.rfft2`
  420. here is equivalent to a combination of :func:`~torch.fft.fft` and
  421. :func:`~torch.fft.rfft`:
  422. >>> two_ffts = torch.fft.fft(torch.fft.rfft(t, dim=1), dim=0)
  423. >>> torch.testing.assert_close(rfft2, two_ffts, check_stride=False)
  424. """.format(**common_args),
  425. )
  426. irfft2 = _add_docstr(
  427. _fft.fft_irfft2,
  428. r"""
  429. irfft2(input, s=None, dim=(-2, -1), norm=None, *, out=None) -> Tensor
  430. Computes the inverse of :func:`~torch.fft.rfft2`.
  431. Equivalent to :func:`~torch.fft.irfftn` but IFFTs only the last two dimensions by default.
  432. :attr:`input` is interpreted as a one-sided Hermitian signal in the Fourier
  433. domain, as produced by :func:`~torch.fft.rfft2`. By the Hermitian property, the
  434. output will be real-valued.
  435. Note:
  436. Some input frequencies must be real-valued to satisfy the Hermitian
  437. property. In these cases the imaginary component will be ignored.
  438. For example, any imaginary component in the zero-frequency term cannot
  439. be represented in a real output and so will always be ignored.
  440. Note:
  441. The correct interpretation of the Hermitian input depends on the length of
  442. the original data, as given by :attr:`s`. This is because each input shape
  443. could correspond to either an odd or even length signal. By default, the
  444. signal is assumed to be even length and odd signals will not round-trip
  445. properly. So, it is recommended to always pass the signal shape :attr:`s`.
  446. Note:
  447. Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
  448. However it only supports powers of 2 signal length in every transformed dimensions.
  449. With default arguments, the size of last dimension should be (2^n + 1) as argument
  450. `s` defaults to even output size = 2 * (last_dim_size - 1)
  451. Args:
  452. input (Tensor): the input tensor
  453. s (Tuple[int], optional): Signal size in the transformed dimensions.
  454. If given, each dimension ``dim[i]`` will either be zero-padded or
  455. trimmed to the length ``s[i]`` before computing the real FFT.
  456. If a length ``-1`` is specified, no padding is done in that dimension.
  457. Defaults to even output in the last dimension:
  458. ``s[-1] = 2*(input.size(dim[-1]) - 1)``.
  459. dim (Tuple[int], optional): Dimensions to be transformed.
  460. The last dimension must be the half-Hermitian compressed dimension.
  461. Default: last two dimensions.
  462. norm (str, optional): Normalization mode. For the backward transform
  463. (:func:`~torch.fft.irfft2`), these correspond to:
  464. * ``"forward"`` - no normalization
  465. * ``"backward"`` - normalize by ``1/n``
  466. * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the real IFFT orthonormal)
  467. Where ``n = prod(s)`` is the logical IFFT size.
  468. Calling the forward transform (:func:`~torch.fft.rfft2`) with the same
  469. normalization mode will apply an overall normalization of ``1/n`` between
  470. the two transforms. This is required to make :func:`~torch.fft.irfft2`
  471. the exact inverse.
  472. Default is ``"backward"`` (normalize by ``1/n``).
  473. Keyword args:
  474. {out}
  475. Example:
  476. >>> t = torch.rand(10, 9)
  477. >>> T = torch.fft.rfft2(t)
  478. Without specifying the output length to :func:`~torch.fft.irfft2`, the output
  479. will not round-trip properly because the input is odd-length in the last
  480. dimension:
  481. >>> torch.fft.irfft2(T).size()
  482. torch.Size([10, 8])
  483. So, it is recommended to always pass the signal shape :attr:`s`.
  484. >>> roundtrip = torch.fft.irfft2(T, t.size())
  485. >>> roundtrip.size()
  486. torch.Size([10, 9])
  487. >>> torch.testing.assert_close(roundtrip, t, check_stride=False)
  488. """.format(**common_args),
  489. )
  490. rfftn = _add_docstr(
  491. _fft.fft_rfftn,
  492. r"""
  493. rfftn(input, s=None, dim=None, norm=None, *, out=None) -> Tensor
  494. Computes the N-dimensional discrete Fourier transform of real :attr:`input`.
  495. The FFT of a real signal is Hermitian-symmetric,
  496. ``X[i_1, ..., i_n] = conj(X[-i_1, ..., -i_n])`` so the full
  497. :func:`~torch.fft.fftn` output contains redundant information.
  498. :func:`~torch.fft.rfftn` instead omits the negative frequencies in the
  499. last dimension.
  500. Note:
  501. Supports torch.half on CUDA with GPU Architecture SM53 or greater.
  502. However it only supports powers of 2 signal length in every transformed dimensions.
  503. Args:
  504. input (Tensor): the input tensor
  505. s (Tuple[int], optional): Signal size in the transformed dimensions.
  506. If given, each dimension ``dim[i]`` will either be zero-padded or
  507. trimmed to the length ``s[i]`` before computing the real FFT.
  508. If a length ``-1`` is specified, no padding is done in that dimension.
  509. Default: ``s = [input.size(d) for d in dim]``
  510. dim (Tuple[int], optional): Dimensions to be transformed.
  511. Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given.
  512. norm (str, optional): Normalization mode. For the forward transform
  513. (:func:`~torch.fft.rfftn`), these correspond to:
  514. * ``"forward"`` - normalize by ``1/n``
  515. * ``"backward"`` - no normalization
  516. * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the real FFT orthonormal)
  517. Where ``n = prod(s)`` is the logical FFT size.
  518. Calling the backward transform (:func:`~torch.fft.irfftn`) with the same
  519. normalization mode will apply an overall normalization of ``1/n`` between
  520. the two transforms. This is required to make :func:`~torch.fft.irfftn`
  521. the exact inverse.
  522. Default is ``"backward"`` (no normalization).
  523. Keyword args:
  524. {out}
  525. Example:
  526. >>> t = torch.rand(10, 10)
  527. >>> rfftn = torch.fft.rfftn(t)
  528. >>> rfftn.size()
  529. torch.Size([10, 6])
  530. Compared against the full output from :func:`~torch.fft.fftn`, we have all
  531. elements up to the Nyquist frequency.
  532. >>> fftn = torch.fft.fftn(t)
  533. >>> torch.testing.assert_close(fftn[..., :6], rfftn, check_stride=False)
  534. The discrete Fourier transform is separable, so :func:`~torch.fft.rfftn`
  535. here is equivalent to a combination of :func:`~torch.fft.fft` and
  536. :func:`~torch.fft.rfft`:
  537. >>> two_ffts = torch.fft.fft(torch.fft.rfft(t, dim=1), dim=0)
  538. >>> torch.testing.assert_close(rfftn, two_ffts, check_stride=False)
  539. """.format(**common_args),
  540. )
  541. irfftn = _add_docstr(
  542. _fft.fft_irfftn,
  543. r"""
  544. irfftn(input, s=None, dim=None, norm=None, *, out=None) -> Tensor
  545. Computes the inverse of :func:`~torch.fft.rfftn`.
  546. :attr:`input` is interpreted as a one-sided Hermitian signal in the Fourier
  547. domain, as produced by :func:`~torch.fft.rfftn`. By the Hermitian property, the
  548. output will be real-valued.
  549. Note:
  550. Some input frequencies must be real-valued to satisfy the Hermitian
  551. property. In these cases the imaginary component will be ignored.
  552. For example, any imaginary component in the zero-frequency term cannot
  553. be represented in a real output and so will always be ignored.
  554. Note:
  555. The correct interpretation of the Hermitian input depends on the length of
  556. the original data, as given by :attr:`s`. This is because each input shape
  557. could correspond to either an odd or even length signal. By default, the
  558. signal is assumed to be even length and odd signals will not round-trip
  559. properly. So, it is recommended to always pass the signal shape :attr:`s`.
  560. Note:
  561. Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
  562. However it only supports powers of 2 signal length in every transformed dimensions.
  563. With default arguments, the size of last dimension should be (2^n + 1) as argument
  564. `s` defaults to even output size = 2 * (last_dim_size - 1)
  565. Args:
  566. input (Tensor): the input tensor
  567. s (Tuple[int], optional): Signal size in the transformed dimensions.
  568. If given, each dimension ``dim[i]`` will either be zero-padded or
  569. trimmed to the length ``s[i]`` before computing the real FFT.
  570. If a length ``-1`` is specified, no padding is done in that dimension.
  571. Defaults to even output in the last dimension:
  572. ``s[-1] = 2*(input.size(dim[-1]) - 1)``.
  573. dim (Tuple[int], optional): Dimensions to be transformed.
  574. The last dimension must be the half-Hermitian compressed dimension.
  575. Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given.
  576. norm (str, optional): Normalization mode. For the backward transform
  577. (:func:`~torch.fft.irfftn`), these correspond to:
  578. * ``"forward"`` - no normalization
  579. * ``"backward"`` - normalize by ``1/n``
  580. * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the real IFFT orthonormal)
  581. Where ``n = prod(s)`` is the logical IFFT size.
  582. Calling the forward transform (:func:`~torch.fft.rfftn`) with the same
  583. normalization mode will apply an overall normalization of ``1/n`` between
  584. the two transforms. This is required to make :func:`~torch.fft.irfftn`
  585. the exact inverse.
  586. Default is ``"backward"`` (normalize by ``1/n``).
  587. Keyword args:
  588. {out}
  589. Example:
  590. >>> t = torch.rand(10, 9)
  591. >>> T = torch.fft.rfftn(t)
  592. Without specifying the output length to :func:`~torch.fft.irfft`, the output
  593. will not round-trip properly because the input is odd-length in the last
  594. dimension:
  595. >>> torch.fft.irfftn(T).size()
  596. torch.Size([10, 8])
  597. So, it is recommended to always pass the signal shape :attr:`s`.
  598. >>> roundtrip = torch.fft.irfftn(T, t.size())
  599. >>> roundtrip.size()
  600. torch.Size([10, 9])
  601. >>> torch.testing.assert_close(roundtrip, t, check_stride=False)
  602. """.format(**common_args),
  603. )
  604. hfft = _add_docstr(
  605. _fft.fft_hfft,
  606. r"""
  607. hfft(input, n=None, dim=-1, norm=None, *, out=None) -> Tensor
  608. Computes the one dimensional discrete Fourier transform of a Hermitian
  609. symmetric :attr:`input` signal.
  610. Note:
  611. :func:`~torch.fft.hfft`/:func:`~torch.fft.ihfft` are analogous to
  612. :func:`~torch.fft.rfft`/:func:`~torch.fft.irfft`. The real FFT expects
  613. a real signal in the time-domain and gives a Hermitian symmetry in the
  614. frequency-domain. The Hermitian FFT is the opposite; Hermitian symmetric in
  615. the time-domain and real-valued in the frequency-domain. For this reason,
  616. special care needs to be taken with the length argument :attr:`n`, in the
  617. same way as with :func:`~torch.fft.irfft`.
  618. Note:
  619. Because the signal is Hermitian in the time-domain, the result will be
  620. real in the frequency domain. Note that some input frequencies must be
  621. real-valued to satisfy the Hermitian property. In these cases the imaginary
  622. component will be ignored. For example, any imaginary component in
  623. ``input[0]`` would result in one or more complex frequency terms which
  624. cannot be represented in a real output and so will always be ignored.
  625. Note:
  626. The correct interpretation of the Hermitian input depends on the length of
  627. the original data, as given by :attr:`n`. This is because each input shape
  628. could correspond to either an odd or even length signal. By default, the
  629. signal is assumed to be even length and odd signals will not round-trip
  630. properly. So, it is recommended to always pass the signal length :attr:`n`.
  631. Note:
  632. Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
  633. However it only supports powers of 2 signal length in every transformed dimension.
  634. With default arguments, size of the transformed dimension should be (2^n + 1) as argument
  635. `n` defaults to even output size = 2 * (transformed_dim_size - 1)
  636. Args:
  637. input (Tensor): the input tensor representing a half-Hermitian signal
  638. n (int, optional): Output signal length. This determines the length of the
  639. real output. If given, the input will either be zero-padded or trimmed to this
  640. length before computing the Hermitian FFT.
  641. Defaults to even output: ``n=2*(input.size(dim) - 1)``.
  642. dim (int, optional): The dimension along which to take the one dimensional Hermitian FFT.
  643. norm (str, optional): Normalization mode. For the forward transform
  644. (:func:`~torch.fft.hfft`), these correspond to:
  645. * ``"forward"`` - normalize by ``1/n``
  646. * ``"backward"`` - no normalization
  647. * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the Hermitian FFT orthonormal)
  648. Calling the backward transform (:func:`~torch.fft.ihfft`) with the same
  649. normalization mode will apply an overall normalization of ``1/n`` between
  650. the two transforms. This is required to make :func:`~torch.fft.ihfft`
  651. the exact inverse.
  652. Default is ``"backward"`` (no normalization).
  653. Keyword args:
  654. {out}
  655. Example:
  656. Taking a real-valued frequency signal and bringing it into the time domain
  657. gives Hermitian symmetric output:
  658. >>> t = torch.linspace(0, 1, 5)
  659. >>> t
  660. tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000])
  661. >>> T = torch.fft.ifft(t)
  662. >>> T
  663. tensor([ 0.5000-0.0000j, -0.1250-0.1720j, -0.1250-0.0406j, -0.1250+0.0406j,
  664. -0.1250+0.1720j])
  665. Note that ``T[1] == T[-1].conj()`` and ``T[2] == T[-2].conj()`` is
  666. redundant. We can thus compute the forward transform without considering
  667. negative frequencies:
  668. >>> torch.fft.hfft(T[:3], n=5)
  669. tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000])
  670. Like with :func:`~torch.fft.irfft`, the output length must be given in order
  671. to recover an even length output:
  672. >>> torch.fft.hfft(T[:3])
  673. tensor([0.1250, 0.2809, 0.6250, 0.9691])
  674. """.format(**common_args),
  675. )
  676. ihfft = _add_docstr(
  677. _fft.fft_ihfft,
  678. r"""
  679. ihfft(input, n=None, dim=-1, norm=None, *, out=None) -> Tensor
  680. Computes the inverse of :func:`~torch.fft.hfft`.
  681. :attr:`input` must be a real-valued signal, interpreted in the Fourier domain.
  682. The IFFT of a real signal is Hermitian-symmetric, ``X[i] = conj(X[-i])``.
  683. :func:`~torch.fft.ihfft` represents this in the one-sided form where only the
  684. positive frequencies below the Nyquist frequency are included. To compute the
  685. full output, use :func:`~torch.fft.ifft`.
  686. Note:
  687. Supports torch.half on CUDA with GPU Architecture SM53 or greater.
  688. However it only supports powers of 2 signal length in every transformed dimension.
  689. Args:
  690. input (Tensor): the real input tensor
  691. n (int, optional): Signal length. If given, the input will either be zero-padded
  692. or trimmed to this length before computing the Hermitian IFFT.
  693. dim (int, optional): The dimension along which to take the one dimensional Hermitian IFFT.
  694. norm (str, optional): Normalization mode. For the backward transform
  695. (:func:`~torch.fft.ihfft`), these correspond to:
  696. * ``"forward"`` - no normalization
  697. * ``"backward"`` - normalize by ``1/n``
  698. * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the IFFT orthonormal)
  699. Calling the forward transform (:func:`~torch.fft.hfft`) with the same
  700. normalization mode will apply an overall normalization of ``1/n`` between
  701. the two transforms. This is required to make :func:`~torch.fft.ihfft`
  702. the exact inverse.
  703. Default is ``"backward"`` (normalize by ``1/n``).
  704. Keyword args:
  705. {out}
  706. Example:
  707. >>> t = torch.arange(5)
  708. >>> t
  709. tensor([0, 1, 2, 3, 4])
  710. >>> torch.fft.ihfft(t)
  711. tensor([ 2.0000-0.0000j, -0.5000-0.6882j, -0.5000-0.1625j])
  712. Compare against the full output from :func:`~torch.fft.ifft`:
  713. >>> torch.fft.ifft(t)
  714. tensor([ 2.0000-0.0000j, -0.5000-0.6882j, -0.5000-0.1625j, -0.5000+0.1625j,
  715. -0.5000+0.6882j])
  716. """.format(**common_args),
  717. )
  718. hfft2 = _add_docstr(
  719. _fft.fft_hfft2,
  720. r"""
  721. hfft2(input, s=None, dim=(-2, -1), norm=None, *, out=None) -> Tensor
  722. Computes the 2-dimensional discrete Fourier transform of a Hermitian symmetric
  723. :attr:`input` signal. Equivalent to :func:`~torch.fft.hfftn` but only
  724. transforms the last two dimensions by default.
  725. :attr:`input` is interpreted as a one-sided Hermitian signal in the time
  726. domain. By the Hermitian property, the Fourier transform will be real-valued.
  727. Note:
  728. Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
  729. However it only supports powers of 2 signal length in every transformed dimensions.
  730. With default arguments, the size of last dimension should be (2^n + 1) as argument
  731. `s` defaults to even output size = 2 * (last_dim_size - 1)
  732. Args:
  733. input (Tensor): the input tensor
  734. s (Tuple[int], optional): Signal size in the transformed dimensions.
  735. If given, each dimension ``dim[i]`` will either be zero-padded or
  736. trimmed to the length ``s[i]`` before computing the Hermitian FFT.
  737. If a length ``-1`` is specified, no padding is done in that dimension.
  738. Defaults to even output in the last dimension:
  739. ``s[-1] = 2*(input.size(dim[-1]) - 1)``.
  740. dim (Tuple[int], optional): Dimensions to be transformed.
  741. The last dimension must be the half-Hermitian compressed dimension.
  742. Default: last two dimensions.
  743. norm (str, optional): Normalization mode. For the forward transform
  744. (:func:`~torch.fft.hfft2`), these correspond to:
  745. * ``"forward"`` - normalize by ``1/n``
  746. * ``"backward"`` - no normalization
  747. * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the Hermitian FFT orthonormal)
  748. Where ``n = prod(s)`` is the logical FFT size.
  749. Calling the backward transform (:func:`~torch.fft.ihfft2`) with the same
  750. normalization mode will apply an overall normalization of ``1/n`` between
  751. the two transforms. This is required to make :func:`~torch.fft.ihfft2`
  752. the exact inverse.
  753. Default is ``"backward"`` (no normalization).
  754. Keyword args:
  755. {out}
  756. Example:
  757. Starting from a real frequency-space signal, we can generate a
  758. Hermitian-symmetric time-domain signal:
  759. >>> T = torch.rand(10, 9)
  760. >>> t = torch.fft.ihfft2(T)
  761. Without specifying the output length to :func:`~torch.fft.hfftn`, the
  762. output will not round-trip properly because the input is odd-length in the
  763. last dimension:
  764. >>> torch.fft.hfft2(t).size()
  765. torch.Size([10, 10])
  766. So, it is recommended to always pass the signal shape :attr:`s`.
  767. >>> roundtrip = torch.fft.hfft2(t, T.size())
  768. >>> roundtrip.size()
  769. torch.Size([10, 9])
  770. >>> torch.allclose(roundtrip, T)
  771. True
  772. """.format(**common_args),
  773. )
  774. ihfft2 = _add_docstr(
  775. _fft.fft_ihfft2,
  776. r"""
  777. ihfft2(input, s=None, dim=(-2, -1), norm=None, *, out=None) -> Tensor
  778. Computes the 2-dimensional inverse discrete Fourier transform of real
  779. :attr:`input`. Equivalent to :func:`~torch.fft.ihfftn` but transforms only the
  780. two last dimensions by default.
  781. Note:
  782. Supports torch.half on CUDA with GPU Architecture SM53 or greater.
  783. However it only supports powers of 2 signal length in every transformed dimensions.
  784. Args:
  785. input (Tensor): the input tensor
  786. s (Tuple[int], optional): Signal size in the transformed dimensions.
  787. If given, each dimension ``dim[i]`` will either be zero-padded or
  788. trimmed to the length ``s[i]`` before computing the Hermitian IFFT.
  789. If a length ``-1`` is specified, no padding is done in that dimension.
  790. Default: ``s = [input.size(d) for d in dim]``
  791. dim (Tuple[int], optional): Dimensions to be transformed.
  792. Default: last two dimensions.
  793. norm (str, optional): Normalization mode. For the backward transform
  794. (:func:`~torch.fft.ihfft2`), these correspond to:
  795. * ``"forward"`` - no normalization
  796. * ``"backward"`` - normalize by ``1/n``
  797. * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the Hermitian IFFT orthonormal)
  798. Where ``n = prod(s)`` is the logical IFFT size.
  799. Calling the forward transform (:func:`~torch.fft.hfft2`) with the same
  800. normalization mode will apply an overall normalization of ``1/n`` between
  801. the two transforms. This is required to make :func:`~torch.fft.ihfft2`
  802. the exact inverse.
  803. Default is ``"backward"`` (normalize by ``1/n``).
  804. Keyword args:
  805. {out}
  806. Example:
  807. >>> T = torch.rand(10, 10)
  808. >>> t = torch.fft.ihfft2(t)
  809. >>> t.size()
  810. torch.Size([10, 6])
  811. Compared against the full output from :func:`~torch.fft.ifft2`, the
  812. Hermitian time-space signal takes up only half the space.
  813. >>> fftn = torch.fft.ifft2(t)
  814. >>> torch.allclose(fftn[..., :6], rfftn)
  815. True
  816. The discrete Fourier transform is separable, so :func:`~torch.fft.ihfft2`
  817. here is equivalent to a combination of :func:`~torch.fft.ifft` and
  818. :func:`~torch.fft.ihfft`:
  819. >>> two_ffts = torch.fft.ifft(torch.fft.ihfft(t, dim=1), dim=0)
  820. >>> torch.allclose(t, two_ffts)
  821. True
  822. """.format(**common_args),
  823. )
  824. hfftn = _add_docstr(
  825. _fft.fft_hfftn,
  826. r"""
  827. hfftn(input, s=None, dim=None, norm=None, *, out=None) -> Tensor
  828. Computes the n-dimensional discrete Fourier transform of a Hermitian symmetric
  829. :attr:`input` signal.
  830. :attr:`input` is interpreted as a one-sided Hermitian signal in the time
  831. domain. By the Hermitian property, the Fourier transform will be real-valued.
  832. Note:
  833. :func:`~torch.fft.hfftn`/:func:`~torch.fft.ihfftn` are analogous to
  834. :func:`~torch.fft.rfftn`/:func:`~torch.fft.irfftn`. The real FFT expects
  835. a real signal in the time-domain and gives Hermitian symmetry in the
  836. frequency-domain. The Hermitian FFT is the opposite; Hermitian symmetric in
  837. the time-domain and real-valued in the frequency-domain. For this reason,
  838. special care needs to be taken with the shape argument :attr:`s`, in the
  839. same way as with :func:`~torch.fft.irfftn`.
  840. Note:
  841. Some input frequencies must be real-valued to satisfy the Hermitian
  842. property. In these cases the imaginary component will be ignored.
  843. For example, any imaginary component in the zero-frequency term cannot
  844. be represented in a real output and so will always be ignored.
  845. Note:
  846. The correct interpretation of the Hermitian input depends on the length of
  847. the original data, as given by :attr:`s`. This is because each input shape
  848. could correspond to either an odd or even length signal. By default, the
  849. signal is assumed to be even length and odd signals will not round-trip
  850. properly. It is recommended to always pass the signal shape :attr:`s`.
  851. Note:
  852. Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater.
  853. However it only supports powers of 2 signal length in every transformed dimensions.
  854. With default arguments, the size of last dimension should be (2^n + 1) as argument
  855. `s` defaults to even output size = 2 * (last_dim_size - 1)
  856. Args:
  857. input (Tensor): the input tensor
  858. s (Tuple[int], optional): Signal size in the transformed dimensions.
  859. If given, each dimension ``dim[i]`` will either be zero-padded or
  860. trimmed to the length ``s[i]`` before computing the real FFT.
  861. If a length ``-1`` is specified, no padding is done in that dimension.
  862. Defaults to even output in the last dimension:
  863. ``s[-1] = 2*(input.size(dim[-1]) - 1)``.
  864. dim (Tuple[int], optional): Dimensions to be transformed.
  865. The last dimension must be the half-Hermitian compressed dimension.
  866. Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given.
  867. norm (str, optional): Normalization mode. For the forward transform
  868. (:func:`~torch.fft.hfftn`), these correspond to:
  869. * ``"forward"`` - normalize by ``1/n``
  870. * ``"backward"`` - no normalization
  871. * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the Hermitian FFT orthonormal)
  872. Where ``n = prod(s)`` is the logical FFT size.
  873. Calling the backward transform (:func:`~torch.fft.ihfftn`) with the same
  874. normalization mode will apply an overall normalization of ``1/n`` between
  875. the two transforms. This is required to make :func:`~torch.fft.ihfftn`
  876. the exact inverse.
  877. Default is ``"backward"`` (no normalization).
  878. Keyword args:
  879. {out}
  880. Example:
  881. Starting from a real frequency-space signal, we can generate a
  882. Hermitian-symmetric time-domain signal:
  883. >>> T = torch.rand(10, 9)
  884. >>> t = torch.fft.ihfftn(T)
  885. Without specifying the output length to :func:`~torch.fft.hfftn`, the
  886. output will not round-trip properly because the input is odd-length in the
  887. last dimension:
  888. >>> torch.fft.hfftn(t).size()
  889. torch.Size([10, 10])
  890. So, it is recommended to always pass the signal shape :attr:`s`.
  891. >>> roundtrip = torch.fft.hfftn(t, T.size())
  892. >>> roundtrip.size()
  893. torch.Size([10, 9])
  894. >>> torch.allclose(roundtrip, T)
  895. True
  896. """.format(**common_args),
  897. )
  898. ihfftn = _add_docstr(
  899. _fft.fft_ihfftn,
  900. r"""
  901. ihfftn(input, s=None, dim=None, norm=None, *, out=None) -> Tensor
  902. Computes the N-dimensional inverse discrete Fourier transform of real :attr:`input`.
  903. :attr:`input` must be a real-valued signal, interpreted in the Fourier domain.
  904. The n-dimensional IFFT of a real signal is Hermitian-symmetric,
  905. ``X[i, j, ...] = conj(X[-i, -j, ...])``. :func:`~torch.fft.ihfftn` represents
  906. this in the one-sided form where only the positive frequencies below the
  907. Nyquist frequency are included in the last signal dimension. To compute the
  908. full output, use :func:`~torch.fft.ifftn`.
  909. Note:
  910. Supports torch.half on CUDA with GPU Architecture SM53 or greater.
  911. However it only supports powers of 2 signal length in every transformed dimensions.
  912. Args:
  913. input (Tensor): the input tensor
  914. s (Tuple[int], optional): Signal size in the transformed dimensions.
  915. If given, each dimension ``dim[i]`` will either be zero-padded or
  916. trimmed to the length ``s[i]`` before computing the Hermitian IFFT.
  917. If a length ``-1`` is specified, no padding is done in that dimension.
  918. Default: ``s = [input.size(d) for d in dim]``
  919. dim (Tuple[int], optional): Dimensions to be transformed.
  920. Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given.
  921. norm (str, optional): Normalization mode. For the backward transform
  922. (:func:`~torch.fft.ihfftn`), these correspond to:
  923. * ``"forward"`` - no normalization
  924. * ``"backward"`` - normalize by ``1/n``
  925. * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the Hermitian IFFT orthonormal)
  926. Where ``n = prod(s)`` is the logical IFFT size.
  927. Calling the forward transform (:func:`~torch.fft.hfftn`) with the same
  928. normalization mode will apply an overall normalization of ``1/n`` between
  929. the two transforms. This is required to make :func:`~torch.fft.ihfftn`
  930. the exact inverse.
  931. Default is ``"backward"`` (normalize by ``1/n``).
  932. Keyword args:
  933. {out}
  934. Example:
  935. >>> T = torch.rand(10, 10)
  936. >>> ihfftn = torch.fft.ihfftn(T)
  937. >>> ihfftn.size()
  938. torch.Size([10, 6])
  939. Compared against the full output from :func:`~torch.fft.ifftn`, we have all
  940. elements up to the Nyquist frequency.
  941. >>> ifftn = torch.fft.ifftn(t)
  942. >>> torch.allclose(ifftn[..., :6], ihfftn)
  943. True
  944. The discrete Fourier transform is separable, so :func:`~torch.fft.ihfftn`
  945. here is equivalent to a combination of :func:`~torch.fft.ihfft` and
  946. :func:`~torch.fft.ifft`:
  947. >>> two_iffts = torch.fft.ifft(torch.fft.ihfft(t, dim=1), dim=0)
  948. >>> torch.allclose(ihfftn, two_iffts)
  949. True
  950. """.format(**common_args),
  951. )
  952. fftfreq = _add_docstr(
  953. _fft.fft_fftfreq,
  954. r"""
  955. fftfreq(n, d=1.0, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor
  956. Computes the discrete Fourier Transform sample frequencies for a signal of size :attr:`n`.
  957. Note:
  958. By convention, :func:`~torch.fft.fft` returns positive frequency terms
  959. first, followed by the negative frequencies in reverse order, so that
  960. ``f[-i]`` for all :math:`0 < i \leq n/2`` in Python gives the negative
  961. frequency terms. For an FFT of length :attr:`n` and with inputs spaced in
  962. length unit :attr:`d`, the frequencies are::
  963. f = [0, 1, ..., (n - 1) // 2, -(n // 2), ..., -1] / (d * n)
  964. Note:
  965. For even lengths, the Nyquist frequency at ``f[n/2]`` can be thought of as
  966. either negative or positive. :func:`~torch.fft.fftfreq` follows NumPy's
  967. convention of taking it to be negative.
  968. Args:
  969. n (int): the FFT length
  970. d (float, optional): The sampling length scale.
  971. The spacing between individual samples of the FFT input.
  972. The default assumes unit spacing, dividing that result by the actual
  973. spacing gives the result in physical frequency units.
  974. Keyword Args:
  975. {out}
  976. {dtype}
  977. {layout}
  978. {device}
  979. {requires_grad}
  980. Example:
  981. >>> torch.fft.fftfreq(5)
  982. tensor([ 0.0000, 0.2000, 0.4000, -0.4000, -0.2000])
  983. For even input, we can see the Nyquist frequency at ``f[2]`` is given as
  984. negative:
  985. >>> torch.fft.fftfreq(4)
  986. tensor([ 0.0000, 0.2500, -0.5000, -0.2500])
  987. """.format(**factory_common_args),
  988. )
  989. rfftfreq = _add_docstr(
  990. _fft.fft_rfftfreq,
  991. r"""
  992. rfftfreq(n, d=1.0, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor
  993. Computes the sample frequencies for :func:`~torch.fft.rfft` with a signal of size :attr:`n`.
  994. Note:
  995. :func:`~torch.fft.rfft` returns Hermitian one-sided output, so only the
  996. positive frequency terms are returned. For a real FFT of length :attr:`n`
  997. and with inputs spaced in length unit :attr:`d`, the frequencies are::
  998. f = torch.arange((n + 1) // 2) / (d * n)
  999. Note:
  1000. For even lengths, the Nyquist frequency at ``f[n/2]`` can be thought of as
  1001. either negative or positive. Unlike :func:`~torch.fft.fftfreq`,
  1002. :func:`~torch.fft.rfftfreq` always returns it as positive.
  1003. Args:
  1004. n (int): the real FFT length
  1005. d (float, optional): The sampling length scale.
  1006. The spacing between individual samples of the FFT input.
  1007. The default assumes unit spacing, dividing that result by the actual
  1008. spacing gives the result in physical frequency units.
  1009. Keyword Args:
  1010. {out}
  1011. {dtype}
  1012. {layout}
  1013. {device}
  1014. {requires_grad}
  1015. Example:
  1016. >>> torch.fft.rfftfreq(5)
  1017. tensor([0.0000, 0.2000, 0.4000])
  1018. >>> torch.fft.rfftfreq(4)
  1019. tensor([0.0000, 0.2500, 0.5000])
  1020. Compared to the output from :func:`~torch.fft.fftfreq`, we see that the
  1021. Nyquist frequency at ``f[2]`` has changed sign:
  1022. >>> torch.fft.fftfreq(4)
  1023. tensor([ 0.0000, 0.2500, -0.5000, -0.2500])
  1024. """.format(**factory_common_args),
  1025. )
  1026. fftshift = _add_docstr(
  1027. _fft.fft_fftshift,
  1028. r"""
  1029. fftshift(input, dim=None) -> Tensor
  1030. Reorders n-dimensional FFT data, as provided by :func:`~torch.fft.fftn`, to have
  1031. negative frequency terms first.
  1032. This performs a periodic shift of n-dimensional data such that the origin
  1033. ``(0, ..., 0)`` is moved to the center of the tensor. Specifically, to
  1034. ``input.shape[dim] // 2`` in each selected dimension.
  1035. Note:
  1036. By convention, the FFT returns positive frequency terms first, followed by
  1037. the negative frequencies in reverse order, so that ``f[-i]`` for all
  1038. :math:`0 < i \leq n/2` in Python gives the negative frequency terms.
  1039. :func:`~torch.fft.fftshift` rearranges all frequencies into ascending order
  1040. from negative to positive with the zero-frequency term in the center.
  1041. Note:
  1042. For even lengths, the Nyquist frequency at ``f[n/2]`` can be thought of as
  1043. either negative or positive. :func:`~torch.fft.fftshift` always puts the
  1044. Nyquist term at the 0-index. This is the same convention used by
  1045. :func:`~torch.fft.fftfreq`.
  1046. Args:
  1047. input (Tensor): the tensor in FFT order
  1048. dim (int, Tuple[int], optional): The dimensions to rearrange.
  1049. Only dimensions specified here will be rearranged, any other dimensions
  1050. will be left in their original order.
  1051. Default: All dimensions of :attr:`input`.
  1052. Example:
  1053. >>> f = torch.fft.fftfreq(4)
  1054. >>> f
  1055. tensor([ 0.0000, 0.2500, -0.5000, -0.2500])
  1056. >>> torch.fft.fftshift(f)
  1057. tensor([-0.5000, -0.2500, 0.0000, 0.2500])
  1058. Also notice that the Nyquist frequency term at ``f[2]`` was moved to the
  1059. beginning of the tensor.
  1060. This also works for multi-dimensional transforms:
  1061. >>> x = torch.fft.fftfreq(5, d=1/5) + 0.1 * torch.fft.fftfreq(5, d=1/5).unsqueeze(1)
  1062. >>> x
  1063. tensor([[ 0.0000, 1.0000, 2.0000, -2.0000, -1.0000],
  1064. [ 0.1000, 1.1000, 2.1000, -1.9000, -0.9000],
  1065. [ 0.2000, 1.2000, 2.2000, -1.8000, -0.8000],
  1066. [-0.2000, 0.8000, 1.8000, -2.2000, -1.2000],
  1067. [-0.1000, 0.9000, 1.9000, -2.1000, -1.1000]])
  1068. >>> torch.fft.fftshift(x)
  1069. tensor([[-2.2000, -1.2000, -0.2000, 0.8000, 1.8000],
  1070. [-2.1000, -1.1000, -0.1000, 0.9000, 1.9000],
  1071. [-2.0000, -1.0000, 0.0000, 1.0000, 2.0000],
  1072. [-1.9000, -0.9000, 0.1000, 1.1000, 2.1000],
  1073. [-1.8000, -0.8000, 0.2000, 1.2000, 2.2000]])
  1074. :func:`~torch.fft.fftshift` can also be useful for spatial data. If our
  1075. data is defined on a centered grid (``[-(N//2), (N-1)//2]``) then we can
  1076. use the standard FFT defined on an uncentered grid (``[0, N)``) by first
  1077. applying an :func:`~torch.fft.ifftshift`.
  1078. >>> x_centered = torch.arange(-5, 5)
  1079. >>> x_uncentered = torch.fft.ifftshift(x_centered)
  1080. >>> fft_uncentered = torch.fft.fft(x_uncentered)
  1081. Similarly, we can convert the frequency domain components to centered
  1082. convention by applying :func:`~torch.fft.fftshift`.
  1083. >>> fft_centered = torch.fft.fftshift(fft_uncentered)
  1084. The inverse transform, from centered Fourier space back to centered spatial
  1085. data, can be performed by applying the inverse shifts in reverse order:
  1086. >>> x_centered_2 = torch.fft.fftshift(torch.fft.ifft(torch.fft.ifftshift(fft_centered)))
  1087. >>> torch.testing.assert_close(x_centered.to(torch.complex64), x_centered_2, check_stride=False)
  1088. """,
  1089. )
  1090. ifftshift = _add_docstr(
  1091. _fft.fft_ifftshift,
  1092. r"""
  1093. ifftshift(input, dim=None) -> Tensor
  1094. Inverse of :func:`~torch.fft.fftshift`.
  1095. Args:
  1096. input (Tensor): the tensor in FFT order
  1097. dim (int, Tuple[int], optional): The dimensions to rearrange.
  1098. Only dimensions specified here will be rearranged, any other dimensions
  1099. will be left in their original order.
  1100. Default: All dimensions of :attr:`input`.
  1101. Example:
  1102. >>> f = torch.fft.fftfreq(5)
  1103. >>> f
  1104. tensor([ 0.0000, 0.2000, 0.4000, -0.4000, -0.2000])
  1105. A round-trip through :func:`~torch.fft.fftshift` and
  1106. :func:`~torch.fft.ifftshift` gives the same result:
  1107. >>> shifted = torch.fft.fftshift(f)
  1108. >>> torch.fft.ifftshift(shifted)
  1109. tensor([ 0.0000, 0.2000, 0.4000, -0.4000, -0.2000])
  1110. """,
  1111. )