dnn_wpe.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. from typing import Tuple
  2. from pytorch_wpe import wpe_one_iteration
  3. import torch
  4. from torch_complex.tensor import ComplexTensor
  5. from funasr.frontends.utils.mask_estimator import MaskEstimator
  6. from funasr.models.transformer.utils.nets_utils import make_pad_mask
  7. class DNN_WPE(torch.nn.Module):
  8. def __init__(
  9. self,
  10. wtype: str = "blstmp",
  11. widim: int = 257,
  12. wlayers: int = 3,
  13. wunits: int = 300,
  14. wprojs: int = 320,
  15. dropout_rate: float = 0.0,
  16. taps: int = 5,
  17. delay: int = 3,
  18. use_dnn_mask: bool = True,
  19. iterations: int = 1,
  20. normalization: bool = False,
  21. ):
  22. super().__init__()
  23. self.iterations = iterations
  24. self.taps = taps
  25. self.delay = delay
  26. self.normalization = normalization
  27. self.use_dnn_mask = use_dnn_mask
  28. self.inverse_power = True
  29. if self.use_dnn_mask:
  30. self.mask_est = MaskEstimator(
  31. wtype, widim, wlayers, wunits, wprojs, dropout_rate, nmask=1
  32. )
  33. def forward(
  34. self, data: ComplexTensor, ilens: torch.LongTensor
  35. ) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]:
  36. """The forward function
  37. Notation:
  38. B: Batch
  39. C: Channel
  40. T: Time or Sequence length
  41. F: Freq or Some dimension of the feature vector
  42. Args:
  43. data: (B, C, T, F)
  44. ilens: (B,)
  45. Returns:
  46. data: (B, C, T, F)
  47. ilens: (B,)
  48. """
  49. # (B, T, C, F) -> (B, F, C, T)
  50. enhanced = data = data.permute(0, 3, 2, 1)
  51. mask = None
  52. for i in range(self.iterations):
  53. # Calculate power: (..., C, T)
  54. power = enhanced.real**2 + enhanced.imag**2
  55. if i == 0 and self.use_dnn_mask:
  56. # mask: (B, F, C, T)
  57. (mask,), _ = self.mask_est(enhanced, ilens)
  58. if self.normalization:
  59. # Normalize along T
  60. mask = mask / mask.sum(dim=-1)[..., None]
  61. # (..., C, T) * (..., C, T) -> (..., C, T)
  62. power = power * mask
  63. # Averaging along the channel axis: (..., C, T) -> (..., T)
  64. power = power.mean(dim=-2)
  65. # enhanced: (..., C, T) -> (..., C, T)
  66. enhanced = wpe_one_iteration(
  67. data.contiguous(),
  68. power,
  69. taps=self.taps,
  70. delay=self.delay,
  71. inverse_power=self.inverse_power,
  72. )
  73. enhanced.masked_fill_(make_pad_mask(ilens, enhanced.real), 0)
  74. # (B, F, C, T) -> (B, T, C, F)
  75. enhanced = enhanced.permute(0, 3, 2, 1)
  76. if mask is not None:
  77. mask = mask.transpose(-1, -3)
  78. return enhanced, ilens, mask