default.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. import copy
  2. from typing import Optional
  3. from typing import Tuple
  4. from typing import Union
  5. import logging
  6. import numpy as np
  7. import torch
  8. import torch.nn as nn
  9. try:
  10. from torch_complex.tensor import ComplexTensor
  11. except:
  12. print("Please install torch_complex firstly")
  13. from funasr.frontends.utils.log_mel import LogMel
  14. from funasr.frontends.utils.stft import Stft
  15. from funasr.frontends.utils.frontend import Frontend
  16. from funasr.models.transformer.utils.nets_utils import make_pad_mask
  17. from funasr.register import tables
  18. @tables.register("frontend_classes", "DefaultFrontend")
  19. class DefaultFrontend(nn.Module):
  20. """Conventional frontend structure for ASR.
  21. Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
  22. """
  23. def __init__(
  24. self,
  25. fs: int = 16000,
  26. n_fft: int = 512,
  27. win_length: int = None,
  28. hop_length: int = 128,
  29. window: Optional[str] = "hann",
  30. center: bool = True,
  31. normalized: bool = False,
  32. onesided: bool = True,
  33. n_mels: int = 80,
  34. fmin: int = None,
  35. fmax: int = None,
  36. htk: bool = False,
  37. frontend_conf: Optional[dict] = None,
  38. apply_stft: bool = True,
  39. use_channel: int = None,
  40. **kwargs,
  41. ):
  42. super().__init__()
  43. # Deepcopy (In general, dict shouldn't be used as default arg)
  44. frontend_conf = copy.deepcopy(frontend_conf)
  45. self.hop_length = hop_length
  46. if apply_stft:
  47. self.stft = Stft(
  48. n_fft=n_fft,
  49. win_length=win_length,
  50. hop_length=hop_length,
  51. center=center,
  52. window=window,
  53. normalized=normalized,
  54. onesided=onesided,
  55. )
  56. else:
  57. self.stft = None
  58. self.apply_stft = apply_stft
  59. if frontend_conf is not None:
  60. self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
  61. else:
  62. self.frontend = None
  63. self.logmel = LogMel(
  64. fs=fs,
  65. n_fft=n_fft,
  66. n_mels=n_mels,
  67. fmin=fmin,
  68. fmax=fmax,
  69. htk=htk,
  70. )
  71. self.n_mels = n_mels
  72. self.use_channel = use_channel
  73. self.frontend_type = "default"
  74. def output_size(self) -> int:
  75. return self.n_mels
  76. def forward(
  77. self, input: torch.Tensor, input_lengths: torch.Tensor
  78. ) -> Tuple[torch.Tensor, torch.Tensor]:
  79. # 1. Domain-conversion: e.g. Stft: time -> time-freq
  80. if self.stft is not None:
  81. input_stft, feats_lens = self._compute_stft(input, input_lengths)
  82. else:
  83. input_stft = ComplexTensor(input[..., 0], input[..., 1])
  84. feats_lens = input_lengths
  85. # 2. [Option] Speech enhancement
  86. if self.frontend is not None:
  87. assert isinstance(input_stft, ComplexTensor), type(input_stft)
  88. # input_stft: (Batch, Length, [Channel], Freq)
  89. input_stft, _, mask = self.frontend(input_stft, feats_lens)
  90. # 3. [Multi channel case]: Select a channel
  91. if input_stft.dim() == 4:
  92. # h: (B, T, C, F) -> h: (B, T, F)
  93. if self.training:
  94. if self.use_channel is not None:
  95. input_stft = input_stft[:, :, self.use_channel, :]
  96. else:
  97. # Select 1ch randomly
  98. ch = np.random.randint(input_stft.size(2))
  99. input_stft = input_stft[:, :, ch, :]
  100. else:
  101. # Use the first channel
  102. input_stft = input_stft[:, :, 0, :]
  103. # 4. STFT -> Power spectrum
  104. # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
  105. input_power = input_stft.real ** 2 + input_stft.imag ** 2
  106. # 5. Feature transform e.g. Stft -> Log-Mel-Fbank
  107. # input_power: (Batch, [Channel,] Length, Freq)
  108. # -> input_feats: (Batch, Length, Dim)
  109. input_feats, _ = self.logmel(input_power, feats_lens)
  110. return input_feats, feats_lens
  111. def _compute_stft(
  112. self, input: torch.Tensor, input_lengths: torch.Tensor
  113. ) -> torch.Tensor:
  114. input_stft, feats_lens = self.stft(input, input_lengths)
  115. assert input_stft.dim() >= 4, input_stft.shape
  116. # "2" refers to the real/imag parts of Complex
  117. assert input_stft.shape[-1] == 2, input_stft.shape
  118. # Change torch.Tensor to ComplexTensor
  119. # input_stft: (..., F, 2) -> (..., F)
  120. input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
  121. return input_stft, feats_lens
  122. class MultiChannelFrontend(nn.Module):
  123. """Conventional frontend structure for ASR.
  124. Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
  125. """
  126. def __init__(
  127. self,
  128. fs: int = 16000,
  129. n_fft: int = 512,
  130. win_length: int = None,
  131. hop_length: int = None,
  132. frame_length: int = None,
  133. frame_shift: int = None,
  134. window: Optional[str] = "hann",
  135. center: bool = True,
  136. normalized: bool = False,
  137. onesided: bool = True,
  138. n_mels: int = 80,
  139. fmin: int = None,
  140. fmax: int = None,
  141. htk: bool = False,
  142. frontend_conf: Optional[dict] = None,
  143. apply_stft: bool = True,
  144. use_channel: int = None,
  145. lfr_m: int = 1,
  146. lfr_n: int = 1,
  147. cmvn_file: str = None,
  148. mc: bool = True
  149. ):
  150. super().__init__()
  151. # Deepcopy (In general, dict shouldn't be used as default arg)
  152. frontend_conf = copy.deepcopy(frontend_conf)
  153. if win_length is None and hop_length is None:
  154. self.win_length = frame_length * 16
  155. self.hop_length = frame_shift * 16
  156. elif frame_length is None and frame_shift is None:
  157. self.win_length = self.win_length
  158. self.hop_length = self.hop_length
  159. else:
  160. logging.error(
  161. "Only one of (win_length, hop_length) and (frame_length, frame_shift)"
  162. "can be set."
  163. )
  164. exit(1)
  165. if apply_stft:
  166. self.stft = Stft(
  167. n_fft=n_fft,
  168. win_length=self.win_length,
  169. hop_length=self.hop_length,
  170. center=center,
  171. window=window,
  172. normalized=normalized,
  173. onesided=onesided,
  174. )
  175. else:
  176. self.stft = None
  177. self.apply_stft = apply_stft
  178. if frontend_conf is not None:
  179. self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
  180. else:
  181. self.frontend = None
  182. self.logmel = LogMel(
  183. fs=fs,
  184. n_fft=n_fft,
  185. n_mels=n_mels,
  186. fmin=fmin,
  187. fmax=fmax,
  188. htk=htk,
  189. )
  190. self.n_mels = n_mels
  191. self.use_channel = use_channel
  192. self.mc = mc
  193. if not self.mc:
  194. if self.use_channel is not None:
  195. logging.info("use the channel %d" % (self.use_channel))
  196. else:
  197. logging.info("random select channel")
  198. self.cmvn_file = cmvn_file
  199. if self.cmvn_file is not None:
  200. mean, std = self._load_cmvn(self.cmvn_file)
  201. self.register_buffer("mean", torch.from_numpy(mean))
  202. self.register_buffer("std", torch.from_numpy(std))
  203. self.frontend_type = "multichannelfrontend"
  204. def output_size(self) -> int:
  205. return self.n_mels
  206. def forward(
  207. self, input: torch.Tensor, input_lengths: torch.Tensor
  208. ) -> Tuple[torch.Tensor, torch.Tensor]:
  209. # 1. Domain-conversion: e.g. Stft: time -> time-freq
  210. #import pdb;pdb.set_trace()
  211. if self.stft is not None:
  212. input_stft, feats_lens = self._compute_stft(input, input_lengths)
  213. else:
  214. input_stft = ComplexTensor(input[..., 0], input[..., 1])
  215. feats_lens = input_lengths
  216. # 2. [Option] Speech enhancement
  217. if self.frontend is not None:
  218. assert isinstance(input_stft, ComplexTensor), type(input_stft)
  219. # input_stft: (Batch, Length, [Channel], Freq)
  220. input_stft, _, mask = self.frontend(input_stft, feats_lens)
  221. # 3. [Multi channel case]: Select a channel(sa_asr)
  222. if input_stft.dim() == 4 and not self.mc:
  223. # h: (B, T, C, F) -> h: (B, T, F)
  224. if self.training:
  225. if self.use_channel is not None:
  226. input_stft = input_stft[:, :, self.use_channel, :]
  227. else:
  228. # Select 1ch randomly
  229. ch = np.random.randint(input_stft.size(2))
  230. input_stft = input_stft[:, :, ch, :]
  231. else:
  232. # Use the first channel
  233. input_stft = input_stft[:, :, 0, :]
  234. # 4. STFT -> Power spectrum
  235. # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
  236. input_power = input_stft.real ** 2 + input_stft.imag ** 2
  237. # 5. Feature transform e.g. Stft -> Log-Mel-Fbank
  238. # input_power: (Batch, [Channel,] Length, Freq)
  239. # -> input_feats: (Batch, Length, Dim)
  240. input_feats, _ = self.logmel(input_power, feats_lens)
  241. if self.mc:
  242. # MFCCA
  243. if input_feats.dim() ==4:
  244. bt = input_feats.size(0)
  245. channel_size = input_feats.size(2)
  246. input_feats = input_feats.transpose(1,2).reshape(bt*channel_size,-1,80).contiguous()
  247. feats_lens = feats_lens.repeat(1,channel_size).squeeze()
  248. else:
  249. channel_size = 1
  250. return input_feats, feats_lens, channel_size
  251. else:
  252. # 6. Apply CMVN
  253. if self.cmvn_file is not None:
  254. if feats_lens is None:
  255. feats_lens = input_feats.new_full([input_feats.size(0)], input_feats.size(1))
  256. self.mean = self.mean.to(input_feats.device, input_feats.dtype)
  257. self.std = self.std.to(input_feats.device, input_feats.dtype)
  258. mask = make_pad_mask(feats_lens, input_feats, 1)
  259. if input_feats.requires_grad:
  260. input_feats = input_feats + self.mean
  261. else:
  262. input_feats += self.mean
  263. if input_feats.requires_grad:
  264. input_feats = input_feats.masked_fill(mask, 0.0)
  265. else:
  266. input_feats.masked_fill_(mask, 0.0)
  267. input_feats *= self.std
  268. return input_feats, feats_lens
  269. def _compute_stft(
  270. self, input: torch.Tensor, input_lengths: torch.Tensor
  271. ) -> torch.Tensor:
  272. input_stft, feats_lens = self.stft(input, input_lengths)
  273. assert input_stft.dim() >= 4, input_stft.shape
  274. # "2" refers to the real/imag parts of Complex
  275. assert input_stft.shape[-1] == 2, input_stft.shape
  276. # Change torch.Tensor to ComplexTensor
  277. # input_stft: (..., F, 2) -> (..., F)
  278. input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
  279. return input_stft, feats_lens
  280. def _load_cmvn(self, cmvn_file):
  281. with open(cmvn_file, 'r', encoding='utf-8') as f:
  282. lines = f.readlines()
  283. means_list = []
  284. vars_list = []
  285. for i in range(len(lines)):
  286. line_item = lines[i].split()
  287. if line_item[0] == '<AddShift>':
  288. line_item = lines[i + 1].split()
  289. if line_item[0] == '<LearnRateCoef>':
  290. add_shift_line = line_item[3:(len(line_item) - 1)]
  291. means_list = list(add_shift_line)
  292. continue
  293. elif line_item[0] == '<Rescale>':
  294. line_item = lines[i + 1].split()
  295. if line_item[0] == '<LearnRateCoef>':
  296. rescale_line = line_item[3:(len(line_item) - 1)]
  297. vars_list = list(rescale_line)
  298. continue
  299. means = np.array(means_list).astype(np.float)
  300. vars = np.array(vars_list).astype(np.float)
  301. return means, vars