default.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. import copy
  2. from typing import Optional
  3. from typing import Tuple
  4. from typing import Union
  5. import humanfriendly
  6. import numpy as np
  7. import torch
  8. from torch_complex.tensor import ComplexTensor
  9. from typeguard import check_argument_types
  10. from funasr.layers.log_mel import LogMel
  11. from funasr.layers.stft import Stft
  12. from funasr.models.frontend.abs_frontend import AbsFrontend
  13. from funasr.modules.frontends.frontend import Frontend
  14. from funasr.utils.get_default_kwargs import get_default_kwargs
  15. class DefaultFrontend(AbsFrontend):
  16. """Conventional frontend structure for ASR.
  17. Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
  18. """
  19. def __init__(
  20. self,
  21. fs: Union[int, str] = 16000,
  22. n_fft: int = 512,
  23. win_length: int = None,
  24. hop_length: int = 128,
  25. window: Optional[str] = "hann",
  26. center: bool = True,
  27. normalized: bool = False,
  28. onesided: bool = True,
  29. n_mels: int = 80,
  30. fmin: int = None,
  31. fmax: int = None,
  32. htk: bool = False,
  33. frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
  34. apply_stft: bool = True,
  35. use_channel: int = None,
  36. ):
  37. assert check_argument_types()
  38. super().__init__()
  39. if isinstance(fs, str):
  40. fs = humanfriendly.parse_size(fs)
  41. # Deepcopy (In general, dict shouldn't be used as default arg)
  42. frontend_conf = copy.deepcopy(frontend_conf)
  43. self.hop_length = hop_length
  44. if apply_stft:
  45. self.stft = Stft(
  46. n_fft=n_fft,
  47. win_length=win_length,
  48. hop_length=hop_length,
  49. center=center,
  50. window=window,
  51. normalized=normalized,
  52. onesided=onesided,
  53. )
  54. else:
  55. self.stft = None
  56. self.apply_stft = apply_stft
  57. if frontend_conf is not None:
  58. self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
  59. else:
  60. self.frontend = None
  61. self.logmel = LogMel(
  62. fs=fs,
  63. n_fft=n_fft,
  64. n_mels=n_mels,
  65. fmin=fmin,
  66. fmax=fmax,
  67. htk=htk,
  68. )
  69. self.n_mels = n_mels
  70. self.frontend_type = "default"
  71. self.use_channel = use_channel
  72. def output_size(self) -> int:
  73. return self.n_mels
  74. def forward(
  75. self, input: torch.Tensor, input_lengths: torch.Tensor
  76. ) -> Tuple[torch.Tensor, torch.Tensor]:
  77. # 1. Domain-conversion: e.g. Stft: time -> time-freq
  78. if self.stft is not None:
  79. input_stft, feats_lens = self._compute_stft(input, input_lengths)
  80. else:
  81. input_stft = ComplexTensor(input[..., 0], input[..., 1])
  82. feats_lens = input_lengths
  83. # 2. [Option] Speech enhancement
  84. if self.frontend is not None:
  85. assert isinstance(input_stft, ComplexTensor), type(input_stft)
  86. # input_stft: (Batch, Length, [Channel], Freq)
  87. input_stft, _, mask = self.frontend(input_stft, feats_lens)
  88. # 3. [Multi channel case]: Select a channel
  89. if input_stft.dim() == 4:
  90. # h: (B, T, C, F) -> h: (B, T, F)
  91. if self.training:
  92. if self.use_channel is not None:
  93. input_stft = input_stft[:, :, self.use_channel, :]
  94. else:
  95. # Select 1ch randomly
  96. ch = np.random.randint(input_stft.size(2))
  97. input_stft = input_stft[:, :, ch, :]
  98. else:
  99. # Use the first channel
  100. input_stft = input_stft[:, :, 0, :]
  101. # 4. STFT -> Power spectrum
  102. # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
  103. input_power = input_stft.real ** 2 + input_stft.imag ** 2
  104. # 5. Feature transform e.g. Stft -> Log-Mel-Fbank
  105. # input_power: (Batch, [Channel,] Length, Freq)
  106. # -> input_feats: (Batch, Length, Dim)
  107. input_feats, _ = self.logmel(input_power, feats_lens)
  108. return input_feats, feats_lens
  109. def _compute_stft(
  110. self, input: torch.Tensor, input_lengths: torch.Tensor
  111. ) -> torch.Tensor:
  112. input_stft, feats_lens = self.stft(input, input_lengths)
  113. assert input_stft.dim() >= 4, input_stft.shape
  114. # "2" refers to the real/imag parts of Complex
  115. assert input_stft.shape[-1] == 2, input_stft.shape
  116. # Change torch.Tensor to ComplexTensor
  117. # input_stft: (..., F, 2) -> (..., F)
  118. input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
  119. return input_stft, feats_lens
  120. class MultiChannelFrontend(AbsFrontend):
  121. """Conventional frontend structure for ASR.
  122. Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
  123. """
  124. def __init__(
  125. self,
  126. fs: Union[int, str] = 16000,
  127. n_fft: int = 512,
  128. win_length: int = None,
  129. hop_length: int = 128,
  130. window: Optional[str] = "hann",
  131. center: bool = True,
  132. normalized: bool = False,
  133. onesided: bool = True,
  134. n_mels: int = 80,
  135. fmin: int = None,
  136. fmax: int = None,
  137. htk: bool = False,
  138. frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
  139. apply_stft: bool = True,
  140. frame_length: int = None,
  141. frame_shift: int = None,
  142. lfr_m: int = None,
  143. lfr_n: int = None,
  144. ):
  145. assert check_argument_types()
  146. super().__init__()
  147. if isinstance(fs, str):
  148. fs = humanfriendly.parse_size(fs)
  149. # Deepcopy (In general, dict shouldn't be used as default arg)
  150. frontend_conf = copy.deepcopy(frontend_conf)
  151. self.hop_length = hop_length
  152. if apply_stft:
  153. self.stft = Stft(
  154. n_fft=n_fft,
  155. win_length=win_length,
  156. hop_length=hop_length,
  157. center=center,
  158. window=window,
  159. normalized=normalized,
  160. onesided=onesided,
  161. )
  162. else:
  163. self.stft = None
  164. self.apply_stft = apply_stft
  165. if frontend_conf is not None:
  166. self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
  167. else:
  168. self.frontend = None
  169. self.logmel = LogMel(
  170. fs=fs,
  171. n_fft=n_fft,
  172. n_mels=n_mels,
  173. fmin=fmin,
  174. fmax=fmax,
  175. htk=htk,
  176. )
  177. self.n_mels = n_mels
  178. self.frontend_type = "multichannelfrontend"
  179. def output_size(self) -> int:
  180. return self.n_mels
  181. def forward(
  182. self, input: torch.Tensor, input_lengths: torch.Tensor
  183. ) -> Tuple[torch.Tensor, torch.Tensor]:
  184. # 1. Domain-conversion: e.g. Stft: time -> time-freq
  185. #import pdb;pdb.set_trace()
  186. if self.stft is not None:
  187. input_stft, feats_lens = self._compute_stft(input, input_lengths)
  188. else:
  189. if isinstance(input, ComplexTensor):
  190. input_stft = input
  191. else:
  192. input_stft = ComplexTensor(input[..., 0], input[..., 1])
  193. feats_lens = input_lengths
  194. # 2. [Option] Speech enhancement
  195. if self.frontend is not None:
  196. assert isinstance(input_stft, ComplexTensor), type(input_stft)
  197. # input_stft: (Batch, Length, [Channel], Freq)
  198. input_stft, _, mask = self.frontend(input_stft, feats_lens)
  199. # 4. STFT -> Power spectrum
  200. # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
  201. input_power = input_stft.real ** 2 + input_stft.imag ** 2
  202. # 5. Feature transform e.g. Stft -> Log-Mel-Fbank
  203. # input_power: (Batch, [Channel,] Length, Freq)
  204. # -> input_feats: (Batch, Length, Dim)
  205. input_feats, _ = self.logmel(input_power, feats_lens)
  206. bt = input_feats.size(0)
  207. if input_feats.dim() ==4:
  208. channel_size = input_feats.size(2)
  209. # batch * channel * T * D
  210. #pdb.set_trace()
  211. input_feats = input_feats.transpose(1,2).reshape(bt*channel_size,-1,80).contiguous()
  212. # input_feats = input_feats.transpose(1,2)
  213. # batch * channel
  214. feats_lens = feats_lens.repeat(1,channel_size).squeeze()
  215. else:
  216. channel_size = 1
  217. return input_feats, feats_lens, channel_size
  218. def _compute_stft(
  219. self, input: torch.Tensor, input_lengths: torch.Tensor
  220. ) -> torch.Tensor:
  221. input_stft, feats_lens = self.stft(input, input_lengths)
  222. assert input_stft.dim() >= 4, input_stft.shape
  223. # "2" refers to the real/imag parts of Complex
  224. assert input_stft.shape[-1] == 2, input_stft.shape
  225. # Change torch.Tensor to ComplexTensor
  226. # input_stft: (..., F, 2) -> (..., F)
  227. input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
  228. return input_stft, feats_lens