default.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. import copy
  2. from typing import Optional
  3. from typing import Tuple
  4. from typing import Union
  5. import logging
  6. import numpy as np
  7. import torch
  8. import torch.nn as nn
  9. try:
  10. from torch_complex.tensor import ComplexTensor
  11. except:
  12. print("Please install torch_complex firstly")
  13. from funasr.frontends.utils.log_mel import LogMel
  14. from funasr.frontends.utils.stft import Stft
  15. from funasr.frontends.utils.frontend import Frontend
  16. from funasr.models.transformer.utils.nets_utils import make_pad_mask
  17. from funasr.register import tables
  18. @tables.register("frontend_classes", "DefaultFrontend")
  19. class DefaultFrontend(nn.Module):
  20. """Conventional frontend structure for ASR.
  21. Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
  22. """
  23. def __init__(
  24. self,
  25. fs: int = 16000,
  26. n_fft: int = 512,
  27. win_length: int = None,
  28. hop_length: int = 128,
  29. window: Optional[str] = "hann",
  30. center: bool = True,
  31. normalized: bool = False,
  32. onesided: bool = True,
  33. n_mels: int = 80,
  34. fmin: int = None,
  35. fmax: int = None,
  36. htk: bool = False,
  37. frontend_conf: Optional[dict] = None,
  38. apply_stft: bool = True,
  39. use_channel: int = None,
  40. **kwargs,
  41. ):
  42. super().__init__()
  43. # Deepcopy (In general, dict shouldn't be used as default arg)
  44. frontend_conf = copy.deepcopy(frontend_conf)
  45. self.hop_length = hop_length
  46. self.fs = fs
  47. if apply_stft:
  48. self.stft = Stft(
  49. n_fft=n_fft,
  50. win_length=win_length,
  51. hop_length=hop_length,
  52. center=center,
  53. window=window,
  54. normalized=normalized,
  55. onesided=onesided,
  56. )
  57. else:
  58. self.stft = None
  59. self.apply_stft = apply_stft
  60. if frontend_conf is not None:
  61. self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
  62. else:
  63. self.frontend = None
  64. self.logmel = LogMel(
  65. fs=fs,
  66. n_fft=n_fft,
  67. n_mels=n_mels,
  68. fmin=fmin,
  69. fmax=fmax,
  70. htk=htk,
  71. )
  72. self.n_mels = n_mels
  73. self.use_channel = use_channel
  74. self.frontend_type = "default"
  75. def output_size(self) -> int:
  76. return self.n_mels
  77. def forward(
  78. self, input: torch.Tensor, input_lengths: Union[torch.Tensor, list]
  79. ) -> Tuple[torch.Tensor, torch.Tensor]:
  80. if isinstance(input_lengths, list):
  81. input_lengths = torch.tensor(input_lengths)
  82. if input.dtype == torch.float64:
  83. input = input.float()
  84. # 1. Domain-conversion: e.g. Stft: time -> time-freq
  85. if self.stft is not None:
  86. input_stft, feats_lens = self._compute_stft(input, input_lengths)
  87. else:
  88. input_stft = ComplexTensor(input[..., 0], input[..., 1])
  89. feats_lens = input_lengths
  90. # 2. [Option] Speech enhancement
  91. if self.frontend is not None:
  92. assert isinstance(input_stft, ComplexTensor), type(input_stft)
  93. # input_stft: (Batch, Length, [Channel], Freq)
  94. input_stft, _, mask = self.frontend(input_stft, feats_lens)
  95. # 3. [Multi channel case]: Select a channel
  96. if input_stft.dim() == 4:
  97. # h: (B, T, C, F) -> h: (B, T, F)
  98. if self.training:
  99. if self.use_channel is not None:
  100. input_stft = input_stft[:, :, self.use_channel, :]
  101. else:
  102. # Select 1ch randomly
  103. ch = np.random.randint(input_stft.size(2))
  104. input_stft = input_stft[:, :, ch, :]
  105. else:
  106. # Use the first channel
  107. input_stft = input_stft[:, :, 0, :]
  108. # 4. STFT -> Power spectrum
  109. # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
  110. input_power = input_stft.real ** 2 + input_stft.imag ** 2
  111. # 5. Feature transform e.g. Stft -> Log-Mel-Fbank
  112. # input_power: (Batch, [Channel,] Length, Freq)
  113. # -> input_feats: (Batch, Length, Dim)
  114. input_feats, _ = self.logmel(input_power, feats_lens)
  115. return input_feats, feats_lens
  116. def _compute_stft(
  117. self, input: torch.Tensor, input_lengths: torch.Tensor
  118. ) -> torch.Tensor:
  119. input_stft, feats_lens = self.stft(input, input_lengths)
  120. assert input_stft.dim() >= 4, input_stft.shape
  121. # "2" refers to the real/imag parts of Complex
  122. assert input_stft.shape[-1] == 2, input_stft.shape
  123. # Change torch.Tensor to ComplexTensor
  124. # input_stft: (..., F, 2) -> (..., F)
  125. input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
  126. return input_stft, feats_lens
  127. class MultiChannelFrontend(nn.Module):
  128. """Conventional frontend structure for ASR.
  129. Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
  130. """
  131. def __init__(
  132. self,
  133. fs: int = 16000,
  134. n_fft: int = 512,
  135. win_length: int = None,
  136. hop_length: int = None,
  137. frame_length: int = None,
  138. frame_shift: int = None,
  139. window: Optional[str] = "hann",
  140. center: bool = True,
  141. normalized: bool = False,
  142. onesided: bool = True,
  143. n_mels: int = 80,
  144. fmin: int = None,
  145. fmax: int = None,
  146. htk: bool = False,
  147. frontend_conf: Optional[dict] = None,
  148. apply_stft: bool = True,
  149. use_channel: int = None,
  150. lfr_m: int = 1,
  151. lfr_n: int = 1,
  152. cmvn_file: str = None,
  153. mc: bool = True
  154. ):
  155. super().__init__()
  156. # Deepcopy (In general, dict shouldn't be used as default arg)
  157. frontend_conf = copy.deepcopy(frontend_conf)
  158. if win_length is None and hop_length is None:
  159. self.win_length = frame_length * 16
  160. self.hop_length = frame_shift * 16
  161. elif frame_length is None and frame_shift is None:
  162. self.win_length = self.win_length
  163. self.hop_length = self.hop_length
  164. else:
  165. logging.error(
  166. "Only one of (win_length, hop_length) and (frame_length, frame_shift)"
  167. "can be set."
  168. )
  169. exit(1)
  170. if apply_stft:
  171. self.stft = Stft(
  172. n_fft=n_fft,
  173. win_length=self.win_length,
  174. hop_length=self.hop_length,
  175. center=center,
  176. window=window,
  177. normalized=normalized,
  178. onesided=onesided,
  179. )
  180. else:
  181. self.stft = None
  182. self.apply_stft = apply_stft
  183. if frontend_conf is not None:
  184. self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
  185. else:
  186. self.frontend = None
  187. self.logmel = LogMel(
  188. fs=fs,
  189. n_fft=n_fft,
  190. n_mels=n_mels,
  191. fmin=fmin,
  192. fmax=fmax,
  193. htk=htk,
  194. )
  195. self.n_mels = n_mels
  196. self.use_channel = use_channel
  197. self.mc = mc
  198. if not self.mc:
  199. if self.use_channel is not None:
  200. logging.info("use the channel %d" % (self.use_channel))
  201. else:
  202. logging.info("random select channel")
  203. self.cmvn_file = cmvn_file
  204. if self.cmvn_file is not None:
  205. mean, std = self._load_cmvn(self.cmvn_file)
  206. self.register_buffer("mean", torch.from_numpy(mean))
  207. self.register_buffer("std", torch.from_numpy(std))
  208. self.frontend_type = "multichannelfrontend"
  209. def output_size(self) -> int:
  210. return self.n_mels
  211. def forward(
  212. self, input: torch.Tensor, input_lengths: torch.Tensor
  213. ) -> Tuple[torch.Tensor, torch.Tensor]:
  214. # 1. Domain-conversion: e.g. Stft: time -> time-freq
  215. #import pdb;pdb.set_trace()
  216. if self.stft is not None:
  217. input_stft, feats_lens = self._compute_stft(input, input_lengths)
  218. else:
  219. input_stft = ComplexTensor(input[..., 0], input[..., 1])
  220. feats_lens = input_lengths
  221. # 2. [Option] Speech enhancement
  222. if self.frontend is not None:
  223. assert isinstance(input_stft, ComplexTensor), type(input_stft)
  224. # input_stft: (Batch, Length, [Channel], Freq)
  225. input_stft, _, mask = self.frontend(input_stft, feats_lens)
  226. # 3. [Multi channel case]: Select a channel(sa_asr)
  227. if input_stft.dim() == 4 and not self.mc:
  228. # h: (B, T, C, F) -> h: (B, T, F)
  229. if self.training:
  230. if self.use_channel is not None:
  231. input_stft = input_stft[:, :, self.use_channel, :]
  232. else:
  233. # Select 1ch randomly
  234. ch = np.random.randint(input_stft.size(2))
  235. input_stft = input_stft[:, :, ch, :]
  236. else:
  237. # Use the first channel
  238. input_stft = input_stft[:, :, 0, :]
  239. # 4. STFT -> Power spectrum
  240. # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
  241. input_power = input_stft.real ** 2 + input_stft.imag ** 2
  242. # 5. Feature transform e.g. Stft -> Log-Mel-Fbank
  243. # input_power: (Batch, [Channel,] Length, Freq)
  244. # -> input_feats: (Batch, Length, Dim)
  245. input_feats, _ = self.logmel(input_power, feats_lens)
  246. if self.mc:
  247. # MFCCA
  248. if input_feats.dim() ==4:
  249. bt = input_feats.size(0)
  250. channel_size = input_feats.size(2)
  251. input_feats = input_feats.transpose(1,2).reshape(bt*channel_size,-1,80).contiguous()
  252. feats_lens = feats_lens.repeat(1,channel_size).squeeze()
  253. else:
  254. channel_size = 1
  255. return input_feats, feats_lens, channel_size
  256. else:
  257. # 6. Apply CMVN
  258. if self.cmvn_file is not None:
  259. if feats_lens is None:
  260. feats_lens = input_feats.new_full([input_feats.size(0)], input_feats.size(1))
  261. self.mean = self.mean.to(input_feats.device, input_feats.dtype)
  262. self.std = self.std.to(input_feats.device, input_feats.dtype)
  263. mask = make_pad_mask(feats_lens, input_feats, 1)
  264. if input_feats.requires_grad:
  265. input_feats = input_feats + self.mean
  266. else:
  267. input_feats += self.mean
  268. if input_feats.requires_grad:
  269. input_feats = input_feats.masked_fill(mask, 0.0)
  270. else:
  271. input_feats.masked_fill_(mask, 0.0)
  272. input_feats *= self.std
  273. return input_feats, feats_lens
  274. def _compute_stft(
  275. self, input: torch.Tensor, input_lengths: torch.Tensor
  276. ) -> torch.Tensor:
  277. input_stft, feats_lens = self.stft(input, input_lengths)
  278. assert input_stft.dim() >= 4, input_stft.shape
  279. # "2" refers to the real/imag parts of Complex
  280. assert input_stft.shape[-1] == 2, input_stft.shape
  281. # Change torch.Tensor to ComplexTensor
  282. # input_stft: (..., F, 2) -> (..., F)
  283. input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
  284. return input_stft, feats_lens
  285. def _load_cmvn(self, cmvn_file):
  286. with open(cmvn_file, 'r', encoding='utf-8') as f:
  287. lines = f.readlines()
  288. means_list = []
  289. vars_list = []
  290. for i in range(len(lines)):
  291. line_item = lines[i].split()
  292. if line_item[0] == '<AddShift>':
  293. line_item = lines[i + 1].split()
  294. if line_item[0] == '<LearnRateCoef>':
  295. add_shift_line = line_item[3:(len(line_item) - 1)]
  296. means_list = list(add_shift_line)
  297. continue
  298. elif line_item[0] == '<Rescale>':
  299. line_item = lines[i + 1].split()
  300. if line_item[0] == '<LearnRateCoef>':
  301. rescale_line = line_item[3:(len(line_item) - 1)]
  302. vars_list = list(rescale_line)
  303. continue
  304. means = np.array(means_list).astype(np.float)
  305. vars = np.array(vars_list).astype(np.float)
  306. return means, vars