default.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. import copy
  2. from typing import Optional
  3. from typing import Tuple
  4. from typing import Union
  5. import logging
  6. import humanfriendly
  7. import numpy as np
  8. import torch
  9. try:
  10. from torch_complex.tensor import ComplexTensor
  11. except:
  12. print("Please install torch_complex firstly")
  13. from funasr.layers.log_mel import LogMel
  14. from funasr.layers.stft import Stft
  15. from funasr.models.frontend.abs_frontend import AbsFrontend
  16. from funasr.models.frontend.frontends_utils.frontend import Frontend
  17. from funasr.utils.get_default_kwargs import get_default_kwargs
  18. from funasr.modules.nets_utils import make_pad_mask
  19. class DefaultFrontend(AbsFrontend):
  20. """Conventional frontend structure for ASR.
  21. Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
  22. """
  23. def __init__(
  24. self,
  25. fs: Union[int, str] = 16000,
  26. n_fft: int = 512,
  27. win_length: int = None,
  28. hop_length: int = 128,
  29. window: Optional[str] = "hann",
  30. center: bool = True,
  31. normalized: bool = False,
  32. onesided: bool = True,
  33. n_mels: int = 80,
  34. fmin: int = None,
  35. fmax: int = None,
  36. htk: bool = False,
  37. frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
  38. apply_stft: bool = True,
  39. use_channel: int = None,
  40. ):
  41. super().__init__()
  42. if isinstance(fs, str):
  43. fs = humanfriendly.parse_size(fs)
  44. # Deepcopy (In general, dict shouldn't be used as default arg)
  45. frontend_conf = copy.deepcopy(frontend_conf)
  46. self.hop_length = hop_length
  47. if apply_stft:
  48. self.stft = Stft(
  49. n_fft=n_fft,
  50. win_length=win_length,
  51. hop_length=hop_length,
  52. center=center,
  53. window=window,
  54. normalized=normalized,
  55. onesided=onesided,
  56. )
  57. else:
  58. self.stft = None
  59. self.apply_stft = apply_stft
  60. if frontend_conf is not None:
  61. self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
  62. else:
  63. self.frontend = None
  64. self.logmel = LogMel(
  65. fs=fs,
  66. n_fft=n_fft,
  67. n_mels=n_mels,
  68. fmin=fmin,
  69. fmax=fmax,
  70. htk=htk,
  71. )
  72. self.n_mels = n_mels
  73. self.use_channel = use_channel
  74. self.frontend_type = "default"
  75. def output_size(self) -> int:
  76. return self.n_mels
  77. def forward(
  78. self, input: torch.Tensor, input_lengths: torch.Tensor
  79. ) -> Tuple[torch.Tensor, torch.Tensor]:
  80. # 1. Domain-conversion: e.g. Stft: time -> time-freq
  81. if self.stft is not None:
  82. input_stft, feats_lens = self._compute_stft(input, input_lengths)
  83. else:
  84. input_stft = ComplexTensor(input[..., 0], input[..., 1])
  85. feats_lens = input_lengths
  86. # 2. [Option] Speech enhancement
  87. if self.frontend is not None:
  88. assert isinstance(input_stft, ComplexTensor), type(input_stft)
  89. # input_stft: (Batch, Length, [Channel], Freq)
  90. input_stft, _, mask = self.frontend(input_stft, feats_lens)
  91. # 3. [Multi channel case]: Select a channel
  92. if input_stft.dim() == 4:
  93. # h: (B, T, C, F) -> h: (B, T, F)
  94. if self.training:
  95. if self.use_channel is not None:
  96. input_stft = input_stft[:, :, self.use_channel, :]
  97. else:
  98. # Select 1ch randomly
  99. ch = np.random.randint(input_stft.size(2))
  100. input_stft = input_stft[:, :, ch, :]
  101. else:
  102. # Use the first channel
  103. input_stft = input_stft[:, :, 0, :]
  104. # 4. STFT -> Power spectrum
  105. # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
  106. input_power = input_stft.real ** 2 + input_stft.imag ** 2
  107. # 5. Feature transform e.g. Stft -> Log-Mel-Fbank
  108. # input_power: (Batch, [Channel,] Length, Freq)
  109. # -> input_feats: (Batch, Length, Dim)
  110. input_feats, _ = self.logmel(input_power, feats_lens)
  111. return input_feats, feats_lens
  112. def _compute_stft(
  113. self, input: torch.Tensor, input_lengths: torch.Tensor
  114. ) -> torch.Tensor:
  115. input_stft, feats_lens = self.stft(input, input_lengths)
  116. assert input_stft.dim() >= 4, input_stft.shape
  117. # "2" refers to the real/imag parts of Complex
  118. assert input_stft.shape[-1] == 2, input_stft.shape
  119. # Change torch.Tensor to ComplexTensor
  120. # input_stft: (..., F, 2) -> (..., F)
  121. input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
  122. return input_stft, feats_lens
  123. class MultiChannelFrontend(AbsFrontend):
  124. """Conventional frontend structure for ASR.
  125. Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
  126. """
  127. def __init__(
  128. self,
  129. fs: Union[int, str] = 16000,
  130. n_fft: int = 512,
  131. win_length: int = None,
  132. hop_length: int = None,
  133. frame_length: int = None,
  134. frame_shift: int = None,
  135. window: Optional[str] = "hann",
  136. center: bool = True,
  137. normalized: bool = False,
  138. onesided: bool = True,
  139. n_mels: int = 80,
  140. fmin: int = None,
  141. fmax: int = None,
  142. htk: bool = False,
  143. frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
  144. apply_stft: bool = True,
  145. use_channel: int = None,
  146. lfr_m: int = 1,
  147. lfr_n: int = 1,
  148. cmvn_file: str = None,
  149. mc: bool = True
  150. ):
  151. super().__init__()
  152. if isinstance(fs, str):
  153. fs = humanfriendly.parse_size(fs)
  154. # Deepcopy (In general, dict shouldn't be used as default arg)
  155. frontend_conf = copy.deepcopy(frontend_conf)
  156. if win_length is None and hop_length is None:
  157. self.win_length = frame_length * 16
  158. self.hop_length = frame_shift * 16
  159. elif frame_length is None and frame_shift is None:
  160. self.win_length = self.win_length
  161. self.hop_length = self.hop_length
  162. else:
  163. logging.error(
  164. "Only one of (win_length, hop_length) and (frame_length, frame_shift)"
  165. "can be set."
  166. )
  167. exit(1)
  168. if apply_stft:
  169. self.stft = Stft(
  170. n_fft=n_fft,
  171. win_length=self.win_length,
  172. hop_length=self.hop_length,
  173. center=center,
  174. window=window,
  175. normalized=normalized,
  176. onesided=onesided,
  177. )
  178. else:
  179. self.stft = None
  180. self.apply_stft = apply_stft
  181. if frontend_conf is not None:
  182. self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
  183. else:
  184. self.frontend = None
  185. self.logmel = LogMel(
  186. fs=fs,
  187. n_fft=n_fft,
  188. n_mels=n_mels,
  189. fmin=fmin,
  190. fmax=fmax,
  191. htk=htk,
  192. )
  193. self.n_mels = n_mels
  194. self.use_channel = use_channel
  195. self.mc = mc
  196. if not self.mc:
  197. if self.use_channel is not None:
  198. logging.info("use the channel %d" % (self.use_channel))
  199. else:
  200. logging.info("random select channel")
  201. self.cmvn_file = cmvn_file
  202. if self.cmvn_file is not None:
  203. mean, std = self._load_cmvn(self.cmvn_file)
  204. self.register_buffer("mean", torch.from_numpy(mean))
  205. self.register_buffer("std", torch.from_numpy(std))
  206. self.frontend_type = "multichannelfrontend"
  207. def output_size(self) -> int:
  208. return self.n_mels
  209. def forward(
  210. self, input: torch.Tensor, input_lengths: torch.Tensor
  211. ) -> Tuple[torch.Tensor, torch.Tensor]:
  212. # 1. Domain-conversion: e.g. Stft: time -> time-freq
  213. #import pdb;pdb.set_trace()
  214. if self.stft is not None:
  215. input_stft, feats_lens = self._compute_stft(input, input_lengths)
  216. else:
  217. input_stft = ComplexTensor(input[..., 0], input[..., 1])
  218. feats_lens = input_lengths
  219. # 2. [Option] Speech enhancement
  220. if self.frontend is not None:
  221. assert isinstance(input_stft, ComplexTensor), type(input_stft)
  222. # input_stft: (Batch, Length, [Channel], Freq)
  223. input_stft, _, mask = self.frontend(input_stft, feats_lens)
  224. # 3. [Multi channel case]: Select a channel(sa_asr)
  225. if input_stft.dim() == 4 and not self.mc:
  226. # h: (B, T, C, F) -> h: (B, T, F)
  227. if self.training:
  228. if self.use_channel is not None:
  229. input_stft = input_stft[:, :, self.use_channel, :]
  230. else:
  231. # Select 1ch randomly
  232. ch = np.random.randint(input_stft.size(2))
  233. input_stft = input_stft[:, :, ch, :]
  234. else:
  235. # Use the first channel
  236. input_stft = input_stft[:, :, 0, :]
  237. # 4. STFT -> Power spectrum
  238. # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
  239. input_power = input_stft.real ** 2 + input_stft.imag ** 2
  240. # 5. Feature transform e.g. Stft -> Log-Mel-Fbank
  241. # input_power: (Batch, [Channel,] Length, Freq)
  242. # -> input_feats: (Batch, Length, Dim)
  243. input_feats, _ = self.logmel(input_power, feats_lens)
  244. if self.mc:
  245. # MFCCA
  246. if input_feats.dim() ==4:
  247. bt = input_feats.size(0)
  248. channel_size = input_feats.size(2)
  249. input_feats = input_feats.transpose(1,2).reshape(bt*channel_size,-1,80).contiguous()
  250. feats_lens = feats_lens.repeat(1,channel_size).squeeze()
  251. else:
  252. channel_size = 1
  253. return input_feats, feats_lens, channel_size
  254. else:
  255. # 6. Apply CMVN
  256. if self.cmvn_file is not None:
  257. if feats_lens is None:
  258. feats_lens = input_feats.new_full([input_feats.size(0)], input_feats.size(1))
  259. self.mean = self.mean.to(input_feats.device, input_feats.dtype)
  260. self.std = self.std.to(input_feats.device, input_feats.dtype)
  261. mask = make_pad_mask(feats_lens, input_feats, 1)
  262. if input_feats.requires_grad:
  263. input_feats = input_feats + self.mean
  264. else:
  265. input_feats += self.mean
  266. if input_feats.requires_grad:
  267. input_feats = input_feats.masked_fill(mask, 0.0)
  268. else:
  269. input_feats.masked_fill_(mask, 0.0)
  270. input_feats *= self.std
  271. return input_feats, feats_lens
  272. def _compute_stft(
  273. self, input: torch.Tensor, input_lengths: torch.Tensor
  274. ) -> torch.Tensor:
  275. input_stft, feats_lens = self.stft(input, input_lengths)
  276. assert input_stft.dim() >= 4, input_stft.shape
  277. # "2" refers to the real/imag parts of Complex
  278. assert input_stft.shape[-1] == 2, input_stft.shape
  279. # Change torch.Tensor to ComplexTensor
  280. # input_stft: (..., F, 2) -> (..., F)
  281. input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
  282. return input_stft, feats_lens
  283. def _load_cmvn(self, cmvn_file):
  284. with open(cmvn_file, 'r', encoding='utf-8') as f:
  285. lines = f.readlines()
  286. means_list = []
  287. vars_list = []
  288. for i in range(len(lines)):
  289. line_item = lines[i].split()
  290. if line_item[0] == '<AddShift>':
  291. line_item = lines[i + 1].split()
  292. if line_item[0] == '<LearnRateCoef>':
  293. add_shift_line = line_item[3:(len(line_item) - 1)]
  294. means_list = list(add_shift_line)
  295. continue
  296. elif line_item[0] == '<Rescale>':
  297. line_item = lines[i + 1].split()
  298. if line_item[0] == '<LearnRateCoef>':
  299. rescale_line = line_item[3:(len(line_item) - 1)]
  300. vars_list = list(rescale_line)
  301. continue
  302. means = np.array(means_list).astype(np.float)
  303. vars = np.array(vars_list).astype(np.float)
  304. return means, vars