default.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. import copy
  2. from typing import Optional
  3. from typing import Tuple
  4. from typing import Union
  5. import humanfriendly
  6. import numpy as np
  7. import torch
  8. from torch_complex.tensor import ComplexTensor
  9. from typeguard import check_argument_types
  10. from funasr.layers.log_mel import LogMel
  11. from funasr.layers.stft import Stft
  12. from funasr.models.frontend.abs_frontend import AbsFrontend
  13. from funasr.modules.frontends.frontend import Frontend
  14. from funasr.utils.get_default_kwargs import get_default_kwargs
  15. class DefaultFrontend(AbsFrontend):
  16. """Conventional frontend structure for ASR.
  17. Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
  18. """
  19. def __init__(
  20. self,
  21. fs: Union[int, str] = 16000,
  22. n_fft: int = 512,
  23. win_length: int = None,
  24. hop_length: int = 128,
  25. window: Optional[str] = "hann",
  26. center: bool = True,
  27. normalized: bool = False,
  28. onesided: bool = True,
  29. n_mels: int = 80,
  30. fmin: int = None,
  31. fmax: int = None,
  32. htk: bool = False,
  33. frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
  34. apply_stft: bool = True,
  35. ):
  36. assert check_argument_types()
  37. super().__init__()
  38. if isinstance(fs, str):
  39. fs = humanfriendly.parse_size(fs)
  40. # Deepcopy (In general, dict shouldn't be used as default arg)
  41. frontend_conf = copy.deepcopy(frontend_conf)
  42. self.hop_length = hop_length
  43. if apply_stft:
  44. self.stft = Stft(
  45. n_fft=n_fft,
  46. win_length=win_length,
  47. hop_length=hop_length,
  48. center=center,
  49. window=window,
  50. normalized=normalized,
  51. onesided=onesided,
  52. )
  53. else:
  54. self.stft = None
  55. self.apply_stft = apply_stft
  56. if frontend_conf is not None:
  57. self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
  58. else:
  59. self.frontend = None
  60. self.logmel = LogMel(
  61. fs=fs,
  62. n_fft=n_fft,
  63. n_mels=n_mels,
  64. fmin=fmin,
  65. fmax=fmax,
  66. htk=htk,
  67. )
  68. self.n_mels = n_mels
  69. self.frontend_type = "default"
  70. def output_size(self) -> int:
  71. return self.n_mels
  72. def forward(
  73. self, input: torch.Tensor, input_lengths: torch.Tensor
  74. ) -> Tuple[torch.Tensor, torch.Tensor]:
  75. # 1. Domain-conversion: e.g. Stft: time -> time-freq
  76. if self.stft is not None:
  77. input_stft, feats_lens = self._compute_stft(input, input_lengths)
  78. else:
  79. input_stft = ComplexTensor(input[..., 0], input[..., 1])
  80. feats_lens = input_lengths
  81. # 2. [Option] Speech enhancement
  82. if self.frontend is not None:
  83. assert isinstance(input_stft, ComplexTensor), type(input_stft)
  84. # input_stft: (Batch, Length, [Channel], Freq)
  85. input_stft, _, mask = self.frontend(input_stft, feats_lens)
  86. # 3. [Multi channel case]: Select a channel
  87. if input_stft.dim() == 4:
  88. # h: (B, T, C, F) -> h: (B, T, F)
  89. if self.training:
  90. # Select 1ch randomly
  91. ch = np.random.randint(input_stft.size(2))
  92. input_stft = input_stft[:, :, ch, :]
  93. else:
  94. # Use the first channel
  95. input_stft = input_stft[:, :, 0, :]
  96. # 4. STFT -> Power spectrum
  97. # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
  98. input_power = input_stft.real ** 2 + input_stft.imag ** 2
  99. # 5. Feature transform e.g. Stft -> Log-Mel-Fbank
  100. # input_power: (Batch, [Channel,] Length, Freq)
  101. # -> input_feats: (Batch, Length, Dim)
  102. input_feats, _ = self.logmel(input_power, feats_lens)
  103. return input_feats, feats_lens
  104. def _compute_stft(
  105. self, input: torch.Tensor, input_lengths: torch.Tensor
  106. ) -> torch.Tensor:
  107. input_stft, feats_lens = self.stft(input, input_lengths)
  108. assert input_stft.dim() >= 4, input_stft.shape
  109. # "2" refers to the real/imag parts of Complex
  110. assert input_stft.shape[-1] == 2, input_stft.shape
  111. # Change torch.Tensor to ComplexTensor
  112. # input_stft: (..., F, 2) -> (..., F)
  113. input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
  114. return input_stft, feats_lens
  115. class MultiChannelFrontend(AbsFrontend):
  116. """Conventional frontend structure for ASR.
  117. Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
  118. """
  119. def __init__(
  120. self,
  121. fs: Union[int, str] = 16000,
  122. n_fft: int = 512,
  123. win_length: int = None,
  124. hop_length: int = 128,
  125. window: Optional[str] = "hann",
  126. center: bool = True,
  127. normalized: bool = False,
  128. onesided: bool = True,
  129. n_mels: int = 80,
  130. fmin: int = None,
  131. fmax: int = None,
  132. htk: bool = False,
  133. frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
  134. apply_stft: bool = True,
  135. frame_length: int = None,
  136. frame_shift: int = None,
  137. lfr_m: int = None,
  138. lfr_n: int = None,
  139. ):
  140. assert check_argument_types()
  141. super().__init__()
  142. if isinstance(fs, str):
  143. fs = humanfriendly.parse_size(fs)
  144. # Deepcopy (In general, dict shouldn't be used as default arg)
  145. frontend_conf = copy.deepcopy(frontend_conf)
  146. self.hop_length = hop_length
  147. if apply_stft:
  148. self.stft = Stft(
  149. n_fft=n_fft,
  150. win_length=win_length,
  151. hop_length=hop_length,
  152. center=center,
  153. window=window,
  154. normalized=normalized,
  155. onesided=onesided,
  156. )
  157. else:
  158. self.stft = None
  159. self.apply_stft = apply_stft
  160. if frontend_conf is not None:
  161. self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
  162. else:
  163. self.frontend = None
  164. self.logmel = LogMel(
  165. fs=fs,
  166. n_fft=n_fft,
  167. n_mels=n_mels,
  168. fmin=fmin,
  169. fmax=fmax,
  170. htk=htk,
  171. )
  172. self.n_mels = n_mels
  173. self.frontend_type = "multichannelfrontend"
  174. def output_size(self) -> int:
  175. return self.n_mels
  176. def forward(
  177. self, input: torch.Tensor, input_lengths: torch.Tensor
  178. ) -> Tuple[torch.Tensor, torch.Tensor]:
  179. # 1. Domain-conversion: e.g. Stft: time -> time-freq
  180. #import pdb;pdb.set_trace()
  181. if self.stft is not None:
  182. input_stft, feats_lens = self._compute_stft(input, input_lengths)
  183. else:
  184. if isinstance(input, ComplexTensor):
  185. input_stft = input
  186. else:
  187. input_stft = ComplexTensor(input[..., 0], input[..., 1])
  188. feats_lens = input_lengths
  189. # 2. [Option] Speech enhancement
  190. if self.frontend is not None:
  191. assert isinstance(input_stft, ComplexTensor), type(input_stft)
  192. # input_stft: (Batch, Length, [Channel], Freq)
  193. input_stft, _, mask = self.frontend(input_stft, feats_lens)
  194. # 4. STFT -> Power spectrum
  195. # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
  196. input_power = input_stft.real ** 2 + input_stft.imag ** 2
  197. # 5. Feature transform e.g. Stft -> Log-Mel-Fbank
  198. # input_power: (Batch, [Channel,] Length, Freq)
  199. # -> input_feats: (Batch, Length, Dim)
  200. input_feats, _ = self.logmel(input_power, feats_lens)
  201. bt = input_feats.size(0)
  202. if input_feats.dim() ==4:
  203. channel_size = input_feats.size(2)
  204. # batch * channel * T * D
  205. #pdb.set_trace()
  206. input_feats = input_feats.transpose(1,2).reshape(bt*channel_size,-1,80).contiguous()
  207. # input_feats = input_feats.transpose(1,2)
  208. # batch * channel
  209. feats_lens = feats_lens.repeat(1,channel_size).squeeze()
  210. else:
  211. channel_size = 1
  212. return input_feats, feats_lens, channel_size
  213. def _compute_stft(
  214. self, input: torch.Tensor, input_lengths: torch.Tensor
  215. ) -> torch.Tensor:
  216. input_stft, feats_lens = self.stft(input, input_lengths)
  217. assert input_stft.dim() >= 4, input_stft.shape
  218. # "2" refers to the real/imag parts of Complex
  219. assert input_stft.shape[-1] == 2, input_stft.shape
  220. # Change torch.Tensor to ComplexTensor
  221. # input_stft: (..., F, 2) -> (..., F)
  222. input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
  223. return input_stft, feats_lens