frontend.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. from typing import List
  2. from typing import Optional
  3. from typing import Tuple
  4. from typing import Union
  5. import numpy
  6. import torch
  7. import torch.nn as nn
  8. from torch_complex.tensor import ComplexTensor
  9. from funasr.modules.frontends.dnn_beamformer import DNN_Beamformer
  10. from funasr.modules.frontends.dnn_wpe import DNN_WPE
  11. class Frontend(nn.Module):
  12. def __init__(
  13. self,
  14. idim: int,
  15. # WPE options
  16. use_wpe: bool = False,
  17. wtype: str = "blstmp",
  18. wlayers: int = 3,
  19. wunits: int = 300,
  20. wprojs: int = 320,
  21. wdropout_rate: float = 0.0,
  22. taps: int = 5,
  23. delay: int = 3,
  24. use_dnn_mask_for_wpe: bool = True,
  25. # Beamformer options
  26. use_beamformer: bool = False,
  27. btype: str = "blstmp",
  28. blayers: int = 3,
  29. bunits: int = 300,
  30. bprojs: int = 320,
  31. bnmask: int = 2,
  32. badim: int = 320,
  33. ref_channel: int = -1,
  34. bdropout_rate=0.0,
  35. ):
  36. super().__init__()
  37. self.use_beamformer = use_beamformer
  38. self.use_wpe = use_wpe
  39. self.use_dnn_mask_for_wpe = use_dnn_mask_for_wpe
  40. # use frontend for all the data,
  41. # e.g. in the case of multi-speaker speech separation
  42. self.use_frontend_for_all = bnmask > 2
  43. if self.use_wpe:
  44. if self.use_dnn_mask_for_wpe:
  45. # Use DNN for power estimation
  46. # (Not observed significant gains)
  47. iterations = 1
  48. else:
  49. # Performing as conventional WPE, without DNN Estimator
  50. iterations = 2
  51. self.wpe = DNN_WPE(
  52. wtype=wtype,
  53. widim=idim,
  54. wunits=wunits,
  55. wprojs=wprojs,
  56. wlayers=wlayers,
  57. taps=taps,
  58. delay=delay,
  59. dropout_rate=wdropout_rate,
  60. iterations=iterations,
  61. use_dnn_mask=use_dnn_mask_for_wpe,
  62. )
  63. else:
  64. self.wpe = None
  65. if self.use_beamformer:
  66. self.beamformer = DNN_Beamformer(
  67. btype=btype,
  68. bidim=idim,
  69. bunits=bunits,
  70. bprojs=bprojs,
  71. blayers=blayers,
  72. bnmask=bnmask,
  73. dropout_rate=bdropout_rate,
  74. badim=badim,
  75. ref_channel=ref_channel,
  76. )
  77. else:
  78. self.beamformer = None
  79. def forward(
  80. self, x: ComplexTensor, ilens: Union[torch.LongTensor, numpy.ndarray, List[int]]
  81. ) -> Tuple[ComplexTensor, torch.LongTensor, Optional[ComplexTensor]]:
  82. assert len(x) == len(ilens), (len(x), len(ilens))
  83. # (B, T, F) or (B, T, C, F)
  84. if x.dim() not in (3, 4):
  85. raise ValueError(f"Input dim must be 3 or 4: {x.dim()}")
  86. if not torch.is_tensor(ilens):
  87. ilens = torch.from_numpy(numpy.asarray(ilens)).to(x.device)
  88. mask = None
  89. h = x
  90. if h.dim() == 4:
  91. if self.training:
  92. choices = [(False, False)] if not self.use_frontend_for_all else []
  93. if self.use_wpe:
  94. choices.append((True, False))
  95. if self.use_beamformer:
  96. choices.append((False, True))
  97. use_wpe, use_beamformer = choices[numpy.random.randint(len(choices))]
  98. else:
  99. use_wpe = self.use_wpe
  100. use_beamformer = self.use_beamformer
  101. # 1. WPE
  102. if use_wpe:
  103. # h: (B, T, C, F) -> h: (B, T, C, F)
  104. h, ilens, mask = self.wpe(h, ilens)
  105. # 2. Beamformer
  106. if use_beamformer:
  107. # h: (B, T, C, F) -> h: (B, T, F)
  108. h, ilens, mask = self.beamformer(h, ilens)
  109. return h, ilens, mask
  110. def frontend_for(args, idim):
  111. return Frontend(
  112. idim=idim,
  113. # WPE options
  114. use_wpe=args.use_wpe,
  115. wtype=args.wtype,
  116. wlayers=args.wlayers,
  117. wunits=args.wunits,
  118. wprojs=args.wprojs,
  119. wdropout_rate=args.wdropout_rate,
  120. taps=args.wpe_taps,
  121. delay=args.wpe_delay,
  122. use_dnn_mask_for_wpe=args.use_dnn_mask_for_wpe,
  123. # Beamformer options
  124. use_beamformer=args.use_beamformer,
  125. btype=args.btype,
  126. blayers=args.blayers,
  127. bunits=args.bunits,
  128. bprojs=args.bprojs,
  129. bnmask=args.bnmask,
  130. badim=args.badim,
  131. ref_channel=args.ref_channel,
  132. bdropout_rate=args.bdropout_rate,
  133. )