dnn_beamformer.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. """DNN beamformer module."""
  2. from typing import Tuple
  3. import torch
  4. from torch.nn import functional as F
  5. from funasr.modules.frontends.beamformer import apply_beamforming_vector
  6. from funasr.modules.frontends.beamformer import get_mvdr_vector
  7. from funasr.modules.frontends.beamformer import (
  8. get_power_spectral_density_matrix, # noqa: H301
  9. )
  10. from funasr.modules.frontends.mask_estimator import MaskEstimator
  11. from torch_complex.tensor import ComplexTensor
  12. class DNN_Beamformer(torch.nn.Module):
  13. """DNN mask based Beamformer
  14. Citation:
  15. Multichannel End-to-end Speech Recognition; T. Ochiai et al., 2017;
  16. https://arxiv.org/abs/1703.04783
  17. """
  18. def __init__(
  19. self,
  20. bidim,
  21. btype="blstmp",
  22. blayers=3,
  23. bunits=300,
  24. bprojs=320,
  25. bnmask=2,
  26. dropout_rate=0.0,
  27. badim=320,
  28. ref_channel: int = -1,
  29. beamformer_type="mvdr",
  30. ):
  31. super().__init__()
  32. self.mask = MaskEstimator(
  33. btype, bidim, blayers, bunits, bprojs, dropout_rate, nmask=bnmask
  34. )
  35. self.ref = AttentionReference(bidim, badim)
  36. self.ref_channel = ref_channel
  37. self.nmask = bnmask
  38. if beamformer_type != "mvdr":
  39. raise ValueError(
  40. "Not supporting beamformer_type={}".format(beamformer_type)
  41. )
  42. self.beamformer_type = beamformer_type
  43. def forward(
  44. self, data: ComplexTensor, ilens: torch.LongTensor
  45. ) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]:
  46. """The forward function
  47. Notation:
  48. B: Batch
  49. C: Channel
  50. T: Time or Sequence length
  51. F: Freq
  52. Args:
  53. data (ComplexTensor): (B, T, C, F)
  54. ilens (torch.Tensor): (B,)
  55. Returns:
  56. enhanced (ComplexTensor): (B, T, F)
  57. ilens (torch.Tensor): (B,)
  58. """
  59. def apply_beamforming(data, ilens, psd_speech, psd_noise):
  60. # u: (B, C)
  61. if self.ref_channel < 0:
  62. u, _ = self.ref(psd_speech, ilens)
  63. else:
  64. # (optional) Create onehot vector for fixed reference microphone
  65. u = torch.zeros(
  66. *(data.size()[:-3] + (data.size(-2),)), device=data.device
  67. )
  68. u[..., self.ref_channel].fill_(1)
  69. ws = get_mvdr_vector(psd_speech, psd_noise, u)
  70. enhanced = apply_beamforming_vector(ws, data)
  71. return enhanced, ws
  72. # data (B, T, C, F) -> (B, F, C, T)
  73. data = data.permute(0, 3, 2, 1)
  74. # mask: (B, F, C, T)
  75. masks, _ = self.mask(data, ilens)
  76. assert self.nmask == len(masks)
  77. if self.nmask == 2: # (mask_speech, mask_noise)
  78. mask_speech, mask_noise = masks
  79. psd_speech = get_power_spectral_density_matrix(data, mask_speech)
  80. psd_noise = get_power_spectral_density_matrix(data, mask_noise)
  81. enhanced, ws = apply_beamforming(data, ilens, psd_speech, psd_noise)
  82. # (..., F, T) -> (..., T, F)
  83. enhanced = enhanced.transpose(-1, -2)
  84. mask_speech = mask_speech.transpose(-1, -3)
  85. else: # multi-speaker case: (mask_speech1, ..., mask_noise)
  86. mask_speech = list(masks[:-1])
  87. mask_noise = masks[-1]
  88. psd_speeches = [
  89. get_power_spectral_density_matrix(data, mask) for mask in mask_speech
  90. ]
  91. psd_noise = get_power_spectral_density_matrix(data, mask_noise)
  92. enhanced = []
  93. ws = []
  94. for i in range(self.nmask - 1):
  95. psd_speech = psd_speeches.pop(i)
  96. # treat all other speakers' psd_speech as noises
  97. enh, w = apply_beamforming(
  98. data, ilens, psd_speech, sum(psd_speeches) + psd_noise
  99. )
  100. psd_speeches.insert(i, psd_speech)
  101. # (..., F, T) -> (..., T, F)
  102. enh = enh.transpose(-1, -2)
  103. mask_speech[i] = mask_speech[i].transpose(-1, -3)
  104. enhanced.append(enh)
  105. ws.append(w)
  106. return enhanced, ilens, mask_speech
  107. class AttentionReference(torch.nn.Module):
  108. def __init__(self, bidim, att_dim):
  109. super().__init__()
  110. self.mlp_psd = torch.nn.Linear(bidim, att_dim)
  111. self.gvec = torch.nn.Linear(att_dim, 1)
  112. def forward(
  113. self, psd_in: ComplexTensor, ilens: torch.LongTensor, scaling: float = 2.0
  114. ) -> Tuple[torch.Tensor, torch.LongTensor]:
  115. """The forward function
  116. Args:
  117. psd_in (ComplexTensor): (B, F, C, C)
  118. ilens (torch.Tensor): (B,)
  119. scaling (float):
  120. Returns:
  121. u (torch.Tensor): (B, C)
  122. ilens (torch.Tensor): (B,)
  123. """
  124. B, _, C = psd_in.size()[:3]
  125. assert psd_in.size(2) == psd_in.size(3), psd_in.size()
  126. # psd_in: (B, F, C, C)
  127. psd = psd_in.masked_fill(
  128. torch.eye(C, dtype=torch.bool, device=psd_in.device), 0
  129. )
  130. # psd: (B, F, C, C) -> (B, C, F)
  131. psd = (psd.sum(dim=-1) / (C - 1)).transpose(-1, -2)
  132. # Calculate amplitude
  133. psd_feat = (psd.real**2 + psd.imag**2) ** 0.5
  134. # (B, C, F) -> (B, C, F2)
  135. mlp_psd = self.mlp_psd(psd_feat)
  136. # (B, C, F2) -> (B, C, 1) -> (B, C)
  137. e = self.gvec(torch.tanh(mlp_psd)).squeeze(-1)
  138. u = F.softmax(scaling * e, dim=-1)
  139. return u, ilens