windows.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883
  1. # mypy: allow-untyped-defs
  2. from collections.abc import Callable, Iterable
  3. from math import sqrt
  4. from typing import TypeVar
  5. import torch
  6. from torch import Tensor
  7. from torch._torch_docs import factory_common_args, merge_dicts, parse_kwargs
  8. __all__ = [
  9. "bartlett",
  10. "blackman",
  11. "cosine",
  12. "exponential",
  13. "gaussian",
  14. "general_cosine",
  15. "general_hamming",
  16. "hamming",
  17. "hann",
  18. "kaiser",
  19. "nuttall",
  20. ]
  21. _T = TypeVar("_T")
  22. window_common_args = merge_dicts(
  23. parse_kwargs(
  24. """
  25. M (int): the length of the window.
  26. In other words, the number of points of the returned window.
  27. sym (bool, optional): If `False`, returns a periodic window suitable for use in spectral analysis.
  28. If `True`, returns a symmetric window suitable for use in filter design. Default: `True`.
  29. """
  30. ),
  31. factory_common_args,
  32. {
  33. "normalization": "The window is normalized to 1 (maximum value is 1). However, the 1 doesn't appear if "
  34. ":attr:`M` is even and :attr:`sym` is `True`.",
  35. },
  36. )
  37. def _add_docstr(*args: str) -> Callable[[_T], _T]:
  38. r"""Adds docstrings to a given decorated function.
  39. Specially useful when then docstrings needs string interpolation, e.g., with
  40. str.format().
  41. REMARK: Do not use this function if the docstring doesn't need string
  42. interpolation, just write a conventional docstring.
  43. Args:
  44. args (str):
  45. """
  46. def decorator(o: _T) -> _T:
  47. o.__doc__ = "".join(args)
  48. return o
  49. return decorator
  50. def _window_function_checks(
  51. function_name: str, M: int, dtype: torch.dtype, layout: torch.layout
  52. ) -> None:
  53. r"""Performs common checks for all the defined windows.
  54. This function should be called before computing any window.
  55. Args:
  56. function_name (str): name of the window function.
  57. M (int): length of the window.
  58. dtype (:class:`torch.dtype`): the desired data type of returned tensor.
  59. layout (:class:`torch.layout`): the desired layout of returned tensor.
  60. """
  61. if M < 0:
  62. raise ValueError(
  63. f"{function_name} requires non-negative window length, got M={M}"
  64. )
  65. if layout is not torch.strided:
  66. raise ValueError(
  67. f"{function_name} is implemented for strided tensors only, got: {layout}"
  68. )
  69. if dtype not in [torch.float32, torch.float64]:
  70. raise ValueError(
  71. f"{function_name} expects float32 or float64 dtypes, got: {dtype}"
  72. )
  73. @_add_docstr(
  74. r"""
  75. Computes a window with an exponential waveform.
  76. Also known as Poisson window.
  77. The exponential window is defined as follows:
  78. .. math::
  79. w_n = \exp{\left(-\frac{|n - c|}{\tau}\right)}
  80. where `c` is the ``center`` of the window.
  81. """,
  82. r"""
  83. {normalization}
  84. Args:
  85. {M}
  86. Keyword args:
  87. center (float, optional): where the center of the window will be located.
  88. Default: `M / 2` if `sym` is `False`, else `(M - 1) / 2`.
  89. tau (float, optional): the decay value.
  90. Tau is generally associated with a percentage, that means, that the value should
  91. vary within the interval (0, 100]. If tau is 100, it is considered the uniform window.
  92. Default: 1.0.
  93. {sym}
  94. {dtype}
  95. {layout}
  96. {device}
  97. {requires_grad}
  98. Examples::
  99. >>> # Generates a symmetric exponential window of size 10 and with a decay value of 1.0.
  100. >>> # The center will be at (M - 1) / 2, where M is 10.
  101. >>> torch.signal.windows.exponential(10)
  102. tensor([0.0111, 0.0302, 0.0821, 0.2231, 0.6065, 0.6065, 0.2231, 0.0821, 0.0302, 0.0111])
  103. >>> # Generates a periodic exponential window and decay factor equal to .5
  104. >>> torch.signal.windows.exponential(10, sym=False,tau=.5)
  105. tensor([4.5400e-05, 3.3546e-04, 2.4788e-03, 1.8316e-02, 1.3534e-01, 1.0000e+00, 1.3534e-01, 1.8316e-02, 2.4788e-03, 3.3546e-04])
  106. """.format(**window_common_args),
  107. )
  108. def exponential(
  109. M: int,
  110. *,
  111. center: float | None = None,
  112. tau: float = 1.0,
  113. sym: bool = True,
  114. dtype: torch.dtype | None = None,
  115. layout: torch.layout = torch.strided,
  116. device: torch.device | None = None,
  117. requires_grad: bool = False,
  118. ) -> Tensor:
  119. if dtype is None:
  120. dtype = torch.get_default_dtype()
  121. _window_function_checks("exponential", M, dtype, layout)
  122. if tau <= 0:
  123. raise ValueError(f"Tau must be positive, got: {tau} instead.")
  124. if sym and center is not None:
  125. raise ValueError("Center must be None for symmetric windows")
  126. if M == 0:
  127. return torch.empty(
  128. (0,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad
  129. )
  130. if center is None:
  131. center = (M if not sym and M > 1 else M - 1) / 2.0
  132. constant = 1 / tau
  133. k = torch.linspace(
  134. start=-center * constant,
  135. end=(-center + (M - 1)) * constant,
  136. steps=M,
  137. dtype=dtype,
  138. layout=layout,
  139. device=device,
  140. requires_grad=requires_grad,
  141. )
  142. return torch.exp(-torch.abs(k))
  143. @_add_docstr(
  144. r"""
  145. Computes a window with a simple cosine waveform, following the same implementation as SciPy.
  146. This window is also known as the sine window.
  147. The cosine window is defined as follows:
  148. .. math::
  149. w_n = \sin\left(\frac{\pi (n + 0.5)}{M}\right)
  150. This formula differs from the typical cosine window formula by incorporating a 0.5 term in the numerator,
  151. which shifts the sample positions. This adjustment results in a window that starts and ends with non-zero values.
  152. """,
  153. r"""
  154. {normalization}
  155. Args:
  156. {M}
  157. Keyword args:
  158. {sym}
  159. {dtype}
  160. {layout}
  161. {device}
  162. {requires_grad}
  163. Examples::
  164. >>> # Generates a symmetric cosine window.
  165. >>> torch.signal.windows.cosine(10)
  166. tensor([0.1564, 0.4540, 0.7071, 0.8910, 0.9877, 0.9877, 0.8910, 0.7071, 0.4540, 0.1564])
  167. >>> # Generates a periodic cosine window.
  168. >>> torch.signal.windows.cosine(10, sym=False)
  169. tensor([0.1423, 0.4154, 0.6549, 0.8413, 0.9595, 1.0000, 0.9595, 0.8413, 0.6549, 0.4154])
  170. """.format(
  171. **window_common_args,
  172. ),
  173. )
  174. def cosine(
  175. M: int,
  176. *,
  177. sym: bool = True,
  178. dtype: torch.dtype | None = None,
  179. layout: torch.layout = torch.strided,
  180. device: torch.device | None = None,
  181. requires_grad: bool = False,
  182. ) -> Tensor:
  183. if dtype is None:
  184. dtype = torch.get_default_dtype()
  185. _window_function_checks("cosine", M, dtype, layout)
  186. if M == 0:
  187. return torch.empty(
  188. (0,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad
  189. )
  190. start = 0.5
  191. constant = torch.pi / (M + 1 if not sym and M > 1 else M)
  192. k = torch.linspace(
  193. start=start * constant,
  194. end=(start + (M - 1)) * constant,
  195. steps=M,
  196. dtype=dtype,
  197. layout=layout,
  198. device=device,
  199. requires_grad=requires_grad,
  200. )
  201. return torch.sin(k)
  202. @_add_docstr(
  203. r"""
  204. Computes a window with a gaussian waveform.
  205. The gaussian window is defined as follows:
  206. .. math::
  207. w_n = \exp{\left(-\left(\frac{n}{2\sigma}\right)^2\right)}
  208. """,
  209. r"""
  210. {normalization}
  211. Args:
  212. {M}
  213. Keyword args:
  214. std (float, optional): the standard deviation of the gaussian. It controls how narrow or wide the window is.
  215. Default: 1.0.
  216. {sym}
  217. {dtype}
  218. {layout}
  219. {device}
  220. {requires_grad}
  221. Examples::
  222. >>> # Generates a symmetric gaussian window with a standard deviation of 1.0.
  223. >>> torch.signal.windows.gaussian(10)
  224. tensor([4.0065e-05, 2.1875e-03, 4.3937e-02, 3.2465e-01, 8.8250e-01, 8.8250e-01, 3.2465e-01, 4.3937e-02, 2.1875e-03, 4.0065e-05])
  225. >>> # Generates a periodic gaussian window and standard deviation equal to 0.9.
  226. >>> torch.signal.windows.gaussian(10, sym=False,std=0.9)
  227. tensor([1.9858e-07, 5.1365e-05, 3.8659e-03, 8.4658e-02, 5.3941e-01, 1.0000e+00, 5.3941e-01, 8.4658e-02, 3.8659e-03, 5.1365e-05])
  228. """.format(
  229. **window_common_args,
  230. ),
  231. )
  232. def gaussian(
  233. M: int,
  234. *,
  235. std: float = 1.0,
  236. sym: bool = True,
  237. dtype: torch.dtype | None = None,
  238. layout: torch.layout = torch.strided,
  239. device: torch.device | None = None,
  240. requires_grad: bool = False,
  241. ) -> Tensor:
  242. if dtype is None:
  243. dtype = torch.get_default_dtype()
  244. _window_function_checks("gaussian", M, dtype, layout)
  245. if std <= 0:
  246. raise ValueError(f"Standard deviation must be positive, got: {std} instead.")
  247. if M == 0:
  248. return torch.empty(
  249. (0,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad
  250. )
  251. start = -(M if not sym and M > 1 else M - 1) / 2.0
  252. constant = 1 / (std * sqrt(2))
  253. k = torch.linspace(
  254. start=start * constant,
  255. end=(start + (M - 1)) * constant,
  256. steps=M,
  257. dtype=dtype,
  258. layout=layout,
  259. device=device,
  260. requires_grad=requires_grad,
  261. )
  262. return torch.exp(-(k**2)) # pyrefly: ignore [unsupported-operation]
  263. @_add_docstr(
  264. r"""
  265. Computes the Kaiser window.
  266. The Kaiser window is defined as follows:
  267. .. math::
  268. w_n = I_0 \left( \beta \sqrt{1 - \left( {\frac{n - N/2}{N/2}} \right) ^2 } \right) / I_0( \beta )
  269. where ``I_0`` is the zeroth order modified Bessel function of the first kind (see :func:`torch.special.i0`), and
  270. ``N = M - 1 if sym else M``.
  271. """,
  272. r"""
  273. {normalization}
  274. Args:
  275. {M}
  276. Keyword args:
  277. beta (float, optional): shape parameter for the window. Must be non-negative. Default: 12.0
  278. {sym}
  279. {dtype}
  280. {layout}
  281. {device}
  282. {requires_grad}
  283. Examples::
  284. >>> # Generates a symmetric gaussian window with a standard deviation of 1.0.
  285. >>> torch.signal.windows.kaiser(5)
  286. tensor([4.0065e-05, 2.1875e-03, 4.3937e-02, 3.2465e-01, 8.8250e-01, 8.8250e-01, 3.2465e-01, 4.3937e-02, 2.1875e-03, 4.0065e-05])
  287. >>> # Generates a periodic gaussian window and standard deviation equal to 0.9.
  288. >>> torch.signal.windows.kaiser(5, sym=False,std=0.9)
  289. tensor([1.9858e-07, 5.1365e-05, 3.8659e-03, 8.4658e-02, 5.3941e-01, 1.0000e+00, 5.3941e-01, 8.4658e-02, 3.8659e-03, 5.1365e-05])
  290. """.format(
  291. **window_common_args,
  292. ),
  293. )
  294. def kaiser(
  295. M: int,
  296. *,
  297. beta: float = 12.0,
  298. sym: bool = True,
  299. dtype: torch.dtype | None = None,
  300. layout: torch.layout = torch.strided,
  301. device: torch.device | None = None,
  302. requires_grad: bool = False,
  303. ) -> Tensor:
  304. if dtype is None:
  305. dtype = torch.get_default_dtype()
  306. _window_function_checks("kaiser", M, dtype, layout)
  307. if beta < 0:
  308. raise ValueError(f"beta must be non-negative, got: {beta} instead.")
  309. if M == 0:
  310. return torch.empty(
  311. (0,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad
  312. )
  313. if M == 1:
  314. return torch.ones(
  315. (1,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad
  316. )
  317. # Avoid NaNs by casting `beta` to the appropriate dtype.
  318. # pyrefly: ignore [bad-assignment]
  319. beta = torch.tensor(beta, dtype=dtype, device=device)
  320. start = -beta
  321. constant = 2.0 * beta / (M if not sym else M - 1)
  322. end = torch.minimum(
  323. # pyrefly: ignore [bad-argument-type]
  324. beta,
  325. # pyrefly: ignore [bad-argument-type]
  326. start + (M - 1) * constant,
  327. )
  328. k = torch.linspace(
  329. start=start,
  330. end=end,
  331. steps=M,
  332. dtype=dtype,
  333. layout=layout,
  334. device=device,
  335. requires_grad=requires_grad,
  336. )
  337. return torch.i0(torch.sqrt(beta * beta - torch.pow(k, 2))) / torch.i0(
  338. # pyrefly: ignore [bad-argument-type]
  339. beta
  340. )
  341. @_add_docstr(
  342. r"""
  343. Computes the Hamming window.
  344. The Hamming window is defined as follows:
  345. .. math::
  346. w_n = \alpha - \beta\ \cos \left( \frac{2 \pi n}{M - 1} \right)
  347. """,
  348. r"""
  349. {normalization}
  350. Arguments:
  351. {M}
  352. Keyword args:
  353. {sym}
  354. alpha (float, optional): The coefficient :math:`\alpha` in the equation above.
  355. beta (float, optional): The coefficient :math:`\beta` in the equation above.
  356. {dtype}
  357. {layout}
  358. {device}
  359. {requires_grad}
  360. Examples::
  361. >>> # Generates a symmetric Hamming window.
  362. >>> torch.signal.windows.hamming(10)
  363. tensor([0.0800, 0.1876, 0.4601, 0.7700, 0.9723, 0.9723, 0.7700, 0.4601, 0.1876, 0.0800])
  364. >>> # Generates a periodic Hamming window.
  365. >>> torch.signal.windows.hamming(10, sym=False)
  366. tensor([0.0800, 0.1679, 0.3979, 0.6821, 0.9121, 1.0000, 0.9121, 0.6821, 0.3979, 0.1679])
  367. """.format(**window_common_args),
  368. )
  369. def hamming(
  370. M: int,
  371. *,
  372. sym: bool = True,
  373. dtype: torch.dtype | None = None,
  374. layout: torch.layout = torch.strided,
  375. device: torch.device | None = None,
  376. requires_grad: bool = False,
  377. ) -> Tensor:
  378. return general_hamming(
  379. M,
  380. sym=sym,
  381. dtype=dtype,
  382. layout=layout,
  383. device=device,
  384. requires_grad=requires_grad,
  385. )
  386. @_add_docstr(
  387. r"""
  388. Computes the Hann window.
  389. The Hann window is defined as follows:
  390. .. math::
  391. w_n = \frac{1}{2}\ \left[1 - \cos \left( \frac{2 \pi n}{M - 1} \right)\right] =
  392. \sin^2 \left( \frac{\pi n}{M - 1} \right)
  393. """,
  394. r"""
  395. {normalization}
  396. Arguments:
  397. {M}
  398. Keyword args:
  399. {sym}
  400. {dtype}
  401. {layout}
  402. {device}
  403. {requires_grad}
  404. Examples::
  405. >>> # Generates a symmetric Hann window.
  406. >>> torch.signal.windows.hann(10)
  407. tensor([0.0000, 0.1170, 0.4132, 0.7500, 0.9698, 0.9698, 0.7500, 0.4132, 0.1170, 0.0000])
  408. >>> # Generates a periodic Hann window.
  409. >>> torch.signal.windows.hann(10, sym=False)
  410. tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955])
  411. """.format(**window_common_args),
  412. )
  413. def hann(
  414. M: int,
  415. *,
  416. sym: bool = True,
  417. dtype: torch.dtype | None = None,
  418. layout: torch.layout = torch.strided,
  419. device: torch.device | None = None,
  420. requires_grad: bool = False,
  421. ) -> Tensor:
  422. return general_hamming(
  423. M,
  424. alpha=0.5,
  425. sym=sym,
  426. dtype=dtype,
  427. layout=layout,
  428. device=device,
  429. requires_grad=requires_grad,
  430. )
  431. @_add_docstr(
  432. r"""
  433. Computes the Blackman window.
  434. The Blackman window is defined as follows:
  435. .. math::
  436. w_n = 0.42 - 0.5 \cos \left( \frac{2 \pi n}{M - 1} \right) + 0.08 \cos \left( \frac{4 \pi n}{M - 1} \right)
  437. """,
  438. r"""
  439. {normalization}
  440. Arguments:
  441. {M}
  442. Keyword args:
  443. {sym}
  444. {dtype}
  445. {layout}
  446. {device}
  447. {requires_grad}
  448. Examples::
  449. >>> # Generates a symmetric Blackman window.
  450. >>> torch.signal.windows.blackman(5)
  451. tensor([-1.4901e-08, 3.4000e-01, 1.0000e+00, 3.4000e-01, -1.4901e-08])
  452. >>> # Generates a periodic Blackman window.
  453. >>> torch.signal.windows.blackman(5, sym=False)
  454. tensor([-1.4901e-08, 2.0077e-01, 8.4923e-01, 8.4923e-01, 2.0077e-01])
  455. """.format(**window_common_args),
  456. )
  457. def blackman(
  458. M: int,
  459. *,
  460. sym: bool = True,
  461. dtype: torch.dtype | None = None,
  462. layout: torch.layout = torch.strided,
  463. device: torch.device | None = None,
  464. requires_grad: bool = False,
  465. ) -> Tensor:
  466. if dtype is None:
  467. dtype = torch.get_default_dtype()
  468. _window_function_checks("blackman", M, dtype, layout)
  469. return general_cosine(
  470. M,
  471. a=[0.42, 0.5, 0.08],
  472. sym=sym,
  473. dtype=dtype,
  474. layout=layout,
  475. device=device,
  476. requires_grad=requires_grad,
  477. )
  478. @_add_docstr(
  479. r"""
  480. Computes the Bartlett window.
  481. The Bartlett window is defined as follows:
  482. .. math::
  483. w_n = 1 - \left| \frac{2n}{M - 1} - 1 \right| = \begin{cases}
  484. \frac{2n}{M - 1} & \text{if } 0 \leq n \leq \frac{M - 1}{2} \\
  485. 2 - \frac{2n}{M - 1} & \text{if } \frac{M - 1}{2} < n < M \\ \end{cases}
  486. """,
  487. r"""
  488. {normalization}
  489. Arguments:
  490. {M}
  491. Keyword args:
  492. {sym}
  493. {dtype}
  494. {layout}
  495. {device}
  496. {requires_grad}
  497. Examples::
  498. >>> # Generates a symmetric Bartlett window.
  499. >>> torch.signal.windows.bartlett(10)
  500. tensor([0.0000, 0.2222, 0.4444, 0.6667, 0.8889, 0.8889, 0.6667, 0.4444, 0.2222, 0.0000])
  501. >>> # Generates a periodic Bartlett window.
  502. >>> torch.signal.windows.bartlett(10, sym=False)
  503. tensor([0.0000, 0.2000, 0.4000, 0.6000, 0.8000, 1.0000, 0.8000, 0.6000, 0.4000, 0.2000])
  504. """.format(**window_common_args),
  505. )
  506. def bartlett(
  507. M: int,
  508. *,
  509. sym: bool = True,
  510. dtype: torch.dtype | None = None,
  511. layout: torch.layout = torch.strided,
  512. device: torch.device | None = None,
  513. requires_grad: bool = False,
  514. ) -> Tensor:
  515. if dtype is None:
  516. dtype = torch.get_default_dtype()
  517. _window_function_checks("bartlett", M, dtype, layout)
  518. if M == 0:
  519. return torch.empty(
  520. (0,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad
  521. )
  522. if M == 1:
  523. return torch.ones(
  524. (1,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad
  525. )
  526. start = -1
  527. constant = 2 / (M if not sym else M - 1)
  528. k = torch.linspace(
  529. start=start,
  530. end=start + (M - 1) * constant,
  531. steps=M,
  532. dtype=dtype,
  533. layout=layout,
  534. device=device,
  535. requires_grad=requires_grad,
  536. )
  537. return 1 - torch.abs(k)
  538. @_add_docstr(
  539. r"""
  540. Computes the general cosine window.
  541. The general cosine window is defined as follows:
  542. .. math::
  543. w_n = \sum^{M-1}_{i=0} (-1)^i a_i \cos{ \left( \frac{2 \pi i n}{M - 1}\right)}
  544. """,
  545. r"""
  546. {normalization}
  547. Arguments:
  548. {M}
  549. Keyword args:
  550. a (Iterable): the coefficients associated to each of the cosine functions.
  551. {sym}
  552. {dtype}
  553. {layout}
  554. {device}
  555. {requires_grad}
  556. Examples::
  557. >>> # Generates a symmetric general cosine window with 3 coefficients.
  558. >>> torch.signal.windows.general_cosine(10, a=[0.46, 0.23, 0.31], sym=True)
  559. tensor([0.5400, 0.3376, 0.1288, 0.4200, 0.9136, 0.9136, 0.4200, 0.1288, 0.3376, 0.5400])
  560. >>> # Generates a periodic general cosine window with 2 coefficients.
  561. >>> torch.signal.windows.general_cosine(10, a=[0.5, 1 - 0.5], sym=False)
  562. tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955])
  563. """.format(**window_common_args),
  564. )
  565. def general_cosine(
  566. M,
  567. *,
  568. a: Iterable,
  569. sym: bool = True,
  570. dtype: torch.dtype | None = None,
  571. layout: torch.layout = torch.strided,
  572. device: torch.device | None = None,
  573. requires_grad: bool = False,
  574. ) -> Tensor:
  575. if dtype is None:
  576. dtype = torch.get_default_dtype()
  577. _window_function_checks("general_cosine", M, dtype, layout)
  578. if M == 0:
  579. return torch.empty(
  580. (0,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad
  581. )
  582. if M == 1:
  583. return torch.ones(
  584. (1,), dtype=dtype, layout=layout, device=device, requires_grad=requires_grad
  585. )
  586. if not isinstance(a, Iterable):
  587. raise TypeError("Coefficients must be a list/tuple")
  588. if not a:
  589. raise ValueError("Coefficients cannot be empty")
  590. constant = 2 * torch.pi / (M if not sym else M - 1)
  591. k = torch.linspace(
  592. start=0,
  593. end=(M - 1) * constant,
  594. steps=M,
  595. dtype=dtype,
  596. layout=layout,
  597. device=device,
  598. requires_grad=requires_grad,
  599. )
  600. a_i = torch.tensor(
  601. [(-1) ** i * w for i, w in enumerate(a)],
  602. device=device,
  603. dtype=dtype,
  604. requires_grad=requires_grad,
  605. )
  606. i = torch.arange(
  607. a_i.shape[0],
  608. dtype=a_i.dtype,
  609. device=a_i.device,
  610. requires_grad=a_i.requires_grad,
  611. )
  612. return (a_i.unsqueeze(-1) * torch.cos(i.unsqueeze(-1) * k)).sum(0)
  613. @_add_docstr(
  614. r"""
  615. Computes the general Hamming window.
  616. The general Hamming window is defined as follows:
  617. .. math::
  618. w_n = \alpha - (1 - \alpha) \cos{ \left( \frac{2 \pi n}{M-1} \right)}
  619. """,
  620. r"""
  621. {normalization}
  622. Arguments:
  623. {M}
  624. Keyword args:
  625. alpha (float, optional): the window coefficient. Default: 0.54.
  626. {sym}
  627. {dtype}
  628. {layout}
  629. {device}
  630. {requires_grad}
  631. Examples::
  632. >>> # Generates a symmetric Hamming window with the general Hamming window.
  633. >>> torch.signal.windows.general_hamming(10, sym=True)
  634. tensor([0.0800, 0.1876, 0.4601, 0.7700, 0.9723, 0.9723, 0.7700, 0.4601, 0.1876, 0.0800])
  635. >>> # Generates a periodic Hann window with the general Hamming window.
  636. >>> torch.signal.windows.general_hamming(10, alpha=0.5, sym=False)
  637. tensor([0.0000, 0.0955, 0.3455, 0.6545, 0.9045, 1.0000, 0.9045, 0.6545, 0.3455, 0.0955])
  638. """.format(**window_common_args),
  639. )
  640. def general_hamming(
  641. M,
  642. *,
  643. alpha: float = 0.54,
  644. sym: bool = True,
  645. dtype: torch.dtype | None = None,
  646. layout: torch.layout = torch.strided,
  647. device: torch.device | None = None,
  648. requires_grad: bool = False,
  649. ) -> Tensor:
  650. return general_cosine(
  651. M,
  652. a=[alpha, 1.0 - alpha],
  653. sym=sym,
  654. dtype=dtype,
  655. layout=layout,
  656. device=device,
  657. requires_grad=requires_grad,
  658. )
  659. @_add_docstr(
  660. r"""
  661. Computes the minimum 4-term Blackman-Harris window according to Nuttall.
  662. .. math::
  663. w_n = 1 - 0.36358 \cos{(z_n)} + 0.48917 \cos{(2z_n)} - 0.13659 \cos{(3z_n)} + 0.01064 \cos{(4z_n)}
  664. where :math:`z_n = \frac{2 \pi n}{M}`.
  665. """,
  666. """
  667. {normalization}
  668. Arguments:
  669. {M}
  670. Keyword args:
  671. {sym}
  672. {dtype}
  673. {layout}
  674. {device}
  675. {requires_grad}
  676. References::
  677. - A. Nuttall, "Some windows with very good sidelobe behavior,"
  678. IEEE Transactions on Acoustics, Speech, and Signal Processing, vol. 29, no. 1, pp. 84-91,
  679. Feb 1981. https://doi.org/10.1109/TASSP.1981.1163506
  680. - Heinzel G. et al., "Spectrum and spectral density estimation by the Discrete Fourier transform (DFT),
  681. including a comprehensive list of window functions and some new flat-top windows",
  682. February 15, 2002 https://holometer.fnal.gov/GH_FFT.pdf
  683. Examples::
  684. >>> # Generates a symmetric Nutall window.
  685. >>> torch.signal.windows.general_hamming(5, sym=True)
  686. tensor([3.6280e-04, 2.2698e-01, 1.0000e+00, 2.2698e-01, 3.6280e-04])
  687. >>> # Generates a periodic Nuttall window.
  688. >>> torch.signal.windows.general_hamming(5, sym=False)
  689. tensor([3.6280e-04, 1.1052e-01, 7.9826e-01, 7.9826e-01, 1.1052e-01])
  690. """.format(**window_common_args),
  691. )
  692. def nuttall(
  693. M: int,
  694. *,
  695. sym: bool = True,
  696. dtype: torch.dtype | None = None,
  697. layout: torch.layout = torch.strided,
  698. device: torch.device | None = None,
  699. requires_grad: bool = False,
  700. ) -> Tensor:
  701. return general_cosine(
  702. M,
  703. a=[0.3635819, 0.4891775, 0.1365995, 0.0106411],
  704. sym=sym,
  705. dtype=dtype,
  706. layout=layout,
  707. device=device,
  708. requires_grad=requires_grad,
  709. )