beamformer.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import torch
  2. from torch_complex import functional as FC
  3. from torch_complex.tensor import ComplexTensor
  4. def get_power_spectral_density_matrix(
  5. xs: ComplexTensor, mask: torch.Tensor, normalization=True, eps: float = 1e-15
  6. ) -> ComplexTensor:
  7. """Return cross-channel power spectral density (PSD) matrix
  8. Args:
  9. xs (ComplexTensor): (..., F, C, T)
  10. mask (torch.Tensor): (..., F, C, T)
  11. normalization (bool):
  12. eps (float):
  13. Returns
  14. psd (ComplexTensor): (..., F, C, C)
  15. """
  16. # outer product: (..., C_1, T) x (..., C_2, T) -> (..., T, C, C_2)
  17. psd_Y = FC.einsum("...ct,...et->...tce", [xs, xs.conj()])
  18. # Averaging mask along C: (..., C, T) -> (..., T)
  19. mask = mask.mean(dim=-2)
  20. # Normalized mask along T: (..., T)
  21. if normalization:
  22. # If assuming the tensor is padded with zero, the summation along
  23. # the time axis is same regardless of the padding length.
  24. mask = mask / (mask.sum(dim=-1, keepdim=True) + eps)
  25. # psd: (..., T, C, C)
  26. psd = psd_Y * mask[..., None, None]
  27. # (..., T, C, C) -> (..., C, C)
  28. psd = psd.sum(dim=-3)
  29. return psd
  30. def get_mvdr_vector(
  31. psd_s: ComplexTensor,
  32. psd_n: ComplexTensor,
  33. reference_vector: torch.Tensor,
  34. eps: float = 1e-15,
  35. ) -> ComplexTensor:
  36. """Return the MVDR(Minimum Variance Distortionless Response) vector:
  37. h = (Npsd^-1 @ Spsd) / (Tr(Npsd^-1 @ Spsd)) @ u
  38. Reference:
  39. On optimal frequency-domain multichannel linear filtering
  40. for noise reduction; M. Souden et al., 2010;
  41. https://ieeexplore.ieee.org/document/5089420
  42. Args:
  43. psd_s (ComplexTensor): (..., F, C, C)
  44. psd_n (ComplexTensor): (..., F, C, C)
  45. reference_vector (torch.Tensor): (..., C)
  46. eps (float):
  47. Returns:
  48. beamform_vector (ComplexTensor)r: (..., F, C)
  49. """
  50. # Add eps
  51. C = psd_n.size(-1)
  52. eye = torch.eye(C, dtype=psd_n.dtype, device=psd_n.device)
  53. shape = [1 for _ in range(psd_n.dim() - 2)] + [C, C]
  54. eye = eye.view(*shape)
  55. psd_n += eps * eye
  56. # numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3)
  57. numerator = FC.einsum("...ec,...cd->...ed", [psd_n.inverse(), psd_s])
  58. # ws: (..., C, C) / (...,) -> (..., C, C)
  59. ws = numerator / (FC.trace(numerator)[..., None, None] + eps)
  60. # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
  61. beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector])
  62. return beamform_vector
  63. def apply_beamforming_vector(
  64. beamform_vector: ComplexTensor, mix: ComplexTensor
  65. ) -> ComplexTensor:
  66. # (..., C) x (..., C, T) -> (..., T)
  67. es = FC.einsum("...c,...ct->...t", [beamform_vector.conj(), mix])
  68. return es