pit.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from itertools import permutations
  15. from typing import Any, Callable
  16. import numpy as np
  17. import torch
  18. from torch import Tensor
  19. from typing_extensions import Literal
  20. from torchmetrics.utilities import rank_zero_warn
  21. from torchmetrics.utilities.imports import _SCIPY_AVAILABLE
  22. # _ps_dict: cache of permutations
  23. # it's necessary to cache it, otherwise it will consume a large amount of time
  24. _ps_dict: dict = {} # _ps_dict[str(spk_num)+str(device)] = permutations
  25. def _gen_permutations(spk_num: int, device: torch.device) -> Tensor:
  26. key = str(spk_num) + str(device)
  27. if key not in _ps_dict:
  28. # ps: all the permutations, shape [perm_num, spk_num]
  29. # ps: In i-th permutation, the predcition corresponds to the j-th target is ps[j,i]
  30. ps = torch.tensor(list(permutations(range(spk_num))), device=device)
  31. _ps_dict[key] = ps
  32. else:
  33. ps = _ps_dict[key] # all the permutations, shape [perm_num, spk_num]
  34. return ps
  35. def _find_best_perm_by_linear_sum_assignment(
  36. metric_mtx: Tensor,
  37. eval_func: Callable,
  38. ) -> tuple[Tensor, Tensor]:
  39. """Solves the linear sum assignment problem.
  40. This implementation uses scipy and input is therefore transferred to cpu during calculations.
  41. Args:
  42. metric_mtx: the metric matrix, shape [batch_size, spk_num, spk_num]
  43. eval_func: the function to reduce the metric values of different the permutations
  44. Returns:
  45. best_metric: shape ``[batch]``
  46. best_perm: shape ``[batch, spk]``
  47. """
  48. from scipy.optimize import linear_sum_assignment
  49. mmtx = metric_mtx.detach().cpu()
  50. best_perm = torch.tensor(np.array([linear_sum_assignment(pwm, eval_func == torch.max)[1] for pwm in mmtx]))
  51. best_perm = best_perm.to(metric_mtx.device)
  52. best_metric = torch.gather(metric_mtx, 2, best_perm[:, :, None]).mean([-1, -2])
  53. return best_metric, best_perm # shape [batch], shape [batch, spk]
  54. def _find_best_perm_by_exhaustive_method(
  55. metric_mtx: Tensor,
  56. eval_func: Callable,
  57. ) -> tuple[Tensor, Tensor]:
  58. """Solves the linear sum assignment problem using exhaustive method.
  59. This is done by exhaustively calculating the metric values of all possible permutations, and returns the best metric
  60. values and the corresponding permutations.
  61. Args:
  62. metric_mtx: the metric matrix, shape ``[batch_size, spk_num, spk_num]``
  63. eval_func: the function to reduce the metric values of different the permutations
  64. Returns:
  65. best_metric: shape ``[batch]``
  66. best_perm: shape ``[batch, spk]``
  67. """
  68. # create/read/cache the permutations and its indexes
  69. # reading from cache would be much faster than creating in CPU then moving to GPU
  70. batch_size, spk_num = metric_mtx.shape[:2]
  71. ps = _gen_permutations(spk_num=spk_num, device=metric_mtx.device) # [perm_num, spk_num]
  72. # find the metric of each permutation
  73. perm_num = ps.shape[0]
  74. # shape of [batch_size, spk_num, perm_num]
  75. bps = ps.T[None, ...].expand(batch_size, spk_num, perm_num)
  76. # shape of [batch_size, spk_num, perm_num]
  77. metric_of_ps_details = torch.gather(metric_mtx, 2, bps)
  78. # shape of [batch_size, perm_num]
  79. metric_of_ps = metric_of_ps_details.mean(dim=1)
  80. # find the best metric and best permutation
  81. best_metric, best_indexes = eval_func(metric_of_ps, dim=1)
  82. best_indexes = best_indexes.detach()
  83. best_perm = ps[best_indexes, :]
  84. return best_metric, best_perm # shape [batch], shape [batch, spk]
  85. def permutation_invariant_training(
  86. preds: Tensor,
  87. target: Tensor,
  88. metric_func: Callable,
  89. mode: Literal["speaker-wise", "permutation-wise"] = "speaker-wise",
  90. eval_func: Literal["max", "min"] = "max",
  91. **kwargs: Any,
  92. ) -> tuple[Tensor, Tensor]:
  93. """Calculate `Permutation invariant training`_ (PIT).
  94. This metric can evaluate models for speaker independent multi-talker speech separation in a permutation
  95. invariant way.
  96. Args:
  97. preds: float tensor with shape ``(batch_size,num_speakers,...)``
  98. target: float tensor with shape ``(batch_size,num_speakers,...)``
  99. metric_func: a metric function accept a batch of target and estimate.
  100. if `mode`==`'speaker-wise'`, then ``metric_func(preds[:, i, ...], target[:, j, ...])`` is called
  101. and expected to return a batch of metric tensors ``(batch,)``;
  102. if `mode`==`'permutation-wise'`, then ``metric_func(preds[:, p, ...], target[:, :, ...])`` is called,
  103. where `p` is one possible permutation, e.g. [0,1] or [1,0] for 2-speaker case, and expected to return
  104. a batch of metric tensors ``(batch,)``;
  105. mode: can be `'speaker-wise'` or `'permutation-wise'`.
  106. eval_func: the function to find the best permutation, can be ``'min'`` or ``'max'``,
  107. i.e. the smaller the better or the larger the better.
  108. kwargs: Additional args for metric_func
  109. Returns:
  110. Tuple of two float tensors. First tensor with shape ``(batch,)`` contains the best metric value for each sample
  111. and second tensor with shape ``(batch,)`` contains the best permutation.
  112. Example:
  113. >>> from torchmetrics.functional.audio import scale_invariant_signal_distortion_ratio
  114. >>> # [batch, spk, time]
  115. >>> preds = torch.tensor([[[-0.0579, 0.3560, -0.9604], [-0.1719, 0.3205, 0.2951]]])
  116. >>> target = torch.tensor([[[ 1.0958, -0.1648, 0.5228], [-0.4100, 1.1942, -0.5103]]])
  117. >>> best_metric, best_perm = permutation_invariant_training(
  118. ... preds, target, scale_invariant_signal_distortion_ratio,
  119. ... mode="speaker-wise", eval_func="max")
  120. >>> best_metric
  121. tensor([-5.1091])
  122. >>> best_perm
  123. tensor([[0, 1]])
  124. >>> pit_permutate(preds, best_perm)
  125. tensor([[[-0.0579, 0.3560, -0.9604],
  126. [-0.1719, 0.3205, 0.2951]]])
  127. """
  128. if preds.shape[0:2] != target.shape[0:2]:
  129. raise RuntimeError(
  130. "Predictions and targets are expected to have the same shape at the batch and speaker dimensions"
  131. )
  132. if eval_func not in ["max", "min"]:
  133. raise ValueError(f'eval_func can only be "max" or "min" but got {eval_func}')
  134. if mode not in ["speaker-wise", "permutation-wise"]:
  135. raise ValueError(f'mode can only be "speaker-wise" or "permutation-wise" but got {mode}')
  136. if target.ndim < 2:
  137. raise ValueError(f"Inputs must be of shape [batch, spk, ...], got {target.shape} and {preds.shape} instead")
  138. eval_op = torch.max if eval_func == "max" else torch.min
  139. # calculate the metric matrix
  140. batch_size, spk_num = target.shape[0:2]
  141. if mode == "permutation-wise":
  142. perms = _gen_permutations(spk_num=spk_num, device=preds.device) # [perm_num, spk_num]
  143. perm_num = perms.shape[0]
  144. # shape of ppreds and ptarget: [batch_size*perm_num, spk_num, ...]
  145. ppreds = torch.index_select(preds, dim=1, index=perms.reshape(-1)).reshape(
  146. batch_size * perm_num, *preds.shape[1:]
  147. )
  148. ptarget = target.repeat_interleave(repeats=perm_num, dim=0)
  149. # shape of metric_of_ps [batch_size*perm_num] or [batch_size*perm_num, spk_num]
  150. metric_of_ps = metric_func(ppreds, ptarget, **kwargs)
  151. metric_of_ps = torch.mean(metric_of_ps.reshape(batch_size, len(perms), -1), dim=-1)
  152. # find the best metric and best permutation
  153. best_metric, best_indexes = eval_op(metric_of_ps, dim=1)
  154. best_indexes = best_indexes.detach()
  155. best_perm = perms[best_indexes, :]
  156. return best_metric, best_perm
  157. # speaker-wise
  158. first_ele = metric_func(preds[:, 0, ...], target[:, 0, ...], **kwargs) # needed for dtype and device
  159. metric_mtx = torch.empty((batch_size, spk_num, spk_num), dtype=first_ele.dtype, device=first_ele.device)
  160. metric_mtx[:, 0, 0] = first_ele
  161. for target_idx in range(spk_num): # we have spk_num speeches in target in each sample
  162. for preds_idx in range(spk_num): # we have spk_num speeches in preds in each sample
  163. if target_idx == 0 and preds_idx == 0: # already calculated
  164. continue
  165. metric_mtx[:, target_idx, preds_idx] = metric_func(
  166. preds[:, preds_idx, ...], target[:, target_idx, ...], **kwargs
  167. )
  168. # find best
  169. if spk_num < 3 or not _SCIPY_AVAILABLE:
  170. if spk_num >= 3 and not _SCIPY_AVAILABLE:
  171. rank_zero_warn(
  172. f"In pit metric for speaker-num {spk_num}>3, we recommend installing scipy for better performance"
  173. )
  174. best_metric, best_perm = _find_best_perm_by_exhaustive_method(metric_mtx, eval_op)
  175. else:
  176. best_metric, best_perm = _find_best_perm_by_linear_sum_assignment(metric_mtx, eval_op)
  177. return best_metric, best_perm
  178. def pit_permutate(preds: Tensor, perm: Tensor) -> Tensor:
  179. """Permutate estimate according to perm.
  180. Args:
  181. preds: the estimates you want to permutate, shape [batch, spk, ...]
  182. perm: the permutation returned from permutation_invariant_training, shape [batch, spk]
  183. Returns:
  184. Tensor: the permutated version of estimate
  185. """
  186. return torch.stack([torch.index_select(pred, 0, p) for pred, p in zip(preds, perm)])