e2e_asr_bat.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496
  1. """Boundary Aware Transducer (BAT) model."""
  2. import logging
  3. from contextlib import contextmanager
  4. from typing import Dict, List, Optional, Tuple, Union
  5. import torch
  6. from packaging.version import parse as V
  7. from funasr.losses.label_smoothing_loss import (
  8. LabelSmoothingLoss, # noqa: H301
  9. )
  10. from funasr.models.frontend.abs_frontend import AbsFrontend
  11. from funasr.models.specaug.abs_specaug import AbsSpecAug
  12. from funasr.models.decoder.rnnt_decoder import RNNTDecoder
  13. from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
  14. from funasr.models.encoder.abs_encoder import AbsEncoder
  15. from funasr.models.joint_net.joint_network import JointNetwork
  16. from funasr.modules.nets_utils import get_transducer_task_io
  17. from funasr.modules.nets_utils import th_accuracy
  18. from funasr.modules.nets_utils import make_pad_mask
  19. from funasr.modules.add_sos_eos import add_sos_eos
  20. from funasr.layers.abs_normalize import AbsNormalize
  21. from funasr.torch_utils.device_funcs import force_gatherable
  22. from funasr.models.base_model import FunASRModel
  23. if V(torch.__version__) >= V("1.6.0"):
  24. from torch.cuda.amp import autocast
  25. else:
  26. @contextmanager
  27. def autocast(enabled=True):
  28. yield
  29. class BATModel(FunASRModel):
  30. """BATModel module definition.
  31. Args:
  32. vocab_size: Size of complete vocabulary (w/ EOS and blank included).
  33. token_list: List of token
  34. frontend: Frontend module.
  35. specaug: SpecAugment module.
  36. normalize: Normalization module.
  37. encoder: Encoder module.
  38. decoder: Decoder module.
  39. joint_network: Joint Network module.
  40. transducer_weight: Weight of the Transducer loss.
  41. fastemit_lambda: FastEmit lambda value.
  42. auxiliary_ctc_weight: Weight of auxiliary CTC loss.
  43. auxiliary_ctc_dropout_rate: Dropout rate for auxiliary CTC loss inputs.
  44. auxiliary_lm_loss_weight: Weight of auxiliary LM loss.
  45. auxiliary_lm_loss_smoothing: Smoothing rate for LM loss' label smoothing.
  46. ignore_id: Initial padding ID.
  47. sym_space: Space symbol.
  48. sym_blank: Blank Symbol
  49. report_cer: Whether to report Character Error Rate during validation.
  50. report_wer: Whether to report Word Error Rate during validation.
  51. extract_feats_in_collect_stats: Whether to use extract_feats stats collection.
  52. """
  53. def __init__(
  54. self,
  55. vocab_size: int,
  56. token_list: Union[Tuple[str, ...], List[str]],
  57. frontend: Optional[AbsFrontend],
  58. specaug: Optional[AbsSpecAug],
  59. normalize: Optional[AbsNormalize],
  60. encoder: AbsEncoder,
  61. decoder: RNNTDecoder,
  62. joint_network: JointNetwork,
  63. att_decoder: Optional[AbsAttDecoder] = None,
  64. predictor = None,
  65. transducer_weight: float = 1.0,
  66. predictor_weight: float = 1.0,
  67. cif_weight: float = 1.0,
  68. fastemit_lambda: float = 0.0,
  69. auxiliary_ctc_weight: float = 0.0,
  70. auxiliary_ctc_dropout_rate: float = 0.0,
  71. auxiliary_lm_loss_weight: float = 0.0,
  72. auxiliary_lm_loss_smoothing: float = 0.0,
  73. ignore_id: int = -1,
  74. sym_space: str = "<space>",
  75. sym_blank: str = "<blank>",
  76. report_cer: bool = True,
  77. report_wer: bool = True,
  78. extract_feats_in_collect_stats: bool = True,
  79. lsm_weight: float = 0.0,
  80. length_normalized_loss: bool = False,
  81. r_d: int = 5,
  82. r_u: int = 5,
  83. ) -> None:
  84. """Construct an BATModel object."""
  85. super().__init__()
  86. # The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos)
  87. self.blank_id = 0
  88. self.vocab_size = vocab_size
  89. self.ignore_id = ignore_id
  90. self.token_list = token_list.copy()
  91. self.sym_space = sym_space
  92. self.sym_blank = sym_blank
  93. self.frontend = frontend
  94. self.specaug = specaug
  95. self.normalize = normalize
  96. self.encoder = encoder
  97. self.decoder = decoder
  98. self.joint_network = joint_network
  99. self.criterion_transducer = None
  100. self.error_calculator = None
  101. self.use_auxiliary_ctc = auxiliary_ctc_weight > 0
  102. self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0
  103. if self.use_auxiliary_ctc:
  104. self.ctc_lin = torch.nn.Linear(encoder.output_size(), vocab_size)
  105. self.ctc_dropout_rate = auxiliary_ctc_dropout_rate
  106. if self.use_auxiliary_lm_loss:
  107. self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size)
  108. self.lm_loss_smoothing = auxiliary_lm_loss_smoothing
  109. self.transducer_weight = transducer_weight
  110. self.fastemit_lambda = fastemit_lambda
  111. self.auxiliary_ctc_weight = auxiliary_ctc_weight
  112. self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight
  113. self.report_cer = report_cer
  114. self.report_wer = report_wer
  115. self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
  116. self.criterion_pre = torch.nn.L1Loss()
  117. self.predictor_weight = predictor_weight
  118. self.predictor = predictor
  119. self.cif_weight = cif_weight
  120. if self.cif_weight > 0:
  121. self.cif_output_layer = torch.nn.Linear(encoder.output_size(), vocab_size)
  122. self.criterion_cif = LabelSmoothingLoss(
  123. size=vocab_size,
  124. padding_idx=ignore_id,
  125. smoothing=lsm_weight,
  126. normalize_length=length_normalized_loss,
  127. )
  128. self.r_d = r_d
  129. self.r_u = r_u
  130. def forward(
  131. self,
  132. speech: torch.Tensor,
  133. speech_lengths: torch.Tensor,
  134. text: torch.Tensor,
  135. text_lengths: torch.Tensor,
  136. **kwargs,
  137. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  138. """Forward architecture and compute loss(es).
  139. Args:
  140. speech: Speech sequences. (B, S)
  141. speech_lengths: Speech sequences lengths. (B,)
  142. text: Label ID sequences. (B, L)
  143. text_lengths: Label ID sequences lengths. (B,)
  144. kwargs: Contains "utts_id".
  145. Return:
  146. loss: Main loss value.
  147. stats: Task statistics.
  148. weight: Task weights.
  149. """
  150. assert text_lengths.dim() == 1, text_lengths.shape
  151. assert (
  152. speech.shape[0]
  153. == speech_lengths.shape[0]
  154. == text.shape[0]
  155. == text_lengths.shape[0]
  156. ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
  157. batch_size = speech.shape[0]
  158. text = text[:, : text_lengths.max()]
  159. # 1. Encoder
  160. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  161. if hasattr(self.encoder, 'overlap_chunk_cls') and self.encoder.overlap_chunk_cls is not None:
  162. encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens,
  163. chunk_outs=None)
  164. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(encoder_out.device)
  165. # 2. Transducer-related I/O preparation
  166. decoder_in, target, t_len, u_len = get_transducer_task_io(
  167. text,
  168. encoder_out_lens,
  169. ignore_id=self.ignore_id,
  170. )
  171. # 3. Decoder
  172. self.decoder.set_device(encoder_out.device)
  173. decoder_out = self.decoder(decoder_in, u_len)
  174. pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, text, encoder_out_mask, ignore_id=self.ignore_id)
  175. loss_pre = self.criterion_pre(text_lengths.type_as(pre_token_length), pre_token_length)
  176. if self.cif_weight > 0.0:
  177. cif_predict = self.cif_output_layer(pre_acoustic_embeds)
  178. loss_cif = self.criterion_cif(cif_predict, text)
  179. else:
  180. loss_cif = 0.0
  181. # 5. Losses
  182. boundary = torch.zeros((encoder_out.size(0), 4), dtype=torch.int64, device=encoder_out.device)
  183. boundary[:, 2] = u_len.long().detach()
  184. boundary[:, 3] = t_len.long().detach()
  185. pre_peak_index = torch.floor(pre_peak_index).long()
  186. s_begin = pre_peak_index - self.r_d
  187. T = encoder_out.size(1)
  188. B = encoder_out.size(0)
  189. U = decoder_out.size(1)
  190. mask = torch.arange(0, T, device=encoder_out.device).reshape(1, T).expand(B, T)
  191. mask = mask <= boundary[:, 3].reshape(B, 1) - 1
  192. s_begin_padding = boundary[:, 2].reshape(B, 1) - (self.r_u+self.r_d) + 1
  193. # handle the cases where `len(symbols) < s_range`
  194. s_begin_padding = torch.clamp(s_begin_padding, min=0)
  195. s_begin = torch.where(mask, s_begin, s_begin_padding)
  196. mask2 = s_begin < boundary[:, 2].reshape(B, 1) - (self.r_u+self.r_d) + 1
  197. s_begin = torch.where(mask2, s_begin, boundary[:, 2].reshape(B, 1) - (self.r_u+self.r_d) + 1)
  198. s_begin = torch.clamp(s_begin, min=0)
  199. ranges = s_begin.reshape((B, T, 1)).expand((B, T, min(self.r_u+self.r_d, min(u_len)))) + torch.arange(min(self.r_d+self.r_u, min(u_len)), device=encoder_out.device)
  200. import fast_rnnt
  201. am_pruned, lm_pruned = fast_rnnt.do_rnnt_pruning(
  202. am=self.joint_network.lin_enc(encoder_out),
  203. lm=self.joint_network.lin_dec(decoder_out),
  204. ranges=ranges,
  205. )
  206. logits = self.joint_network(am_pruned, lm_pruned, project_input=False)
  207. with torch.cuda.amp.autocast(enabled=False):
  208. loss_trans = fast_rnnt.rnnt_loss_pruned(
  209. logits=logits.float(),
  210. symbols=target.long(),
  211. ranges=ranges,
  212. termination_symbol=self.blank_id,
  213. boundary=boundary,
  214. reduction="sum",
  215. )
  216. cer_trans, wer_trans = None, None
  217. if not self.training and (self.report_cer or self.report_wer):
  218. if self.error_calculator is None:
  219. from funasr.modules.e2e_asr_common import ErrorCalculatorTransducer as ErrorCalculator
  220. self.error_calculator = ErrorCalculator(
  221. self.decoder,
  222. self.joint_network,
  223. self.token_list,
  224. self.sym_space,
  225. self.sym_blank,
  226. report_cer=self.report_cer,
  227. report_wer=self.report_wer,
  228. )
  229. cer_trans, wer_trans = self.error_calculator(encoder_out, target, t_len)
  230. loss_ctc, loss_lm = 0.0, 0.0
  231. if self.use_auxiliary_ctc:
  232. loss_ctc = self._calc_ctc_loss(
  233. encoder_out,
  234. target,
  235. t_len,
  236. u_len,
  237. )
  238. if self.use_auxiliary_lm_loss:
  239. loss_lm = self._calc_lm_loss(decoder_out, target)
  240. loss = (
  241. self.transducer_weight * loss_trans
  242. + self.auxiliary_ctc_weight * loss_ctc
  243. + self.auxiliary_lm_loss_weight * loss_lm
  244. + self.predictor_weight * loss_pre
  245. + self.cif_weight * loss_cif
  246. )
  247. stats = dict(
  248. loss=loss.detach(),
  249. loss_transducer=loss_trans.detach(),
  250. loss_pre=loss_pre.detach(),
  251. loss_cif=loss_cif.detach() if loss_cif > 0.0 else None,
  252. aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None,
  253. aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None,
  254. cer_transducer=cer_trans,
  255. wer_transducer=wer_trans,
  256. )
  257. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  258. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  259. return loss, stats, weight
  260. def collect_feats(
  261. self,
  262. speech: torch.Tensor,
  263. speech_lengths: torch.Tensor,
  264. text: torch.Tensor,
  265. text_lengths: torch.Tensor,
  266. **kwargs,
  267. ) -> Dict[str, torch.Tensor]:
  268. """Collect features sequences and features lengths sequences.
  269. Args:
  270. speech: Speech sequences. (B, S)
  271. speech_lengths: Speech sequences lengths. (B,)
  272. text: Label ID sequences. (B, L)
  273. text_lengths: Label ID sequences lengths. (B,)
  274. kwargs: Contains "utts_id".
  275. Return:
  276. {}: "feats": Features sequences. (B, T, D_feats),
  277. "feats_lengths": Features sequences lengths. (B,)
  278. """
  279. if self.extract_feats_in_collect_stats:
  280. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  281. else:
  282. # Generate dummy stats if extract_feats_in_collect_stats is False
  283. logging.warning(
  284. "Generating dummy stats for feats and feats_lengths, "
  285. "because encoder_conf.extract_feats_in_collect_stats is "
  286. f"{self.extract_feats_in_collect_stats}"
  287. )
  288. feats, feats_lengths = speech, speech_lengths
  289. return {"feats": feats, "feats_lengths": feats_lengths}
  290. def encode(
  291. self,
  292. speech: torch.Tensor,
  293. speech_lengths: torch.Tensor,
  294. ) -> Tuple[torch.Tensor, torch.Tensor]:
  295. """Encoder speech sequences.
  296. Args:
  297. speech: Speech sequences. (B, S)
  298. speech_lengths: Speech sequences lengths. (B,)
  299. Return:
  300. encoder_out: Encoder outputs. (B, T, D_enc)
  301. encoder_out_lens: Encoder outputs lengths. (B,)
  302. """
  303. with autocast(False):
  304. # 1. Extract feats
  305. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  306. # 2. Data augmentation
  307. if self.specaug is not None and self.training:
  308. feats, feats_lengths = self.specaug(feats, feats_lengths)
  309. # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  310. if self.normalize is not None:
  311. feats, feats_lengths = self.normalize(feats, feats_lengths)
  312. # 4. Forward encoder
  313. encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
  314. assert encoder_out.size(0) == speech.size(0), (
  315. encoder_out.size(),
  316. speech.size(0),
  317. )
  318. assert encoder_out.size(1) <= encoder_out_lens.max(), (
  319. encoder_out.size(),
  320. encoder_out_lens.max(),
  321. )
  322. return encoder_out, encoder_out_lens
  323. def _extract_feats(
  324. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  325. ) -> Tuple[torch.Tensor, torch.Tensor]:
  326. """Extract features sequences and features sequences lengths.
  327. Args:
  328. speech: Speech sequences. (B, S)
  329. speech_lengths: Speech sequences lengths. (B,)
  330. Return:
  331. feats: Features sequences. (B, T, D_feats)
  332. feats_lengths: Features sequences lengths. (B,)
  333. """
  334. assert speech_lengths.dim() == 1, speech_lengths.shape
  335. # for data-parallel
  336. speech = speech[:, : speech_lengths.max()]
  337. if self.frontend is not None:
  338. feats, feats_lengths = self.frontend(speech, speech_lengths)
  339. else:
  340. feats, feats_lengths = speech, speech_lengths
  341. return feats, feats_lengths
  342. def _calc_ctc_loss(
  343. self,
  344. encoder_out: torch.Tensor,
  345. target: torch.Tensor,
  346. t_len: torch.Tensor,
  347. u_len: torch.Tensor,
  348. ) -> torch.Tensor:
  349. """Compute CTC loss.
  350. Args:
  351. encoder_out: Encoder output sequences. (B, T, D_enc)
  352. target: Target label ID sequences. (B, L)
  353. t_len: Encoder output sequences lengths. (B,)
  354. u_len: Target label ID sequences lengths. (B,)
  355. Return:
  356. loss_ctc: CTC loss value.
  357. """
  358. ctc_in = self.ctc_lin(
  359. torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate)
  360. )
  361. ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1)
  362. target_mask = target != 0
  363. ctc_target = target[target_mask].cpu()
  364. with torch.backends.cudnn.flags(deterministic=True):
  365. loss_ctc = torch.nn.functional.ctc_loss(
  366. ctc_in,
  367. ctc_target,
  368. t_len,
  369. u_len,
  370. zero_infinity=True,
  371. reduction="sum",
  372. )
  373. loss_ctc /= target.size(0)
  374. return loss_ctc
  375. def _calc_lm_loss(
  376. self,
  377. decoder_out: torch.Tensor,
  378. target: torch.Tensor,
  379. ) -> torch.Tensor:
  380. """Compute LM loss.
  381. Args:
  382. decoder_out: Decoder output sequences. (B, U, D_dec)
  383. target: Target label ID sequences. (B, L)
  384. Return:
  385. loss_lm: LM loss value.
  386. """
  387. lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size)
  388. lm_target = target.view(-1).type(torch.int64)
  389. with torch.no_grad():
  390. true_dist = lm_loss_in.clone()
  391. true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1))
  392. # Ignore blank ID (0)
  393. ignore = lm_target == 0
  394. lm_target = lm_target.masked_fill(ignore, 0)
  395. true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing))
  396. loss_lm = torch.nn.functional.kl_div(
  397. torch.log_softmax(lm_loss_in, dim=1),
  398. true_dist,
  399. reduction="none",
  400. )
  401. loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size(
  402. 0
  403. )
  404. return loss_lm