default.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. import copy
  2. from typing import Optional
  3. from typing import Tuple
  4. from typing import Union
  5. import logging
  6. import numpy as np
  7. import torch
  8. import torch.nn as nn
  9. try:
  10. from torch_complex.tensor import ComplexTensor
  11. except:
  12. print("Please install torch_complex firstly")
  13. from funasr.frontends.utils.log_mel import LogMel
  14. from funasr.frontends.utils.stft import Stft
  15. from funasr.frontends.utils.frontend import Frontend
  16. from funasr.models.transformer.utils.nets_utils import make_pad_mask
  17. from funasr.register import tables
  18. @tables.register("frontend_classes", "DefaultFrontend")
  19. class DefaultFrontend(nn.Module):
  20. """Conventional frontend structure for ASR.
  21. Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
  22. """
  23. def __init__(
  24. self,
  25. fs: int = 16000,
  26. n_fft: int = 512,
  27. win_length: int = None,
  28. hop_length: int = 128,
  29. window: Optional[str] = "hann",
  30. center: bool = True,
  31. normalized: bool = False,
  32. onesided: bool = True,
  33. n_mels: int = 80,
  34. fmin: int = None,
  35. fmax: int = None,
  36. htk: bool = False,
  37. frontend_conf: Optional[dict] = None,
  38. apply_stft: bool = True,
  39. use_channel: int = None,
  40. **kwargs,
  41. ):
  42. super().__init__()
  43. # Deepcopy (In general, dict shouldn't be used as default arg)
  44. frontend_conf = copy.deepcopy(frontend_conf)
  45. self.hop_length = hop_length
  46. self.fs = fs
  47. if apply_stft:
  48. self.stft = Stft(
  49. n_fft=n_fft,
  50. win_length=win_length,
  51. hop_length=hop_length,
  52. center=center,
  53. window=window,
  54. normalized=normalized,
  55. onesided=onesided,
  56. )
  57. else:
  58. self.stft = None
  59. self.apply_stft = apply_stft
  60. if frontend_conf is not None:
  61. self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
  62. else:
  63. self.frontend = None
  64. self.logmel = LogMel(
  65. fs=fs,
  66. n_fft=n_fft,
  67. n_mels=n_mels,
  68. fmin=fmin,
  69. fmax=fmax,
  70. htk=htk,
  71. )
  72. self.n_mels = n_mels
  73. self.use_channel = use_channel
  74. self.frontend_type = "default"
  75. def output_size(self) -> int:
  76. return self.n_mels
  77. def forward(
  78. self, input: torch.Tensor, input_lengths: torch.Tensor
  79. ) -> Tuple[torch.Tensor, torch.Tensor]:
  80. # 1. Domain-conversion: e.g. Stft: time -> time-freq
  81. if self.stft is not None:
  82. input_stft, feats_lens = self._compute_stft(input, input_lengths)
  83. else:
  84. input_stft = ComplexTensor(input[..., 0], input[..., 1])
  85. feats_lens = input_lengths
  86. # 2. [Option] Speech enhancement
  87. if self.frontend is not None:
  88. assert isinstance(input_stft, ComplexTensor), type(input_stft)
  89. # input_stft: (Batch, Length, [Channel], Freq)
  90. input_stft, _, mask = self.frontend(input_stft, feats_lens)
  91. # 3. [Multi channel case]: Select a channel
  92. if input_stft.dim() == 4:
  93. # h: (B, T, C, F) -> h: (B, T, F)
  94. if self.training:
  95. if self.use_channel is not None:
  96. input_stft = input_stft[:, :, self.use_channel, :]
  97. else:
  98. # Select 1ch randomly
  99. ch = np.random.randint(input_stft.size(2))
  100. input_stft = input_stft[:, :, ch, :]
  101. else:
  102. # Use the first channel
  103. input_stft = input_stft[:, :, 0, :]
  104. # 4. STFT -> Power spectrum
  105. # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
  106. input_power = input_stft.real ** 2 + input_stft.imag ** 2
  107. # 5. Feature transform e.g. Stft -> Log-Mel-Fbank
  108. # input_power: (Batch, [Channel,] Length, Freq)
  109. # -> input_feats: (Batch, Length, Dim)
  110. input_feats, _ = self.logmel(input_power, feats_lens)
  111. return input_feats, feats_lens
  112. def _compute_stft(
  113. self, input: torch.Tensor, input_lengths: torch.Tensor
  114. ) -> torch.Tensor:
  115. input_stft, feats_lens = self.stft(input, input_lengths)
  116. assert input_stft.dim() >= 4, input_stft.shape
  117. # "2" refers to the real/imag parts of Complex
  118. assert input_stft.shape[-1] == 2, input_stft.shape
  119. # Change torch.Tensor to ComplexTensor
  120. # input_stft: (..., F, 2) -> (..., F)
  121. input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
  122. return input_stft, feats_lens
  123. class MultiChannelFrontend(nn.Module):
  124. """Conventional frontend structure for ASR.
  125. Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
  126. """
  127. def __init__(
  128. self,
  129. fs: int = 16000,
  130. n_fft: int = 512,
  131. win_length: int = None,
  132. hop_length: int = None,
  133. frame_length: int = None,
  134. frame_shift: int = None,
  135. window: Optional[str] = "hann",
  136. center: bool = True,
  137. normalized: bool = False,
  138. onesided: bool = True,
  139. n_mels: int = 80,
  140. fmin: int = None,
  141. fmax: int = None,
  142. htk: bool = False,
  143. frontend_conf: Optional[dict] = None,
  144. apply_stft: bool = True,
  145. use_channel: int = None,
  146. lfr_m: int = 1,
  147. lfr_n: int = 1,
  148. cmvn_file: str = None,
  149. mc: bool = True
  150. ):
  151. super().__init__()
  152. # Deepcopy (In general, dict shouldn't be used as default arg)
  153. frontend_conf = copy.deepcopy(frontend_conf)
  154. if win_length is None and hop_length is None:
  155. self.win_length = frame_length * 16
  156. self.hop_length = frame_shift * 16
  157. elif frame_length is None and frame_shift is None:
  158. self.win_length = self.win_length
  159. self.hop_length = self.hop_length
  160. else:
  161. logging.error(
  162. "Only one of (win_length, hop_length) and (frame_length, frame_shift)"
  163. "can be set."
  164. )
  165. exit(1)
  166. if apply_stft:
  167. self.stft = Stft(
  168. n_fft=n_fft,
  169. win_length=self.win_length,
  170. hop_length=self.hop_length,
  171. center=center,
  172. window=window,
  173. normalized=normalized,
  174. onesided=onesided,
  175. )
  176. else:
  177. self.stft = None
  178. self.apply_stft = apply_stft
  179. if frontend_conf is not None:
  180. self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
  181. else:
  182. self.frontend = None
  183. self.logmel = LogMel(
  184. fs=fs,
  185. n_fft=n_fft,
  186. n_mels=n_mels,
  187. fmin=fmin,
  188. fmax=fmax,
  189. htk=htk,
  190. )
  191. self.n_mels = n_mels
  192. self.use_channel = use_channel
  193. self.mc = mc
  194. if not self.mc:
  195. if self.use_channel is not None:
  196. logging.info("use the channel %d" % (self.use_channel))
  197. else:
  198. logging.info("random select channel")
  199. self.cmvn_file = cmvn_file
  200. if self.cmvn_file is not None:
  201. mean, std = self._load_cmvn(self.cmvn_file)
  202. self.register_buffer("mean", torch.from_numpy(mean))
  203. self.register_buffer("std", torch.from_numpy(std))
  204. self.frontend_type = "multichannelfrontend"
  205. def output_size(self) -> int:
  206. return self.n_mels
  207. def forward(
  208. self, input: torch.Tensor, input_lengths: torch.Tensor
  209. ) -> Tuple[torch.Tensor, torch.Tensor]:
  210. # 1. Domain-conversion: e.g. Stft: time -> time-freq
  211. #import pdb;pdb.set_trace()
  212. if self.stft is not None:
  213. input_stft, feats_lens = self._compute_stft(input, input_lengths)
  214. else:
  215. input_stft = ComplexTensor(input[..., 0], input[..., 1])
  216. feats_lens = input_lengths
  217. # 2. [Option] Speech enhancement
  218. if self.frontend is not None:
  219. assert isinstance(input_stft, ComplexTensor), type(input_stft)
  220. # input_stft: (Batch, Length, [Channel], Freq)
  221. input_stft, _, mask = self.frontend(input_stft, feats_lens)
  222. # 3. [Multi channel case]: Select a channel(sa_asr)
  223. if input_stft.dim() == 4 and not self.mc:
  224. # h: (B, T, C, F) -> h: (B, T, F)
  225. if self.training:
  226. if self.use_channel is not None:
  227. input_stft = input_stft[:, :, self.use_channel, :]
  228. else:
  229. # Select 1ch randomly
  230. ch = np.random.randint(input_stft.size(2))
  231. input_stft = input_stft[:, :, ch, :]
  232. else:
  233. # Use the first channel
  234. input_stft = input_stft[:, :, 0, :]
  235. # 4. STFT -> Power spectrum
  236. # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
  237. input_power = input_stft.real ** 2 + input_stft.imag ** 2
  238. # 5. Feature transform e.g. Stft -> Log-Mel-Fbank
  239. # input_power: (Batch, [Channel,] Length, Freq)
  240. # -> input_feats: (Batch, Length, Dim)
  241. input_feats, _ = self.logmel(input_power, feats_lens)
  242. if self.mc:
  243. # MFCCA
  244. if input_feats.dim() ==4:
  245. bt = input_feats.size(0)
  246. channel_size = input_feats.size(2)
  247. input_feats = input_feats.transpose(1,2).reshape(bt*channel_size,-1,80).contiguous()
  248. feats_lens = feats_lens.repeat(1,channel_size).squeeze()
  249. else:
  250. channel_size = 1
  251. return input_feats, feats_lens, channel_size
  252. else:
  253. # 6. Apply CMVN
  254. if self.cmvn_file is not None:
  255. if feats_lens is None:
  256. feats_lens = input_feats.new_full([input_feats.size(0)], input_feats.size(1))
  257. self.mean = self.mean.to(input_feats.device, input_feats.dtype)
  258. self.std = self.std.to(input_feats.device, input_feats.dtype)
  259. mask = make_pad_mask(feats_lens, input_feats, 1)
  260. if input_feats.requires_grad:
  261. input_feats = input_feats + self.mean
  262. else:
  263. input_feats += self.mean
  264. if input_feats.requires_grad:
  265. input_feats = input_feats.masked_fill(mask, 0.0)
  266. else:
  267. input_feats.masked_fill_(mask, 0.0)
  268. input_feats *= self.std
  269. return input_feats, feats_lens
  270. def _compute_stft(
  271. self, input: torch.Tensor, input_lengths: torch.Tensor
  272. ) -> torch.Tensor:
  273. input_stft, feats_lens = self.stft(input, input_lengths)
  274. assert input_stft.dim() >= 4, input_stft.shape
  275. # "2" refers to the real/imag parts of Complex
  276. assert input_stft.shape[-1] == 2, input_stft.shape
  277. # Change torch.Tensor to ComplexTensor
  278. # input_stft: (..., F, 2) -> (..., F)
  279. input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
  280. return input_stft, feats_lens
  281. def _load_cmvn(self, cmvn_file):
  282. with open(cmvn_file, 'r', encoding='utf-8') as f:
  283. lines = f.readlines()
  284. means_list = []
  285. vars_list = []
  286. for i in range(len(lines)):
  287. line_item = lines[i].split()
  288. if line_item[0] == '<AddShift>':
  289. line_item = lines[i + 1].split()
  290. if line_item[0] == '<LearnRateCoef>':
  291. add_shift_line = line_item[3:(len(line_item) - 1)]
  292. means_list = list(add_shift_line)
  293. continue
  294. elif line_item[0] == '<Rescale>':
  295. line_item = lines[i + 1].split()
  296. if line_item[0] == '<LearnRateCoef>':
  297. rescale_line = line_item[3:(len(line_item) - 1)]
  298. vars_list = list(rescale_line)
  299. continue
  300. means = np.array(means_list).astype(np.float)
  301. vars = np.array(vars_list).astype(np.float)
  302. return means, vars