| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172 |
- """DNN beamformer module."""
- from typing import Tuple
- import torch
- from torch.nn import functional as F
- from funasr.modules.frontends.beamformer import apply_beamforming_vector
- from funasr.modules.frontends.beamformer import get_mvdr_vector
- from funasr.modules.frontends.beamformer import (
- get_power_spectral_density_matrix, # noqa: H301
- )
- from funasr.modules.frontends.mask_estimator import MaskEstimator
- from torch_complex.tensor import ComplexTensor
- class DNN_Beamformer(torch.nn.Module):
- """DNN mask based Beamformer
- Citation:
- Multichannel End-to-end Speech Recognition; T. Ochiai et al., 2017;
- https://arxiv.org/abs/1703.04783
- """
- def __init__(
- self,
- bidim,
- btype="blstmp",
- blayers=3,
- bunits=300,
- bprojs=320,
- bnmask=2,
- dropout_rate=0.0,
- badim=320,
- ref_channel: int = -1,
- beamformer_type="mvdr",
- ):
- super().__init__()
- self.mask = MaskEstimator(
- btype, bidim, blayers, bunits, bprojs, dropout_rate, nmask=bnmask
- )
- self.ref = AttentionReference(bidim, badim)
- self.ref_channel = ref_channel
- self.nmask = bnmask
- if beamformer_type != "mvdr":
- raise ValueError(
- "Not supporting beamformer_type={}".format(beamformer_type)
- )
- self.beamformer_type = beamformer_type
- def forward(
- self, data: ComplexTensor, ilens: torch.LongTensor
- ) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]:
- """The forward function
- Notation:
- B: Batch
- C: Channel
- T: Time or Sequence length
- F: Freq
- Args:
- data (ComplexTensor): (B, T, C, F)
- ilens (torch.Tensor): (B,)
- Returns:
- enhanced (ComplexTensor): (B, T, F)
- ilens (torch.Tensor): (B,)
- """
- def apply_beamforming(data, ilens, psd_speech, psd_noise):
- # u: (B, C)
- if self.ref_channel < 0:
- u, _ = self.ref(psd_speech, ilens)
- else:
- # (optional) Create onehot vector for fixed reference microphone
- u = torch.zeros(
- *(data.size()[:-3] + (data.size(-2),)), device=data.device
- )
- u[..., self.ref_channel].fill_(1)
- ws = get_mvdr_vector(psd_speech, psd_noise, u)
- enhanced = apply_beamforming_vector(ws, data)
- return enhanced, ws
- # data (B, T, C, F) -> (B, F, C, T)
- data = data.permute(0, 3, 2, 1)
- # mask: (B, F, C, T)
- masks, _ = self.mask(data, ilens)
- assert self.nmask == len(masks)
- if self.nmask == 2: # (mask_speech, mask_noise)
- mask_speech, mask_noise = masks
- psd_speech = get_power_spectral_density_matrix(data, mask_speech)
- psd_noise = get_power_spectral_density_matrix(data, mask_noise)
- enhanced, ws = apply_beamforming(data, ilens, psd_speech, psd_noise)
- # (..., F, T) -> (..., T, F)
- enhanced = enhanced.transpose(-1, -2)
- mask_speech = mask_speech.transpose(-1, -3)
- else: # multi-speaker case: (mask_speech1, ..., mask_noise)
- mask_speech = list(masks[:-1])
- mask_noise = masks[-1]
- psd_speeches = [
- get_power_spectral_density_matrix(data, mask) for mask in mask_speech
- ]
- psd_noise = get_power_spectral_density_matrix(data, mask_noise)
- enhanced = []
- ws = []
- for i in range(self.nmask - 1):
- psd_speech = psd_speeches.pop(i)
- # treat all other speakers' psd_speech as noises
- enh, w = apply_beamforming(
- data, ilens, psd_speech, sum(psd_speeches) + psd_noise
- )
- psd_speeches.insert(i, psd_speech)
- # (..., F, T) -> (..., T, F)
- enh = enh.transpose(-1, -2)
- mask_speech[i] = mask_speech[i].transpose(-1, -3)
- enhanced.append(enh)
- ws.append(w)
- return enhanced, ilens, mask_speech
- class AttentionReference(torch.nn.Module):
- def __init__(self, bidim, att_dim):
- super().__init__()
- self.mlp_psd = torch.nn.Linear(bidim, att_dim)
- self.gvec = torch.nn.Linear(att_dim, 1)
- def forward(
- self, psd_in: ComplexTensor, ilens: torch.LongTensor, scaling: float = 2.0
- ) -> Tuple[torch.Tensor, torch.LongTensor]:
- """The forward function
- Args:
- psd_in (ComplexTensor): (B, F, C, C)
- ilens (torch.Tensor): (B,)
- scaling (float):
- Returns:
- u (torch.Tensor): (B, C)
- ilens (torch.Tensor): (B,)
- """
- B, _, C = psd_in.size()[:3]
- assert psd_in.size(2) == psd_in.size(3), psd_in.size()
- # psd_in: (B, F, C, C)
- psd = psd_in.masked_fill(
- torch.eye(C, dtype=torch.bool, device=psd_in.device), 0
- )
- # psd: (B, F, C, C) -> (B, C, F)
- psd = (psd.sum(dim=-1) / (C - 1)).transpose(-1, -2)
- # Calculate amplitude
- psd_feat = (psd.real**2 + psd.imag**2) ** 0.5
- # (B, C, F) -> (B, C, F2)
- mlp_psd = self.mlp_psd(psd_feat)
- # (B, C, F2) -> (B, C, 1) -> (B, C)
- e = self.gvec(torch.tanh(mlp_psd)).squeeze(-1)
- u = F.softmax(scaling * e, dim=-1)
- return u, ilens
|