mask_estimator.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. from typing import Tuple
  2. import numpy as np
  3. import torch
  4. from torch.nn import functional as F
  5. from torch_complex.tensor import ComplexTensor
  6. from funasr.models.transformer.utils.nets_utils import make_pad_mask
  7. from funasr.models.language_model.rnn.encoders import RNN
  8. from funasr.models.language_model.rnn.encoders import RNNP
  9. class MaskEstimator(torch.nn.Module):
  10. def __init__(self, type, idim, layers, units, projs, dropout, nmask=1):
  11. super().__init__()
  12. subsample = np.ones(layers + 1, dtype=np.int32)
  13. typ = type.lstrip("vgg").rstrip("p")
  14. if type[-1] == "p":
  15. self.brnn = RNNP(idim, layers, units, projs, subsample, dropout, typ=typ)
  16. else:
  17. self.brnn = RNN(idim, layers, units, projs, dropout, typ=typ)
  18. self.type = type
  19. self.nmask = nmask
  20. self.linears = torch.nn.ModuleList(
  21. [torch.nn.Linear(projs, idim) for _ in range(nmask)]
  22. )
  23. def forward(
  24. self, xs: ComplexTensor, ilens: torch.LongTensor
  25. ) -> Tuple[Tuple[torch.Tensor, ...], torch.LongTensor]:
  26. """The forward function
  27. Args:
  28. xs: (B, F, C, T)
  29. ilens: (B,)
  30. Returns:
  31. hs (torch.Tensor): The hidden vector (B, F, C, T)
  32. masks: A tuple of the masks. (B, F, C, T)
  33. ilens: (B,)
  34. """
  35. assert xs.size(0) == ilens.size(0), (xs.size(0), ilens.size(0))
  36. _, _, C, input_length = xs.size()
  37. # (B, F, C, T) -> (B, C, T, F)
  38. xs = xs.permute(0, 2, 3, 1)
  39. # Calculate amplitude: (B, C, T, F) -> (B, C, T, F)
  40. xs = (xs.real**2 + xs.imag**2) ** 0.5
  41. # xs: (B, C, T, F) -> xs: (B * C, T, F)
  42. xs = xs.contiguous().view(-1, xs.size(-2), xs.size(-1))
  43. # ilens: (B,) -> ilens_: (B * C)
  44. ilens_ = ilens[:, None].expand(-1, C).contiguous().view(-1)
  45. # xs: (B * C, T, F) -> xs: (B * C, T, D)
  46. xs, _, _ = self.brnn(xs, ilens_)
  47. # xs: (B * C, T, D) -> xs: (B, C, T, D)
  48. xs = xs.view(-1, C, xs.size(-2), xs.size(-1))
  49. masks = []
  50. for linear in self.linears:
  51. # xs: (B, C, T, D) -> mask:(B, C, T, F)
  52. mask = linear(xs)
  53. mask = torch.sigmoid(mask)
  54. # Zero padding
  55. mask.masked_fill(make_pad_mask(ilens, mask, length_dim=2), 0)
  56. # (B, C, T, F) -> (B, F, C, T)
  57. mask = mask.permute(0, 3, 1, 2)
  58. # Take cares of multi gpu cases: If input_length > max(ilens)
  59. if mask.size(-1) < input_length:
  60. mask = F.pad(mask, [0, input_length - mask.size(-1)], value=0)
  61. masks.append(mask)
  62. return tuple(masks), ilens