e2e_asr_mfcca.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. from contextlib import contextmanager
  2. from distutils.version import LooseVersion
  3. from typing import Dict
  4. from typing import List
  5. from typing import Optional
  6. from typing import Tuple
  7. from typing import Union
  8. import logging
  9. import torch
  10. from funasr.metrics import ErrorCalculator
  11. from funasr.metrics.compute_acc import th_accuracy
  12. from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
  13. from funasr.losses.label_smoothing_loss import (
  14. LabelSmoothingLoss, # noqa: H301
  15. )
  16. from funasr.models.ctc import CTC
  17. from funasr.models.decoder.abs_decoder import AbsDecoder
  18. from funasr.models.encoder.abs_encoder import AbsEncoder
  19. from funasr.frontends.abs_frontend import AbsFrontend
  20. from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
  21. from funasr.models.specaug.abs_specaug import AbsSpecAug
  22. from funasr.layers.abs_normalize import AbsNormalize
  23. from funasr.train_utils.device_funcs import force_gatherable
  24. from funasr.models.base_model import FunASRModel
  25. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  26. from torch.cuda.amp import autocast
  27. else:
  28. # Nothing to do if torch<1.6.0
  29. @contextmanager
  30. def autocast(enabled=True):
  31. yield
  32. import pdb
  33. import random
  34. import math
  35. class MFCCA(FunASRModel):
  36. """
  37. Author: Audio, Speech and Language Processing Group (ASLP@NPU), Northwestern Polytechnical University
  38. MFCCA:Multi-Frame Cross-Channel attention for multi-speaker ASR in Multi-party meeting scenario
  39. https://arxiv.org/abs/2210.05265
  40. """
  41. def __init__(
  42. self,
  43. vocab_size: int,
  44. token_list: Union[Tuple[str, ...], List[str]],
  45. frontend: Optional[AbsFrontend],
  46. specaug: Optional[AbsSpecAug],
  47. normalize: Optional[AbsNormalize],
  48. encoder: AbsEncoder,
  49. decoder: AbsDecoder,
  50. ctc: CTC,
  51. rnnt_decoder: None = None,
  52. ctc_weight: float = 0.5,
  53. ignore_id: int = -1,
  54. lsm_weight: float = 0.0,
  55. mask_ratio: float = 0.0,
  56. length_normalized_loss: bool = False,
  57. report_cer: bool = True,
  58. report_wer: bool = True,
  59. sym_space: str = "<space>",
  60. sym_blank: str = "<blank>",
  61. preencoder: Optional[AbsPreEncoder] = None,
  62. ):
  63. assert 0.0 <= ctc_weight <= 1.0, ctc_weight
  64. assert rnnt_decoder is None, "Not implemented"
  65. super().__init__()
  66. # note that eos is the same as sos (equivalent ID)
  67. self.sos = vocab_size - 1
  68. self.eos = vocab_size - 1
  69. self.vocab_size = vocab_size
  70. self.ignore_id = ignore_id
  71. self.ctc_weight = ctc_weight
  72. self.token_list = token_list.copy()
  73. self.mask_ratio = mask_ratio
  74. self.frontend = frontend
  75. self.specaug = specaug
  76. self.normalize = normalize
  77. self.preencoder = preencoder
  78. self.encoder = encoder
  79. # we set self.decoder = None in the CTC mode since
  80. # self.decoder parameters were never used and PyTorch complained
  81. # and threw an Exception in the multi-GPU experiment.
  82. # thanks Jeff Farris for pointing out the issue.
  83. if ctc_weight == 1.0:
  84. self.decoder = None
  85. else:
  86. self.decoder = decoder
  87. if ctc_weight == 0.0:
  88. self.ctc = None
  89. else:
  90. self.ctc = ctc
  91. self.rnnt_decoder = rnnt_decoder
  92. self.criterion_att = LabelSmoothingLoss(
  93. size=vocab_size,
  94. padding_idx=ignore_id,
  95. smoothing=lsm_weight,
  96. normalize_length=length_normalized_loss,
  97. )
  98. if report_cer or report_wer:
  99. self.error_calculator = ErrorCalculator(
  100. token_list, sym_space, sym_blank, report_cer, report_wer
  101. )
  102. else:
  103. self.error_calculator = None
  104. def forward(
  105. self,
  106. speech: torch.Tensor,
  107. speech_lengths: torch.Tensor,
  108. text: torch.Tensor,
  109. text_lengths: torch.Tensor,
  110. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  111. """Frontend + Encoder + Decoder + Calc loss
  112. Args:
  113. speech: (Batch, Length, ...)
  114. speech_lengths: (Batch, )
  115. text: (Batch, Length)
  116. text_lengths: (Batch,)
  117. """
  118. assert text_lengths.dim() == 1, text_lengths.shape
  119. # Check that batch_size is unified
  120. assert (
  121. speech.shape[0]
  122. == speech_lengths.shape[0]
  123. == text.shape[0]
  124. == text_lengths.shape[0]
  125. ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
  126. # pdb.set_trace()
  127. if (speech.dim() == 3 and speech.size(2) == 8 and self.mask_ratio != 0):
  128. rate_num = random.random()
  129. # rate_num = 0.1
  130. if (rate_num <= self.mask_ratio):
  131. retain_channel = math.ceil(random.random() * 8)
  132. if (retain_channel > 1):
  133. speech = speech[:, :, torch.randperm(8)[0:retain_channel].sort().values]
  134. else:
  135. speech = speech[:, :, torch.randperm(8)[0]]
  136. # pdb.set_trace()
  137. batch_size = speech.shape[0]
  138. # for data-parallel
  139. text = text[:, : text_lengths.max()]
  140. # 1. Encoder
  141. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  142. # 2a. Attention-decoder branch
  143. if self.ctc_weight == 1.0:
  144. loss_att, acc_att, cer_att, wer_att = None, None, None, None
  145. else:
  146. loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
  147. encoder_out, encoder_out_lens, text, text_lengths
  148. )
  149. # 2b. CTC branch
  150. if self.ctc_weight == 0.0:
  151. loss_ctc, cer_ctc = None, None
  152. else:
  153. loss_ctc, cer_ctc = self._calc_ctc_loss(
  154. encoder_out, encoder_out_lens, text, text_lengths
  155. )
  156. # 2c. RNN-T branch
  157. if self.rnnt_decoder is not None:
  158. _ = self._calc_rnnt_loss(encoder_out, encoder_out_lens, text, text_lengths)
  159. if self.ctc_weight == 0.0:
  160. loss = loss_att
  161. elif self.ctc_weight == 1.0:
  162. loss = loss_ctc
  163. else:
  164. loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
  165. stats = dict(
  166. loss=loss.detach(),
  167. loss_att=loss_att.detach() if loss_att is not None else None,
  168. loss_ctc=loss_ctc.detach() if loss_ctc is not None else None,
  169. acc=acc_att,
  170. cer=cer_att,
  171. wer=wer_att,
  172. cer_ctc=cer_ctc,
  173. )
  174. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  175. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  176. return loss, stats, weight
  177. def collect_feats(
  178. self,
  179. speech: torch.Tensor,
  180. speech_lengths: torch.Tensor,
  181. text: torch.Tensor,
  182. text_lengths: torch.Tensor,
  183. ) -> Dict[str, torch.Tensor]:
  184. feats, feats_lengths, channel_size = self._extract_feats(speech, speech_lengths)
  185. return {"feats": feats, "feats_lengths": feats_lengths}
  186. def encode(
  187. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  188. ) -> Tuple[torch.Tensor, torch.Tensor]:
  189. """Frontend + Encoder. Note that this method is used by asr_inference.py
  190. Args:
  191. speech: (Batch, Length, ...)
  192. speech_lengths: (Batch, )
  193. """
  194. with autocast(False):
  195. # 1. Extract feats
  196. feats, feats_lengths, channel_size = self._extract_feats(speech, speech_lengths)
  197. # 2. Data augmentation
  198. if self.specaug is not None and self.training:
  199. feats, feats_lengths = self.specaug(feats, feats_lengths)
  200. # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  201. if self.normalize is not None:
  202. feats, feats_lengths = self.normalize(feats, feats_lengths)
  203. # Pre-encoder, e.g. used for raw input data
  204. if self.preencoder is not None:
  205. feats, feats_lengths = self.preencoder(feats, feats_lengths)
  206. # pdb.set_trace()
  207. encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, channel_size)
  208. assert encoder_out.size(0) == speech.size(0), (
  209. encoder_out.size(),
  210. speech.size(0),
  211. )
  212. if (encoder_out.dim() == 4):
  213. assert encoder_out.size(2) <= encoder_out_lens.max(), (
  214. encoder_out.size(),
  215. encoder_out_lens.max(),
  216. )
  217. else:
  218. assert encoder_out.size(1) <= encoder_out_lens.max(), (
  219. encoder_out.size(),
  220. encoder_out_lens.max(),
  221. )
  222. return encoder_out, encoder_out_lens
  223. def _extract_feats(
  224. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  225. ) -> Tuple[torch.Tensor, torch.Tensor]:
  226. assert speech_lengths.dim() == 1, speech_lengths.shape
  227. # for data-parallel
  228. speech = speech[:, : speech_lengths.max()]
  229. if self.frontend is not None:
  230. # Frontend
  231. # e.g. STFT and Feature extract
  232. # data_loader may send time-domain signal in this case
  233. # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
  234. feats, feats_lengths, channel_size = self.frontend(speech, speech_lengths)
  235. else:
  236. # No frontend and no feature extract
  237. feats, feats_lengths = speech, speech_lengths
  238. channel_size = 1
  239. return feats, feats_lengths, channel_size
  240. def _calc_att_loss(
  241. self,
  242. encoder_out: torch.Tensor,
  243. encoder_out_lens: torch.Tensor,
  244. ys_pad: torch.Tensor,
  245. ys_pad_lens: torch.Tensor,
  246. ):
  247. ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  248. ys_in_lens = ys_pad_lens + 1
  249. # 1. Forward decoder
  250. decoder_out, _ = self.decoder(
  251. encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
  252. )
  253. # 2. Compute attention loss
  254. loss_att = self.criterion_att(decoder_out, ys_out_pad)
  255. acc_att = th_accuracy(
  256. decoder_out.view(-1, self.vocab_size),
  257. ys_out_pad,
  258. ignore_label=self.ignore_id,
  259. )
  260. # Compute cer/wer using attention-decoder
  261. if self.training or self.error_calculator is None:
  262. cer_att, wer_att = None, None
  263. else:
  264. ys_hat = decoder_out.argmax(dim=-1)
  265. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  266. return loss_att, acc_att, cer_att, wer_att
  267. def _calc_ctc_loss(
  268. self,
  269. encoder_out: torch.Tensor,
  270. encoder_out_lens: torch.Tensor,
  271. ys_pad: torch.Tensor,
  272. ys_pad_lens: torch.Tensor,
  273. ):
  274. # Calc CTC loss
  275. if (encoder_out.dim() == 4):
  276. encoder_out = encoder_out.mean(1)
  277. loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
  278. # Calc CER using CTC
  279. cer_ctc = None
  280. if not self.training and self.error_calculator is not None:
  281. ys_hat = self.ctc.argmax(encoder_out).data
  282. cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
  283. return loss_ctc, cer_ctc
  284. def _calc_rnnt_loss(
  285. self,
  286. encoder_out: torch.Tensor,
  287. encoder_out_lens: torch.Tensor,
  288. ys_pad: torch.Tensor,
  289. ys_pad_lens: torch.Tensor,
  290. ):
  291. raise NotImplementedError