e2e_ss.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import math
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. import copy
  6. from funasr.models.base_model import FunASRModel
  7. from funasr.models.encoder.mossformer_encoder import MossFormerEncoder, MossFormer_MaskNet
  8. from funasr.models.decoder.mossformer_decoder import MossFormerDecoder
  9. class MossFormer(FunASRModel):
  10. """The MossFormer model for separating input mixed speech into different speaker's speech.
  11. Arguments
  12. ---------
  13. in_channels : int
  14. Number of channels at the output of the encoder.
  15. out_channels : int
  16. Number of channels that would be inputted to the intra and inter blocks.
  17. num_blocks : int
  18. Number of layers of Dual Computation Block.
  19. norm : str
  20. Normalization type.
  21. num_spks : int
  22. Number of sources (speakers).
  23. skip_around_intra : bool
  24. Skip connection around intra.
  25. use_global_pos_enc : bool
  26. Global positional encodings.
  27. max_length : int
  28. Maximum sequence length.
  29. kernel_size: int
  30. Encoder and decoder kernel size
  31. """
  32. def __init__(
  33. self,
  34. in_channels=512,
  35. out_channels=512,
  36. num_blocks=24,
  37. kernel_size=16,
  38. norm="ln",
  39. num_spks=2,
  40. skip_around_intra=True,
  41. use_global_pos_enc=True,
  42. max_length=20000,
  43. ):
  44. super(MossFormer, self).__init__()
  45. self.num_spks = num_spks
  46. # Encoding
  47. self.enc = MossFormerEncoder(kernel_size=kernel_size, out_channels=in_channels, in_channels=1)
  48. ##Compute Mask
  49. self.mask_net = MossFormer_MaskNet(
  50. in_channels=in_channels,
  51. out_channels=out_channels,
  52. num_blocks=num_blocks,
  53. norm=norm,
  54. num_spks=num_spks,
  55. skip_around_intra=skip_around_intra,
  56. use_global_pos_enc=use_global_pos_enc,
  57. max_length=max_length,
  58. )
  59. self.dec = MossFormerDecoder(
  60. in_channels=out_channels,
  61. out_channels=1,
  62. kernel_size=kernel_size,
  63. stride = kernel_size//2,
  64. bias=False
  65. )
  66. def forward(self, input):
  67. x = self.enc(input)
  68. mask = self.mask_net(x)
  69. x = torch.stack([x] * self.num_spks)
  70. sep_x = x * mask
  71. # Decoding
  72. est_source = torch.cat(
  73. [
  74. self.dec(sep_x[i]).unsqueeze(-1)
  75. for i in range(self.num_spks)
  76. ],
  77. dim=-1,
  78. )
  79. T_origin = input.size(1)
  80. T_est = est_source.size(1)
  81. if T_origin > T_est:
  82. est_source = F.pad(est_source, (0, 0, 0, T_origin - T_est))
  83. else:
  84. est_source = est_source[:, :T_origin, :]
  85. out = []
  86. for spk in range(self.num_spks):
  87. out.append(est_source[:,:,spk])
  88. return out