default.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. import copy
  2. from typing import Optional
  3. from typing import Tuple
  4. from typing import Union
  5. import humanfriendly
  6. import numpy as np
  7. import torch
  8. from torch_complex.tensor import ComplexTensor
  9. from typeguard import check_argument_types
  10. from funasr.layers.log_mel import LogMel
  11. from funasr.layers.stft import Stft
  12. from funasr.models.frontend.abs_frontend import AbsFrontend
  13. from funasr.modules.frontends.frontend import Frontend
  14. from funasr.utils.get_default_kwargs import get_default_kwargs
  15. class DefaultFrontend(AbsFrontend):
  16. """Conventional frontend structure for ASR.
  17. Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
  18. """
  19. def __init__(
  20. self,
  21. fs: Union[int, str] = 16000,
  22. n_fft: int = 512,
  23. win_length: int = None,
  24. hop_length: int = 128,
  25. window: Optional[str] = "hann",
  26. center: bool = True,
  27. normalized: bool = False,
  28. onesided: bool = True,
  29. n_mels: int = 80,
  30. fmin: int = None,
  31. fmax: int = None,
  32. htk: bool = False,
  33. frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
  34. apply_stft: bool = True,
  35. ):
  36. assert check_argument_types()
  37. super().__init__()
  38. if isinstance(fs, str):
  39. fs = humanfriendly.parse_size(fs)
  40. # Deepcopy (In general, dict shouldn't be used as default arg)
  41. frontend_conf = copy.deepcopy(frontend_conf)
  42. self.hop_length = hop_length
  43. if apply_stft:
  44. self.stft = Stft(
  45. n_fft=n_fft,
  46. win_length=win_length,
  47. hop_length=hop_length,
  48. center=center,
  49. window=window,
  50. normalized=normalized,
  51. onesided=onesided,
  52. )
  53. else:
  54. self.stft = None
  55. self.apply_stft = apply_stft
  56. if frontend_conf is not None:
  57. self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
  58. else:
  59. self.frontend = None
  60. self.logmel = LogMel(
  61. fs=fs,
  62. n_fft=n_fft,
  63. n_mels=n_mels,
  64. fmin=fmin,
  65. fmax=fmax,
  66. htk=htk,
  67. )
  68. self.n_mels = n_mels
  69. self.frontend_type = "default"
  70. def output_size(self) -> int:
  71. return self.n_mels
  72. def forward(
  73. self, input: torch.Tensor, input_lengths: torch.Tensor
  74. ) -> Tuple[torch.Tensor, torch.Tensor]:
  75. # 1. Domain-conversion: e.g. Stft: time -> time-freq
  76. if self.stft is not None:
  77. input_stft, feats_lens = self._compute_stft(input, input_lengths)
  78. else:
  79. input_stft = ComplexTensor(input[..., 0], input[..., 1])
  80. feats_lens = input_lengths
  81. # 2. [Option] Speech enhancement
  82. if self.frontend is not None:
  83. assert isinstance(input_stft, ComplexTensor), type(input_stft)
  84. # input_stft: (Batch, Length, [Channel], Freq)
  85. input_stft, _, mask = self.frontend(input_stft, feats_lens)
  86. # 3. [Multi channel case]: Select a channel
  87. if input_stft.dim() == 4:
  88. # h: (B, T, C, F) -> h: (B, T, F)
  89. if self.training:
  90. # Select 1ch randomly
  91. ch = np.random.randint(input_stft.size(2))
  92. input_stft = input_stft[:, :, ch, :]
  93. else:
  94. # Use the first channel
  95. input_stft = input_stft[:, :, 0, :]
  96. # 4. STFT -> Power spectrum
  97. # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
  98. input_power = input_stft.real ** 2 + input_stft.imag ** 2
  99. # 5. Feature transform e.g. Stft -> Log-Mel-Fbank
  100. # input_power: (Batch, [Channel,] Length, Freq)
  101. # -> input_feats: (Batch, Length, Dim)
  102. input_feats, _ = self.logmel(input_power, feats_lens)
  103. return input_feats, feats_lens
  104. def _compute_stft(
  105. self, input: torch.Tensor, input_lengths: torch.Tensor
  106. ) -> torch.Tensor:
  107. input_stft, feats_lens = self.stft(input, input_lengths)
  108. assert input_stft.dim() >= 4, input_stft.shape
  109. # "2" refers to the real/imag parts of Complex
  110. assert input_stft.shape[-1] == 2, input_stft.shape
  111. # Change torch.Tensor to ComplexTensor
  112. # input_stft: (..., F, 2) -> (..., F)
  113. input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
  114. return input_stft, feats_lens