model.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. import logging
  2. from typing import Union, Dict, List, Tuple, Optional
  3. import time
  4. import torch
  5. import torch.nn as nn
  6. from torch.cuda.amp import autocast
  7. from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
  8. from funasr.models.ctc.ctc import CTC
  9. from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
  10. from funasr.metrics.compute_acc import th_accuracy
  11. # from funasr.models.e2e_asr_common import ErrorCalculator
  12. from funasr.train_utils.device_funcs import force_gatherable
  13. from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
  14. from funasr.utils import postprocess_utils
  15. from funasr.utils.datadir_writer import DatadirWriter
  16. from funasr.register import tables
  17. @tables.register("model_classes", "Transformer")
  18. class Transformer(nn.Module):
  19. """CTC-attention hybrid Encoder-Decoder model"""
  20. def __init__(
  21. self,
  22. frontend: Optional[str] = None,
  23. frontend_conf: Optional[Dict] = None,
  24. specaug: Optional[str] = None,
  25. specaug_conf: Optional[Dict] = None,
  26. normalize: str = None,
  27. normalize_conf: Optional[Dict] = None,
  28. encoder: str = None,
  29. encoder_conf: Optional[Dict] = None,
  30. decoder: str = None,
  31. decoder_conf: Optional[Dict] = None,
  32. ctc: str = None,
  33. ctc_conf: Optional[Dict] = None,
  34. ctc_weight: float = 0.5,
  35. interctc_weight: float = 0.0,
  36. input_size: int = 80,
  37. vocab_size: int = -1,
  38. ignore_id: int = -1,
  39. blank_id: int = 0,
  40. sos: int = 1,
  41. eos: int = 2,
  42. lsm_weight: float = 0.0,
  43. length_normalized_loss: bool = False,
  44. report_cer: bool = True,
  45. report_wer: bool = True,
  46. sym_space: str = "<space>",
  47. sym_blank: str = "<blank>",
  48. # extract_feats_in_collect_stats: bool = True,
  49. share_embedding: bool = False,
  50. # preencoder: Optional[AbsPreEncoder] = None,
  51. # postencoder: Optional[AbsPostEncoder] = None,
  52. **kwargs,
  53. ):
  54. super().__init__()
  55. if frontend is not None:
  56. frontend_class = tables.frontend_classes.get_class(frontend)
  57. frontend = frontend_class(**frontend_conf)
  58. if specaug is not None:
  59. specaug_class = tables.specaug_classes.get_class(specaug)
  60. specaug = specaug_class(**specaug_conf)
  61. if normalize is not None:
  62. normalize_class = tables.normalize_classes.get_class(normalize)
  63. normalize = normalize_class(**normalize_conf)
  64. encoder_class = tables.encoder_classes.get_class(encoder)
  65. encoder = encoder_class(input_size=input_size, **encoder_conf)
  66. encoder_output_size = encoder.output_size()
  67. if decoder is not None:
  68. decoder_class = tables.decoder_classes.get_class(decoder)
  69. decoder = decoder_class(
  70. vocab_size=vocab_size,
  71. encoder_output_size=encoder_output_size,
  72. **decoder_conf,
  73. )
  74. if ctc_weight > 0.0:
  75. if ctc_conf is None:
  76. ctc_conf = {}
  77. ctc = CTC(
  78. odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
  79. )
  80. self.blank_id = blank_id
  81. self.sos = sos if sos is not None else vocab_size - 1
  82. self.eos = eos if eos is not None else vocab_size - 1
  83. self.vocab_size = vocab_size
  84. self.ignore_id = ignore_id
  85. self.ctc_weight = ctc_weight
  86. self.frontend = frontend
  87. self.specaug = specaug
  88. self.normalize = normalize
  89. self.encoder = encoder
  90. if not hasattr(self.encoder, "interctc_use_conditioning"):
  91. self.encoder.interctc_use_conditioning = False
  92. if self.encoder.interctc_use_conditioning:
  93. self.encoder.conditioning_layer = torch.nn.Linear(
  94. vocab_size, self.encoder.output_size()
  95. )
  96. self.interctc_weight = interctc_weight
  97. # self.error_calculator = None
  98. if ctc_weight == 1.0:
  99. self.decoder = None
  100. else:
  101. self.decoder = decoder
  102. self.criterion_att = LabelSmoothingLoss(
  103. size=vocab_size,
  104. padding_idx=ignore_id,
  105. smoothing=lsm_weight,
  106. normalize_length=length_normalized_loss,
  107. )
  108. #
  109. # if report_cer or report_wer:
  110. # self.error_calculator = ErrorCalculator(
  111. # token_list, sym_space, sym_blank, report_cer, report_wer
  112. # )
  113. #
  114. if ctc_weight == 0.0:
  115. self.ctc = None
  116. else:
  117. self.ctc = ctc
  118. self.share_embedding = share_embedding
  119. if self.share_embedding:
  120. self.decoder.embed = None
  121. self.length_normalized_loss = length_normalized_loss
  122. self.beam_search = None
  123. def forward(
  124. self,
  125. speech: torch.Tensor,
  126. speech_lengths: torch.Tensor,
  127. text: torch.Tensor,
  128. text_lengths: torch.Tensor,
  129. **kwargs,
  130. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  131. """Encoder + Decoder + Calc loss
  132. Args:
  133. speech: (Batch, Length, ...)
  134. speech_lengths: (Batch, )
  135. text: (Batch, Length)
  136. text_lengths: (Batch,)
  137. """
  138. # import pdb;
  139. # pdb.set_trace()
  140. if len(text_lengths.size()) > 1:
  141. text_lengths = text_lengths[:, 0]
  142. if len(speech_lengths.size()) > 1:
  143. speech_lengths = speech_lengths[:, 0]
  144. batch_size = speech.shape[0]
  145. # 1. Encoder
  146. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  147. intermediate_outs = None
  148. if isinstance(encoder_out, tuple):
  149. intermediate_outs = encoder_out[1]
  150. encoder_out = encoder_out[0]
  151. loss_att, acc_att, cer_att, wer_att = None, None, None, None
  152. loss_ctc, cer_ctc = None, None
  153. stats = dict()
  154. # decoder: CTC branch
  155. if self.ctc_weight != 0.0:
  156. loss_ctc, cer_ctc = self._calc_ctc_loss(
  157. encoder_out, encoder_out_lens, text, text_lengths
  158. )
  159. # Collect CTC branch stats
  160. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  161. stats["cer_ctc"] = cer_ctc
  162. # Intermediate CTC (optional)
  163. loss_interctc = 0.0
  164. if self.interctc_weight != 0.0 and intermediate_outs is not None:
  165. for layer_idx, intermediate_out in intermediate_outs:
  166. # we assume intermediate_out has the same length & padding
  167. # as those of encoder_out
  168. loss_ic, cer_ic = self._calc_ctc_loss(
  169. intermediate_out, encoder_out_lens, text, text_lengths
  170. )
  171. loss_interctc = loss_interctc + loss_ic
  172. # Collect Intermedaite CTC stats
  173. stats["loss_interctc_layer{}".format(layer_idx)] = (
  174. loss_ic.detach() if loss_ic is not None else None
  175. )
  176. stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
  177. loss_interctc = loss_interctc / len(intermediate_outs)
  178. # calculate whole encoder loss
  179. loss_ctc = (
  180. 1 - self.interctc_weight
  181. ) * loss_ctc + self.interctc_weight * loss_interctc
  182. # decoder: Attention decoder branch
  183. loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
  184. encoder_out, encoder_out_lens, text, text_lengths
  185. )
  186. # 3. CTC-Att loss definition
  187. if self.ctc_weight == 0.0:
  188. loss = loss_att
  189. elif self.ctc_weight == 1.0:
  190. loss = loss_ctc
  191. else:
  192. loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
  193. # Collect Attn branch stats
  194. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  195. stats["acc"] = acc_att
  196. stats["cer"] = cer_att
  197. stats["wer"] = wer_att
  198. # Collect total loss stats
  199. stats["loss"] = torch.clone(loss.detach())
  200. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  201. if self.length_normalized_loss:
  202. batch_size = int((text_lengths + 1).sum())
  203. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  204. return loss, stats, weight
  205. def encode(
  206. self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
  207. ) -> Tuple[torch.Tensor, torch.Tensor]:
  208. """Frontend + Encoder. Note that this method is used by asr_inference.py
  209. Args:
  210. speech: (Batch, Length, ...)
  211. speech_lengths: (Batch, )
  212. ind: int
  213. """
  214. with autocast(False):
  215. # Data augmentation
  216. if self.specaug is not None and self.training:
  217. speech, speech_lengths = self.specaug(speech, speech_lengths)
  218. # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  219. if self.normalize is not None:
  220. speech, speech_lengths = self.normalize(speech, speech_lengths)
  221. # Forward encoder
  222. # feats: (Batch, Length, Dim)
  223. # -> encoder_out: (Batch, Length2, Dim2)
  224. if self.encoder.interctc_use_conditioning:
  225. encoder_out, encoder_out_lens, _ = self.encoder(
  226. speech, speech_lengths, ctc=self.ctc
  227. )
  228. else:
  229. encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
  230. intermediate_outs = None
  231. if isinstance(encoder_out, tuple):
  232. intermediate_outs = encoder_out[1]
  233. encoder_out = encoder_out[0]
  234. if intermediate_outs is not None:
  235. return (encoder_out, intermediate_outs), encoder_out_lens
  236. return encoder_out, encoder_out_lens
  237. def _calc_att_loss(
  238. self,
  239. encoder_out: torch.Tensor,
  240. encoder_out_lens: torch.Tensor,
  241. ys_pad: torch.Tensor,
  242. ys_pad_lens: torch.Tensor,
  243. ):
  244. ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  245. ys_in_lens = ys_pad_lens + 1
  246. # 1. Forward decoder
  247. decoder_out, _ = self.decoder(
  248. encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
  249. )
  250. # 2. Compute attention loss
  251. loss_att = self.criterion_att(decoder_out, ys_out_pad)
  252. acc_att = th_accuracy(
  253. decoder_out.view(-1, self.vocab_size),
  254. ys_out_pad,
  255. ignore_label=self.ignore_id,
  256. )
  257. # Compute cer/wer using attention-decoder
  258. if self.training or self.error_calculator is None:
  259. cer_att, wer_att = None, None
  260. else:
  261. ys_hat = decoder_out.argmax(dim=-1)
  262. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  263. return loss_att, acc_att, cer_att, wer_att
  264. def _calc_ctc_loss(
  265. self,
  266. encoder_out: torch.Tensor,
  267. encoder_out_lens: torch.Tensor,
  268. ys_pad: torch.Tensor,
  269. ys_pad_lens: torch.Tensor,
  270. ):
  271. # Calc CTC loss
  272. loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
  273. # Calc CER using CTC
  274. cer_ctc = None
  275. if not self.training and self.error_calculator is not None:
  276. ys_hat = self.ctc.argmax(encoder_out).data
  277. cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
  278. return loss_ctc, cer_ctc
  279. def init_beam_search(self,
  280. **kwargs,
  281. ):
  282. from funasr.models.transformer.search import BeamSearch
  283. from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
  284. from funasr.models.transformer.scorers.length_bonus import LengthBonus
  285. # 1. Build ASR model
  286. scorers = {}
  287. if self.ctc != None:
  288. ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
  289. scorers.update(
  290. ctc=ctc
  291. )
  292. token_list = kwargs.get("token_list")
  293. scorers.update(
  294. length_bonus=LengthBonus(len(token_list)),
  295. )
  296. # 3. Build ngram model
  297. # ngram is not supported now
  298. ngram = None
  299. scorers["ngram"] = ngram
  300. weights = dict(
  301. decoder=1.0 - kwargs.get("decoding_ctc_weight"),
  302. ctc=kwargs.get("decoding_ctc_weight", 0.0),
  303. lm=kwargs.get("lm_weight", 0.0),
  304. ngram=kwargs.get("ngram_weight", 0.0),
  305. length_bonus=kwargs.get("penalty", 0.0),
  306. )
  307. beam_search = BeamSearch(
  308. beam_size=kwargs.get("beam_size", 2),
  309. weights=weights,
  310. scorers=scorers,
  311. sos=self.sos,
  312. eos=self.eos,
  313. vocab_size=len(token_list),
  314. token_list=token_list,
  315. pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
  316. )
  317. # beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
  318. # for scorer in scorers.values():
  319. # if isinstance(scorer, torch.nn.Module):
  320. # scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
  321. self.beam_search = beam_search
  322. def generate(self,
  323. data_in: list,
  324. data_lengths: list=None,
  325. key: list=None,
  326. tokenizer=None,
  327. **kwargs,
  328. ):
  329. if kwargs.get("batch_size", 1) > 1:
  330. raise NotImplementedError("batch decoding is not implemented")
  331. # init beamsearch
  332. is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
  333. is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
  334. if self.beam_search is None and (is_use_lm or is_use_ctc):
  335. logging.info("enable beam_search")
  336. self.init_beam_search(**kwargs)
  337. self.nbest = kwargs.get("nbest", 1)
  338. meta_data = {}
  339. # extract fbank feats
  340. time1 = time.perf_counter()
  341. audio_sample_list = load_audio_text_image_video(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
  342. time2 = time.perf_counter()
  343. meta_data["load_data"] = f"{time2 - time1:0.3f}"
  344. speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=self.frontend)
  345. time3 = time.perf_counter()
  346. meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
  347. meta_data["batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
  348. speech = speech.to(device=kwargs["device"])
  349. speech_lengths = speech_lengths.to(device=kwargs["device"])
  350. # Encoder
  351. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  352. if isinstance(encoder_out, tuple):
  353. encoder_out = encoder_out[0]
  354. # c. Passed the encoder result and the beam search
  355. nbest_hyps = self.beam_search(
  356. x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
  357. )
  358. nbest_hyps = nbest_hyps[: self.nbest]
  359. results = []
  360. b, n, d = encoder_out.size()
  361. for i in range(b):
  362. for nbest_idx, hyp in enumerate(nbest_hyps):
  363. ibest_writer = None
  364. if ibest_writer is None and kwargs.get("output_dir") is not None:
  365. writer = DatadirWriter(kwargs.get("output_dir"))
  366. ibest_writer = writer[f"{nbest_idx+1}best_recog"]
  367. # remove sos/eos and get results
  368. last_pos = -1
  369. if isinstance(hyp.yseq, list):
  370. token_int = hyp.yseq[1:last_pos]
  371. else:
  372. token_int = hyp.yseq[1:last_pos].tolist()
  373. # remove blank symbol id, which is assumed to be 0
  374. token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
  375. # Change integer-ids to tokens
  376. token = tokenizer.ids2tokens(token_int)
  377. text = tokenizer.tokens2text(token)
  378. text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
  379. result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed}
  380. results.append(result_i)
  381. if ibest_writer is not None:
  382. ibest_writer["token"][key[i]] = " ".join(token)
  383. ibest_writer["text"][key[i]] = text
  384. ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
  385. return results, meta_data