default.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. import copy
  2. from typing import Optional
  3. from typing import Tuple
  4. from typing import Union
  5. import logging
  6. import humanfriendly
  7. import numpy as np
  8. import torch
  9. import torch.nn as nn
  10. try:
  11. from torch_complex.tensor import ComplexTensor
  12. except:
  13. print("Please install torch_complex firstly")
  14. from funasr.frontends.utils.log_mel import LogMel
  15. from funasr.frontends.utils.stft import Stft
  16. from funasr.frontends.utils.frontend import Frontend
  17. from funasr.models.transformer.utils.nets_utils import make_pad_mask
  18. class DefaultFrontend(nn.Module):
  19. """Conventional frontend structure for ASR.
  20. Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
  21. """
  22. def __init__(
  23. self,
  24. fs: Union[int, str] = 16000,
  25. n_fft: int = 512,
  26. win_length: int = None,
  27. hop_length: int = 128,
  28. window: Optional[str] = "hann",
  29. center: bool = True,
  30. normalized: bool = False,
  31. onesided: bool = True,
  32. n_mels: int = 80,
  33. fmin: int = None,
  34. fmax: int = None,
  35. htk: bool = False,
  36. frontend_conf: Optional[dict] = None,
  37. apply_stft: bool = True,
  38. use_channel: int = None,
  39. ):
  40. super().__init__()
  41. if isinstance(fs, str):
  42. fs = humanfriendly.parse_size(fs)
  43. # Deepcopy (In general, dict shouldn't be used as default arg)
  44. frontend_conf = copy.deepcopy(frontend_conf)
  45. self.hop_length = hop_length
  46. if apply_stft:
  47. self.stft = Stft(
  48. n_fft=n_fft,
  49. win_length=win_length,
  50. hop_length=hop_length,
  51. center=center,
  52. window=window,
  53. normalized=normalized,
  54. onesided=onesided,
  55. )
  56. else:
  57. self.stft = None
  58. self.apply_stft = apply_stft
  59. if frontend_conf is not None:
  60. self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
  61. else:
  62. self.frontend = None
  63. self.logmel = LogMel(
  64. fs=fs,
  65. n_fft=n_fft,
  66. n_mels=n_mels,
  67. fmin=fmin,
  68. fmax=fmax,
  69. htk=htk,
  70. )
  71. self.n_mels = n_mels
  72. self.use_channel = use_channel
  73. self.frontend_type = "default"
  74. def output_size(self) -> int:
  75. return self.n_mels
  76. def forward(
  77. self, input: torch.Tensor, input_lengths: torch.Tensor
  78. ) -> Tuple[torch.Tensor, torch.Tensor]:
  79. # 1. Domain-conversion: e.g. Stft: time -> time-freq
  80. if self.stft is not None:
  81. input_stft, feats_lens = self._compute_stft(input, input_lengths)
  82. else:
  83. input_stft = ComplexTensor(input[..., 0], input[..., 1])
  84. feats_lens = input_lengths
  85. # 2. [Option] Speech enhancement
  86. if self.frontend is not None:
  87. assert isinstance(input_stft, ComplexTensor), type(input_stft)
  88. # input_stft: (Batch, Length, [Channel], Freq)
  89. input_stft, _, mask = self.frontend(input_stft, feats_lens)
  90. # 3. [Multi channel case]: Select a channel
  91. if input_stft.dim() == 4:
  92. # h: (B, T, C, F) -> h: (B, T, F)
  93. if self.training:
  94. if self.use_channel is not None:
  95. input_stft = input_stft[:, :, self.use_channel, :]
  96. else:
  97. # Select 1ch randomly
  98. ch = np.random.randint(input_stft.size(2))
  99. input_stft = input_stft[:, :, ch, :]
  100. else:
  101. # Use the first channel
  102. input_stft = input_stft[:, :, 0, :]
  103. # 4. STFT -> Power spectrum
  104. # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
  105. input_power = input_stft.real ** 2 + input_stft.imag ** 2
  106. # 5. Feature transform e.g. Stft -> Log-Mel-Fbank
  107. # input_power: (Batch, [Channel,] Length, Freq)
  108. # -> input_feats: (Batch, Length, Dim)
  109. input_feats, _ = self.logmel(input_power, feats_lens)
  110. return input_feats, feats_lens
  111. def _compute_stft(
  112. self, input: torch.Tensor, input_lengths: torch.Tensor
  113. ) -> torch.Tensor:
  114. input_stft, feats_lens = self.stft(input, input_lengths)
  115. assert input_stft.dim() >= 4, input_stft.shape
  116. # "2" refers to the real/imag parts of Complex
  117. assert input_stft.shape[-1] == 2, input_stft.shape
  118. # Change torch.Tensor to ComplexTensor
  119. # input_stft: (..., F, 2) -> (..., F)
  120. input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
  121. return input_stft, feats_lens
  122. class MultiChannelFrontend(nn.Module):
  123. """Conventional frontend structure for ASR.
  124. Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
  125. """
  126. def __init__(
  127. self,
  128. fs: Union[int, str] = 16000,
  129. n_fft: int = 512,
  130. win_length: int = None,
  131. hop_length: int = None,
  132. frame_length: int = None,
  133. frame_shift: int = None,
  134. window: Optional[str] = "hann",
  135. center: bool = True,
  136. normalized: bool = False,
  137. onesided: bool = True,
  138. n_mels: int = 80,
  139. fmin: int = None,
  140. fmax: int = None,
  141. htk: bool = False,
  142. frontend_conf: Optional[dict] = None,
  143. apply_stft: bool = True,
  144. use_channel: int = None,
  145. lfr_m: int = 1,
  146. lfr_n: int = 1,
  147. cmvn_file: str = None,
  148. mc: bool = True
  149. ):
  150. super().__init__()
  151. if isinstance(fs, str):
  152. fs = humanfriendly.parse_size(fs)
  153. # Deepcopy (In general, dict shouldn't be used as default arg)
  154. frontend_conf = copy.deepcopy(frontend_conf)
  155. if win_length is None and hop_length is None:
  156. self.win_length = frame_length * 16
  157. self.hop_length = frame_shift * 16
  158. elif frame_length is None and frame_shift is None:
  159. self.win_length = self.win_length
  160. self.hop_length = self.hop_length
  161. else:
  162. logging.error(
  163. "Only one of (win_length, hop_length) and (frame_length, frame_shift)"
  164. "can be set."
  165. )
  166. exit(1)
  167. if apply_stft:
  168. self.stft = Stft(
  169. n_fft=n_fft,
  170. win_length=self.win_length,
  171. hop_length=self.hop_length,
  172. center=center,
  173. window=window,
  174. normalized=normalized,
  175. onesided=onesided,
  176. )
  177. else:
  178. self.stft = None
  179. self.apply_stft = apply_stft
  180. if frontend_conf is not None:
  181. self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
  182. else:
  183. self.frontend = None
  184. self.logmel = LogMel(
  185. fs=fs,
  186. n_fft=n_fft,
  187. n_mels=n_mels,
  188. fmin=fmin,
  189. fmax=fmax,
  190. htk=htk,
  191. )
  192. self.n_mels = n_mels
  193. self.use_channel = use_channel
  194. self.mc = mc
  195. if not self.mc:
  196. if self.use_channel is not None:
  197. logging.info("use the channel %d" % (self.use_channel))
  198. else:
  199. logging.info("random select channel")
  200. self.cmvn_file = cmvn_file
  201. if self.cmvn_file is not None:
  202. mean, std = self._load_cmvn(self.cmvn_file)
  203. self.register_buffer("mean", torch.from_numpy(mean))
  204. self.register_buffer("std", torch.from_numpy(std))
  205. self.frontend_type = "multichannelfrontend"
  206. def output_size(self) -> int:
  207. return self.n_mels
  208. def forward(
  209. self, input: torch.Tensor, input_lengths: torch.Tensor
  210. ) -> Tuple[torch.Tensor, torch.Tensor]:
  211. # 1. Domain-conversion: e.g. Stft: time -> time-freq
  212. #import pdb;pdb.set_trace()
  213. if self.stft is not None:
  214. input_stft, feats_lens = self._compute_stft(input, input_lengths)
  215. else:
  216. input_stft = ComplexTensor(input[..., 0], input[..., 1])
  217. feats_lens = input_lengths
  218. # 2. [Option] Speech enhancement
  219. if self.frontend is not None:
  220. assert isinstance(input_stft, ComplexTensor), type(input_stft)
  221. # input_stft: (Batch, Length, [Channel], Freq)
  222. input_stft, _, mask = self.frontend(input_stft, feats_lens)
  223. # 3. [Multi channel case]: Select a channel(sa_asr)
  224. if input_stft.dim() == 4 and not self.mc:
  225. # h: (B, T, C, F) -> h: (B, T, F)
  226. if self.training:
  227. if self.use_channel is not None:
  228. input_stft = input_stft[:, :, self.use_channel, :]
  229. else:
  230. # Select 1ch randomly
  231. ch = np.random.randint(input_stft.size(2))
  232. input_stft = input_stft[:, :, ch, :]
  233. else:
  234. # Use the first channel
  235. input_stft = input_stft[:, :, 0, :]
  236. # 4. STFT -> Power spectrum
  237. # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
  238. input_power = input_stft.real ** 2 + input_stft.imag ** 2
  239. # 5. Feature transform e.g. Stft -> Log-Mel-Fbank
  240. # input_power: (Batch, [Channel,] Length, Freq)
  241. # -> input_feats: (Batch, Length, Dim)
  242. input_feats, _ = self.logmel(input_power, feats_lens)
  243. if self.mc:
  244. # MFCCA
  245. if input_feats.dim() ==4:
  246. bt = input_feats.size(0)
  247. channel_size = input_feats.size(2)
  248. input_feats = input_feats.transpose(1,2).reshape(bt*channel_size,-1,80).contiguous()
  249. feats_lens = feats_lens.repeat(1,channel_size).squeeze()
  250. else:
  251. channel_size = 1
  252. return input_feats, feats_lens, channel_size
  253. else:
  254. # 6. Apply CMVN
  255. if self.cmvn_file is not None:
  256. if feats_lens is None:
  257. feats_lens = input_feats.new_full([input_feats.size(0)], input_feats.size(1))
  258. self.mean = self.mean.to(input_feats.device, input_feats.dtype)
  259. self.std = self.std.to(input_feats.device, input_feats.dtype)
  260. mask = make_pad_mask(feats_lens, input_feats, 1)
  261. if input_feats.requires_grad:
  262. input_feats = input_feats + self.mean
  263. else:
  264. input_feats += self.mean
  265. if input_feats.requires_grad:
  266. input_feats = input_feats.masked_fill(mask, 0.0)
  267. else:
  268. input_feats.masked_fill_(mask, 0.0)
  269. input_feats *= self.std
  270. return input_feats, feats_lens
  271. def _compute_stft(
  272. self, input: torch.Tensor, input_lengths: torch.Tensor
  273. ) -> torch.Tensor:
  274. input_stft, feats_lens = self.stft(input, input_lengths)
  275. assert input_stft.dim() >= 4, input_stft.shape
  276. # "2" refers to the real/imag parts of Complex
  277. assert input_stft.shape[-1] == 2, input_stft.shape
  278. # Change torch.Tensor to ComplexTensor
  279. # input_stft: (..., F, 2) -> (..., F)
  280. input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
  281. return input_stft, feats_lens
  282. def _load_cmvn(self, cmvn_file):
  283. with open(cmvn_file, 'r', encoding='utf-8') as f:
  284. lines = f.readlines()
  285. means_list = []
  286. vars_list = []
  287. for i in range(len(lines)):
  288. line_item = lines[i].split()
  289. if line_item[0] == '<AddShift>':
  290. line_item = lines[i + 1].split()
  291. if line_item[0] == '<LearnRateCoef>':
  292. add_shift_line = line_item[3:(len(line_item) - 1)]
  293. means_list = list(add_shift_line)
  294. continue
  295. elif line_item[0] == '<Rescale>':
  296. line_item = lines[i + 1].split()
  297. if line_item[0] == '<LearnRateCoef>':
  298. rescale_line = line_item[3:(len(line_item) - 1)]
  299. vars_list = list(rescale_line)
  300. continue
  301. means = np.array(means_list).astype(np.float)
  302. vars = np.array(vars_list).astype(np.float)
  303. return means, vars