model.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. import logging
  2. from typing import Union, Dict, List, Tuple, Optional
  3. import time
  4. import torch
  5. import torch.nn as nn
  6. from torch.cuda.amp import autocast
  7. from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
  8. from funasr.models.ctc.ctc import CTC
  9. from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
  10. from funasr.metrics.compute_acc import th_accuracy
  11. # from funasr.models.e2e_asr_common import ErrorCalculator
  12. from funasr.train_utils.device_funcs import force_gatherable
  13. from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
  14. from funasr.utils import postprocess_utils
  15. from funasr.utils.datadir_writer import DatadirWriter
  16. from funasr.register import tables
  17. @tables.register("model_classes", "Transformer")
  18. class Transformer(nn.Module):
  19. """CTC-attention hybrid Encoder-Decoder model"""
  20. def __init__(
  21. self,
  22. specaug: str = None,
  23. specaug_conf: dict = None,
  24. normalize: str = None,
  25. normalize_conf: dict = None,
  26. encoder: str = None,
  27. encoder_conf: dict = None,
  28. decoder: str = None,
  29. decoder_conf: dict = None,
  30. ctc: str = None,
  31. ctc_conf: dict = None,
  32. ctc_weight: float = 0.5,
  33. interctc_weight: float = 0.0,
  34. input_size: int = 80,
  35. vocab_size: int = -1,
  36. ignore_id: int = -1,
  37. blank_id: int = 0,
  38. sos: int = 1,
  39. eos: int = 2,
  40. lsm_weight: float = 0.0,
  41. length_normalized_loss: bool = False,
  42. report_cer: bool = True,
  43. report_wer: bool = True,
  44. sym_space: str = "<space>",
  45. sym_blank: str = "<blank>",
  46. # extract_feats_in_collect_stats: bool = True,
  47. share_embedding: bool = False,
  48. # preencoder: Optional[AbsPreEncoder] = None,
  49. # postencoder: Optional[AbsPostEncoder] = None,
  50. **kwargs,
  51. ):
  52. super().__init__()
  53. if specaug is not None:
  54. specaug_class = tables.specaug_classes.get(specaug)
  55. specaug = specaug_class(**specaug_conf)
  56. if normalize is not None:
  57. normalize_class = tables.normalize_classes.get(normalize)
  58. normalize = normalize_class(**normalize_conf)
  59. encoder_class = tables.encoder_classes.get(encoder)
  60. encoder = encoder_class(input_size=input_size, **encoder_conf)
  61. encoder_output_size = encoder.output_size()
  62. if decoder is not None:
  63. decoder_class = tables.decoder_classes.get(decoder)
  64. decoder = decoder_class(
  65. vocab_size=vocab_size,
  66. encoder_output_size=encoder_output_size,
  67. **decoder_conf,
  68. )
  69. if ctc_weight > 0.0:
  70. if ctc_conf is None:
  71. ctc_conf = {}
  72. ctc = CTC(
  73. odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
  74. )
  75. self.blank_id = blank_id
  76. self.sos = sos if sos is not None else vocab_size - 1
  77. self.eos = eos if eos is not None else vocab_size - 1
  78. self.vocab_size = vocab_size
  79. self.ignore_id = ignore_id
  80. self.ctc_weight = ctc_weight
  81. self.specaug = specaug
  82. self.normalize = normalize
  83. self.encoder = encoder
  84. if not hasattr(self.encoder, "interctc_use_conditioning"):
  85. self.encoder.interctc_use_conditioning = False
  86. if self.encoder.interctc_use_conditioning:
  87. self.encoder.conditioning_layer = torch.nn.Linear(
  88. vocab_size, self.encoder.output_size()
  89. )
  90. self.interctc_weight = interctc_weight
  91. # self.error_calculator = None
  92. if ctc_weight == 1.0:
  93. self.decoder = None
  94. else:
  95. self.decoder = decoder
  96. self.criterion_att = LabelSmoothingLoss(
  97. size=vocab_size,
  98. padding_idx=ignore_id,
  99. smoothing=lsm_weight,
  100. normalize_length=length_normalized_loss,
  101. )
  102. #
  103. # if report_cer or report_wer:
  104. # self.error_calculator = ErrorCalculator(
  105. # token_list, sym_space, sym_blank, report_cer, report_wer
  106. # )
  107. #
  108. self.error_calculator = None
  109. if ctc_weight == 0.0:
  110. self.ctc = None
  111. else:
  112. self.ctc = ctc
  113. self.share_embedding = share_embedding
  114. if self.share_embedding:
  115. self.decoder.embed = None
  116. self.length_normalized_loss = length_normalized_loss
  117. self.beam_search = None
  118. def forward(
  119. self,
  120. speech: torch.Tensor,
  121. speech_lengths: torch.Tensor,
  122. text: torch.Tensor,
  123. text_lengths: torch.Tensor,
  124. **kwargs,
  125. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  126. """Encoder + Decoder + Calc loss
  127. Args:
  128. speech: (Batch, Length, ...)
  129. speech_lengths: (Batch, )
  130. text: (Batch, Length)
  131. text_lengths: (Batch,)
  132. """
  133. # import pdb;
  134. # pdb.set_trace()
  135. if len(text_lengths.size()) > 1:
  136. text_lengths = text_lengths[:, 0]
  137. if len(speech_lengths.size()) > 1:
  138. speech_lengths = speech_lengths[:, 0]
  139. batch_size = speech.shape[0]
  140. # 1. Encoder
  141. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  142. intermediate_outs = None
  143. if isinstance(encoder_out, tuple):
  144. intermediate_outs = encoder_out[1]
  145. encoder_out = encoder_out[0]
  146. loss_att, acc_att, cer_att, wer_att = None, None, None, None
  147. loss_ctc, cer_ctc = None, None
  148. stats = dict()
  149. # decoder: CTC branch
  150. if self.ctc_weight != 0.0:
  151. loss_ctc, cer_ctc = self._calc_ctc_loss(
  152. encoder_out, encoder_out_lens, text, text_lengths
  153. )
  154. # Collect CTC branch stats
  155. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  156. stats["cer_ctc"] = cer_ctc
  157. # Intermediate CTC (optional)
  158. loss_interctc = 0.0
  159. if self.interctc_weight != 0.0 and intermediate_outs is not None:
  160. for layer_idx, intermediate_out in intermediate_outs:
  161. # we assume intermediate_out has the same length & padding
  162. # as those of encoder_out
  163. loss_ic, cer_ic = self._calc_ctc_loss(
  164. intermediate_out, encoder_out_lens, text, text_lengths
  165. )
  166. loss_interctc = loss_interctc + loss_ic
  167. # Collect Intermedaite CTC stats
  168. stats["loss_interctc_layer{}".format(layer_idx)] = (
  169. loss_ic.detach() if loss_ic is not None else None
  170. )
  171. stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
  172. loss_interctc = loss_interctc / len(intermediate_outs)
  173. # calculate whole encoder loss
  174. loss_ctc = (
  175. 1 - self.interctc_weight
  176. ) * loss_ctc + self.interctc_weight * loss_interctc
  177. # decoder: Attention decoder branch
  178. loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
  179. encoder_out, encoder_out_lens, text, text_lengths
  180. )
  181. # 3. CTC-Att loss definition
  182. if self.ctc_weight == 0.0:
  183. loss = loss_att
  184. elif self.ctc_weight == 1.0:
  185. loss = loss_ctc
  186. else:
  187. loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
  188. # Collect Attn branch stats
  189. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  190. stats["acc"] = acc_att
  191. stats["cer"] = cer_att
  192. stats["wer"] = wer_att
  193. # Collect total loss stats
  194. stats["loss"] = torch.clone(loss.detach())
  195. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  196. if self.length_normalized_loss:
  197. batch_size = int((text_lengths + 1).sum())
  198. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  199. return loss, stats, weight
  200. def encode(
  201. self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
  202. ) -> Tuple[torch.Tensor, torch.Tensor]:
  203. """Frontend + Encoder. Note that this method is used by asr_inference.py
  204. Args:
  205. speech: (Batch, Length, ...)
  206. speech_lengths: (Batch, )
  207. ind: int
  208. """
  209. with autocast(False):
  210. # Data augmentation
  211. if self.specaug is not None and self.training:
  212. speech, speech_lengths = self.specaug(speech, speech_lengths)
  213. # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  214. if self.normalize is not None:
  215. speech, speech_lengths = self.normalize(speech, speech_lengths)
  216. # Forward encoder
  217. # feats: (Batch, Length, Dim)
  218. # -> encoder_out: (Batch, Length2, Dim2)
  219. if self.encoder.interctc_use_conditioning:
  220. encoder_out, encoder_out_lens, _ = self.encoder(
  221. speech, speech_lengths, ctc=self.ctc
  222. )
  223. else:
  224. encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
  225. intermediate_outs = None
  226. if isinstance(encoder_out, tuple):
  227. intermediate_outs = encoder_out[1]
  228. encoder_out = encoder_out[0]
  229. if intermediate_outs is not None:
  230. return (encoder_out, intermediate_outs), encoder_out_lens
  231. return encoder_out, encoder_out_lens
  232. def _calc_att_loss(
  233. self,
  234. encoder_out: torch.Tensor,
  235. encoder_out_lens: torch.Tensor,
  236. ys_pad: torch.Tensor,
  237. ys_pad_lens: torch.Tensor,
  238. ):
  239. ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  240. ys_in_lens = ys_pad_lens + 1
  241. # 1. Forward decoder
  242. decoder_out, _ = self.decoder(
  243. encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
  244. )
  245. # 2. Compute attention loss
  246. loss_att = self.criterion_att(decoder_out, ys_out_pad)
  247. acc_att = th_accuracy(
  248. decoder_out.view(-1, self.vocab_size),
  249. ys_out_pad,
  250. ignore_label=self.ignore_id,
  251. )
  252. # Compute cer/wer using attention-decoder
  253. if self.training or self.error_calculator is None:
  254. cer_att, wer_att = None, None
  255. else:
  256. ys_hat = decoder_out.argmax(dim=-1)
  257. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  258. return loss_att, acc_att, cer_att, wer_att
  259. def _calc_ctc_loss(
  260. self,
  261. encoder_out: torch.Tensor,
  262. encoder_out_lens: torch.Tensor,
  263. ys_pad: torch.Tensor,
  264. ys_pad_lens: torch.Tensor,
  265. ):
  266. # Calc CTC loss
  267. loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
  268. # Calc CER using CTC
  269. cer_ctc = None
  270. if not self.training and self.error_calculator is not None:
  271. ys_hat = self.ctc.argmax(encoder_out).data
  272. cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
  273. return loss_ctc, cer_ctc
  274. def init_beam_search(self,
  275. **kwargs,
  276. ):
  277. from funasr.models.transformer.search import BeamSearch
  278. from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
  279. from funasr.models.transformer.scorers.length_bonus import LengthBonus
  280. # 1. Build ASR model
  281. scorers = {}
  282. if self.ctc != None:
  283. ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
  284. scorers.update(
  285. ctc=ctc
  286. )
  287. token_list = kwargs.get("token_list")
  288. scorers.update(
  289. decoder=self.decoder,
  290. length_bonus=LengthBonus(len(token_list)),
  291. )
  292. # 3. Build ngram model
  293. # ngram is not supported now
  294. ngram = None
  295. scorers["ngram"] = ngram
  296. weights = dict(
  297. decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.5),
  298. ctc=kwargs.get("decoding_ctc_weight", 0.5),
  299. lm=kwargs.get("lm_weight", 0.0),
  300. ngram=kwargs.get("ngram_weight", 0.0),
  301. length_bonus=kwargs.get("penalty", 0.0),
  302. )
  303. beam_search = BeamSearch(
  304. beam_size=kwargs.get("beam_size", 10),
  305. weights=weights,
  306. scorers=scorers,
  307. sos=self.sos,
  308. eos=self.eos,
  309. vocab_size=len(token_list),
  310. token_list=token_list,
  311. pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
  312. )
  313. self.beam_search = beam_search
  314. def inference(self,
  315. data_in,
  316. data_lengths=None,
  317. key: list=None,
  318. tokenizer=None,
  319. frontend=None,
  320. **kwargs,
  321. ):
  322. if kwargs.get("batch_size", 1) > 1:
  323. raise NotImplementedError("batch decoding is not implemented")
  324. # init beamsearch
  325. if self.beam_search is None:
  326. logging.info("enable beam_search")
  327. self.init_beam_search(**kwargs)
  328. self.nbest = kwargs.get("nbest", 1)
  329. meta_data = {}
  330. if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
  331. speech, speech_lengths = data_in, data_lengths
  332. if len(speech.shape) < 3:
  333. speech = speech[None, :, :]
  334. if speech_lengths is None:
  335. speech_lengths = speech.shape[1]
  336. else:
  337. # extract fbank feats
  338. time1 = time.perf_counter()
  339. audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
  340. data_type=kwargs.get("data_type", "sound"),
  341. tokenizer=tokenizer)
  342. time2 = time.perf_counter()
  343. meta_data["load_data"] = f"{time2 - time1:0.3f}"
  344. speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
  345. frontend=frontend)
  346. time3 = time.perf_counter()
  347. meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
  348. meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
  349. speech = speech.to(device=kwargs["device"])
  350. speech_lengths = speech_lengths.to(device=kwargs["device"])
  351. # Encoder
  352. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  353. if isinstance(encoder_out, tuple):
  354. encoder_out = encoder_out[0]
  355. # c. Passed the encoder result and the beam search
  356. nbest_hyps = self.beam_search(
  357. x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
  358. )
  359. nbest_hyps = nbest_hyps[: self.nbest]
  360. results = []
  361. b, n, d = encoder_out.size()
  362. for i in range(b):
  363. for nbest_idx, hyp in enumerate(nbest_hyps):
  364. ibest_writer = None
  365. if kwargs.get("output_dir") is not None:
  366. if not hasattr(self, "writer"):
  367. self.writer = DatadirWriter(kwargs.get("output_dir"))
  368. ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
  369. # remove sos/eos and get results
  370. last_pos = -1
  371. if isinstance(hyp.yseq, list):
  372. token_int = hyp.yseq[1:last_pos]
  373. else:
  374. token_int = hyp.yseq[1:last_pos].tolist()
  375. # remove blank symbol id, which is assumed to be 0
  376. token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
  377. # Change integer-ids to tokens
  378. token = tokenizer.ids2tokens(token_int)
  379. text = tokenizer.tokens2text(token)
  380. text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
  381. result_i = {"key": key[i], "token": token, "text": text_postprocessed}
  382. results.append(result_i)
  383. if ibest_writer is not None:
  384. ibest_writer["token"][key[i]] = " ".join(token)
  385. ibest_writer["text"][key[i]] = text_postprocessed
  386. return results, meta_data