nisqa.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  15. # Code related main NISQA model definition are under the following copyright
  16. # Copyright (c) 2021 Gabriel Mittag, Quality and Usability Lab
  17. # Permission is hereby granted, free of charge, to any person obtaining a copy
  18. # of this software and associated documentation files (the "Software"), to deal
  19. # in the Software without restriction, including without limitation the rights
  20. # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  21. # copies of the Software, and to permit persons to whom the Software is
  22. # furnished to do so, subject to the following conditions:
  23. # The above copyright notice and this permission notice shall be included in all
  24. # copies or substantial portions of the Software.
  25. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  26. # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  27. # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  28. # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  29. # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  30. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  31. # SOFTWARE.
  32. import copy
  33. import math
  34. import os
  35. import warnings
  36. from functools import lru_cache
  37. from typing import Any
  38. import numpy as np
  39. import torch
  40. import torch.nn as nn
  41. from torch import Tensor
  42. from torch.nn.functional import adaptive_max_pool2d, relu, softmax
  43. from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
  44. from torchmetrics.utilities import rank_zero_info
  45. from torchmetrics.utilities.imports import _LIBROSA_AVAILABLE, _REQUESTS_AVAILABLE
  46. if _LIBROSA_AVAILABLE and _REQUESTS_AVAILABLE:
  47. import librosa
  48. import requests
  49. else:
  50. librosa, requests = None, None # type:ignore
  51. __doctest_requires__ = {("non_intrusive_speech_quality_assessment",): ["librosa", "requests"]}
  52. NISQA_DIR = "~/.torchmetrics/NISQA"
  53. def non_intrusive_speech_quality_assessment(preds: Tensor, fs: int) -> Tensor:
  54. """`Non-Intrusive Speech Quality Assessment`_ (NISQA v2.0) [1], [2].
  55. .. hint::
  56. Usingsing this metric requires you to have ``librosa`` and ``requests`` installed. Install as
  57. ``pip install librosa requests``.
  58. Args:
  59. preds: float tensor with shape ``(...,time)``
  60. fs: sampling frequency of input
  61. Returns:
  62. Float tensor with shape ``(...,5)`` corresponding to overall MOS, noisiness, discontinuity, coloration and
  63. loudness in that order
  64. Raises:
  65. ModuleNotFoundError:
  66. If ``librosa`` or ``requests`` are not installed
  67. RuntimeError:
  68. If the input is too short, causing the number of mel spectrogram windows to be zero
  69. RuntimeError:
  70. If the input is too long, causing the number of mel spectrogram windows to exceed the maximum allowed
  71. Example:
  72. >>> import torch
  73. >>> from torchmetrics.functional.audio.nisqa import non_intrusive_speech_quality_assessment
  74. >>> _ = torch.manual_seed(42)
  75. >>> preds = torch.randn(16000)
  76. >>> non_intrusive_speech_quality_assessment(preds, 16000)
  77. tensor([1.0433, 1.9545, 2.6087, 1.3460, 1.7117])
  78. References:
  79. - [1] G. Mittag and S. Möller, "Non-intrusive speech quality assessment for super-wideband speech communication
  80. networks", in Proc. ICASSP, 2019.
  81. - [2] G. Mittag, B. Naderi, A. Chehadi and S. Möller, "NISQA: A deep CNN-self-attention model for
  82. multidimensional speech quality prediction with crowdsourced datasets", in Proc. INTERSPEECH, 2021.
  83. """
  84. if not _LIBROSA_AVAILABLE or not _REQUESTS_AVAILABLE:
  85. raise ModuleNotFoundError(
  86. "NISQA metric requires that librosa and requests are installed. Install as `pip install librosa requests`."
  87. )
  88. model, args = _load_nisqa_model()
  89. if not isinstance(fs, int) or fs <= 0:
  90. raise ValueError(f"Argument `fs` expected to be a positive integer, but got {fs}")
  91. model.eval()
  92. x = preds.reshape(-1, preds.shape[-1])
  93. x = _get_librosa_melspec(x.cpu().numpy(), fs, args)
  94. x, n_wins = _segment_specs(torch.from_numpy(x), args)
  95. with torch.no_grad():
  96. x = model(x, n_wins.expand(x.shape[0]))
  97. # ["mos_pred", "noi_pred", "dis_pred", "col_pred", "loud_pred"]
  98. # the dimensions are always listed in the papers as MOS, noisiness, coloration, discontinuity and loudness
  99. # but based on original code the actual model output order is MOS, noisiness, discontinuity, coloration, loudness
  100. return x.reshape((*preds.shape[:-1], 5))
  101. @lru_cache
  102. def _load_nisqa_model() -> tuple[nn.Module, dict[str, Any]]:
  103. """Load NISQA model and its parameters.
  104. Returns:
  105. Tuple ``(model,args)`` where ``model`` is the NISQA model and ``args`` is a dictionary with all its parameters
  106. """
  107. model_path = os.path.expanduser(os.path.join(NISQA_DIR, "nisqa.tar"))
  108. if not os.path.exists(model_path):
  109. _download_weights()
  110. checkpoint = torch.load(model_path, map_location="cpu", weights_only=True)
  111. args = checkpoint["args"]
  112. model = _NISQADIM(args)
  113. model.load_state_dict(checkpoint["model_state_dict"], strict=True)
  114. return model, args
  115. def _download_weights() -> None:
  116. """Download NISQA model weights."""
  117. url = "https://github.com/gabrielmittag/NISQA/raw/refs/heads/master/weights/nisqa.tar"
  118. nisqa_dir = os.path.expanduser(NISQA_DIR)
  119. os.makedirs(nisqa_dir, exist_ok=True)
  120. saveto = os.path.join(nisqa_dir, "nisqa.tar")
  121. if os.path.exists(saveto):
  122. return
  123. rank_zero_info(f"downloading {url} to {saveto}")
  124. myfile = requests.get(url)
  125. with open(saveto, "wb") as f:
  126. f.write(myfile.content)
  127. class _NISQADIM(nn.Module):
  128. # main NISQA model definition
  129. # ported from https://github.com/gabrielmittag/NISQA
  130. # Copyright (c) 2021 Gabriel Mittag, Quality and Usability Lab
  131. # MIT License
  132. def __init__(self, args: dict[str, Any]) -> None:
  133. super().__init__()
  134. self.cnn = _Framewise(args)
  135. self.time_dependency = _TimeDependency(args)
  136. pool = _Pooling(args)
  137. self.pool_layers = _get_clones(pool, 5)
  138. def forward(self, x: Tensor, n_wins: Tensor) -> Tensor:
  139. x = self.cnn(x, n_wins)
  140. x, n_wins = self.time_dependency(x, n_wins)
  141. out = [mod(x, n_wins) for mod in self.pool_layers]
  142. return torch.cat(out, dim=1)
  143. class _Framewise(nn.Module):
  144. # part of NISQA model definition
  145. def __init__(self, args: dict[str, Any]) -> None:
  146. super().__init__()
  147. self.model = _AdaptCNN(args)
  148. def forward(self, x: Tensor, n_wins: Tensor) -> Tensor:
  149. x_packed = pack_padded_sequence(x, n_wins, batch_first=True, enforce_sorted=False)
  150. x = self.model(x_packed.data.unsqueeze(1))
  151. x = x_packed._replace(data=x)
  152. x, _ = pad_packed_sequence(x, batch_first=True, padding_value=0.0, total_length=int(n_wins.max()))
  153. return x
  154. class _AdaptCNN(nn.Module):
  155. # part of NISQA model definition
  156. def __init__(self, args: dict[str, Any]) -> None:
  157. super().__init__()
  158. self.pool_1 = args["cnn_pool_1"]
  159. self.pool_2 = args["cnn_pool_2"]
  160. self.pool_3 = args["cnn_pool_3"]
  161. self.dropout = nn.Dropout2d(p=args["cnn_dropout"])
  162. cnn_pad = (1, 0) if args["cnn_kernel_size"][0] == 1 else (1, 1)
  163. self.conv1 = nn.Conv2d(1, args["cnn_c_out_1"], args["cnn_kernel_size"], padding=cnn_pad)
  164. self.bn1 = nn.BatchNorm2d(self.conv1.out_channels)
  165. self.conv2 = nn.Conv2d(self.conv1.out_channels, args["cnn_c_out_2"], args["cnn_kernel_size"], padding=cnn_pad)
  166. self.bn2 = nn.BatchNorm2d(self.conv2.out_channels)
  167. self.conv3 = nn.Conv2d(self.conv2.out_channels, args["cnn_c_out_3"], args["cnn_kernel_size"], padding=cnn_pad)
  168. self.bn3 = nn.BatchNorm2d(self.conv3.out_channels)
  169. self.conv4 = nn.Conv2d(self.conv3.out_channels, args["cnn_c_out_3"], args["cnn_kernel_size"], padding=cnn_pad)
  170. self.bn4 = nn.BatchNorm2d(self.conv4.out_channels)
  171. self.conv5 = nn.Conv2d(self.conv4.out_channels, args["cnn_c_out_3"], args["cnn_kernel_size"], padding=cnn_pad)
  172. self.bn5 = nn.BatchNorm2d(self.conv5.out_channels)
  173. self.conv6 = nn.Conv2d(
  174. self.conv5.out_channels,
  175. args["cnn_c_out_3"],
  176. (args["cnn_kernel_size"][0], args["cnn_pool_3"][1]),
  177. padding=(1, 0),
  178. )
  179. self.bn6 = nn.BatchNorm2d(self.conv6.out_channels)
  180. def forward(self, x: Tensor) -> Tensor:
  181. x = relu(self.bn1(self.conv1(x)))
  182. x = adaptive_max_pool2d(x, output_size=(self.pool_1))
  183. x = relu(self.bn2(self.conv2(x)))
  184. x = adaptive_max_pool2d(x, output_size=(self.pool_2))
  185. x = self.dropout(x)
  186. x = relu(self.bn3(self.conv3(x)))
  187. x = self.dropout(x)
  188. x = relu(self.bn4(self.conv4(x)))
  189. x = adaptive_max_pool2d(x, output_size=(self.pool_3))
  190. x = self.dropout(x)
  191. x = relu(self.bn5(self.conv5(x)))
  192. x = self.dropout(x)
  193. x = relu(self.bn6(self.conv6(x)))
  194. return x.view(-1, self.conv6.out_channels * self.pool_3[0])
  195. class _TimeDependency(nn.Module):
  196. # part of NISQA model definition
  197. def __init__(self, args: dict[str, Any]) -> None:
  198. super().__init__()
  199. self.model = _SelfAttention(args)
  200. def forward(self, x: Tensor, n_wins: Tensor) -> Tensor:
  201. return self.model(x, n_wins)
  202. class _SelfAttention(nn.Module):
  203. # part of NISQA model definition
  204. def __init__(self, args: dict[str, Any]) -> None:
  205. super().__init__()
  206. encoder_layer = _SelfAttentionLayer(args)
  207. self.norm1 = nn.LayerNorm(args["td_sa_d_model"])
  208. self.linear = nn.Linear(args["cnn_c_out_3"] * args["cnn_pool_3"][0], args["td_sa_d_model"])
  209. self.layers = _get_clones(encoder_layer, args["td_sa_num_layers"])
  210. self._reset_parameters()
  211. def _reset_parameters(self) -> None:
  212. for p in self.parameters():
  213. if p.dim() > 1:
  214. nn.init.xavier_uniform_(p)
  215. def forward(self, src: Tensor, n_wins: Tensor) -> tuple[Tensor, Tensor]:
  216. src = self.linear(src)
  217. output = src.transpose(1, 0)
  218. output = self.norm1(output)
  219. for mod in self.layers:
  220. output, n_wins = mod(output, n_wins)
  221. return output.transpose(1, 0), n_wins
  222. class _SelfAttentionLayer(nn.Module):
  223. # part of NISQA model definition
  224. def __init__(self, args: dict[str, Any]) -> None:
  225. super().__init__()
  226. self.self_attn = nn.MultiheadAttention(args["td_sa_d_model"], args["td_sa_nhead"], args["td_sa_dropout"])
  227. self.linear1 = nn.Linear(args["td_sa_d_model"], args["td_sa_h"])
  228. self.dropout = nn.Dropout(args["td_sa_dropout"])
  229. self.linear2 = nn.Linear(args["td_sa_h"], args["td_sa_d_model"])
  230. self.norm1 = nn.LayerNorm(args["td_sa_d_model"])
  231. self.norm2 = nn.LayerNorm(args["td_sa_d_model"])
  232. self.dropout1 = nn.Dropout(args["td_sa_dropout"])
  233. self.dropout2 = nn.Dropout(args["td_sa_dropout"])
  234. self.activation = relu
  235. def forward(self, src: Tensor, n_wins: Tensor) -> tuple[Tensor, Tensor]:
  236. mask = torch.arange(src.shape[0])[None, :] < n_wins[:, None]
  237. src2 = self.self_attn(src, src, src, key_padding_mask=~mask)[0]
  238. src = src + self.dropout1(src2)
  239. src = self.norm1(src)
  240. src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
  241. src = src + self.dropout2(src2)
  242. src = self.norm2(src)
  243. return src, n_wins
  244. class _Pooling(nn.Module):
  245. # part of NISQA model definition
  246. def __init__(self, args: dict[str, Any]) -> None:
  247. super().__init__()
  248. self.model = _PoolAttFF(args)
  249. def forward(self, x: Tensor, n_wins: Tensor) -> Tensor:
  250. return self.model(x, n_wins)
  251. class _PoolAttFF(torch.nn.Module):
  252. # part of NISQA model definition
  253. def __init__(self, args: dict[str, Any]) -> None:
  254. super().__init__()
  255. self.linear1 = nn.Linear(args["td_sa_d_model"], args["pool_att_h"])
  256. self.linear2 = nn.Linear(args["pool_att_h"], 1)
  257. self.linear3 = nn.Linear(args["td_sa_d_model"], 1)
  258. self.activation = relu
  259. self.dropout = nn.Dropout(args["pool_att_dropout"])
  260. def forward(self, x: Tensor, n_wins: Tensor) -> Tensor:
  261. att = self.linear2(self.dropout(self.activation(self.linear1(x))))
  262. att = att.transpose(2, 1)
  263. mask = torch.arange(att.shape[2])[None, :] < n_wins[:, None]
  264. att[~mask.unsqueeze(1)] = float("-inf")
  265. att = softmax(att, dim=2)
  266. x = torch.bmm(att, x)
  267. x = x.squeeze(1)
  268. return self.linear3(x)
  269. def _get_librosa_melspec(y: np.ndarray, sr: int, args: dict[str, Any]) -> np.ndarray:
  270. """Compute mel spectrogram from waveform using librosa.
  271. Args:
  272. y: waveform with shape ``(batch_size,time)``
  273. sr: sampling rate
  274. args: dictionary with all NISQA parameters
  275. Returns:
  276. Mel spectrogram with shape ``(batch_size,n_mels,n_frames)``
  277. """
  278. hop_length = int(sr * args["ms_hop_length"])
  279. win_length = int(sr * args["ms_win_length"])
  280. with warnings.catch_warnings():
  281. # ignore empty mel filter warning since this is expected when input signal is not fullband
  282. # see https://github.com/gabrielmittag/NISQA/issues/6#issuecomment-838157571
  283. warnings.filterwarnings("ignore", message="Empty filters detected in mel frequency basis")
  284. melspec = librosa.feature.melspectrogram(
  285. y=y,
  286. sr=sr,
  287. S=None,
  288. n_fft=args["ms_n_fft"],
  289. hop_length=hop_length,
  290. win_length=win_length,
  291. window="hann",
  292. center=True,
  293. pad_mode="reflect",
  294. power=1.0,
  295. n_mels=args["ms_n_mels"],
  296. fmin=0.0,
  297. fmax=args["ms_fmax"],
  298. htk=False,
  299. norm="slaney",
  300. )
  301. # batch processing of librosa.core.amplitude_to_db is not equivalent to individual processing due to top_db being
  302. # relative to max value
  303. # so process individually and then stack
  304. return np.stack([librosa.amplitude_to_db(m, ref=1.0, amin=1e-4, top_db=80.0) for m in melspec])
  305. def _segment_specs(x: Tensor, args: dict[str, Any]) -> tuple[Tensor, Tensor]:
  306. """Segment mel spectrogram into overlapping windows.
  307. Args:
  308. x: mel spectrogram with shape ``(batch_size,n_mels,n_frames)``
  309. args: dictionary with all NISQA parameters
  310. Returns:
  311. Tuple ``(x_padded,n_wins)```, where ``x_padded`` is the segmented mel spectrogram with shape
  312. ``(batch_size,max_length,n_mels,seg_length)`` where the second dimension is the number of windows and was
  313. padded to ``max_length``, and ``n_wins`` is the number of windows and is 0-dimensional
  314. """
  315. seg_length = args["ms_seg_length"]
  316. seg_hop = args["ms_seg_hop_length"]
  317. max_length = args["ms_max_segments"]
  318. n_wins = x.shape[2] - (seg_length - 1)
  319. if n_wins < 1:
  320. raise RuntimeError("Input signal is too short.")
  321. idx1 = torch.arange(seg_length)
  322. idx2 = torch.arange(n_wins)
  323. idx3 = idx1.unsqueeze(0) + idx2.unsqueeze(1)
  324. x = x.transpose(2, 1)[:, idx3, :].transpose(3, 2)
  325. x = x[:, ::seg_hop]
  326. n_wins = math.ceil(n_wins / seg_hop)
  327. if max_length < n_wins:
  328. raise RuntimeError("Maximum number of mel spectrogram windows exceeded. Use shorter audio.")
  329. x_padded = torch.zeros((x.shape[0], max_length, x.shape[2], x.shape[3]))
  330. x_padded[:, :n_wins] = x
  331. return x_padded, torch.tensor(n_wins)
  332. def _get_clones(module: nn.Module, n: int) -> nn.ModuleList:
  333. """Create ``n`` copies of a module."""
  334. return nn.ModuleList([copy.deepcopy(module) for i in range(n)])