model.py 43 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import time
  6. import torch
  7. import logging
  8. from torch.cuda.amp import autocast
  9. from typing import Union, Dict, List, Tuple, Optional
  10. from funasr.register import tables
  11. from funasr.models.ctc.ctc import CTC
  12. from funasr.utils import postprocess_utils
  13. from funasr.metrics.compute_acc import th_accuracy
  14. from funasr.utils.datadir_writer import DatadirWriter
  15. from funasr.models.paraformer.cif_predictor import mae_loss
  16. from funasr.train_utils.device_funcs import force_gatherable
  17. from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
  18. from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
  19. from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
  20. from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
  21. from funasr.models.scama.utils import sequence_mask
  22. @tables.register("model_classes", "UniASR")
  23. class UniASR(torch.nn.Module):
  24. """
  25. Author: Speech Lab of DAMO Academy, Alibaba Group
  26. """
  27. def __init__(
  28. self,
  29. specaug: str = None,
  30. specaug_conf: dict = None,
  31. normalize: str = None,
  32. normalize_conf: dict = None,
  33. encoder: str = None,
  34. encoder_conf: dict = None,
  35. encoder2: str = None,
  36. encoder2_conf: dict = None,
  37. decoder: str = None,
  38. decoder_conf: dict = None,
  39. decoder2: str = None,
  40. decoder2_conf: dict = None,
  41. predictor: str = None,
  42. predictor_conf: dict = None,
  43. predictor_bias: int = 0,
  44. predictor_weight: float = 0.0,
  45. predictor2: str = None,
  46. predictor2_conf: dict = None,
  47. predictor2_bias: int = 0,
  48. predictor2_weight: float = 0.0,
  49. ctc: str = None,
  50. ctc_conf: dict = None,
  51. ctc_weight: float = 0.5,
  52. ctc2: str = None,
  53. ctc2_conf: dict = None,
  54. ctc2_weight: float = 0.5,
  55. decoder_attention_chunk_type: str = 'chunk',
  56. decoder_attention_chunk_type2: str = 'chunk',
  57. stride_conv=None,
  58. stride_conv_conf: dict = None,
  59. loss_weight_model1: float = 0.5,
  60. input_size: int = 80,
  61. vocab_size: int = -1,
  62. ignore_id: int = -1,
  63. blank_id: int = 0,
  64. sos: int = 1,
  65. eos: int = 2,
  66. lsm_weight: float = 0.0,
  67. length_normalized_loss: bool = False,
  68. share_embedding: bool = False,
  69. **kwargs,
  70. ):
  71. super().__init__()
  72. if specaug is not None:
  73. specaug_class = tables.specaug_classes.get(specaug)
  74. specaug = specaug_class(**specaug_conf)
  75. if normalize is not None:
  76. normalize_class = tables.normalize_classes.get(normalize)
  77. normalize = normalize_class(**normalize_conf)
  78. encoder_class = tables.encoder_classes.get(encoder)
  79. encoder = encoder_class(input_size=input_size, **encoder_conf)
  80. encoder_output_size = encoder.output_size()
  81. decoder_class = tables.decoder_classes.get(decoder)
  82. decoder = decoder_class(
  83. vocab_size=vocab_size,
  84. encoder_output_size=encoder_output_size,
  85. **decoder_conf,
  86. )
  87. predictor_class = tables.predictor_classes.get(predictor)
  88. predictor = predictor_class(**predictor_conf)
  89. from funasr.models.transformer.utils.subsampling import Conv1dSubsampling
  90. stride_conv = Conv1dSubsampling(**stride_conv_conf, idim=input_size + encoder_output_size,
  91. odim=input_size + encoder_output_size)
  92. stride_conv_output_size = stride_conv.output_size()
  93. encoder_class = tables.encoder_classes.get(encoder2)
  94. encoder2 = encoder_class(input_size=stride_conv_output_size, **encoder2_conf)
  95. encoder2_output_size = encoder2.output_size()
  96. decoder_class = tables.decoder_classes.get(decoder2)
  97. decoder2 = decoder_class(
  98. vocab_size=vocab_size,
  99. encoder_output_size=encoder2_output_size,
  100. **decoder2_conf,
  101. )
  102. predictor_class = tables.predictor_classes.get(predictor2)
  103. predictor2 = predictor_class(**predictor2_conf)
  104. self.blank_id = blank_id
  105. self.sos = sos
  106. self.eos = eos
  107. self.vocab_size = vocab_size
  108. self.ignore_id = ignore_id
  109. self.ctc_weight = ctc_weight
  110. self.ctc2_weight = ctc2_weight
  111. self.specaug = specaug
  112. self.normalize = normalize
  113. self.encoder = encoder
  114. self.error_calculator = None
  115. self.decoder = decoder
  116. self.ctc = None
  117. self.ctc2 = None
  118. self.criterion_att = LabelSmoothingLoss(
  119. size=vocab_size,
  120. padding_idx=ignore_id,
  121. smoothing=lsm_weight,
  122. normalize_length=length_normalized_loss,
  123. )
  124. self.predictor = predictor
  125. self.predictor_weight = predictor_weight
  126. self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
  127. self.encoder1_encoder2_joint_training = kwargs.get("encoder1_encoder2_joint_training", True)
  128. if self.encoder.overlap_chunk_cls is not None:
  129. from funasr.models.scama.chunk_utilis import build_scama_mask_for_cross_attention_decoder
  130. self.build_scama_mask_for_cross_attention_decoder_fn = build_scama_mask_for_cross_attention_decoder
  131. self.decoder_attention_chunk_type = decoder_attention_chunk_type
  132. self.encoder2 = encoder2
  133. self.decoder2 = decoder2
  134. self.ctc2_weight = ctc2_weight
  135. self.predictor2 = predictor2
  136. self.predictor2_weight = predictor2_weight
  137. self.decoder_attention_chunk_type2 = decoder_attention_chunk_type2
  138. self.stride_conv = stride_conv
  139. self.loss_weight_model1 = loss_weight_model1
  140. if self.encoder2.overlap_chunk_cls is not None:
  141. from funasr.models.scama.chunk_utilis import build_scama_mask_for_cross_attention_decoder
  142. self.build_scama_mask_for_cross_attention_decoder_fn2 = build_scama_mask_for_cross_attention_decoder
  143. self.decoder_attention_chunk_type2 = decoder_attention_chunk_type2
  144. self.length_normalized_loss = length_normalized_loss
  145. self.enable_maas_finetune = kwargs.get("enable_maas_finetune", False)
  146. self.freeze_encoder2 = kwargs.get("freeze_encoder2", False)
  147. self.beam_search = None
  148. def forward(
  149. self,
  150. speech: torch.Tensor,
  151. speech_lengths: torch.Tensor,
  152. text: torch.Tensor,
  153. text_lengths: torch.Tensor,
  154. **kwargs,
  155. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  156. """Frontend + Encoder + Decoder + Calc loss
  157. Args:
  158. speech: (Batch, Length, ...)
  159. speech_lengths: (Batch, )
  160. text: (Batch, Length)
  161. text_lengths: (Batch,)
  162. """
  163. decoding_ind = kwargs.get("decoding_ind", None)
  164. if len(text_lengths.size()) > 1:
  165. text_lengths = text_lengths[:, 0]
  166. if len(speech_lengths.size()) > 1:
  167. speech_lengths = speech_lengths[:, 0]
  168. batch_size = speech.shape[0]
  169. ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
  170. # 1. Encoder
  171. if self.enable_maas_finetune:
  172. with torch.no_grad():
  173. speech_raw, encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
  174. else:
  175. speech_raw, encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
  176. loss_att, acc_att, cer_att, wer_att = None, None, None, None
  177. loss_ctc, cer_ctc = None, None
  178. stats = dict()
  179. loss_pre = None
  180. loss, loss1, loss2 = 0.0, 0.0, 0.0
  181. if self.loss_weight_model1 > 0.0:
  182. ## model1
  183. # 1. CTC branch
  184. if self.enable_maas_finetune:
  185. with torch.no_grad():
  186. loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss(
  187. encoder_out, encoder_out_lens, text, text_lengths
  188. )
  189. loss = loss_att + loss_pre * self.predictor_weight
  190. # Collect Attn branch stats
  191. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  192. stats["acc"] = acc_att
  193. stats["cer"] = cer_att
  194. stats["wer"] = wer_att
  195. stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
  196. else:
  197. loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss(
  198. encoder_out, encoder_out_lens, text, text_lengths
  199. )
  200. loss = loss_att + loss_pre * self.predictor_weight
  201. # Collect Attn branch stats
  202. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  203. stats["acc"] = acc_att
  204. stats["cer"] = cer_att
  205. stats["wer"] = wer_att
  206. stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
  207. loss1 = loss
  208. if self.loss_weight_model1 < 1.0:
  209. ## model2
  210. # encoder2
  211. if self.freeze_encoder2:
  212. with torch.no_grad():
  213. encoder_out, encoder_out_lens = self.encode2(encoder_out, encoder_out_lens, speech_raw, speech_lengths, ind=ind)
  214. else:
  215. encoder_out, encoder_out_lens = self.encode2(encoder_out, encoder_out_lens, speech_raw, speech_lengths, ind=ind)
  216. intermediate_outs = None
  217. if isinstance(encoder_out, tuple):
  218. intermediate_outs = encoder_out[1]
  219. encoder_out = encoder_out[0]
  220. loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss2(
  221. encoder_out, encoder_out_lens, text, text_lengths
  222. )
  223. loss = loss_att + loss_pre * self.predictor2_weight
  224. # Collect Attn branch stats
  225. stats["loss_att2"] = loss_att.detach() if loss_att is not None else None
  226. stats["acc2"] = acc_att
  227. stats["cer2"] = cer_att
  228. stats["wer2"] = wer_att
  229. stats["loss_pre2"] = loss_pre.detach().cpu() if loss_pre is not None else None
  230. loss2 = loss
  231. loss = loss1 * self.loss_weight_model1 + loss2 * (1 - self.loss_weight_model1)
  232. stats["loss1"] = torch.clone(loss1.detach())
  233. stats["loss2"] = torch.clone(loss2.detach())
  234. stats["loss"] = torch.clone(loss.detach())
  235. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  236. if self.length_normalized_loss:
  237. batch_size = int((text_lengths + 1).sum())
  238. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  239. return loss, stats, weight
  240. def collect_feats(
  241. self,
  242. speech: torch.Tensor,
  243. speech_lengths: torch.Tensor,
  244. text: torch.Tensor,
  245. text_lengths: torch.Tensor,
  246. ) -> Dict[str, torch.Tensor]:
  247. if self.extract_feats_in_collect_stats:
  248. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  249. else:
  250. # Generate dummy stats if extract_feats_in_collect_stats is False
  251. logging.warning(
  252. "Generating dummy stats for feats and feats_lengths, "
  253. "because encoder_conf.extract_feats_in_collect_stats is "
  254. f"{self.extract_feats_in_collect_stats}"
  255. )
  256. feats, feats_lengths = speech, speech_lengths
  257. return {"feats": feats, "feats_lengths": feats_lengths}
  258. def encode(
  259. self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
  260. ):
  261. """Frontend + Encoder. Note that this method is used by asr_inference.py
  262. Args:
  263. speech: (Batch, Length, ...)
  264. speech_lengths: (Batch, )
  265. """
  266. ind = kwargs.get("ind", 0)
  267. with autocast(False):
  268. # Data augmentation
  269. if self.specaug is not None and self.training:
  270. speech, speech_lengths = self.specaug(speech, speech_lengths)
  271. # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  272. if self.normalize is not None:
  273. speech, speech_lengths = self.normalize(speech, speech_lengths)
  274. speech_raw = speech.clone().to(speech.device)
  275. # 4. Forward encoder
  276. encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths, ind=ind)
  277. if isinstance(encoder_out, tuple):
  278. encoder_out = encoder_out[0]
  279. return speech_raw, encoder_out, encoder_out_lens
  280. def encode2(
  281. self,
  282. encoder_out: torch.Tensor,
  283. encoder_out_lens: torch.Tensor,
  284. speech: torch.Tensor,
  285. speech_lengths: torch.Tensor,
  286. **kwargs,
  287. ):
  288. """Frontend + Encoder. Note that this method is used by asr_inference.py
  289. Args:
  290. speech: (Batch, Length, ...)
  291. speech_lengths: (Batch, )
  292. """
  293. ind = kwargs.get("ind", 0)
  294. encoder_out_rm, encoder_out_lens_rm = self.encoder.overlap_chunk_cls.remove_chunk(
  295. encoder_out,
  296. encoder_out_lens,
  297. chunk_outs=None,
  298. )
  299. # residual_input
  300. encoder_out = torch.cat((speech, encoder_out_rm), dim=-1)
  301. encoder_out_lens = encoder_out_lens_rm
  302. if self.stride_conv is not None:
  303. speech, speech_lengths = self.stride_conv(encoder_out, encoder_out_lens)
  304. if not self.encoder1_encoder2_joint_training:
  305. speech = speech.detach()
  306. speech_lengths = speech_lengths.detach()
  307. # 4. Forward encoder
  308. # feats: (Batch, Length, Dim)
  309. # -> encoder_out: (Batch, Length2, Dim2)
  310. encoder_out, encoder_out_lens, _ = self.encoder2(speech, speech_lengths, ind=ind)
  311. if isinstance(encoder_out, tuple):
  312. encoder_out = encoder_out[0]
  313. return encoder_out, encoder_out_lens
  314. def nll(
  315. self,
  316. encoder_out: torch.Tensor,
  317. encoder_out_lens: torch.Tensor,
  318. ys_pad: torch.Tensor,
  319. ys_pad_lens: torch.Tensor,
  320. ) -> torch.Tensor:
  321. """Compute negative log likelihood(nll) from transformer-decoder
  322. Normally, this function is called in batchify_nll.
  323. Args:
  324. encoder_out: (Batch, Length, Dim)
  325. encoder_out_lens: (Batch,)
  326. ys_pad: (Batch, Length)
  327. ys_pad_lens: (Batch,)
  328. """
  329. ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  330. ys_in_lens = ys_pad_lens + 1
  331. # 1. Forward decoder
  332. decoder_out, _ = self.decoder(
  333. encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
  334. ) # [batch, seqlen, dim]
  335. batch_size = decoder_out.size(0)
  336. decoder_num_class = decoder_out.size(2)
  337. # nll: negative log-likelihood
  338. nll = torch.nn.functional.cross_entropy(
  339. decoder_out.view(-1, decoder_num_class),
  340. ys_out_pad.view(-1),
  341. ignore_index=self.ignore_id,
  342. reduction="none",
  343. )
  344. nll = nll.view(batch_size, -1)
  345. nll = nll.sum(dim=1)
  346. assert nll.size(0) == batch_size
  347. return nll
  348. def batchify_nll(
  349. self,
  350. encoder_out: torch.Tensor,
  351. encoder_out_lens: torch.Tensor,
  352. ys_pad: torch.Tensor,
  353. ys_pad_lens: torch.Tensor,
  354. batch_size: int = 100,
  355. ):
  356. """Compute negative log likelihood(nll) from transformer-decoder
  357. To avoid OOM, this fuction seperate the input into batches.
  358. Then call nll for each batch and combine and return results.
  359. Args:
  360. encoder_out: (Batch, Length, Dim)
  361. encoder_out_lens: (Batch,)
  362. ys_pad: (Batch, Length)
  363. ys_pad_lens: (Batch,)
  364. batch_size: int, samples each batch contain when computing nll,
  365. you may change this to avoid OOM or increase
  366. GPU memory usage
  367. """
  368. total_num = encoder_out.size(0)
  369. if total_num <= batch_size:
  370. nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
  371. else:
  372. nll = []
  373. start_idx = 0
  374. while True:
  375. end_idx = min(start_idx + batch_size, total_num)
  376. batch_encoder_out = encoder_out[start_idx:end_idx, :, :]
  377. batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx]
  378. batch_ys_pad = ys_pad[start_idx:end_idx, :]
  379. batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx]
  380. batch_nll = self.nll(
  381. batch_encoder_out,
  382. batch_encoder_out_lens,
  383. batch_ys_pad,
  384. batch_ys_pad_lens,
  385. )
  386. nll.append(batch_nll)
  387. start_idx = end_idx
  388. if start_idx == total_num:
  389. break
  390. nll = torch.cat(nll)
  391. assert nll.size(0) == total_num
  392. return nll
  393. def _calc_att_loss(
  394. self,
  395. encoder_out: torch.Tensor,
  396. encoder_out_lens: torch.Tensor,
  397. ys_pad: torch.Tensor,
  398. ys_pad_lens: torch.Tensor,
  399. ):
  400. ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  401. ys_in_lens = ys_pad_lens + 1
  402. # 1. Forward decoder
  403. decoder_out, _ = self.decoder(
  404. encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
  405. )
  406. # 2. Compute attention loss
  407. loss_att = self.criterion_att(decoder_out, ys_out_pad)
  408. acc_att = th_accuracy(
  409. decoder_out.view(-1, self.vocab_size),
  410. ys_out_pad,
  411. ignore_label=self.ignore_id,
  412. )
  413. # Compute cer/wer using attention-decoder
  414. if self.training or self.error_calculator is None:
  415. cer_att, wer_att = None, None
  416. else:
  417. ys_hat = decoder_out.argmax(dim=-1)
  418. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  419. return loss_att, acc_att, cer_att, wer_att
  420. def _calc_att_predictor_loss(
  421. self,
  422. encoder_out: torch.Tensor,
  423. encoder_out_lens: torch.Tensor,
  424. ys_pad: torch.Tensor,
  425. ys_pad_lens: torch.Tensor,
  426. ):
  427. ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  428. ys_in_lens = ys_pad_lens + 1
  429. encoder_out_mask = sequence_mask(encoder_out_lens, maxlen=encoder_out.size(1), dtype=encoder_out.dtype,
  430. device=encoder_out.device)[:, None, :]
  431. mask_chunk_predictor = None
  432. if self.encoder.overlap_chunk_cls is not None:
  433. mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
  434. device=encoder_out.device,
  435. batch_size=encoder_out.size(
  436. 0))
  437. mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
  438. batch_size=encoder_out.size(0))
  439. encoder_out = encoder_out * mask_shfit_chunk
  440. pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor(encoder_out,
  441. ys_out_pad,
  442. encoder_out_mask,
  443. ignore_id=self.ignore_id,
  444. mask_chunk_predictor=mask_chunk_predictor,
  445. target_label_length=ys_in_lens,
  446. )
  447. predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
  448. encoder_out_lens)
  449. scama_mask = None
  450. if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
  451. encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
  452. attention_chunk_center_bias = 0
  453. attention_chunk_size = encoder_chunk_size
  454. decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
  455. mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls.get_mask_shift_att_chunk_decoder(None,
  456. device=encoder_out.device,
  457. batch_size=encoder_out.size(
  458. 0))
  459. scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
  460. predictor_alignments=predictor_alignments,
  461. encoder_sequence_length=encoder_out_lens,
  462. chunk_size=1,
  463. encoder_chunk_size=encoder_chunk_size,
  464. attention_chunk_center_bias=attention_chunk_center_bias,
  465. attention_chunk_size=attention_chunk_size,
  466. attention_chunk_type=self.decoder_attention_chunk_type,
  467. step=None,
  468. predictor_mask_chunk_hopping=mask_chunk_predictor,
  469. decoder_att_look_back_factor=decoder_att_look_back_factor,
  470. mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
  471. target_length=ys_in_lens,
  472. is_training=self.training,
  473. )
  474. elif self.encoder.overlap_chunk_cls is not None:
  475. encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens,
  476. chunk_outs=None)
  477. # try:
  478. # 1. Forward decoder
  479. decoder_out, _ = self.decoder(
  480. encoder_out,
  481. encoder_out_lens,
  482. ys_in_pad,
  483. ys_in_lens,
  484. chunk_mask=scama_mask,
  485. pre_acoustic_embeds=pre_acoustic_embeds,
  486. )
  487. # 2. Compute attention loss
  488. loss_att = self.criterion_att(decoder_out, ys_out_pad)
  489. acc_att = th_accuracy(
  490. decoder_out.view(-1, self.vocab_size),
  491. ys_out_pad,
  492. ignore_label=self.ignore_id,
  493. )
  494. # predictor loss
  495. loss_pre = self.criterion_pre(ys_in_lens.type_as(pre_token_length), pre_token_length)
  496. # Compute cer/wer using attention-decoder
  497. if self.training or self.error_calculator is None:
  498. cer_att, wer_att = None, None
  499. else:
  500. ys_hat = decoder_out.argmax(dim=-1)
  501. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  502. return loss_att, acc_att, cer_att, wer_att, loss_pre
  503. def _calc_att_predictor_loss2(
  504. self,
  505. encoder_out: torch.Tensor,
  506. encoder_out_lens: torch.Tensor,
  507. ys_pad: torch.Tensor,
  508. ys_pad_lens: torch.Tensor,
  509. ):
  510. ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  511. ys_in_lens = ys_pad_lens + 1
  512. encoder_out_mask = sequence_mask(encoder_out_lens, maxlen=encoder_out.size(1), dtype=encoder_out.dtype,
  513. device=encoder_out.device)[:, None, :]
  514. mask_chunk_predictor = None
  515. if self.encoder2.overlap_chunk_cls is not None:
  516. mask_chunk_predictor = self.encoder2.overlap_chunk_cls.get_mask_chunk_predictor(None,
  517. device=encoder_out.device,
  518. batch_size=encoder_out.size(
  519. 0))
  520. mask_shfit_chunk = self.encoder2.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
  521. batch_size=encoder_out.size(0))
  522. encoder_out = encoder_out * mask_shfit_chunk
  523. pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor2(encoder_out,
  524. ys_out_pad,
  525. encoder_out_mask,
  526. ignore_id=self.ignore_id,
  527. mask_chunk_predictor=mask_chunk_predictor,
  528. target_label_length=ys_in_lens,
  529. )
  530. predictor_alignments, predictor_alignments_len = self.predictor2.gen_frame_alignments(pre_alphas,
  531. encoder_out_lens)
  532. scama_mask = None
  533. if self.encoder2.overlap_chunk_cls is not None and self.decoder_attention_chunk_type2 == 'chunk':
  534. encoder_chunk_size = self.encoder2.overlap_chunk_cls.chunk_size_pad_shift_cur
  535. attention_chunk_center_bias = 0
  536. attention_chunk_size = encoder_chunk_size
  537. decoder_att_look_back_factor = self.encoder2.overlap_chunk_cls.decoder_att_look_back_factor_cur
  538. mask_shift_att_chunk_decoder = self.encoder2.overlap_chunk_cls.get_mask_shift_att_chunk_decoder(None,
  539. device=encoder_out.device,
  540. batch_size=encoder_out.size(
  541. 0))
  542. scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn2(
  543. predictor_alignments=predictor_alignments,
  544. encoder_sequence_length=encoder_out_lens,
  545. chunk_size=1,
  546. encoder_chunk_size=encoder_chunk_size,
  547. attention_chunk_center_bias=attention_chunk_center_bias,
  548. attention_chunk_size=attention_chunk_size,
  549. attention_chunk_type=self.decoder_attention_chunk_type2,
  550. step=None,
  551. predictor_mask_chunk_hopping=mask_chunk_predictor,
  552. decoder_att_look_back_factor=decoder_att_look_back_factor,
  553. mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
  554. target_length=ys_in_lens,
  555. is_training=self.training,
  556. )
  557. elif self.encoder2.overlap_chunk_cls is not None:
  558. encoder_out, encoder_out_lens = self.encoder2.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens,
  559. chunk_outs=None)
  560. # try:
  561. # 1. Forward decoder
  562. decoder_out, _ = self.decoder2(
  563. encoder_out,
  564. encoder_out_lens,
  565. ys_in_pad,
  566. ys_in_lens,
  567. chunk_mask=scama_mask,
  568. pre_acoustic_embeds=pre_acoustic_embeds,
  569. )
  570. # 2. Compute attention loss
  571. loss_att = self.criterion_att(decoder_out, ys_out_pad)
  572. acc_att = th_accuracy(
  573. decoder_out.view(-1, self.vocab_size),
  574. ys_out_pad,
  575. ignore_label=self.ignore_id,
  576. )
  577. # predictor loss
  578. loss_pre = self.criterion_pre(ys_in_lens.type_as(pre_token_length), pre_token_length)
  579. # Compute cer/wer using attention-decoder
  580. if self.training or self.error_calculator is None:
  581. cer_att, wer_att = None, None
  582. else:
  583. ys_hat = decoder_out.argmax(dim=-1)
  584. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  585. return loss_att, acc_att, cer_att, wer_att, loss_pre
  586. def calc_predictor_mask(
  587. self,
  588. encoder_out: torch.Tensor,
  589. encoder_out_lens: torch.Tensor,
  590. ys_pad: torch.Tensor = None,
  591. ys_pad_lens: torch.Tensor = None,
  592. ):
  593. # ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  594. # ys_in_lens = ys_pad_lens + 1
  595. ys_out_pad, ys_in_lens = None, None
  596. encoder_out_mask = sequence_mask(encoder_out_lens, maxlen=encoder_out.size(1), dtype=encoder_out.dtype,
  597. device=encoder_out.device)[:, None, :]
  598. mask_chunk_predictor = None
  599. if self.encoder.overlap_chunk_cls is not None:
  600. mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
  601. device=encoder_out.device,
  602. batch_size=encoder_out.size(
  603. 0))
  604. mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
  605. batch_size=encoder_out.size(0))
  606. encoder_out = encoder_out * mask_shfit_chunk
  607. pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor(encoder_out,
  608. ys_out_pad,
  609. encoder_out_mask,
  610. ignore_id=self.ignore_id,
  611. mask_chunk_predictor=mask_chunk_predictor,
  612. target_label_length=ys_in_lens,
  613. )
  614. predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
  615. encoder_out_lens)
  616. scama_mask = None
  617. if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
  618. encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
  619. attention_chunk_center_bias = 0
  620. attention_chunk_size = encoder_chunk_size
  621. decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
  622. mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls.get_mask_shift_att_chunk_decoder(None,
  623. device=encoder_out.device,
  624. batch_size=encoder_out.size(
  625. 0))
  626. scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
  627. predictor_alignments=predictor_alignments,
  628. encoder_sequence_length=encoder_out_lens,
  629. chunk_size=1,
  630. encoder_chunk_size=encoder_chunk_size,
  631. attention_chunk_center_bias=attention_chunk_center_bias,
  632. attention_chunk_size=attention_chunk_size,
  633. attention_chunk_type=self.decoder_attention_chunk_type,
  634. step=None,
  635. predictor_mask_chunk_hopping=mask_chunk_predictor,
  636. decoder_att_look_back_factor=decoder_att_look_back_factor,
  637. mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
  638. target_length=ys_in_lens,
  639. is_training=self.training,
  640. )
  641. elif self.encoder.overlap_chunk_cls is not None:
  642. encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens,
  643. chunk_outs=None)
  644. return pre_acoustic_embeds, pre_token_length, predictor_alignments, predictor_alignments_len, scama_mask
  645. def calc_predictor_mask2(
  646. self,
  647. encoder_out: torch.Tensor,
  648. encoder_out_lens: torch.Tensor,
  649. ys_pad: torch.Tensor = None,
  650. ys_pad_lens: torch.Tensor = None,
  651. ):
  652. # ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  653. # ys_in_lens = ys_pad_lens + 1
  654. ys_out_pad, ys_in_lens = None, None
  655. encoder_out_mask = sequence_mask(encoder_out_lens, maxlen=encoder_out.size(1), dtype=encoder_out.dtype,
  656. device=encoder_out.device)[:, None, :]
  657. mask_chunk_predictor = None
  658. if self.encoder2.overlap_chunk_cls is not None:
  659. mask_chunk_predictor = self.encoder2.overlap_chunk_cls.get_mask_chunk_predictor(None,
  660. device=encoder_out.device,
  661. batch_size=encoder_out.size(
  662. 0))
  663. mask_shfit_chunk = self.encoder2.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
  664. batch_size=encoder_out.size(0))
  665. encoder_out = encoder_out * mask_shfit_chunk
  666. pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor2(encoder_out,
  667. ys_out_pad,
  668. encoder_out_mask,
  669. ignore_id=self.ignore_id,
  670. mask_chunk_predictor=mask_chunk_predictor,
  671. target_label_length=ys_in_lens,
  672. )
  673. predictor_alignments, predictor_alignments_len = self.predictor2.gen_frame_alignments(pre_alphas,
  674. encoder_out_lens)
  675. scama_mask = None
  676. if self.encoder2.overlap_chunk_cls is not None and self.decoder_attention_chunk_type2 == 'chunk':
  677. encoder_chunk_size = self.encoder2.overlap_chunk_cls.chunk_size_pad_shift_cur
  678. attention_chunk_center_bias = 0
  679. attention_chunk_size = encoder_chunk_size
  680. decoder_att_look_back_factor = self.encoder2.overlap_chunk_cls.decoder_att_look_back_factor_cur
  681. mask_shift_att_chunk_decoder = self.encoder2.overlap_chunk_cls.get_mask_shift_att_chunk_decoder(None,
  682. device=encoder_out.device,
  683. batch_size=encoder_out.size(
  684. 0))
  685. scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn2(
  686. predictor_alignments=predictor_alignments,
  687. encoder_sequence_length=encoder_out_lens,
  688. chunk_size=1,
  689. encoder_chunk_size=encoder_chunk_size,
  690. attention_chunk_center_bias=attention_chunk_center_bias,
  691. attention_chunk_size=attention_chunk_size,
  692. attention_chunk_type=self.decoder_attention_chunk_type2,
  693. step=None,
  694. predictor_mask_chunk_hopping=mask_chunk_predictor,
  695. decoder_att_look_back_factor=decoder_att_look_back_factor,
  696. mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
  697. target_length=ys_in_lens,
  698. is_training=self.training,
  699. )
  700. elif self.encoder2.overlap_chunk_cls is not None:
  701. encoder_out, encoder_out_lens = self.encoder2.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens,
  702. chunk_outs=None)
  703. return pre_acoustic_embeds, pre_token_length, predictor_alignments, predictor_alignments_len, scama_mask
  704. def init_beam_search(self,
  705. **kwargs,
  706. ):
  707. from funasr.models.uniasr.beam_search import BeamSearchScama
  708. from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
  709. from funasr.models.transformer.scorers.length_bonus import LengthBonus
  710. decoding_mode = kwargs.get("decoding_mode", "model1")
  711. if decoding_mode == "model1":
  712. decoder = self.decoder
  713. else:
  714. decoder = self.decoder2
  715. # 1. Build ASR model
  716. scorers = {}
  717. if self.ctc != None:
  718. ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
  719. scorers.update(
  720. ctc=ctc
  721. )
  722. token_list = kwargs.get("token_list")
  723. scorers.update(
  724. decoder=decoder,
  725. length_bonus=LengthBonus(len(token_list)),
  726. )
  727. # 3. Build ngram model
  728. # ngram is not supported now
  729. ngram = None
  730. scorers["ngram"] = ngram
  731. weights = dict(
  732. decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.0),
  733. ctc=kwargs.get("decoding_ctc_weight", 0.0),
  734. lm=kwargs.get("lm_weight", 0.0),
  735. ngram=kwargs.get("ngram_weight", 0.0),
  736. length_bonus=kwargs.get("penalty", 0.0),
  737. )
  738. beam_search = BeamSearchScama(
  739. beam_size=kwargs.get("beam_size", 5),
  740. weights=weights,
  741. scorers=scorers,
  742. sos=self.sos,
  743. eos=self.eos,
  744. vocab_size=len(token_list),
  745. token_list=token_list,
  746. pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
  747. )
  748. self.beam_search = beam_search
  749. def inference(self,
  750. data_in,
  751. data_lengths=None,
  752. key: list = None,
  753. tokenizer=None,
  754. frontend=None,
  755. **kwargs,
  756. ):
  757. decoding_model = kwargs.get("decoding_model", "normal")
  758. token_num_relax = kwargs.get("token_num_relax", 5)
  759. if decoding_model == "fast":
  760. decoding_ind = 0
  761. decoding_mode = "model1"
  762. elif decoding_model == "offline":
  763. decoding_ind = 1
  764. decoding_mode = "model2"
  765. else:
  766. decoding_ind = 0
  767. decoding_mode = "model2"
  768. # init beamsearch
  769. if self.beam_search is None:
  770. logging.info("enable beam_search")
  771. self.init_beam_search(decoding_mode=decoding_mode, **kwargs)
  772. self.nbest = kwargs.get("nbest", 1)
  773. meta_data = {}
  774. if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
  775. speech, speech_lengths = data_in, data_lengths
  776. if len(speech.shape) < 3:
  777. speech = speech[None, :, :]
  778. if speech_lengths is None:
  779. speech_lengths = speech.shape[1]
  780. else:
  781. # extract fbank feats
  782. time1 = time.perf_counter()
  783. audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
  784. data_type=kwargs.get("data_type", "sound"),
  785. tokenizer=tokenizer)
  786. time2 = time.perf_counter()
  787. meta_data["load_data"] = f"{time2 - time1:0.3f}"
  788. speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
  789. frontend=frontend)
  790. time3 = time.perf_counter()
  791. meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
  792. meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
  793. speech = speech.to(device=kwargs["device"])
  794. speech_lengths = speech_lengths.to(device=kwargs["device"])
  795. speech_raw = speech.clone().to(device=kwargs["device"])
  796. # Encoder
  797. _, encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=decoding_ind)
  798. if decoding_mode == "model1":
  799. predictor_outs = self.calc_predictor_mask(encoder_out, encoder_out_lens)
  800. else:
  801. encoder_out, encoder_out_lens = self.encode2(encoder_out, encoder_out_lens, speech_raw, speech_lengths, ind=decoding_ind)
  802. predictor_outs = self.calc_predictor_mask2(encoder_out, encoder_out_lens)
  803. scama_mask = predictor_outs[4]
  804. pre_token_length = predictor_outs[1]
  805. pre_acoustic_embeds = predictor_outs[0]
  806. maxlen = pre_token_length.sum().item() + token_num_relax
  807. minlen = max(0, pre_token_length.sum().item() - token_num_relax)
  808. # c. Passed the encoder result and the beam search
  809. nbest_hyps = self.beam_search(
  810. x=encoder_out[0], scama_mask=scama_mask, pre_acoustic_embeds=pre_acoustic_embeds, maxlenratio=0.0,
  811. minlenratio=0.0, maxlen=int(maxlen), minlen=int(minlen),
  812. )
  813. nbest_hyps = nbest_hyps[: self.nbest]
  814. results = []
  815. for hyp in nbest_hyps:
  816. # remove sos/eos and get results
  817. last_pos = -1
  818. if isinstance(hyp.yseq, list):
  819. token_int = hyp.yseq[1:last_pos]
  820. else:
  821. token_int = hyp.yseq[1:last_pos].tolist()
  822. # remove blank symbol id, which is assumed to be 0
  823. token_int = list(filter(lambda x: x != 0, token_int))
  824. # Change integer-ids to tokens
  825. token = tokenizer.ids2tokens(token_int)
  826. text_postprocessed = tokenizer.tokens2text(token)
  827. if not hasattr(tokenizer, "bpemodel"):
  828. text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
  829. result_i = {"key": key[0], "text": text_postprocessed}
  830. results.append(result_i)
  831. return results, meta_data