default.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. import copy
  2. from typing import Optional
  3. from typing import Tuple
  4. from typing import Union
  5. import logging
  6. import humanfriendly
  7. import numpy as np
  8. import torch
  9. from torch_complex.tensor import ComplexTensor
  10. from funasr.layers.log_mel import LogMel
  11. from funasr.layers.stft import Stft
  12. from funasr.models.frontend.abs_frontend import AbsFrontend
  13. from funasr.modules.frontends.frontend import Frontend
  14. from funasr.utils.get_default_kwargs import get_default_kwargs
  15. from funasr.modules.nets_utils import make_pad_mask
  16. class DefaultFrontend(AbsFrontend):
  17. """Conventional frontend structure for ASR.
  18. Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
  19. """
  20. def __init__(
  21. self,
  22. fs: Union[int, str] = 16000,
  23. n_fft: int = 512,
  24. win_length: int = None,
  25. hop_length: int = 128,
  26. window: Optional[str] = "hann",
  27. center: bool = True,
  28. normalized: bool = False,
  29. onesided: bool = True,
  30. n_mels: int = 80,
  31. fmin: int = None,
  32. fmax: int = None,
  33. htk: bool = False,
  34. frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
  35. apply_stft: bool = True,
  36. use_channel: int = None,
  37. ):
  38. super().__init__()
  39. if isinstance(fs, str):
  40. fs = humanfriendly.parse_size(fs)
  41. # Deepcopy (In general, dict shouldn't be used as default arg)
  42. frontend_conf = copy.deepcopy(frontend_conf)
  43. self.hop_length = hop_length
  44. if apply_stft:
  45. self.stft = Stft(
  46. n_fft=n_fft,
  47. win_length=win_length,
  48. hop_length=hop_length,
  49. center=center,
  50. window=window,
  51. normalized=normalized,
  52. onesided=onesided,
  53. )
  54. else:
  55. self.stft = None
  56. self.apply_stft = apply_stft
  57. if frontend_conf is not None:
  58. self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
  59. else:
  60. self.frontend = None
  61. self.logmel = LogMel(
  62. fs=fs,
  63. n_fft=n_fft,
  64. n_mels=n_mels,
  65. fmin=fmin,
  66. fmax=fmax,
  67. htk=htk,
  68. )
  69. self.n_mels = n_mels
  70. self.use_channel = use_channel
  71. self.frontend_type = "default"
  72. def output_size(self) -> int:
  73. return self.n_mels
  74. def forward(
  75. self, input: torch.Tensor, input_lengths: torch.Tensor
  76. ) -> Tuple[torch.Tensor, torch.Tensor]:
  77. # 1. Domain-conversion: e.g. Stft: time -> time-freq
  78. if self.stft is not None:
  79. input_stft, feats_lens = self._compute_stft(input, input_lengths)
  80. else:
  81. input_stft = ComplexTensor(input[..., 0], input[..., 1])
  82. feats_lens = input_lengths
  83. # 2. [Option] Speech enhancement
  84. if self.frontend is not None:
  85. assert isinstance(input_stft, ComplexTensor), type(input_stft)
  86. # input_stft: (Batch, Length, [Channel], Freq)
  87. input_stft, _, mask = self.frontend(input_stft, feats_lens)
  88. # 3. [Multi channel case]: Select a channel
  89. if input_stft.dim() == 4:
  90. # h: (B, T, C, F) -> h: (B, T, F)
  91. if self.training:
  92. if self.use_channel is not None:
  93. input_stft = input_stft[:, :, self.use_channel, :]
  94. else:
  95. # Select 1ch randomly
  96. ch = np.random.randint(input_stft.size(2))
  97. input_stft = input_stft[:, :, ch, :]
  98. else:
  99. # Use the first channel
  100. input_stft = input_stft[:, :, 0, :]
  101. # 4. STFT -> Power spectrum
  102. # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
  103. input_power = input_stft.real ** 2 + input_stft.imag ** 2
  104. # 5. Feature transform e.g. Stft -> Log-Mel-Fbank
  105. # input_power: (Batch, [Channel,] Length, Freq)
  106. # -> input_feats: (Batch, Length, Dim)
  107. input_feats, _ = self.logmel(input_power, feats_lens)
  108. return input_feats, feats_lens
  109. def _compute_stft(
  110. self, input: torch.Tensor, input_lengths: torch.Tensor
  111. ) -> torch.Tensor:
  112. input_stft, feats_lens = self.stft(input, input_lengths)
  113. assert input_stft.dim() >= 4, input_stft.shape
  114. # "2" refers to the real/imag parts of Complex
  115. assert input_stft.shape[-1] == 2, input_stft.shape
  116. # Change torch.Tensor to ComplexTensor
  117. # input_stft: (..., F, 2) -> (..., F)
  118. input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
  119. return input_stft, feats_lens
  120. class MultiChannelFrontend(AbsFrontend):
  121. """Conventional frontend structure for ASR.
  122. Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
  123. """
  124. def __init__(
  125. self,
  126. fs: Union[int, str] = 16000,
  127. n_fft: int = 512,
  128. win_length: int = None,
  129. hop_length: int = None,
  130. frame_length: int = None,
  131. frame_shift: int = None,
  132. window: Optional[str] = "hann",
  133. center: bool = True,
  134. normalized: bool = False,
  135. onesided: bool = True,
  136. n_mels: int = 80,
  137. fmin: int = None,
  138. fmax: int = None,
  139. htk: bool = False,
  140. frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
  141. apply_stft: bool = True,
  142. use_channel: int = None,
  143. lfr_m: int = 1,
  144. lfr_n: int = 1,
  145. cmvn_file: str = None,
  146. mc: bool = True
  147. ):
  148. super().__init__()
  149. if isinstance(fs, str):
  150. fs = humanfriendly.parse_size(fs)
  151. # Deepcopy (In general, dict shouldn't be used as default arg)
  152. frontend_conf = copy.deepcopy(frontend_conf)
  153. if win_length is None and hop_length is None:
  154. self.win_length = frame_length * 16
  155. self.hop_length = frame_shift * 16
  156. elif frame_length is None and frame_shift is None:
  157. self.win_length = self.win_length
  158. self.hop_length = self.hop_length
  159. else:
  160. logging.error(
  161. "Only one of (win_length, hop_length) and (frame_length, frame_shift)"
  162. "can be set."
  163. )
  164. exit(1)
  165. if apply_stft:
  166. self.stft = Stft(
  167. n_fft=n_fft,
  168. win_length=self.win_length,
  169. hop_length=self.hop_length,
  170. center=center,
  171. window=window,
  172. normalized=normalized,
  173. onesided=onesided,
  174. )
  175. else:
  176. self.stft = None
  177. self.apply_stft = apply_stft
  178. if frontend_conf is not None:
  179. self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
  180. else:
  181. self.frontend = None
  182. self.logmel = LogMel(
  183. fs=fs,
  184. n_fft=n_fft,
  185. n_mels=n_mels,
  186. fmin=fmin,
  187. fmax=fmax,
  188. htk=htk,
  189. )
  190. self.n_mels = n_mels
  191. self.use_channel = use_channel
  192. self.mc = mc
  193. if not self.mc:
  194. if self.use_channel is not None:
  195. logging.info("use the channel %d" % (self.use_channel))
  196. else:
  197. logging.info("random select channel")
  198. self.cmvn_file = cmvn_file
  199. if self.cmvn_file is not None:
  200. mean, std = self._load_cmvn(self.cmvn_file)
  201. self.register_buffer("mean", torch.from_numpy(mean))
  202. self.register_buffer("std", torch.from_numpy(std))
  203. self.frontend_type = "multichannelfrontend"
  204. def output_size(self) -> int:
  205. return self.n_mels
  206. def forward(
  207. self, input: torch.Tensor, input_lengths: torch.Tensor
  208. ) -> Tuple[torch.Tensor, torch.Tensor]:
  209. # 1. Domain-conversion: e.g. Stft: time -> time-freq
  210. #import pdb;pdb.set_trace()
  211. if self.stft is not None:
  212. input_stft, feats_lens = self._compute_stft(input, input_lengths)
  213. else:
  214. input_stft = ComplexTensor(input[..., 0], input[..., 1])
  215. feats_lens = input_lengths
  216. # 2. [Option] Speech enhancement
  217. if self.frontend is not None:
  218. assert isinstance(input_stft, ComplexTensor), type(input_stft)
  219. # input_stft: (Batch, Length, [Channel], Freq)
  220. input_stft, _, mask = self.frontend(input_stft, feats_lens)
  221. # 3. [Multi channel case]: Select a channel(sa_asr)
  222. if input_stft.dim() == 4 and not self.mc:
  223. # h: (B, T, C, F) -> h: (B, T, F)
  224. if self.training:
  225. if self.use_channel is not None:
  226. input_stft = input_stft[:, :, self.use_channel, :]
  227. else:
  228. # Select 1ch randomly
  229. ch = np.random.randint(input_stft.size(2))
  230. input_stft = input_stft[:, :, ch, :]
  231. else:
  232. # Use the first channel
  233. input_stft = input_stft[:, :, 0, :]
  234. # 4. STFT -> Power spectrum
  235. # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
  236. input_power = input_stft.real ** 2 + input_stft.imag ** 2
  237. # 5. Feature transform e.g. Stft -> Log-Mel-Fbank
  238. # input_power: (Batch, [Channel,] Length, Freq)
  239. # -> input_feats: (Batch, Length, Dim)
  240. input_feats, _ = self.logmel(input_power, feats_lens)
  241. if self.mc:
  242. # MFCCA
  243. if input_feats.dim() ==4:
  244. bt = input_feats.size(0)
  245. channel_size = input_feats.size(2)
  246. input_feats = input_feats.transpose(1,2).reshape(bt*channel_size,-1,80).contiguous()
  247. feats_lens = feats_lens.repeat(1,channel_size).squeeze()
  248. else:
  249. channel_size = 1
  250. return input_feats, feats_lens, channel_size
  251. else:
  252. # 6. Apply CMVN
  253. if self.cmvn_file is not None:
  254. if feats_lens is None:
  255. feats_lens = input_feats.new_full([input_feats.size(0)], input_feats.size(1))
  256. self.mean = self.mean.to(input_feats.device, input_feats.dtype)
  257. self.std = self.std.to(input_feats.device, input_feats.dtype)
  258. mask = make_pad_mask(feats_lens, input_feats, 1)
  259. if input_feats.requires_grad:
  260. input_feats = input_feats + self.mean
  261. else:
  262. input_feats += self.mean
  263. if input_feats.requires_grad:
  264. input_feats = input_feats.masked_fill(mask, 0.0)
  265. else:
  266. input_feats.masked_fill_(mask, 0.0)
  267. input_feats *= self.std
  268. return input_feats, feats_lens
  269. def _compute_stft(
  270. self, input: torch.Tensor, input_lengths: torch.Tensor
  271. ) -> torch.Tensor:
  272. input_stft, feats_lens = self.stft(input, input_lengths)
  273. assert input_stft.dim() >= 4, input_stft.shape
  274. # "2" refers to the real/imag parts of Complex
  275. assert input_stft.shape[-1] == 2, input_stft.shape
  276. # Change torch.Tensor to ComplexTensor
  277. # input_stft: (..., F, 2) -> (..., F)
  278. input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
  279. return input_stft, feats_lens
  280. def _load_cmvn(self, cmvn_file):
  281. with open(cmvn_file, 'r', encoding='utf-8') as f:
  282. lines = f.readlines()
  283. means_list = []
  284. vars_list = []
  285. for i in range(len(lines)):
  286. line_item = lines[i].split()
  287. if line_item[0] == '<AddShift>':
  288. line_item = lines[i + 1].split()
  289. if line_item[0] == '<LearnRateCoef>':
  290. add_shift_line = line_item[3:(len(line_item) - 1)]
  291. means_list = list(add_shift_line)
  292. continue
  293. elif line_item[0] == '<Rescale>':
  294. line_item = lines[i + 1].split()
  295. if line_item[0] == '<LearnRateCoef>':
  296. rescale_line = line_item[3:(len(line_item) - 1)]
  297. vars_list = list(rescale_line)
  298. continue
  299. means = np.array(means_list).astype(np.float)
  300. vars = np.array(vars_list).astype(np.float)
  301. return means, vars