model.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import logging
  6. from typing import Union, Dict, List, Tuple, Optional
  7. import time
  8. import torch
  9. import torch.nn as nn
  10. from torch.cuda.amp import autocast
  11. from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
  12. from funasr.models.ctc.ctc import CTC
  13. from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
  14. from funasr.metrics.compute_acc import th_accuracy
  15. # from funasr.models.e2e_asr_common import ErrorCalculator
  16. from funasr.train_utils.device_funcs import force_gatherable
  17. from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
  18. from funasr.utils import postprocess_utils
  19. from funasr.utils.datadir_writer import DatadirWriter
  20. from funasr.register import tables
  21. import pdb
  22. @tables.register("model_classes", "LCBNet")
  23. class LCBNet(nn.Module):
  24. """
  25. Author: Speech Lab of DAMO Academy, Alibaba Group
  26. LCB-NET: LONG-CONTEXT BIASING FOR AUDIO-VISUAL SPEECH RECOGNITION
  27. https://arxiv.org/abs/2401.06390
  28. """
  29. def __init__(
  30. self,
  31. specaug: str = None,
  32. specaug_conf: dict = None,
  33. normalize: str = None,
  34. normalize_conf: dict = None,
  35. encoder: str = None,
  36. encoder_conf: dict = None,
  37. decoder: str = None,
  38. decoder_conf: dict = None,
  39. text_encoder: str = None,
  40. text_encoder_conf: dict = None,
  41. bias_predictor: str = None,
  42. bias_predictor_conf: dict = None,
  43. fusion_encoder: str = None,
  44. fusion_encoder_conf: dict = None,
  45. ctc: str = None,
  46. ctc_conf: dict = None,
  47. ctc_weight: float = 0.5,
  48. interctc_weight: float = 0.0,
  49. select_num: int = 2,
  50. select_length: int = 3,
  51. insert_blank: bool = True,
  52. input_size: int = 80,
  53. vocab_size: int = -1,
  54. ignore_id: int = -1,
  55. blank_id: int = 0,
  56. sos: int = 1,
  57. eos: int = 2,
  58. lsm_weight: float = 0.0,
  59. length_normalized_loss: bool = False,
  60. report_cer: bool = True,
  61. report_wer: bool = True,
  62. sym_space: str = "<space>",
  63. sym_blank: str = "<blank>",
  64. # extract_feats_in_collect_stats: bool = True,
  65. share_embedding: bool = False,
  66. # preencoder: Optional[AbsPreEncoder] = None,
  67. # postencoder: Optional[AbsPostEncoder] = None,
  68. **kwargs,
  69. ):
  70. super().__init__()
  71. if specaug is not None:
  72. specaug_class = tables.specaug_classes.get(specaug)
  73. specaug = specaug_class(**specaug_conf)
  74. if normalize is not None:
  75. normalize_class = tables.normalize_classes.get(normalize)
  76. normalize = normalize_class(**normalize_conf)
  77. encoder_class = tables.encoder_classes.get(encoder)
  78. encoder = encoder_class(input_size=input_size, **encoder_conf)
  79. encoder_output_size = encoder.output_size()
  80. # lcbnet modules: text encoder, fusion encoder and bias predictor
  81. text_encoder_class = tables.encoder_classes.get(text_encoder)
  82. text_encoder = text_encoder_class(input_size=vocab_size, **text_encoder_conf)
  83. fusion_encoder_class = tables.encoder_classes.get(fusion_encoder)
  84. fusion_encoder = fusion_encoder_class(**fusion_encoder_conf)
  85. bias_predictor_class = tables.encoder_classes.get(bias_predictor)
  86. bias_predictor = bias_predictor_class(**bias_predictor_conf)
  87. if decoder is not None:
  88. decoder_class = tables.decoder_classes.get(decoder)
  89. decoder = decoder_class(
  90. vocab_size=vocab_size,
  91. encoder_output_size=encoder_output_size,
  92. **decoder_conf,
  93. )
  94. if ctc_weight > 0.0:
  95. if ctc_conf is None:
  96. ctc_conf = {}
  97. ctc = CTC(
  98. odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
  99. )
  100. self.blank_id = blank_id
  101. self.sos = sos if sos is not None else vocab_size - 1
  102. self.eos = eos if eos is not None else vocab_size - 1
  103. self.vocab_size = vocab_size
  104. self.ignore_id = ignore_id
  105. self.ctc_weight = ctc_weight
  106. self.specaug = specaug
  107. self.normalize = normalize
  108. self.encoder = encoder
  109. # lcbnet
  110. self.text_encoder = text_encoder
  111. self.fusion_encoder = fusion_encoder
  112. self.bias_predictor = bias_predictor
  113. self.select_num = select_num
  114. self.select_length = select_length
  115. self.insert_blank = insert_blank
  116. if not hasattr(self.encoder, "interctc_use_conditioning"):
  117. self.encoder.interctc_use_conditioning = False
  118. if self.encoder.interctc_use_conditioning:
  119. self.encoder.conditioning_layer = torch.nn.Linear(
  120. vocab_size, self.encoder.output_size()
  121. )
  122. self.interctc_weight = interctc_weight
  123. # self.error_calculator = None
  124. if ctc_weight == 1.0:
  125. self.decoder = None
  126. else:
  127. self.decoder = decoder
  128. self.criterion_att = LabelSmoothingLoss(
  129. size=vocab_size,
  130. padding_idx=ignore_id,
  131. smoothing=lsm_weight,
  132. normalize_length=length_normalized_loss,
  133. )
  134. #
  135. # if report_cer or report_wer:
  136. # self.error_calculator = ErrorCalculator(
  137. # token_list, sym_space, sym_blank, report_cer, report_wer
  138. # )
  139. #
  140. self.error_calculator = None
  141. if ctc_weight == 0.0:
  142. self.ctc = None
  143. else:
  144. self.ctc = ctc
  145. self.share_embedding = share_embedding
  146. if self.share_embedding:
  147. self.decoder.embed = None
  148. self.length_normalized_loss = length_normalized_loss
  149. self.beam_search = None
  150. def forward(
  151. self,
  152. speech: torch.Tensor,
  153. speech_lengths: torch.Tensor,
  154. text: torch.Tensor,
  155. text_lengths: torch.Tensor,
  156. **kwargs,
  157. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  158. """Encoder + Decoder + Calc loss
  159. Args:
  160. speech: (Batch, Length, ...)
  161. speech_lengths: (Batch, )
  162. text: (Batch, Length)
  163. text_lengths: (Batch,)
  164. """
  165. # import pdb;
  166. # pdb.set_trace()
  167. if len(text_lengths.size()) > 1:
  168. text_lengths = text_lengths[:, 0]
  169. if len(speech_lengths.size()) > 1:
  170. speech_lengths = speech_lengths[:, 0]
  171. batch_size = speech.shape[0]
  172. # 1. Encoder
  173. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  174. intermediate_outs = None
  175. if isinstance(encoder_out, tuple):
  176. intermediate_outs = encoder_out[1]
  177. encoder_out = encoder_out[0]
  178. loss_att, acc_att, cer_att, wer_att = None, None, None, None
  179. loss_ctc, cer_ctc = None, None
  180. stats = dict()
  181. # decoder: CTC branch
  182. if self.ctc_weight != 0.0:
  183. loss_ctc, cer_ctc = self._calc_ctc_loss(
  184. encoder_out, encoder_out_lens, text, text_lengths
  185. )
  186. # Collect CTC branch stats
  187. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  188. stats["cer_ctc"] = cer_ctc
  189. # Intermediate CTC (optional)
  190. loss_interctc = 0.0
  191. if self.interctc_weight != 0.0 and intermediate_outs is not None:
  192. for layer_idx, intermediate_out in intermediate_outs:
  193. # we assume intermediate_out has the same length & padding
  194. # as those of encoder_out
  195. loss_ic, cer_ic = self._calc_ctc_loss(
  196. intermediate_out, encoder_out_lens, text, text_lengths
  197. )
  198. loss_interctc = loss_interctc + loss_ic
  199. # Collect Intermedaite CTC stats
  200. stats["loss_interctc_layer{}".format(layer_idx)] = (
  201. loss_ic.detach() if loss_ic is not None else None
  202. )
  203. stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
  204. loss_interctc = loss_interctc / len(intermediate_outs)
  205. # calculate whole encoder loss
  206. loss_ctc = (
  207. 1 - self.interctc_weight
  208. ) * loss_ctc + self.interctc_weight * loss_interctc
  209. # decoder: Attention decoder branch
  210. loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
  211. encoder_out, encoder_out_lens, text, text_lengths
  212. )
  213. # 3. CTC-Att loss definition
  214. if self.ctc_weight == 0.0:
  215. loss = loss_att
  216. elif self.ctc_weight == 1.0:
  217. loss = loss_ctc
  218. else:
  219. loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
  220. # Collect Attn branch stats
  221. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  222. stats["acc"] = acc_att
  223. stats["cer"] = cer_att
  224. stats["wer"] = wer_att
  225. # Collect total loss stats
  226. stats["loss"] = torch.clone(loss.detach())
  227. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  228. if self.length_normalized_loss:
  229. batch_size = int((text_lengths + 1).sum())
  230. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  231. return loss, stats, weight
  232. def encode(
  233. self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
  234. ) -> Tuple[torch.Tensor, torch.Tensor]:
  235. """Frontend + Encoder. Note that this method is used by asr_inference.py
  236. Args:
  237. speech: (Batch, Length, ...)
  238. speech_lengths: (Batch, )
  239. ind: int
  240. """
  241. with autocast(False):
  242. pdb.set_trace()
  243. # Data augmentation
  244. if self.specaug is not None and self.training:
  245. speech, speech_lengths = self.specaug(speech, speech_lengths)
  246. pdb.set_trace()
  247. # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  248. if self.normalize is not None:
  249. speech, speech_lengths = self.normalize(speech, speech_lengths)
  250. pdb.set_trace()
  251. # Forward encoder
  252. # feats: (Batch, Length, Dim)
  253. # -> encoder_out: (Batch, Length2, Dim2)
  254. if self.encoder.interctc_use_conditioning:
  255. encoder_out, encoder_out_lens, _ = self.encoder(
  256. speech, speech_lengths, ctc=self.ctc
  257. )
  258. else:
  259. encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
  260. intermediate_outs = None
  261. if isinstance(encoder_out, tuple):
  262. intermediate_outs = encoder_out[1]
  263. encoder_out = encoder_out[0]
  264. if intermediate_outs is not None:
  265. return (encoder_out, intermediate_outs), encoder_out_lens
  266. pdb.set_trace()
  267. return encoder_out, encoder_out_lens
  268. def _calc_att_loss(
  269. self,
  270. encoder_out: torch.Tensor,
  271. encoder_out_lens: torch.Tensor,
  272. ys_pad: torch.Tensor,
  273. ys_pad_lens: torch.Tensor,
  274. ):
  275. ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  276. ys_in_lens = ys_pad_lens + 1
  277. # 1. Forward decoder
  278. decoder_out, _ = self.decoder(
  279. encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
  280. )
  281. # 2. Compute attention loss
  282. loss_att = self.criterion_att(decoder_out, ys_out_pad)
  283. acc_att = th_accuracy(
  284. decoder_out.view(-1, self.vocab_size),
  285. ys_out_pad,
  286. ignore_label=self.ignore_id,
  287. )
  288. # Compute cer/wer using attention-decoder
  289. if self.training or self.error_calculator is None:
  290. cer_att, wer_att = None, None
  291. else:
  292. ys_hat = decoder_out.argmax(dim=-1)
  293. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  294. return loss_att, acc_att, cer_att, wer_att
  295. def _calc_ctc_loss(
  296. self,
  297. encoder_out: torch.Tensor,
  298. encoder_out_lens: torch.Tensor,
  299. ys_pad: torch.Tensor,
  300. ys_pad_lens: torch.Tensor,
  301. ):
  302. # Calc CTC loss
  303. loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
  304. # Calc CER using CTC
  305. cer_ctc = None
  306. if not self.training and self.error_calculator is not None:
  307. ys_hat = self.ctc.argmax(encoder_out).data
  308. cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
  309. return loss_ctc, cer_ctc
  310. def init_beam_search(self,
  311. **kwargs,
  312. ):
  313. from funasr.models.transformer.search import BeamSearch
  314. from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
  315. from funasr.models.transformer.scorers.length_bonus import LengthBonus
  316. # 1. Build ASR model
  317. scorers = {}
  318. if self.ctc != None:
  319. ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
  320. scorers.update(
  321. ctc=ctc
  322. )
  323. token_list = kwargs.get("token_list")
  324. scorers.update(
  325. decoder=self.decoder,
  326. length_bonus=LengthBonus(len(token_list)),
  327. )
  328. # 3. Build ngram model
  329. # ngram is not supported now
  330. ngram = None
  331. scorers["ngram"] = ngram
  332. weights = dict(
  333. decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.5),
  334. ctc=kwargs.get("decoding_ctc_weight", 0.5),
  335. lm=kwargs.get("lm_weight", 0.0),
  336. ngram=kwargs.get("ngram_weight", 0.0),
  337. length_bonus=kwargs.get("penalty", 0.0),
  338. )
  339. beam_search = BeamSearch(
  340. beam_size=kwargs.get("beam_size", 10),
  341. weights=weights,
  342. scorers=scorers,
  343. sos=self.sos,
  344. eos=self.eos,
  345. vocab_size=len(token_list),
  346. token_list=token_list,
  347. pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
  348. )
  349. self.beam_search = beam_search
  350. def inference(self,
  351. data_in,
  352. data_lengths=None,
  353. key: list=None,
  354. tokenizer=None,
  355. frontend=None,
  356. **kwargs,
  357. ):
  358. if kwargs.get("batch_size", 1) > 1:
  359. raise NotImplementedError("batch decoding is not implemented")
  360. # init beamsearch
  361. if self.beam_search is None:
  362. logging.info("enable beam_search")
  363. self.init_beam_search(**kwargs)
  364. self.nbest = kwargs.get("nbest", 1)
  365. meta_data = {}
  366. if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
  367. speech, speech_lengths = data_in, data_lengths
  368. if len(speech.shape) < 3:
  369. speech = speech[None, :, :]
  370. if speech_lengths is None:
  371. speech_lengths = speech.shape[1]
  372. else:
  373. # extract fbank feats
  374. time1 = time.perf_counter()
  375. sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
  376. data_type=kwargs.get("data_type", "sound"),
  377. tokenizer=tokenizer)
  378. time2 = time.perf_counter()
  379. meta_data["load_data"] = f"{time2 - time1:0.3f}"
  380. audio_sample_list = sample_list[0]
  381. ocr_sample_list = sample_list[1]
  382. speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
  383. frontend=frontend)
  384. time3 = time.perf_counter()
  385. meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
  386. frame_shift = 10
  387. meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift / 1000
  388. speech = speech.to(device=kwargs["device"])
  389. speech_lengths = speech_lengths.to(device=kwargs["device"])
  390. pdb.set_trace()
  391. # Encoder
  392. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  393. if isinstance(encoder_out, tuple):
  394. encoder_out = encoder_out[0]
  395. # c. Passed the encoder result and the beam search
  396. nbest_hyps = self.beam_search(
  397. x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
  398. )
  399. nbest_hyps = nbest_hyps[: self.nbest]
  400. results = []
  401. b, n, d = encoder_out.size()
  402. for i in range(b):
  403. for nbest_idx, hyp in enumerate(nbest_hyps):
  404. ibest_writer = None
  405. if kwargs.get("output_dir") is not None:
  406. if not hasattr(self, "writer"):
  407. self.writer = DatadirWriter(kwargs.get("output_dir"))
  408. ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
  409. # remove sos/eos and get results
  410. last_pos = -1
  411. if isinstance(hyp.yseq, list):
  412. token_int = hyp.yseq[1:last_pos]
  413. else:
  414. token_int = hyp.yseq[1:last_pos].tolist()
  415. # remove blank symbol id, which is assumed to be 0
  416. token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
  417. # Change integer-ids to tokens
  418. token = tokenizer.ids2tokens(token_int)
  419. text = tokenizer.tokens2text(token)
  420. text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
  421. result_i = {"key": key[i], "token": token, "text": text_postprocessed}
  422. results.append(result_i)
  423. if ibest_writer is not None:
  424. ibest_writer["token"][key[i]] = " ".join(token)
  425. ibest_writer["text"][key[i]] = text_postprocessed
  426. return results, meta_data